Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def format_json_list_data(origin_data: list[dict]):
if len(decimal_str) > 15:
value = str(value)
_row[key] = value
data.append(_row)
data.append(DataFormat.normalize_qualified_sql_column_keys(_row))

return data

Expand Down Expand Up @@ -253,6 +253,7 @@ def get_chart_data_ds(session: SessionDep,ds_id,sql):
else:
result = exec_sql(ds=datasource,sql=sql, origin_column=False)
_data = DataFormat.convert_large_numbers_in_object_array(result.get('data'))
_data = DataFormat.normalize_qualified_sql_column_keys_in_object_array(_data)
json_result['data'] = _data
return json_result
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class AiModelQuestion(BaseModel):
custom_prompt: str = ""
error_msg: str = ""
regenerate_record_id: Optional[int] = None
sample_data: str = ""

def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
templates: dict[str, str] = {}
Expand Down Expand Up @@ -256,7 +257,7 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
example_answer_1=_example_answer_1,
example_answer_2=_example_answer_2,
example_answer_3=_example_answer_3)
templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema)
templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema, sample_data=self.sample_data)

if self.terminologies:
templates['terminologies'] = _base_template['generate_terminologies_info'].format(
Expand Down
17 changes: 16 additions & 1 deletion backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
ChatFinishStep, AxisObj, SystemPromptMessage, HumanPromptMessage, AIPromptMessage
from apps.data_training.curd.data_training import get_training_template
from apps.datasource.crud.datasource import get_table_schema
from apps.datasource.crud.datasource import get_table_schema, get_tables_sample_data
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
from apps.datasource.embedding.ds_embedding import get_ds_embedding
from apps.datasource.models.datasource import CoreDatasource
Expand Down Expand Up @@ -384,6 +384,13 @@ def choose_table_schema(self, _session: Session):
ds=self.ds,
question=self.chat_question.question)

# Get sample data for all tables
if not self.out_ds_instance:
self.chat_question.sample_data = get_tables_sample_data(
session=_session,
current_user=self.current_user,
ds=self.ds)

self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session,
log=self.current_logs[OperationEnum.CHOOSE_TABLE],
full_message=self.chat_question.db_schema)
Expand Down Expand Up @@ -505,6 +512,13 @@ def generate_recommend_questions_task(self, _session: Session):
question=self.chat_question.question,
embedding=False)

# Get sample data for all tables
if not self.out_ds_instance:
self.chat_question.sample_data = get_tables_sample_data(
session=_session,
current_user=self.current_user,
ds=self.ds)

guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemPromptMessage(content=self.chat_question.guess_sys_question(self.articles_number)))

Expand Down Expand Up @@ -1304,6 +1318,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
'count': len(result.get('data'))})

_data = DataFormat.convert_large_numbers_in_object_array(result.get('data'))
_data = DataFormat.normalize_qualified_sql_column_keys_in_object_array(_data)
result["data"] = _data

self.save_sql_data(session=_session, data_obj=result)
Expand Down
90 changes: 84 additions & 6 deletions backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from common.core.config import settings
from common.core.deps import SessionDep, CurrentUser, Trans
from common.utils.embedding_threads import run_save_table_embeddings, run_save_ds_embeddings
from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra
from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra, equals_ignore_case
from common.core.sqlbot_cache import cache, clear_cache
from .table import get_tables_by_ds_id
from ..crud.field import delete_field_by_ds_id, update_field
Expand Down Expand Up @@ -357,12 +357,16 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
{where}
LIMIT 100"""
elif ds.type == "dm":
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
{where}
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
{where}
LIMIT 100"""
elif ds.type == "es":
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
{where}
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
{where}
LIMIT 100"""
elif ds.type == "sqlite":
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
{where}
LIMIT 100"""
return exec_sql(ds, sql, True)

Expand Down Expand Up @@ -430,6 +434,79 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
return _list


