pgvector.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. import ast
  2. import json
  3. import logging
  4. import uuid
  5. import pandas as pd
  6. from langchain_core.documents import Document
  7. from langchain_postgres.vectorstores import PGVector
  8. from sqlalchemy import create_engine, text
  9. from core.logging import get_vanna_logger
  10. from vanna.exceptions import ValidationError
  11. from vanna.base import VannaBase
  12. from vanna.types import TrainingPlan, TrainingPlanItem
  13. # 导入embedding缓存管理器
  14. from common.embedding_cache_manager import get_embedding_cache_manager
  15. class PG_VectorStore(VannaBase):
  16. def __init__(self, config=None):
  17. if not config or "connection_string" not in config:
  18. raise ValueError(
  19. "A valid 'config' dictionary with a 'connection_string' is required.")
  20. VannaBase.__init__(self, config=config)
  21. # 初始化日志
  22. self.logger = get_vanna_logger("PGVector")
  23. if config and "connection_string" in config:
  24. self.connection_string = config.get("connection_string")
  25. self.n_results = config.get("n_results", 10)
  26. if config and "embedding_function" in config:
  27. self.embedding_function = config.get("embedding_function")
  28. else:
  29. raise ValueError("No embedding_function was found.")
  30. # from langchain_huggingface import HuggingFaceEmbeddings
  31. # self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
  32. self.sql_collection = PGVector(
  33. embeddings=self.embedding_function,
  34. collection_name="sql",
  35. connection=self.connection_string,
  36. )
  37. self.ddl_collection = PGVector(
  38. embeddings=self.embedding_function,
  39. collection_name="ddl",
  40. connection=self.connection_string,
  41. )
  42. self.documentation_collection = PGVector(
  43. embeddings=self.embedding_function,
  44. collection_name="documentation",
  45. connection=self.connection_string,
  46. )
  47. self.error_sql_collection = PGVector(
  48. embeddings=self.embedding_function,
  49. collection_name="error_sql",
  50. connection=self.connection_string,
  51. )
  52. def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
  53. question_sql_json = json.dumps(
  54. {
  55. "question": question,
  56. "sql": sql,
  57. },
  58. ensure_ascii=False,
  59. )
  60. id = str(uuid.uuid4()) + "-sql"
  61. createdat = kwargs.get("createdat")
  62. doc = Document(
  63. page_content=question_sql_json,
  64. metadata={"id": id, "createdat": createdat},
  65. )
  66. self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
  67. return id
  68. def add_ddl(self, ddl: str, **kwargs) -> str:
  69. _id = str(uuid.uuid4()) + "-ddl"
  70. doc = Document(
  71. page_content=ddl,
  72. metadata={"id": _id},
  73. )
  74. self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
  75. return _id
  76. def add_documentation(self, documentation: str, **kwargs) -> str:
  77. _id = str(uuid.uuid4()) + "-doc"
  78. doc = Document(
  79. page_content=documentation,
  80. metadata={"id": _id},
  81. )
  82. self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
  83. return _id
  84. def get_collection(self, collection_name):
  85. match collection_name:
  86. case "sql":
  87. return self.sql_collection
  88. case "ddl":
  89. return self.ddl_collection
  90. case "documentation":
  91. return self.documentation_collection
  92. case "error_sql":
  93. return self.error_sql_collection
  94. case _:
  95. raise ValueError("Specified collection does not exist.")
  96. # def get_similar_question_sql(self, question: str) -> list:
  97. # documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
  98. # return [ast.literal_eval(document.page_content) for document in documents]
  99. # 在原来的基础之上,增加相似度的值。
  100. def get_similar_question_sql(self, question: str) -> list:
  101. # 尝试使用embedding缓存
  102. embedding_cache = get_embedding_cache_manager()
  103. cached_embedding = embedding_cache.get_cached_embedding(question)
  104. if cached_embedding is not None:
  105. # 使用缓存的embedding进行向量搜索
  106. docs_with_scores = self.sql_collection.similarity_search_with_score_by_vector(
  107. embedding=cached_embedding,
  108. k=self.n_results
  109. )
  110. else:
  111. # 执行常规的向量搜索(会自动生成embedding)
  112. docs_with_scores = self.sql_collection.similarity_search_with_score(
  113. query=question,
  114. k=self.n_results
  115. )
  116. # 获取刚生成的embedding并缓存
  117. try:
  118. # 通过embedding_function生成向量并缓存
  119. generated_embedding = self.embedding_function.generate_embedding(question)
  120. if generated_embedding:
  121. embedding_cache.cache_embedding(question, generated_embedding)
  122. except Exception as e:
  123. self.logger.warning(f"缓存embedding失败: {e}")
  124. results = []
  125. for doc, score in docs_with_scores:
  126. # 将文档内容转换为 dict
  127. base = ast.literal_eval(doc.page_content)
  128. # 计算相似度
  129. similarity = round(1 - score, 4)
  130. # 每条记录单独打印
  131. self.logger.debug(f"SQL Match: {base.get('question', '')} | similarity: {similarity}")
  132. # 添加 similarity 字段
  133. base["similarity"] = similarity
  134. results.append(base)
  135. # 检查原始查询结果是否为空
  136. if not results:
  137. self.logger.warning(f"向量查询未找到任何相似的SQL问答对,问题: {question}")
  138. # 应用阈值过滤
  139. filtered_results = self._apply_score_threshold_filter(
  140. results,
  141. "RESULT_VECTOR_SQL_SCORE_THRESHOLD",
  142. "SQL"
  143. )
  144. # 检查过滤后结果是否为空
  145. if results and not filtered_results:
  146. self.logger.warning(f"向量查询找到了 {len(results)} 条SQL问答对,但全部被阈值过滤掉了.")
  147. # self.logger.warning(f"问题: {question}")
  148. return filtered_results
  149. def get_related_ddl(self, question: str, **kwargs) -> list:
  150. # 尝试使用embedding缓存
  151. embedding_cache = get_embedding_cache_manager()
  152. cached_embedding = embedding_cache.get_cached_embedding(question)
  153. if cached_embedding is not None:
  154. # 使用缓存的embedding进行向量搜索
  155. docs_with_scores = self.ddl_collection.similarity_search_with_score_by_vector(
  156. embedding=cached_embedding,
  157. k=self.n_results
  158. )
  159. else:
  160. # 执行常规的向量搜索(会自动生成embedding)
  161. docs_with_scores = self.ddl_collection.similarity_search_with_score(
  162. query=question,
  163. k=self.n_results
  164. )
  165. # 获取刚生成的embedding并缓存
  166. try:
  167. # 通过embedding_function生成向量并缓存
  168. generated_embedding = self.embedding_function.generate_embedding(question)
  169. if generated_embedding:
  170. embedding_cache.cache_embedding(question, generated_embedding)
  171. except Exception as e:
  172. self.logger.warning(f"缓存embedding失败: {e}")
  173. results = []
  174. for doc, score in docs_with_scores:
  175. # 计算相似度
  176. similarity = round(1 - score, 4)
  177. # 每条记录单独打印
  178. self.logger.debug(f"DDL Match: {doc.page_content[:50]}... | similarity: {similarity}")
  179. # 添加 similarity 字段
  180. result = {
  181. "content": doc.page_content,
  182. "similarity": similarity
  183. }
  184. results.append(result)
  185. # 检查原始查询结果是否为空
  186. if not results:
  187. self.logger.warning(f"向量查询未找到任何相关的DDL表结构,问题: {question}")
  188. # 应用阈值过滤
  189. filtered_results = self._apply_score_threshold_filter(
  190. results,
  191. "RESULT_VECTOR_DDL_SCORE_THRESHOLD",
  192. "DDL"
  193. )
  194. # 检查过滤后结果是否为空
  195. if results and not filtered_results:
  196. self.logger.warning(f"向量查询找到了 {len(results)} 条DDL表结构,但全部被阈值过滤掉,问题: {question}")
  197. return filtered_results
  198. def get_related_documentation(self, question: str, **kwargs) -> list:
  199. # 尝试使用embedding缓存
  200. embedding_cache = get_embedding_cache_manager()
  201. cached_embedding = embedding_cache.get_cached_embedding(question)
  202. if cached_embedding is not None:
  203. # 使用缓存的embedding进行向量搜索
  204. docs_with_scores = self.documentation_collection.similarity_search_with_score_by_vector(
  205. embedding=cached_embedding,
  206. k=self.n_results
  207. )
  208. else:
  209. # 执行常规的向量搜索(会自动生成embedding)
  210. docs_with_scores = self.documentation_collection.similarity_search_with_score(
  211. query=question,
  212. k=self.n_results
  213. )
  214. # 获取刚生成的embedding并缓存
  215. try:
  216. # 通过embedding_function生成向量并缓存
  217. generated_embedding = self.embedding_function.generate_embedding(question)
  218. if generated_embedding:
  219. embedding_cache.cache_embedding(question, generated_embedding)
  220. except Exception as e:
  221. self.logger.warning(f"缓存embedding失败: {e}")
  222. results = []
  223. for doc, score in docs_with_scores:
  224. # 计算相似度
  225. similarity = round(1 - score, 4)
  226. # 每条记录单独打印
  227. self.logger.debug(f"Doc Match: {doc.page_content[:50]}... | similarity: {similarity}")
  228. # 添加 similarity 字段
  229. result = {
  230. "content": doc.page_content,
  231. "similarity": similarity
  232. }
  233. results.append(result)
  234. # 检查原始查询结果是否为空
  235. if not results:
  236. self.logger.warning(f"向量查询未找到任何相关的文档,问题: {question}")
  237. # 应用阈值过滤
  238. filtered_results = self._apply_score_threshold_filter(
  239. results,
  240. "RESULT_VECTOR_DOC_SCORE_THRESHOLD",
  241. "DOC"
  242. )
  243. # 检查过滤后结果是否为空
  244. if results and not filtered_results:
  245. self.logger.warning(f"向量查询找到了 {len(results)} 条文档,但全部被阈值过滤掉,问题: {question}")
  246. return filtered_results
  247. def _apply_score_threshold_filter(self, results: list, threshold_config_key: str, result_type: str) -> list:
  248. """
  249. 应用相似度阈值过滤逻辑
  250. Args:
  251. results: 原始结果列表,每个元素包含 similarity 字段
  252. threshold_config_key: 配置中的阈值参数名
  253. result_type: 结果类型(用于日志)
  254. Returns:
  255. 过滤后的结果列表
  256. """
  257. if not results:
  258. return results
  259. # 导入配置
  260. try:
  261. import app_config
  262. enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
  263. threshold = getattr(app_config, threshold_config_key, 0.65)
  264. except (ImportError, AttributeError) as e:
  265. self.logger.warning(f"无法加载阈值配置: {e},使用默认值")
  266. enable_threshold = False
  267. threshold = 0.65
  268. # 如果未启用阈值过滤,直接返回原结果
  269. if not enable_threshold:
  270. self.logger.debug(f"{result_type} 阈值过滤未启用,返回全部 {len(results)} 条结果")
  271. return results
  272. total_count = len(results)
  273. min_required = max((total_count + 1) // 2, 1)
  274. self.logger.debug(f"{result_type} 阈值过滤: 总数={total_count}, 阈值={threshold}, 最少保留={min_required}")
  275. # 按相似度降序排序(确保最相似的在前面)
  276. sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
  277. # 找出满足阈值的结果
  278. above_threshold = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
  279. # 应用过滤逻辑
  280. if len(above_threshold) >= min_required:
  281. # 情况1: 满足阈值的结果数量 >= 最少保留数量,返回满足阈值的结果
  282. filtered_results = above_threshold
  283. filtered_count = len(above_threshold)
  284. self.logger.debug(f"{result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (全部满足阈值)")
  285. else:
  286. # 情况2: 满足阈值的结果数量 < 最少保留数量,强制保留前 min_required 条
  287. filtered_results = sorted_results[:min_required]
  288. above_count = len(above_threshold)
  289. below_count = min_required - above_count
  290. filtered_count = min_required
  291. self.logger.debug(f"{result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (满足阈值: {above_count}, 强制保留: {below_count})")
  292. # 打印过滤详情
  293. for i, result in enumerate(filtered_results):
  294. similarity = result.get('similarity', 0)
  295. status = "✓" if similarity >= threshold else "✗"
  296. self.logger.debug(f"{result_type} 保留 {i+1}: similarity={similarity} {status}")
  297. return filtered_results
  298. def _apply_error_sql_threshold_filter(self, results: list) -> list:
  299. """
  300. 应用错误SQL特有的相似度阈值过滤逻辑
  301. 与其他方法不同,错误SQL的过滤逻辑是:
  302. - 只返回相似度高于阈值的结果
  303. - 不设置最低返回数量
  304. - 如果都低于阈值,返回空列表
  305. Args:
  306. results: 原始结果列表,每个元素包含 similarity 字段
  307. Returns:
  308. 过滤后的结果列表
  309. """
  310. if not results:
  311. return results
  312. # 导入配置
  313. try:
  314. import app_config
  315. enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
  316. threshold = getattr(app_config, 'RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD', 0.5)
  317. except (ImportError, AttributeError) as e:
  318. self.logger.warning(f"无法加载错误SQL阈值配置: {e},使用默认值")
  319. enable_threshold = False
  320. threshold = 0.5
  321. # 如果未启用阈值过滤,直接返回原结果
  322. if not enable_threshold:
  323. self.logger.debug(f"Error SQL 阈值过滤未启用,返回全部 {len(results)} 条结果")
  324. return results
  325. total_count = len(results)
  326. self.logger.debug(f"Error SQL 阈值过滤: 总数={total_count}, 阈值={threshold}")
  327. # 按相似度降序排序(确保最相似的在前面)
  328. sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
  329. # 只保留满足阈值的结果,不设置最低返回数量
  330. filtered_results = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
  331. filtered_count = len(filtered_results)
  332. filtered_out_count = total_count - filtered_count
  333. if filtered_count > 0:
  334. self.logger.debug(f"Error SQL 过滤结果: 保留 {filtered_count} 条, 过滤掉 {filtered_out_count} 条")
  335. # 打印保留的结果详情
  336. for i, result in enumerate(filtered_results):
  337. similarity = result.get('similarity', 0)
  338. self.logger.debug(f"Error SQL 保留 {i+1}: similarity={similarity} ✓")
  339. else:
  340. self.logger.debug(f"Error SQL 过滤结果: 所有 {total_count} 条结果都低于阈值 {threshold},返回空列表")
  341. return filtered_results
  342. def train(
  343. self,
  344. question: str | None = None,
  345. sql: str | None = None,
  346. ddl: str | None = None,
  347. documentation: str | None = None,
  348. plan: TrainingPlan | None = None,
  349. createdat: str | None = None,
  350. ):
  351. if question and not sql:
  352. raise ValidationError("Please provide a SQL query.")
  353. if documentation:
  354. logging.info(f"Adding documentation: {documentation}")
  355. return self.add_documentation(documentation)
  356. if sql and question:
  357. return self.add_question_sql(question=question, sql=sql, createdat=createdat)
  358. if ddl:
  359. logging.info(f"Adding ddl: {ddl}")
  360. return self.add_ddl(ddl)
  361. if plan:
  362. for item in plan._plan:
  363. if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
  364. self.add_ddl(item.item_value)
  365. elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
  366. self.add_documentation(item.item_value)
  367. elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
  368. self.add_question_sql(question=item.item_name, sql=item.item_value)
  369. def get_training_data(self, **kwargs) -> pd.DataFrame:
  370. # Establishing the connection
  371. engine = create_engine(self.connection_string)
  372. # Querying the 'langchain_pg_embedding' table
  373. query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
  374. df_embedding = pd.read_sql(query_embedding, engine)
  375. # List to accumulate the processed rows
  376. processed_rows = []
  377. # Process each row in the DataFrame
  378. for _, row in df_embedding.iterrows():
  379. custom_id = row["cmetadata"]["id"]
  380. document = row["document"]
  381. # 处理不同类型的ID后缀
  382. if custom_id.endswith("-doc"):
  383. training_data_type = "documentation"
  384. elif custom_id.endswith("-error_sql"):
  385. training_data_type = "error_sql"
  386. elif custom_id.endswith("-sql"):
  387. training_data_type = "sql"
  388. elif custom_id.endswith("-ddl"):
  389. training_data_type = "ddl"
  390. else:
  391. # If the suffix is not recognized, skip this row
  392. logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
  393. continue
  394. if training_data_type in ["sql", "error_sql"]:
  395. # Convert the document string to a dictionary
  396. try:
  397. doc_dict = ast.literal_eval(document)
  398. question = doc_dict.get("question")
  399. content = doc_dict.get("sql")
  400. except (ValueError, SyntaxError):
  401. logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
  402. continue
  403. elif training_data_type in ["documentation", "ddl"]:
  404. question = None # Default value for question
  405. content = document
  406. # Append the processed data to the list
  407. processed_rows.append(
  408. {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
  409. )
  410. # Create a DataFrame from the list of processed rows
  411. df_processed = pd.DataFrame(processed_rows)
  412. return df_processed
  413. def remove_training_data(self, id: str, **kwargs) -> bool:
  414. # Create the database engine
  415. engine = create_engine(self.connection_string)
  416. # SQL DELETE statement
  417. delete_statement = text(
  418. """
  419. DELETE FROM langchain_pg_embedding
  420. WHERE cmetadata ->> 'id' = :id
  421. """
  422. )
  423. # Connect to the database and execute the delete statement
  424. with engine.connect() as connection:
  425. # Start a transaction
  426. with connection.begin() as transaction:
  427. try:
  428. result = connection.execute(delete_statement, {"id": id})
  429. # Commit the transaction if the delete was successful
  430. transaction.commit()
  431. # Check if any row was deleted and return True or False accordingly
  432. return result.rowcount > 0
  433. except Exception as e:
  434. # Rollback the transaction in case of error
  435. logging.error(f"An error occurred: {e}")
  436. transaction.rollback()
  437. return False
  438. def remove_collection(self, collection_name: str) -> bool:
  439. engine = create_engine(self.connection_string)
  440. # Determine the suffix to look for based on the collection name
  441. suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc", "error_sql": "error_sql"}
  442. suffix = suffix_map.get(collection_name)
  443. if not suffix:
  444. logging.info("Invalid collection name. Choose from 'ddl', 'sql', 'documentation', or 'error_sql'.")
  445. return False
  446. # SQL query to delete rows based on the condition
  447. query = text(
  448. f"""
  449. DELETE FROM langchain_pg_embedding
  450. WHERE cmetadata->>'id' LIKE '%{suffix}'
  451. """
  452. )
  453. # Execute the deletion within a transaction block
  454. with engine.connect() as connection:
  455. with connection.begin() as transaction:
  456. try:
  457. result = connection.execute(query)
  458. transaction.commit() # Explicitly commit the transaction
  459. if result.rowcount > 0:
  460. logging.info(
  461. f"Deleted {result.rowcount} rows from "
  462. f"langchain_pg_embedding where collection is {collection_name}."
  463. )
  464. return True
  465. else:
  466. logging.info(f"No rows deleted for collection {collection_name}.")
  467. return False
  468. except Exception as e:
  469. logging.error(f"An error occurred: {e}")
  470. transaction.rollback() # Rollback in case of error
  471. return False
  472. def generate_embedding(self, *args, **kwargs):
  473. pass
  474. # 增加错误SQL的训练和查询功能
  475. # 1. 确保error_sql集合存在
  476. def _ensure_error_sql_collection(self):
  477. """确保error_sql集合存在"""
  478. # 集合已在 __init__ 中初始化,这里只是为了保持方法的一致性
  479. pass
  480. # 2. 将错误的question-sql对存储到error_sql集合中
  481. def train_error_sql(self, question: str, sql: str, **kwargs) -> str:
  482. """
  483. 将错误的question-sql对存储到error_sql集合中
  484. """
  485. # 确保集合存在
  486. self._ensure_error_sql_collection()
  487. # 创建文档内容,格式与现有SQL训练数据一致
  488. question_sql_json = json.dumps(
  489. {
  490. "question": question,
  491. "sql": sql,
  492. "type": "error_sql"
  493. },
  494. ensure_ascii=False,
  495. )
  496. # 生成ID,使用与现有方法一致的格式
  497. id = str(uuid.uuid4()) + "-error_sql"
  498. createdat = kwargs.get("createdat")
  499. # 创建Document对象
  500. doc = Document(
  501. page_content=question_sql_json,
  502. metadata={"id": id, "createdat": createdat},
  503. )
  504. # 添加到error_sql集合
  505. self.error_sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
  506. return id
  507. # 3. 获取相关的错误SQL示例
  508. def get_related_error_sql(self, question: str, **kwargs) -> list:
  509. """
  510. 获取相关的错误SQL示例
  511. """
  512. # 确保集合存在
  513. self._ensure_error_sql_collection()
  514. try:
  515. # 尝试使用embedding缓存
  516. embedding_cache = get_embedding_cache_manager()
  517. cached_embedding = embedding_cache.get_cached_embedding(question)
  518. if cached_embedding is not None:
  519. # 使用缓存的embedding进行向量搜索
  520. docs_with_scores = self.error_sql_collection.similarity_search_with_score_by_vector(
  521. embedding=cached_embedding,
  522. k=self.n_results
  523. )
  524. else:
  525. # 执行常规的向量搜索(会自动生成embedding)
  526. docs_with_scores = self.error_sql_collection.similarity_search_with_score(
  527. query=question,
  528. k=self.n_results
  529. )
  530. # 获取刚生成的embedding并缓存
  531. try:
  532. # 通过embedding_function生成向量并缓存
  533. generated_embedding = self.embedding_function.generate_embedding(question)
  534. if generated_embedding:
  535. embedding_cache.cache_embedding(question, generated_embedding)
  536. except Exception as e:
  537. self.logger.warning(f"缓存embedding失败: {e}")
  538. results = []
  539. for doc, score in docs_with_scores:
  540. try:
  541. # 将文档内容转换为 dict,与现有方法保持一致
  542. base = ast.literal_eval(doc.page_content)
  543. # 计算相似度
  544. similarity = round(1 - score, 4)
  545. # 每条记录单独打印
  546. self.logger.debug(f"Error SQL Match: {base.get('question', '')} | similarity: {similarity}")
  547. # 添加 similarity 字段
  548. base["similarity"] = similarity
  549. results.append(base)
  550. except (ValueError, SyntaxError) as e:
  551. self.logger.error(f"Error parsing error SQL document: {e}")
  552. continue
  553. # 检查原始查询结果是否为空
  554. if not results:
  555. self.logger.warning(f"向量查询未找到任何相关的错误SQL示例,问题: {question}")
  556. # 应用错误SQL特有的阈值过滤逻辑
  557. filtered_results = self._apply_error_sql_threshold_filter(results)
  558. # 检查过滤后结果是否为空
  559. if results and not filtered_results:
  560. self.logger.warning(f"向量查询找到了 {len(results)} 条错误SQL示例,但全部被阈值过滤掉,问题: {question}")
  561. return filtered_results
  562. except Exception as e:
  563. self.logger.error(f"Error retrieving error SQL examples: {e}")
  564. return []