pgvector.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. import ast
  2. import json
  3. import logging
  4. import uuid
  5. import pandas as pd
  6. from langchain_core.documents import Document
  7. from langchain_postgres.vectorstores import PGVector
  8. from sqlalchemy import create_engine, text
  9. from vanna.exceptions import ValidationError
  10. from vanna.base import VannaBase
  11. from vanna.types import TrainingPlan, TrainingPlanItem
  12. # 导入embedding缓存管理器
  13. from common.embedding_cache_manager import get_embedding_cache_manager
  14. class PG_VectorStore(VannaBase):
  15. def __init__(self, config=None):
  16. if not config or "connection_string" not in config:
  17. raise ValueError(
  18. "A valid 'config' dictionary with a 'connection_string' is required.")
  19. VannaBase.__init__(self, config=config)
  20. if config and "connection_string" in config:
  21. self.connection_string = config.get("connection_string")
  22. self.n_results = config.get("n_results", 10)
  23. if config and "embedding_function" in config:
  24. self.embedding_function = config.get("embedding_function")
  25. else:
  26. raise ValueError("No embedding_function was found.")
  27. # from langchain_huggingface import HuggingFaceEmbeddings
  28. # self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
  29. self.sql_collection = PGVector(
  30. embeddings=self.embedding_function,
  31. collection_name="sql",
  32. connection=self.connection_string,
  33. )
  34. self.ddl_collection = PGVector(
  35. embeddings=self.embedding_function,
  36. collection_name="ddl",
  37. connection=self.connection_string,
  38. )
  39. self.documentation_collection = PGVector(
  40. embeddings=self.embedding_function,
  41. collection_name="documentation",
  42. connection=self.connection_string,
  43. )
  44. self.error_sql_collection = PGVector(
  45. embeddings=self.embedding_function,
  46. collection_name="error_sql",
  47. connection=self.connection_string,
  48. )
  49. def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
  50. question_sql_json = json.dumps(
  51. {
  52. "question": question,
  53. "sql": sql,
  54. },
  55. ensure_ascii=False,
  56. )
  57. id = str(uuid.uuid4()) + "-sql"
  58. createdat = kwargs.get("createdat")
  59. doc = Document(
  60. page_content=question_sql_json,
  61. metadata={"id": id, "createdat": createdat},
  62. )
  63. self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
  64. return id
  65. def add_ddl(self, ddl: str, **kwargs) -> str:
  66. _id = str(uuid.uuid4()) + "-ddl"
  67. doc = Document(
  68. page_content=ddl,
  69. metadata={"id": _id},
  70. )
  71. self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
  72. return _id
  73. def add_documentation(self, documentation: str, **kwargs) -> str:
  74. _id = str(uuid.uuid4()) + "-doc"
  75. doc = Document(
  76. page_content=documentation,
  77. metadata={"id": _id},
  78. )
  79. self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
  80. return _id
  81. def get_collection(self, collection_name):
  82. match collection_name:
  83. case "sql":
  84. return self.sql_collection
  85. case "ddl":
  86. return self.ddl_collection
  87. case "documentation":
  88. return self.documentation_collection
  89. case "error_sql":
  90. return self.error_sql_collection
  91. case _:
  92. raise ValueError("Specified collection does not exist.")
  93. # def get_similar_question_sql(self, question: str) -> list:
  94. # documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
  95. # return [ast.literal_eval(document.page_content) for document in documents]
  96. # 在原来的基础之上,增加相似度的值。
  97. def get_similar_question_sql(self, question: str) -> list:
  98. # 尝试使用embedding缓存
  99. embedding_cache = get_embedding_cache_manager()
  100. cached_embedding = embedding_cache.get_cached_embedding(question)
  101. if cached_embedding is not None:
  102. # 使用缓存的embedding进行向量搜索
  103. docs_with_scores = self.sql_collection.similarity_search_with_score_by_vector(
  104. embedding=cached_embedding,
  105. k=self.n_results
  106. )
  107. else:
  108. # 执行常规的向量搜索(会自动生成embedding)
  109. docs_with_scores = self.sql_collection.similarity_search_with_score(
  110. query=question,
  111. k=self.n_results
  112. )
  113. # 获取刚生成的embedding并缓存
  114. try:
  115. # 通过embedding_function生成向量并缓存
  116. generated_embedding = self.embedding_function.generate_embedding(question)
  117. if generated_embedding:
  118. embedding_cache.cache_embedding(question, generated_embedding)
  119. except Exception as e:
  120. print(f"[WARNING] 缓存embedding失败: {e}")
  121. results = []
  122. for doc, score in docs_with_scores:
  123. # 将文档内容转换为 dict
  124. base = ast.literal_eval(doc.page_content)
  125. # 计算相似度
  126. similarity = round(1 - score, 4)
  127. # 每条记录单独打印
  128. print(f"[DEBUG] SQL Match: {base.get('question', '')} | similarity: {similarity}")
  129. # 添加 similarity 字段
  130. base["similarity"] = similarity
  131. results.append(base)
  132. # 应用阈值过滤
  133. filtered_results = self._apply_score_threshold_filter(
  134. results,
  135. "RESULT_VECTOR_SQL_SCORE_THRESHOLD",
  136. "SQL"
  137. )
  138. return filtered_results
  139. def get_related_ddl(self, question: str, **kwargs) -> list:
  140. # 尝试使用embedding缓存
  141. embedding_cache = get_embedding_cache_manager()
  142. cached_embedding = embedding_cache.get_cached_embedding(question)
  143. if cached_embedding is not None:
  144. # 使用缓存的embedding进行向量搜索
  145. docs_with_scores = self.ddl_collection.similarity_search_with_score_by_vector(
  146. embedding=cached_embedding,
  147. k=self.n_results
  148. )
  149. else:
  150. # 执行常规的向量搜索(会自动生成embedding)
  151. docs_with_scores = self.ddl_collection.similarity_search_with_score(
  152. query=question,
  153. k=self.n_results
  154. )
  155. # 获取刚生成的embedding并缓存
  156. try:
  157. # 通过embedding_function生成向量并缓存
  158. generated_embedding = self.embedding_function.generate_embedding(question)
  159. if generated_embedding:
  160. embedding_cache.cache_embedding(question, generated_embedding)
  161. except Exception as e:
  162. print(f"[WARNING] 缓存embedding失败: {e}")
  163. results = []
  164. for doc, score in docs_with_scores:
  165. # 计算相似度
  166. similarity = round(1 - score, 4)
  167. # 每条记录单独打印
  168. print(f"[DEBUG] DDL Match: {doc.page_content[:50]}... | similarity: {similarity}")
  169. # 添加 similarity 字段
  170. result = {
  171. "content": doc.page_content,
  172. "similarity": similarity
  173. }
  174. results.append(result)
  175. # 应用阈值过滤
  176. filtered_results = self._apply_score_threshold_filter(
  177. results,
  178. "RESULT_VECTOR_DDL_SCORE_THRESHOLD",
  179. "DDL"
  180. )
  181. return filtered_results
  182. def get_related_documentation(self, question: str, **kwargs) -> list:
  183. # 尝试使用embedding缓存
  184. embedding_cache = get_embedding_cache_manager()
  185. cached_embedding = embedding_cache.get_cached_embedding(question)
  186. if cached_embedding is not None:
  187. # 使用缓存的embedding进行向量搜索
  188. docs_with_scores = self.documentation_collection.similarity_search_with_score_by_vector(
  189. embedding=cached_embedding,
  190. k=self.n_results
  191. )
  192. else:
  193. # 执行常规的向量搜索(会自动生成embedding)
  194. docs_with_scores = self.documentation_collection.similarity_search_with_score(
  195. query=question,
  196. k=self.n_results
  197. )
  198. # 获取刚生成的embedding并缓存
  199. try:
  200. # 通过embedding_function生成向量并缓存
  201. generated_embedding = self.embedding_function.generate_embedding(question)
  202. if generated_embedding:
  203. embedding_cache.cache_embedding(question, generated_embedding)
  204. except Exception as e:
  205. print(f"[WARNING] 缓存embedding失败: {e}")
  206. results = []
  207. for doc, score in docs_with_scores:
  208. # 计算相似度
  209. similarity = round(1 - score, 4)
  210. # 每条记录单独打印
  211. print(f"[DEBUG] Doc Match: {doc.page_content[:50]}... | similarity: {similarity}")
  212. # 添加 similarity 字段
  213. result = {
  214. "content": doc.page_content,
  215. "similarity": similarity
  216. }
  217. results.append(result)
  218. # 应用阈值过滤
  219. filtered_results = self._apply_score_threshold_filter(
  220. results,
  221. "RESULT_VECTOR_DOC_SCORE_THRESHOLD",
  222. "DOC"
  223. )
  224. return filtered_results
  225. def _apply_score_threshold_filter(self, results: list, threshold_config_key: str, result_type: str) -> list:
  226. """
  227. 应用相似度阈值过滤逻辑
  228. Args:
  229. results: 原始结果列表,每个元素包含 similarity 字段
  230. threshold_config_key: 配置中的阈值参数名
  231. result_type: 结果类型(用于日志)
  232. Returns:
  233. 过滤后的结果列表
  234. """
  235. if not results:
  236. return results
  237. # 导入配置
  238. try:
  239. import app_config
  240. enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
  241. threshold = getattr(app_config, threshold_config_key, 0.65)
  242. except (ImportError, AttributeError) as e:
  243. print(f"[WARNING] 无法加载阈值配置: {e},使用默认值")
  244. enable_threshold = False
  245. threshold = 0.65
  246. # 如果未启用阈值过滤,直接返回原结果
  247. if not enable_threshold:
  248. print(f"[DEBUG] {result_type} 阈值过滤未启用,返回全部 {len(results)} 条结果")
  249. return results
  250. total_count = len(results)
  251. min_required = max((total_count + 1) // 2, 1)
  252. print(f"[DEBUG] {result_type} 阈值过滤: 总数={total_count}, 阈值={threshold}, 最少保留={min_required}")
  253. # 按相似度降序排序(确保最相似的在前面)
  254. sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
  255. # 找出满足阈值的结果
  256. above_threshold = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
  257. # 应用过滤逻辑
  258. if len(above_threshold) >= min_required:
  259. # 情况1: 满足阈值的结果数量 >= 最少保留数量,返回满足阈值的结果
  260. filtered_results = above_threshold
  261. filtered_count = len(above_threshold)
  262. print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (全部满足阈值)")
  263. else:
  264. # 情况2: 满足阈值的结果数量 < 最少保留数量,强制保留前 min_required 条
  265. filtered_results = sorted_results[:min_required]
  266. above_count = len(above_threshold)
  267. below_count = min_required - above_count
  268. filtered_count = min_required
  269. print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (满足阈值: {above_count}, 强制保留: {below_count})")
  270. # 打印过滤详情
  271. for i, result in enumerate(filtered_results):
  272. similarity = result.get('similarity', 0)
  273. status = "✓" if similarity >= threshold else "✗"
  274. print(f"[DEBUG] {result_type} 保留 {i+1}: similarity={similarity} {status}")
  275. return filtered_results
  276. def _apply_error_sql_threshold_filter(self, results: list) -> list:
  277. """
  278. 应用错误SQL特有的相似度阈值过滤逻辑
  279. 与其他方法不同,错误SQL的过滤逻辑是:
  280. - 只返回相似度高于阈值的结果
  281. - 不设置最低返回数量
  282. - 如果都低于阈值,返回空列表
  283. Args:
  284. results: 原始结果列表,每个元素包含 similarity 字段
  285. Returns:
  286. 过滤后的结果列表
  287. """
  288. if not results:
  289. return results
  290. # 导入配置
  291. try:
  292. import app_config
  293. enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
  294. threshold = getattr(app_config, 'RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD', 0.5)
  295. except (ImportError, AttributeError) as e:
  296. print(f"[WARNING] 无法加载错误SQL阈值配置: {e},使用默认值")
  297. enable_threshold = False
  298. threshold = 0.5
  299. # 如果未启用阈值过滤,直接返回原结果
  300. if not enable_threshold:
  301. print(f"[DEBUG] Error SQL 阈值过滤未启用,返回全部 {len(results)} 条结果")
  302. return results
  303. total_count = len(results)
  304. print(f"[DEBUG] Error SQL 阈值过滤: 总数={total_count}, 阈值={threshold}")
  305. # 按相似度降序排序(确保最相似的在前面)
  306. sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
  307. # 只保留满足阈值的结果,不设置最低返回数量
  308. filtered_results = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
  309. filtered_count = len(filtered_results)
  310. filtered_out_count = total_count - filtered_count
  311. if filtered_count > 0:
  312. print(f"[DEBUG] Error SQL 过滤结果: 保留 {filtered_count} 条, 过滤掉 {filtered_out_count} 条")
  313. # 打印保留的结果详情
  314. for i, result in enumerate(filtered_results):
  315. similarity = result.get('similarity', 0)
  316. print(f"[DEBUG] Error SQL 保留 {i+1}: similarity={similarity} ✓")
  317. else:
  318. print(f"[DEBUG] Error SQL 过滤结果: 所有 {total_count} 条结果都低于阈值 {threshold},返回空列表")
  319. return filtered_results
  320. def train(
  321. self,
  322. question: str | None = None,
  323. sql: str | None = None,
  324. ddl: str | None = None,
  325. documentation: str | None = None,
  326. plan: TrainingPlan | None = None,
  327. createdat: str | None = None,
  328. ):
  329. if question and not sql:
  330. raise ValidationError("Please provide a SQL query.")
  331. if documentation:
  332. logging.info(f"Adding documentation: {documentation}")
  333. return self.add_documentation(documentation)
  334. if sql and question:
  335. return self.add_question_sql(question=question, sql=sql, createdat=createdat)
  336. if ddl:
  337. logging.info(f"Adding ddl: {ddl}")
  338. return self.add_ddl(ddl)
  339. if plan:
  340. for item in plan._plan:
  341. if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
  342. self.add_ddl(item.item_value)
  343. elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
  344. self.add_documentation(item.item_value)
  345. elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
  346. self.add_question_sql(question=item.item_name, sql=item.item_value)
  347. def get_training_data(self, **kwargs) -> pd.DataFrame:
  348. # Establishing the connection
  349. engine = create_engine(self.connection_string)
  350. # Querying the 'langchain_pg_embedding' table
  351. query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
  352. df_embedding = pd.read_sql(query_embedding, engine)
  353. # List to accumulate the processed rows
  354. processed_rows = []
  355. # Process each row in the DataFrame
  356. for _, row in df_embedding.iterrows():
  357. custom_id = row["cmetadata"]["id"]
  358. document = row["document"]
  359. # 处理不同类型的ID后缀
  360. if custom_id.endswith("-doc"):
  361. training_data_type = "documentation"
  362. elif custom_id.endswith("-error_sql"):
  363. training_data_type = "error_sql"
  364. elif custom_id.endswith("-sql"):
  365. training_data_type = "sql"
  366. elif custom_id.endswith("-ddl"):
  367. training_data_type = "ddl"
  368. else:
  369. # If the suffix is not recognized, skip this row
  370. logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
  371. continue
  372. if training_data_type in ["sql", "error_sql"]:
  373. # Convert the document string to a dictionary
  374. try:
  375. doc_dict = ast.literal_eval(document)
  376. question = doc_dict.get("question")
  377. content = doc_dict.get("sql")
  378. except (ValueError, SyntaxError):
  379. logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
  380. continue
  381. elif training_data_type in ["documentation", "ddl"]:
  382. question = None # Default value for question
  383. content = document
  384. # Append the processed data to the list
  385. processed_rows.append(
  386. {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
  387. )
  388. # Create a DataFrame from the list of processed rows
  389. df_processed = pd.DataFrame(processed_rows)
  390. return df_processed
  391. def remove_training_data(self, id: str, **kwargs) -> bool:
  392. # Create the database engine
  393. engine = create_engine(self.connection_string)
  394. # SQL DELETE statement
  395. delete_statement = text(
  396. """
  397. DELETE FROM langchain_pg_embedding
  398. WHERE cmetadata ->> 'id' = :id
  399. """
  400. )
  401. # Connect to the database and execute the delete statement
  402. with engine.connect() as connection:
  403. # Start a transaction
  404. with connection.begin() as transaction:
  405. try:
  406. result = connection.execute(delete_statement, {"id": id})
  407. # Commit the transaction if the delete was successful
  408. transaction.commit()
  409. # Check if any row was deleted and return True or False accordingly
  410. return result.rowcount > 0
  411. except Exception as e:
  412. # Rollback the transaction in case of error
  413. logging.error(f"An error occurred: {e}")
  414. transaction.rollback()
  415. return False
  416. def remove_collection(self, collection_name: str) -> bool:
  417. engine = create_engine(self.connection_string)
  418. # Determine the suffix to look for based on the collection name
  419. suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc", "error_sql": "error_sql"}
  420. suffix = suffix_map.get(collection_name)
  421. if not suffix:
  422. logging.info("Invalid collection name. Choose from 'ddl', 'sql', 'documentation', or 'error_sql'.")
  423. return False
  424. # SQL query to delete rows based on the condition
  425. query = text(
  426. f"""
  427. DELETE FROM langchain_pg_embedding
  428. WHERE cmetadata->>'id' LIKE '%{suffix}'
  429. """
  430. )
  431. # Execute the deletion within a transaction block
  432. with engine.connect() as connection:
  433. with connection.begin() as transaction:
  434. try:
  435. result = connection.execute(query)
  436. transaction.commit() # Explicitly commit the transaction
  437. if result.rowcount > 0:
  438. logging.info(
  439. f"Deleted {result.rowcount} rows from "
  440. f"langchain_pg_embedding where collection is {collection_name}."
  441. )
  442. return True
  443. else:
  444. logging.info(f"No rows deleted for collection {collection_name}.")
  445. return False
  446. except Exception as e:
  447. logging.error(f"An error occurred: {e}")
  448. transaction.rollback() # Rollback in case of error
  449. return False
  450. def generate_embedding(self, *args, **kwargs):
  451. pass
  452. # 增加错误SQL的训练和查询功能
  453. # 1. 确保error_sql集合存在
  454. def _ensure_error_sql_collection(self):
  455. """确保error_sql集合存在"""
  456. # 集合已在 __init__ 中初始化,这里只是为了保持方法的一致性
  457. pass
  458. # 2. 将错误的question-sql对存储到error_sql集合中
  459. def train_error_sql(self, question: str, sql: str, **kwargs) -> str:
  460. """
  461. 将错误的question-sql对存储到error_sql集合中
  462. """
  463. # 确保集合存在
  464. self._ensure_error_sql_collection()
  465. # 创建文档内容,格式与现有SQL训练数据一致
  466. question_sql_json = json.dumps(
  467. {
  468. "question": question,
  469. "sql": sql,
  470. "type": "error_sql"
  471. },
  472. ensure_ascii=False,
  473. )
  474. # 生成ID,使用与现有方法一致的格式
  475. id = str(uuid.uuid4()) + "-error_sql"
  476. createdat = kwargs.get("createdat")
  477. # 创建Document对象
  478. doc = Document(
  479. page_content=question_sql_json,
  480. metadata={"id": id, "createdat": createdat},
  481. )
  482. # 添加到error_sql集合
  483. self.error_sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
  484. return id
  485. # 3. 获取相关的错误SQL示例
  486. def get_related_error_sql(self, question: str, **kwargs) -> list:
  487. """
  488. 获取相关的错误SQL示例
  489. """
  490. # 确保集合存在
  491. self._ensure_error_sql_collection()
  492. try:
  493. # 尝试使用embedding缓存
  494. embedding_cache = get_embedding_cache_manager()
  495. cached_embedding = embedding_cache.get_cached_embedding(question)
  496. if cached_embedding is not None:
  497. # 使用缓存的embedding进行向量搜索
  498. docs_with_scores = self.error_sql_collection.similarity_search_with_score_by_vector(
  499. embedding=cached_embedding,
  500. k=self.n_results
  501. )
  502. else:
  503. # 执行常规的向量搜索(会自动生成embedding)
  504. docs_with_scores = self.error_sql_collection.similarity_search_with_score(
  505. query=question,
  506. k=self.n_results
  507. )
  508. # 获取刚生成的embedding并缓存
  509. try:
  510. # 通过embedding_function生成向量并缓存
  511. generated_embedding = self.embedding_function.generate_embedding(question)
  512. if generated_embedding:
  513. embedding_cache.cache_embedding(question, generated_embedding)
  514. except Exception as e:
  515. print(f"[WARNING] 缓存embedding失败: {e}")
  516. results = []
  517. for doc, score in docs_with_scores:
  518. try:
  519. # 将文档内容转换为 dict,与现有方法保持一致
  520. base = ast.literal_eval(doc.page_content)
  521. # 计算相似度
  522. similarity = round(1 - score, 4)
  523. # 每条记录单独打印
  524. print(f"[DEBUG] Error SQL Match: {base.get('question', '')} | similarity: {similarity}")
  525. # 添加 similarity 字段
  526. base["similarity"] = similarity
  527. results.append(base)
  528. except (ValueError, SyntaxError) as e:
  529. print(f"Error parsing error SQL document: {e}")
  530. continue
  531. # 应用错误SQL特有的阈值过滤逻辑
  532. filtered_results = self._apply_error_sql_threshold_filter(results)
  533. return filtered_results
  534. except Exception as e:
  535. print(f"Error retrieving error SQL examples: {e}")
  536. return []