def get_table_sample_data(ds: CoreDatasource, table_name: str, fields: list) -> str:
"""Get 3 sample rows from a table in JSON format to help AI understand the data"""
if not fields:
return ""

db = DB.get_db(ds.type)
# Get prefix/suffix for identifier quoting
prefix = db.prefix if hasattr(db, 'prefix') else '"'
suffix = db.suffix if hasattr(db, 'suffix') else '"'

# Build field list with proper quoting
field_names = []
for field in fields[:10]: # Limit to first 10 fields to avoid too wide results
field_name = f"{prefix}{field.field_name}{suffix}"
field_names.append(field_name)

# Build LIMIT query based on database type
if equals_ignore_case(ds.type, "sqlServer"):
query = f"SELECT TOP 3 {','.join(field_names)} FROM {prefix}{table_name}{suffix}"
elif equals_ignore_case(ds.type, "ck"):
query = f"SELECT {','.join(field_names)} FROM {table_name} LIMIT 3"
elif equals_ignore_case(ds.type, "hive"):
query = f"SELECT {','.join(field_names)} FROM {table_name} LIMIT 3"
elif equals_ignore_case(ds.type, "oracle"):
query = f"SELECT {','.join(field_names)} FROM \"{table_name}\" WHERE ROWNUM <= 3"
elif equals_ignore_case(ds.type, "dm"):
query = f"SELECT {','.join(field_names)} FROM \"{table_name}\" WHERE ROWNUM <= 3"
else:
query = f"SELECT {','.join(field_names)} FROM {prefix}{table_name}{suffix} LIMIT 3"

try:
result = exec_sql(ds=ds, sql=query, origin_column=True)
if result and result.get('data') and len(result['data']) > 0:
import json
# Truncate long string values for readability
json_rows = []
for row in result['data'][:3]:
truncated_row = {}
for key, value in row.items():
if value is None:
truncated_row[key] = None
elif isinstance(value, str):
# Truncate long strings
if len(value) > 100:
value = value[:100] + '...'
truncated_row[key] = value.replace('\n', ' ').replace('\r', ' ')
else:
truncated_row[key] = value
json_rows.append(truncated_row)
return json.dumps(json_rows, ensure_ascii=False, indent=2)
except Exception:
pass
return ""


def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource,
table_list: list[str] = None) -> str:
"""Get sample data (3 rows) for all tables to help AI understand the data"""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
return ""

sample_data_parts = []
for obj in table_objs:
if table_list is not None and obj.table.table_name not in table_list:
continue
if obj.fields:
sample = get_table_sample_data(ds, obj.table.table_name, obj.fields)
if sample:
sample_data_parts.append(f"# Table: {obj.table.table_name}\n{sample}")
return "\n".join(sample_data_parts)


def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
embedding: bool = True, table_list: list[str] = None) -> str:
schema_str = ""
Expand All @@ -446,7 +523,8 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
continue

schema_table = ''
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
no_schema_types = ["mysql", "es", "sqlite", "hive", "doris", "starrocks"]
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type not in no_schema_types and db_name else f"# Table: {obj.table.table_name}"
table_comment = ''
if obj.table.custom_comment:
table_comment = obj.table.custom_comment.strip()
Expand Down
2 changes: 1 addition & 1 deletion backend/apps/datasource/models/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def to_dict(self):


class TableSchema:
def __init__(self, attr1, attr2):
def __init__(self, attr1, attr2=None):
self.tableName = attr1
self.tableComment = attr2 if attr2 is None or isinstance(attr2, str) else attr2.decode("utf-8")

Expand Down
2 changes: 2 additions & 0 deletions backend/apps/db/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class DB(Enum):
oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle', [])
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL', [])
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks', [])
sqlite = ('sqlite', 'SQLite', '"', '"', ConnectType.sqlalchemy, 'SQLite', [])
hive = ('hive', 'Apache Hive', '`', '`', ConnectType.py_driver, 'Hive', [])

def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str,
illegalParams: List[str]):
Expand Down
Loading