diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index 1f8befdf5..eaddc671e 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -211,7 +211,7 @@ def format_json_list_data(origin_data: list[dict]): if len(decimal_str) > 15: value = str(value) _row[key] = value - data.append(_row) + data.append(DataFormat.normalize_qualified_sql_column_keys(_row)) return data @@ -253,6 +253,7 @@ def get_chart_data_ds(session: SessionDep,ds_id,sql): else: result = exec_sql(ds=datasource,sql=sql, origin_column=False) _data = DataFormat.convert_large_numbers_in_object_array(result.get('data')) + _data = DataFormat.normalize_qualified_sql_column_keys_in_object_array(_data) json_result['data'] = _data return json_result except Exception as e: diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 66ef2060d..b0a6c2c73 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -229,6 +229,7 @@ class AiModelQuestion(BaseModel): custom_prompt: str = "" error_msg: str = "" regenerate_record_id: Optional[int] = None + sample_data: str = "" def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True): templates: dict[str, str] = {} @@ -256,7 +257,7 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T example_answer_1=_example_answer_1, example_answer_2=_example_answer_2, example_answer_3=_example_answer_3) - templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema) + templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema, sample_data=self.sample_data) if self.terminologies: templates['terminologies'] = _base_template['generate_terminologies_info'].format( diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 006b8bb9a..cf469b371 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -36,7 +36,7 @@ from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ ChatFinishStep, AxisObj, SystemPromptMessage, HumanPromptMessage, AIPromptMessage from apps.data_training.curd.data_training import get_training_template -from apps.datasource.crud.datasource import get_table_schema +from apps.datasource.crud.datasource import get_table_schema, get_tables_sample_data from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user from apps.datasource.embedding.ds_embedding import get_ds_embedding from apps.datasource.models.datasource import CoreDatasource @@ -384,6 +384,13 @@ def choose_table_schema(self, _session: Session): ds=self.ds, question=self.chat_question.question) + # Get sample data for all tables + if not self.out_ds_instance: + self.chat_question.sample_data = get_tables_sample_data( + session=_session, + current_user=self.current_user, + ds=self.ds) + self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session, log=self.current_logs[OperationEnum.CHOOSE_TABLE], full_message=self.chat_question.db_schema) @@ -505,6 +512,13 @@ def generate_recommend_questions_task(self, _session: Session): question=self.chat_question.question, embedding=False) + # Get sample data for all tables + if not self.out_ds_instance: + self.chat_question.sample_data = get_tables_sample_data( + session=_session, + current_user=self.current_user, + ds=self.ds) + guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] guess_msg.append(SystemPromptMessage(content=self.chat_question.guess_sys_question(self.articles_number))) @@ -1304,6 +1318,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, 'count': len(result.get('data'))}) _data = DataFormat.convert_large_numbers_in_object_array(result.get('data')) + _data = DataFormat.normalize_qualified_sql_column_keys_in_object_array(_data) result["data"] = _data self.save_sql_data(session=_session, data_obj=result) diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index 5a4cc6223..fc7afb7cf 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -17,7 +17,7 @@ from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans from common.utils.embedding_threads import run_save_table_embeddings, run_save_ds_embeddings -from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra +from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra, equals_ignore_case from common.core.sqlbot_cache import cache, clear_cache from .table import get_tables_by_ds_id from ..crud.field import delete_field_by_ds_id, update_field @@ -357,12 +357,16 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table {where} LIMIT 100""" elif ds.type == "dm": - sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}" - {where} + sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}" + {where} LIMIT 100""" elif ds.type == "es": - sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}" - {where} + sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}" + {where} + LIMIT 100""" + elif ds.type == "sqlite": + sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}" + {where} LIMIT 100""" return exec_sql(ds, sql, True) @@ -430,6 +434,79 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core return _list +def get_table_sample_data(ds: CoreDatasource, table_name: str, fields: list) -> str: + """Get 3 sample rows from a table in JSON format to help AI understand the data""" + if not fields: + return "" + + db = DB.get_db(ds.type) + # Get prefix/suffix for identifier quoting + prefix = db.prefix if hasattr(db, 'prefix') else '"' + suffix = db.suffix if hasattr(db, 'suffix') else '"' + + # Build field list with proper quoting + field_names = [] + for field in fields[:10]: # Limit to first 10 fields to avoid too wide results + field_name = f"{prefix}{field.field_name}{suffix}" + field_names.append(field_name) + + # Build LIMIT query based on database type + if equals_ignore_case(ds.type, "sqlServer"): + query = f"SELECT TOP 3 {','.join(field_names)} FROM {prefix}{table_name}{suffix}" + elif equals_ignore_case(ds.type, "ck"): + query = f"SELECT {','.join(field_names)} FROM {table_name} LIMIT 3" + elif equals_ignore_case(ds.type, "hive"): + query = f"SELECT {','.join(field_names)} FROM {table_name} LIMIT 3" + elif equals_ignore_case(ds.type, "oracle"): + query = f"SELECT {','.join(field_names)} FROM \"{table_name}\" WHERE ROWNUM <= 3" + elif equals_ignore_case(ds.type, "dm"): + query = f"SELECT {','.join(field_names)} FROM \"{table_name}\" WHERE ROWNUM <= 3" + else: + query = f"SELECT {','.join(field_names)} FROM {prefix}{table_name}{suffix} LIMIT 3" + + try: + result = exec_sql(ds=ds, sql=query, origin_column=True) + if result and result.get('data') and len(result['data']) > 0: + import json + # Truncate long string values for readability + json_rows = [] + for row in result['data'][:3]: + truncated_row = {} + for key, value in row.items(): + if value is None: + truncated_row[key] = None + elif isinstance(value, str): + # Truncate long strings + if len(value) > 100: + value = value[:100] + '...' + truncated_row[key] = value.replace('\n', ' ').replace('\r', ' ') + else: + truncated_row[key] = value + json_rows.append(truncated_row) + return json.dumps(json_rows, ensure_ascii=False, indent=2) + except Exception: + pass + return "" + + +def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, + table_list: list[str] = None) -> str: + """Get sample data (3 rows) for all tables to help AI understand the data""" + table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds) + if len(table_objs) == 0: + return "" + + sample_data_parts = [] + for obj in table_objs: + if table_list is not None and obj.table.table_name not in table_list: + continue + if obj.fields: + sample = get_table_sample_data(ds, obj.table.table_name, obj.fields) + if sample: + sample_data_parts.append(f"# Table: {obj.table.table_name}\n{sample}") + return "\n".join(sample_data_parts) + + def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str, embedding: bool = True, table_list: list[str] = None) -> str: schema_str = "" @@ -446,7 +523,8 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat continue schema_table = '' - schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}" + no_schema_types = ["mysql", "es", "sqlite", "hive", "doris", "starrocks"] + schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type not in no_schema_types and db_name else f"# Table: {obj.table.table_name}" table_comment = '' if obj.table.custom_comment: table_comment = obj.table.custom_comment.strip() diff --git a/backend/apps/datasource/models/datasource.py b/backend/apps/datasource/models/datasource.py index 6a23e0b7f..3971318cf 100644 --- a/backend/apps/datasource/models/datasource.py +++ b/backend/apps/datasource/models/datasource.py @@ -143,7 +143,7 @@ def to_dict(self): class TableSchema: - def __init__(self, attr1, attr2): + def __init__(self, attr1, attr2=None): self.tableName = attr1 self.tableComment = attr2 if attr2 is None or isinstance(attr2, str) else attr2.decode("utf-8") diff --git a/backend/apps/db/constant.py b/backend/apps/db/constant.py index 1509fcf47..6ee33f02f 100644 --- a/backend/apps/db/constant.py +++ b/backend/apps/db/constant.py @@ -28,6 +28,8 @@ class DB(Enum): oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle', []) pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL', []) starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks', []) + sqlite = ('sqlite', 'SQLite', '"', '"', ConnectType.sqlalchemy, 'SQLite', []) + hive = ('hive', 'Apache Hive', '`', '`', ConnectType.py_driver, 'Hive', []) def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str, illegalParams: List[str]): diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index bbf9e97e1..99e5911b7 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -2,6 +2,7 @@ import json import os import platform +import re import urllib.parse from datetime import datetime, date, time, timedelta from decimal import Decimal @@ -35,6 +36,8 @@ import sqlglot from sqlglot import expressions as exp from sqlalchemy.pool import NullPool +from pyhive import hive + try: if os.path.exists(settings.ORACLE_CLIENT_PATH): @@ -88,6 +91,8 @@ def get_uri_from_config(type: str, conf: DatasourceConf) -> str: db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" + elif equals_ignore_case(type, "sqlite"): + db_url = f"sqlite:///{conf.filename}" else: raise 'The datasource type not support.' return db_url @@ -157,6 +162,8 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: elif equals_ignore_case(ds.type, 'mysql'): # mysql ssl_mode = {"require": True} if conf.ssl else None engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout, "ssl": ssl_mode}, poolclass=NullPool) + elif equals_ignore_case(ds.type, 'sqlite'): + engine = create_engine(get_uri(ds), connect_args={"check_same_thread": False}, poolclass=NullPool) else: # ck engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, poolclass=NullPool) return engine @@ -207,9 +214,10 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False elif equals_ignore_case(ds.type, 'doris', 'starrocks'): + ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {} with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=10, - read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + read_timeout=10, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor: try: cursor.execute('select 1') SQLBotLogUtil.info("success") @@ -247,6 +255,23 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False + elif equals_ignore_case(ds.type, 'hive'): + try: + conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, + database=conf.database, **extra_config_dict) + cursor = conn.cursor() + cursor.execute('select 1') + cursor.fetchall() + cursor.close() + conn.close() + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif equals_ignore_case(ds.type, 'es'): es_conn = get_es_connect(conf) if es_conn.ping(): @@ -289,6 +314,8 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema): # conf.timeout = 10 db = DB.get_db(ds.type) sql = get_version_sql(ds, conf) + if equals_ignore_case(ds.type, 'sqlite'): + return '' try: if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: @@ -304,13 +331,14 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema): res = cursor.fetchall() version = res[0][0] elif equals_ignore_case(ds.type, 'doris', 'starrocks'): + ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {} with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=10, - read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + read_timeout=10, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor: cursor.execute(sql) res = cursor.fetchall() version = res[0][0] - elif equals_ignore_case(ds.type, 'redshift', 'es'): + elif equals_ignore_case(ds.type, 'redshift', 'es', 'hive'): version = '' except Exception as e: print(e) @@ -333,6 +361,8 @@ def get_schema(ds: CoreDatasource): elif equals_ignore_case(ds.type, "oracle"): sql = """select * from all_users""" + elif equals_ignore_case(ds.type, "sqlite"): + return ['main'] with session.execute(text(sql)) as result: res = result.fetchall() res_list = [item[0] for item in res] @@ -367,6 +397,30 @@ def get_schema(ds: CoreDatasource): res = cursor.fetchall() res_list = [item[0] for item in res] return res_list + elif equals_ignore_case(ds.type, 'hive'): + conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, + database=conf.database, **extra_config_dict) + cursor = conn.cursor() + cursor.execute('SHOW DATABASES') + res = cursor.fetchall() + res_list = [item[0] for item in res] + cursor.close() + conn.close() + return res_list + elif equals_ignore_case(ds.type, 'doris', 'starrocks'): + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=10, + read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute('SHOW DATABASES') + res = cursor.fetchall() + res_list = [item[0] for item in res] + return res_list + elif equals_ignore_case(ds.type, 'ck'): + with get_session(ds) as session: + with session.execute(text('SHOW DATABASES')) as result: + res = result.fetchall() + res_list = [item[0] for item in res] + return res_list def get_tables(ds: CoreDatasource): @@ -374,7 +428,16 @@ def get_tables(ds: CoreDatasource): "excel") else get_engine_config() db = DB.get_db(ds.type) sql, sql_param = get_table_sql(ds, conf, get_version(ds)) - if db.connect_type == ConnectType.sqlalchemy: + if equals_ignore_case(ds.type, "sqlite"): + engine = get_engine(ds) + with engine.raw_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql) + res = cursor.fetchall() + cursor.close() + res_list = [TableSchema(*item) for item in res] + return res_list + elif db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: with session.execute(text(sql), {"param": sql_param}) as result: res = result.fetchall() @@ -390,9 +453,10 @@ def get_tables(ds: CoreDatasource): res_list = [TableSchema(*item) for item in res] return res_list elif equals_ignore_case(ds.type, 'doris', 'starrocks'): + ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {} with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor: cursor.execute(sql, (sql_param,)) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] @@ -418,6 +482,16 @@ def get_tables(ds: CoreDatasource): res = get_es_index(conf) res_list = [TableSchema(*item) for item in res] return res_list + elif equals_ignore_case(ds.type, 'hive'): + conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, + database=conf.database, **extra_config_dict) + cursor = conn.cursor() + cursor.execute(sql) + res = cursor.fetchall() + res_list = [TableSchema(*item) for item in res] + cursor.close() + conn.close() + return res_list def get_fields(ds: CoreDatasource, table_name: str = None): @@ -425,7 +499,16 @@ def get_fields(ds: CoreDatasource, table_name: str = None): "excel") else get_engine_config() db = DB.get_db(ds.type) sql, p1, p2 = get_field_sql(ds, conf, table_name) - if db.connect_type == ConnectType.sqlalchemy: + if equals_ignore_case(ds.type, "sqlite"): + engine = get_engine(ds) + with engine.raw_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql) + res = cursor.fetchall() + cursor.close() + res_list = [ColumnSchema(item[1], item[2], '') for item in res] + return res_list + elif db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: with session.execute(text(sql), {"param1": p1, "param2": p2}) as result: res = result.fetchall() @@ -441,9 +524,10 @@ def get_fields(ds: CoreDatasource, table_name: str = None): res_list = [ColumnSchema(*item) for item in res] return res_list elif equals_ignore_case(ds.type, 'doris', 'starrocks'): + ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {} with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor: cursor.execute(sql, (p1, p2)) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] @@ -469,6 +553,16 @@ def get_fields(ds: CoreDatasource, table_name: str = None): res = get_es_fields(conf, table_name) res_list = [ColumnSchema(*item) for item in res] return res_list + elif equals_ignore_case(ds.type, 'hive'): + conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, + database=conf.database, **extra_config_dict) + cursor = conn.cursor() + cursor.execute(sql) + res = cursor.fetchall() + res_list = [ColumnSchema(*item) for item in res] + cursor.close() + conn.close() + return res_list def convert_value(value, datetime_format='space'): @@ -587,9 +681,10 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= except Exception as ex: raise ParseSQLResultError(str(ex)) elif equals_ignore_case(ds.type, 'doris', 'starrocks'): + ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {} with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor: try: cursor.execute(sql) res = cursor.fetchall() @@ -655,15 +750,54 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise Exception(str(ex)) + elif equals_ignore_case(ds.type, 'hive'): + conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, + database=conf.database, **extra_config_dict) + cursor = conn.cursor() + try: + # Hive uses backticks for identifiers; normalize quoted identifiers as a compatibility fallback. + hive_sql = re.sub(r'"([A-Za-z_][A-Za-z0-9_]*)"', r'`\1`', sql) + cursor.execute(hive_sql) + res = cursor.fetchall() + columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for + field in + cursor.description] + result_list = [ + {str(columns[i]): convert_value(value) for i, value in enumerate(tuple_item)} for tuple_item in + res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(hive_sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex)) + finally: + cursor.close() + conn.close() def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema): try: + normalized_sql = sql.strip().lstrip("(").strip() + first_keyword = normalized_sql.split(None, 1)[0].upper() if normalized_sql else "" + allowed_read_commands = {"SELECT", "WITH", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"} + denied_write_commands = { + "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", + "TRUNCATE", "MERGE", "COPY", "REPLACE", "GRANT", "REVOKE", + "USE", "SET", "CALL" + } + + if not first_keyword: + raise ValueError("Parse SQL Error") + if first_keyword in denied_write_commands: + return False + dialect = None if equals_ignore_case(ds.type, 'mysql', 'doris', 'starrocks'): dialect = 'mysql' elif equals_ignore_case(ds.type, 'sqlServer'): dialect = 'tsql' + elif equals_ignore_case(ds.type, 'hive'): + dialect = 'hive' statements = sqlglot.parse(sql, dialect=dialect) @@ -673,7 +807,7 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema): write_types = ( exp.Insert, exp.Update, exp.Delete, exp.Create, exp.Drop, exp.Alter, - exp.Merge, exp.Command, exp.Copy + exp.Merge, exp.Copy ) for stmt in statements: @@ -682,7 +816,7 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema): if isinstance(stmt, write_types): return False - return True + return first_keyword in allowed_read_commands except Exception as e: raise ValueError(f"Parse SQL Error: {e}") diff --git a/backend/apps/db/db_sql.py b/backend/apps/db/db_sql.py index 566fbeb03..d4c3c74f8 100644 --- a/backend/apps/db/db_sql.py +++ b/backend/apps/db/db_sql.py @@ -31,6 +31,8 @@ def get_version_sql(ds: CoreDatasource, conf: DatasourceConf): """ elif equals_ignore_case(ds.type, "redshift"): return '' + elif equals_ignore_case(ds.type, "sqlite"): + return '' def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = ''): @@ -162,6 +164,17 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' """, conf.dbSchema elif equals_ignore_case(ds.type, "es"): return "", None + elif equals_ignore_case(ds.type, "sqlite"): + return """ + SELECT name AS TABLE_NAME, '' + FROM sqlite_master + WHERE type='table' + ORDER BY name + """, None + elif equals_ignore_case(ds.type, "hive"): + return """ + SHOW TABLES + """, None def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = None): @@ -312,3 +325,9 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No return sql1 + sql2, conf.dbSchema, table_name elif equals_ignore_case(ds.type, "es"): return "", None, None + elif equals_ignore_case(ds.type, "sqlite"): + sql1 = f"PRAGMA table_info({table_name})" + return sql1, None, None + elif equals_ignore_case(ds.type, "hive"): + sql1 = f"DESCRIBE {table_name}" + return sql1, None, None diff --git a/backend/common/utils/data_format.py b/backend/common/utils/data_format.py index 1991fb3e4..bfa9e88b5 100644 --- a/backend/common/utils/data_format.py +++ b/backend/common/utils/data_format.py @@ -17,6 +17,34 @@ def safe_convert_to_string(df): return df_copy + @staticmethod + def normalize_qualified_sql_column_keys(row: dict) -> dict: + """Add unqualified keys for names like ``alias.column`` (Hive/MySQL return shape). + + Chart bindings use the bare column name (``table_name``) while drivers may return + ``_u2.table_name``. Only adds ``short`` when absent to avoid clobbering real duplicates. + """ + if not row: + return row + out = dict(row) + for k, v in row.items(): + ks = str(k) + if "." not in ks: + continue + short = ks.rsplit(".", 1)[-1] + if short not in out: + out[short] = v + return out + + @staticmethod + def normalize_qualified_sql_column_keys_in_object_array(obj_array: list) -> list: + if not obj_array: + return obj_array + return [ + DataFormat.normalize_qualified_sql_column_keys(obj) if isinstance(obj, dict) else obj + for obj in obj_array + ] + @staticmethod def convert_large_numbers_in_object_array(obj_array, int_threshold=1e15, float_threshold=1e10): """处理对象数组,将每个对象中的大数字转换为字符串""" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index f112f7eca..727efde3b 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -53,7 +53,9 @@ dependencies = [ "elasticsearch[requests] (>=7.10,<8.0)", "ldap3>=2.9.1", "sqlglot>=28.6.0", - "numpy==2.3.5" + "numpy==2.3.5", + "pyhive[hive]>=0.7.0", + "thrift-sasl" ] [project.optional-dependencies] diff --git a/backend/templates/sql_examples/Hive.yaml b/backend/templates/sql_examples/Hive.yaml new file mode 100644 index 000000000..813f6ab50 --- /dev/null +++ b/backend/templates/sql_examples/Hive.yaml @@ -0,0 +1,87 @@ +template: + quot_rule: | + + 必须对数据库名、表名、字段名、别名外层加反引号(`)。 + + 1. 点号(.)不能包含在引号内,必须写成 `database`.`table` + 2. 即使标识符不含特殊字符或非关键字,也需强制加反引号 + 3. 在多表关联(JOIN)/ 多子查询的 SQL 中,只要多个表 / 子查询存在 同名字段,所有引用该字段的位置,必须显式指定 表别名 + + + + limit_rule: | + + 当需要限制行数时,必须使用标准的LIMIT语法 + + + other_rule: | + 必须为每个表生成别名(不加AS) + {multi_table_condition} + 禁止使用星号(*),必须明确字段名 + 中文/特殊字符字段需保留原名并添加英文别名 + 不能用 + 拼接字符串,字符串必须使用单引号 + 分组非常严格:SELECT 里的字段必须出现在 GROUP BY 里,或者是聚合函数 + 函数字段必须加别名 + 百分比字段保留两位小数并以%结尾 + WHERE 条件中不能使用 >、<、>=、<= 等比较运算符,必须使用 = + HIVE 中没有 NOT IN 操作符,必须使用 LEFT JOIN 或 EXISTS 替代 + 判空使用 NVL()函数 + 避免与数据库关键字冲突 + + basic_example: | + + + 📌 以下示例严格遵循中的 Hive 规范,展示符合要求的 SQL 写法与典型错误案例。 + ⚠️ 注意:示例中的表名、字段名均为演示虚构,实际使用时需替换为用户提供的真实标识符。 + 🔍 重点观察: + 1. 反引号包裹所有数据库对象的规范用法 + 2. 中英别名/百分比/函数等特殊字段的处理 + 3. 关键字冲突的规避方式 + + + 查询 ods.orders 表的前100条订单(含中文字段和百分比) + + SELECT * FROM ods.orders LIMIT 100 -- 错误:未加引号、使用星号 + SELECT `订单ID`, `金额` FROM `ods`.`orders` `t1` LIMIT 100 -- 错误:缺少英文别名 + SELECT COUNT(`订单ID`) FROM `ods`.`orders` `t1` -- 错误:函数未加别名 + + + SELECT + `t1`.`订单ID` AS `order_id`, + `t1`.`金额` AS `amount`, + COUNT(`t1`.`订单ID`) AS `total_orders`, + CONCAT(CAST(ROUND(`t1`.`折扣率` * 100, 2) AS STRING), '%') AS `discount_percent` + FROM `ods`.`orders` `t1` + LIMIT 100 + + + + + 统计 dim.users(含关键字字段user)的活跃占比 + + SELECT user, status FROM dim.users -- 错误:未处理关键字和引号 + SELECT `user`, ROUND(active_ratio) FROM `dim`.`users` -- 错误:百分比格式错误 + + + SELECT + `u`.`user` AS `username`, + CONCAT(CAST(ROUND(`u`.`active_ratio` * 100, 2) AS STRING), '%') AS `active_percent` + FROM `dim`.`users` `u` + WHERE `u`.`status` = 1 + + + + + example_engine: Apache Hive 2.X + example_answer_1: | + {"success":true,"sql":"SELECT `country` AS `country_name`, `continent` AS `continent_name`, `year` AS `year`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` ORDER BY `country`, `year`","tables":["sample_country_gdp"],"chart-type":"line"} + example_answer_1_with_limit: | + {"success":true,"sql":"SELECT `country` AS `country_name`, `continent` AS `continent_name`, `year` AS `year`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` ORDER BY `country`, `year` LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"} + example_answer_2: | + {"success":true,"sql":"SELECT `country` AS `country_name`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` WHERE `year` = '2024' ORDER BY `gdp` DESC","tables":["sample_country_gdp"],"chart-type":"pie"} + example_answer_2_with_limit: | + {"success":true,"sql":"SELECT `country` AS `country_name`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` WHERE `year` = '2024' ORDER BY `gdp` DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"} + example_answer_3: | + {"success":true,"sql":"SELECT `country` AS `country_name`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` WHERE `year` = '2025' AND `country` = '中国'","tables":["sample_country_gdp"],"chart-type":"table"} + example_answer_3_with_limit: | + {"success":true,"sql":"SELECT `country` AS `country_name`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` WHERE `year` = '2025' AND `country` = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"} diff --git a/backend/templates/sql_examples/SQLite.yaml b/backend/templates/sql_examples/SQLite.yaml new file mode 100644 index 000000000..bfcaacc94 --- /dev/null +++ b/backend/templates/sql_examples/SQLite.yaml @@ -0,0 +1,81 @@ +template: + quot_rule: | + + 必须对数据库名、表名、字段名、别名外层加双引号(")。 + + 1. 点号(.)不能包含在引号内,必须写成 "table" + 2. 即使标识符不含特殊字符或非关键字,也需强制加双引号 + + + + limit_rule: | + + 当需要限制行数时,必须使用标准的LIMIT语法 + + + other_rule: | + 必须为每个表生成别名(不加AS) + {multi_table_condition} + 禁止使用星号(*),必须明确字段名 + 中文/特殊字符字段需保留原名并添加英文别名 + 函数字段必须加别名 + 百分比字段保留两位小数并以%结尾 + 避免与数据库关键字冲突 + + basic_example: | + + + 📌 以下示例严格遵循中的 SQLite 规范,展示符合要求的 SQL 写法与典型错误案例。 + ⚠️ 注意:示例中的表名、字段名均为演示虚构,实际使用时需替换为用户提供的真实标识符。 + 🔍 重点观察: + 1. 双引号包裹所有数据库对象的规范用法 + 2. 中英别名/百分比/函数等特殊字段的处理 + 3. 关键字冲突的规避方式 + + + 查询 ORDERS 表的前100条订单(含中文字段和百分比) + + SELECT * FROM ORDERS LIMIT 100 -- 错误:未加引号、使用星号 + SELECT "订单ID", "金额" FROM "ORDERS" "t1" LIMIT 100 -- 错误:缺少英文别名 + SELECT COUNT("订单ID") FROM "ORDERS" "t1" -- 错误:函数未加别名 + + + SELECT + "t1"."订单ID" AS "order_id", + "t1"."金额" AS "amount", + COUNT("t1"."订单ID") AS "total_orders", + ROUND("t1"."折扣率" * 100, 2) || '%' AS "discount_percent" + FROM "ORDERS" "t1" + LIMIT 100 + + + + + 统计用户表 USERS(含关键字字段user)的活跃占比 + + SELECT user, status FROM USERS -- 错误:未处理关键字和引号 + SELECT "user", ROUND(active_ratio) FROM "USERS" -- 错误:百分比格式错误 + + + SELECT + "u"."user" AS "username", + ROUND("u"."active_ratio" * 100, 2) || '%' AS "active_percent" + FROM "USERS" "u" + WHERE "u"."status" = 1 + + + + + example_engine: SQLite 3.x + example_answer_1: | + {"success":true,"sql":"SELECT \"country_name\", \"continent_name\", \"year\", \"gdp\" FROM \"sample_country_gdp\" ORDER BY \"country_name\", \"year\"","tables":["sample_country_gdp"],"chart-type":"line"} + example_answer_1_with_limit: | + {"success":true,"sql":"SELECT \"country_name\", \"continent_name\", \"year\", \"gdp\" FROM \"sample_country_gdp\" ORDER BY \"country_name\", \"year\" LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"} + example_answer_2: | + {"success":true,"sql":"SELECT \"country_name\", \"gdp\" FROM \"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC","tables":["sample_country_gdp"],"chart-type":"pie"} + example_answer_2_with_limit: | + {"success":true,"sql":"SELECT \"country_name\", \"gdp\" FROM \"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"} + example_answer_3: | + {"success":true,"sql":"SELECT \"country_name\", \"gdp\" FROM \"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country_name\" = '中国'","tables":["sample_country_gdp"],"chart-type":"table"} + example_answer_3_with_limit: | + {"success":true,"sql":"SELECT \"country_name\", \"gdp\" FROM \"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country_name\" = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"} diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index a447c8287..0179c4605 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -348,6 +348,9 @@ template: {schema} + + {sample_data} + user: | diff --git a/frontend/src/assets/datasource/icon_hive.png b/frontend/src/assets/datasource/icon_hive.png new file mode 100644 index 000000000..51b996f5e Binary files /dev/null and b/frontend/src/assets/datasource/icon_hive.png differ diff --git a/frontend/src/assets/datasource/icon_sqlite.png b/frontend/src/assets/datasource/icon_sqlite.png new file mode 100644 index 000000000..1a198d0b4 Binary files /dev/null and b/frontend/src/assets/datasource/icon_sqlite.png differ diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index df8d29c06..4670a92eb 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -395,7 +395,8 @@ "timeout": "Connection Timeout(second)", "address": "Address", "low_version": "Compatible with lower versions", - "ssl": "Enable SSL" + "ssl": "Enable SSL", + "file_path": "File Path" }, "sync_fields": "Sync Fields" }, diff --git a/frontend/src/i18n/ko-KR.json b/frontend/src/i18n/ko-KR.json index bb269ca1f..133c2813e 100644 --- a/frontend/src/i18n/ko-KR.json +++ b/frontend/src/i18n/ko-KR.json @@ -395,7 +395,8 @@ "timeout": "연결 시간 초과(초)", "address": "주소", "low_version": "낮은 버전 호환", - "ssl": "SSL 활성화" + "ssl": "SSL 활성화", + "file_path": "파일 경로" }, "sync_fields": "동기화된 테이블 구조" }, diff --git a/frontend/src/i18n/zh-CN.json b/frontend/src/i18n/zh-CN.json index 902b06421..8040db821 100644 --- a/frontend/src/i18n/zh-CN.json +++ b/frontend/src/i18n/zh-CN.json @@ -395,7 +395,8 @@ "timeout": "连接超时(秒)", "address": "地址", "low_version": "兼容低版本", - "ssl": "启用 SSL" + "ssl": "启用 SSL", + "file_path": "文件路径" }, "sync_fields": "同步表结构" }, diff --git a/frontend/src/i18n/zh-TW.json b/frontend/src/i18n/zh-TW.json index bd2aa57cb..d8973c4b8 100644 --- a/frontend/src/i18n/zh-TW.json +++ b/frontend/src/i18n/zh-TW.json @@ -395,7 +395,8 @@ "timeout": "連線逾時(秒)", "address": "位址", "low_version": "相容低版本", - "ssl": "啟用 SSL" + "ssl": "啟用 SSL", + "file_path": "文件路徑" }, "sync_fields": "同步表結構" }, diff --git a/frontend/src/views/ds/DatasourceForm.vue b/frontend/src/views/ds/DatasourceForm.vue index 240692d13..892824006 100644 --- a/frontend/src/views/ds/DatasourceForm.vue +++ b/frontend/src/views/ds/DatasourceForm.vue @@ -99,6 +99,13 @@ const rules = reactive({ trigger: 'blur', }, ], + filename: [ + { + required: true, + message: t('datasource.please_enter') + t('common.empty') + t('ds.form.file_path'), + trigger: 'blur', + }, + ], }) const dialogVisible = ref(false) @@ -647,7 +654,16 @@ defineExpose({ {{ $t('common.not_exceed_50mb') }} -
+
+ + + +
+
- + diff --git a/frontend/src/views/ds/js/ds-type.ts b/frontend/src/views/ds/js/ds-type.ts index 800fdefd4..b2d578b74 100644 --- a/frontend/src/views/ds/js/ds-type.ts +++ b/frontend/src/views/ds/js/ds-type.ts @@ -10,6 +10,8 @@ import redshift from '@/assets/datasource/icon_redshift.png' import es from '@/assets/datasource/icon_es.png' import kingbase from '@/assets/datasource/icon_kingbase.png' import starrocks from '@/assets/datasource/icon_starrocks.png' +import sqlite_icon from '@/assets/datasource/icon_sqlite.png' +import hive_icon from '@/assets/datasource/icon_hive.png' import { i18n } from '@/i18n' const t = i18n.global.t @@ -26,6 +28,8 @@ export const dsType = [ { label: 'Elasticsearch', value: 'es' }, { label: 'Kingbase', value: 'kingbase' }, { label: 'StarRocks', value: 'starrocks' }, + { label: 'SQLite', value: 'sqlite' }, + { label: 'Apache Hive', value: 'hive' }, ] export const dsTypeWithImg = [ @@ -41,6 +45,8 @@ export const dsTypeWithImg = [ { name: 'Elasticsearch', type: 'es', img: es }, { name: 'Kingbase', type: 'kingbase', img: kingbase }, { name: 'StarRocks', type: 'starrocks', img: starrocks }, + { name: 'SQLite', type: 'sqlite', img: sqlite_icon }, + { name: 'Apache Hive', type: 'hive', img: hive_icon }, ] export const haveSchema = ['sqlServer', 'pg', 'oracle', 'dm', 'redshift', 'kingbase']