pgvector.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. import ast
  2. import json
  3. import logging
  4. import uuid
  5. import pandas as pd
  6. from langchain_core.documents import Document
  7. from langchain_postgres.vectorstores import PGVector
  8. from sqlalchemy import create_engine, text
  9. from vanna.exceptions import ValidationError
  10. from vanna.base import VannaBase
  11. from vanna.types import TrainingPlan, TrainingPlanItem
  12. class PG_VectorStore(VannaBase):
  13. def __init__(self, config=None):
  14. if not config or "connection_string" not in config:
  15. raise ValueError(
  16. "A valid 'config' dictionary with a 'connection_string' is required.")
  17. VannaBase.__init__(self, config=config)
  18. if config and "connection_string" in config:
  19. self.connection_string = config.get("connection_string")
  20. self.n_results = config.get("n_results", 10)
  21. if config and "embedding_function" in config:
  22. self.embedding_function = config.get("embedding_function")
  23. else:
  24. raise ValueError("No embedding_function was found.")
  25. # from langchain_huggingface import HuggingFaceEmbeddings
  26. # self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
  27. self.sql_collection = PGVector(
  28. embeddings=self.embedding_function,
  29. collection_name="sql",
  30. connection=self.connection_string,
  31. )
  32. self.ddl_collection = PGVector(
  33. embeddings=self.embedding_function,
  34. collection_name="ddl",
  35. connection=self.connection_string,
  36. )
  37. self.documentation_collection = PGVector(
  38. embeddings=self.embedding_function,
  39. collection_name="documentation",
  40. connection=self.connection_string,
  41. )
  42. self.error_sql_collection = PGVector(
  43. embeddings=self.embedding_function,
  44. collection_name="error_sql",
  45. connection=self.connection_string,
  46. )
  47. def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
  48. question_sql_json = json.dumps(
  49. {
  50. "question": question,
  51. "sql": sql,
  52. },
  53. ensure_ascii=False,
  54. )
  55. id = str(uuid.uuid4()) + "-sql"
  56. createdat = kwargs.get("createdat")
  57. doc = Document(
  58. page_content=question_sql_json,
  59. metadata={"id": id, "createdat": createdat},
  60. )
  61. self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
  62. return id
  63. def add_ddl(self, ddl: str, **kwargs) -> str:
  64. _id = str(uuid.uuid4()) + "-ddl"
  65. doc = Document(
  66. page_content=ddl,
  67. metadata={"id": _id},
  68. )
  69. self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
  70. return _id
  71. def add_documentation(self, documentation: str, **kwargs) -> str:
  72. _id = str(uuid.uuid4()) + "-doc"
  73. doc = Document(
  74. page_content=documentation,
  75. metadata={"id": _id},
  76. )
  77. self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
  78. return _id
  79. def get_collection(self, collection_name):
  80. match collection_name:
  81. case "sql":
  82. return self.sql_collection
  83. case "ddl":
  84. return self.ddl_collection
  85. case "documentation":
  86. return self.documentation_collection
  87. case "error_sql":
  88. return self.error_sql_collection
  89. case _:
  90. raise ValueError("Specified collection does not exist.")
  91. # def get_similar_question_sql(self, question: str) -> list:
  92. # documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
  93. # return [ast.literal_eval(document.page_content) for document in documents]
  94. # 在原来的基础之上,增加相似度的值。
  95. def get_similar_question_sql(self, question: str) -> list:
  96. docs_with_scores = self.sql_collection.similarity_search_with_score(
  97. query=question,
  98. k=self.n_results
  99. )
  100. results = []
  101. for doc, score in docs_with_scores:
  102. # 将文档内容转换为 dict
  103. base = ast.literal_eval(doc.page_content)
  104. # 计算相似度
  105. similarity = round(1 - score, 4)
  106. # 每条记录单独打印
  107. print(f"[DEBUG] SQL Match: {base.get('question', '')} | similarity: {similarity}")
  108. # 添加 similarity 字段
  109. base["similarity"] = similarity
  110. results.append(base)
  111. # 应用阈值过滤
  112. filtered_results = self._apply_score_threshold_filter(
  113. results,
  114. "RESULT_VECTOR_SQL_SCORE_THRESHOLD",
  115. "SQL"
  116. )
  117. return filtered_results
  118. def get_related_ddl(self, question: str, **kwargs) -> list:
  119. docs_with_scores = self.ddl_collection.similarity_search_with_score(
  120. query=question,
  121. k=self.n_results
  122. )
  123. results = []
  124. for doc, score in docs_with_scores:
  125. # 计算相似度
  126. similarity = round(1 - score, 4)
  127. # 每条记录单独打印
  128. print(f"[DEBUG] DDL Match: {doc.page_content[:50]}... | similarity: {similarity}")
  129. # 添加 similarity 字段
  130. result = {
  131. "content": doc.page_content,
  132. "similarity": similarity
  133. }
  134. results.append(result)
  135. # 应用阈值过滤
  136. filtered_results = self._apply_score_threshold_filter(
  137. results,
  138. "RESULT_VECTOR_DDL_SCORE_THRESHOLD",
  139. "DDL"
  140. )
  141. return filtered_results
  142. def get_related_documentation(self, question: str, **kwargs) -> list:
  143. docs_with_scores = self.documentation_collection.similarity_search_with_score(
  144. query=question,
  145. k=self.n_results
  146. )
  147. results = []
  148. for doc, score in docs_with_scores:
  149. # 计算相似度
  150. similarity = round(1 - score, 4)
  151. # 每条记录单独打印
  152. print(f"[DEBUG] Doc Match: {doc.page_content[:50]}... | similarity: {similarity}")
  153. # 添加 similarity 字段
  154. result = {
  155. "content": doc.page_content,
  156. "similarity": similarity
  157. }
  158. results.append(result)
  159. # 应用阈值过滤
  160. filtered_results = self._apply_score_threshold_filter(
  161. results,
  162. "RESULT_VECTOR_DOC_SCORE_THRESHOLD",
  163. "DOC"
  164. )
  165. return filtered_results
  166. def _apply_score_threshold_filter(self, results: list, threshold_config_key: str, result_type: str) -> list:
  167. """
  168. 应用相似度阈值过滤逻辑
  169. Args:
  170. results: 原始结果列表,每个元素包含 similarity 字段
  171. threshold_config_key: 配置中的阈值参数名
  172. result_type: 结果类型(用于日志)
  173. Returns:
  174. 过滤后的结果列表
  175. """
  176. if not results:
  177. return results
  178. # 导入配置
  179. try:
  180. import app_config
  181. enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
  182. threshold = getattr(app_config, threshold_config_key, 0.65)
  183. except (ImportError, AttributeError) as e:
  184. print(f"[WARNING] 无法加载阈值配置: {e},使用默认值")
  185. enable_threshold = False
  186. threshold = 0.65
  187. # 如果未启用阈值过滤,直接返回原结果
  188. if not enable_threshold:
  189. print(f"[DEBUG] {result_type} 阈值过滤未启用,返回全部 {len(results)} 条结果")
  190. return results
  191. total_count = len(results)
  192. min_required = max((total_count + 1) // 2, 1)
  193. print(f"[DEBUG] {result_type} 阈值过滤: 总数={total_count}, 阈值={threshold}, 最少保留={min_required}")
  194. # 按相似度降序排序(确保最相似的在前面)
  195. sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
  196. # 找出满足阈值的结果
  197. above_threshold = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
  198. # 应用过滤逻辑
  199. if len(above_threshold) >= min_required:
  200. # 情况1: 满足阈值的结果数量 >= 最少保留数量,返回满足阈值的结果
  201. filtered_results = above_threshold
  202. filtered_count = len(above_threshold)
  203. print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (全部满足阈值)")
  204. else:
  205. # 情况2: 满足阈值的结果数量 < 最少保留数量,强制保留前 min_required 条
  206. filtered_results = sorted_results[:min_required]
  207. above_count = len(above_threshold)
  208. below_count = min_required - above_count
  209. filtered_count = min_required
  210. print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (满足阈值: {above_count}, 强制保留: {below_count})")
  211. # 打印过滤详情
  212. for i, result in enumerate(filtered_results):
  213. similarity = result.get('similarity', 0)
  214. status = "✓" if similarity >= threshold else "✗"
  215. print(f"[DEBUG] {result_type} 保留 {i+1}: similarity={similarity} {status}")
  216. return filtered_results
  217. def _apply_error_sql_threshold_filter(self, results: list) -> list:
  218. """
  219. 应用错误SQL特有的相似度阈值过滤逻辑
  220. 与其他方法不同,错误SQL的过滤逻辑是:
  221. - 只返回相似度高于阈值的结果
  222. - 不设置最低返回数量
  223. - 如果都低于阈值,返回空列表
  224. Args:
  225. results: 原始结果列表,每个元素包含 similarity 字段
  226. Returns:
  227. 过滤后的结果列表
  228. """
  229. if not results:
  230. return results
  231. # 导入配置
  232. try:
  233. import app_config
  234. enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
  235. threshold = getattr(app_config, 'RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD', 0.5)
  236. except (ImportError, AttributeError) as e:
  237. print(f"[WARNING] 无法加载错误SQL阈值配置: {e},使用默认值")
  238. enable_threshold = False
  239. threshold = 0.5
  240. # 如果未启用阈值过滤,直接返回原结果
  241. if not enable_threshold:
  242. print(f"[DEBUG] Error SQL 阈值过滤未启用,返回全部 {len(results)} 条结果")
  243. return results
  244. total_count = len(results)
  245. print(f"[DEBUG] Error SQL 阈值过滤: 总数={total_count}, 阈值={threshold}")
  246. # 按相似度降序排序(确保最相似的在前面)
  247. sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
  248. # 只保留满足阈值的结果,不设置最低返回数量
  249. filtered_results = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
  250. filtered_count = len(filtered_results)
  251. filtered_out_count = total_count - filtered_count
  252. if filtered_count > 0:
  253. print(f"[DEBUG] Error SQL 过滤结果: 保留 {filtered_count} 条, 过滤掉 {filtered_out_count} 条")
  254. # 打印保留的结果详情
  255. for i, result in enumerate(filtered_results):
  256. similarity = result.get('similarity', 0)
  257. print(f"[DEBUG] Error SQL 保留 {i+1}: similarity={similarity} ✓")
  258. else:
  259. print(f"[DEBUG] Error SQL 过滤结果: 所有 {total_count} 条结果都低于阈值 {threshold},返回空列表")
  260. return filtered_results
  261. def train(
  262. self,
  263. question: str | None = None,
  264. sql: str | None = None,
  265. ddl: str | None = None,
  266. documentation: str | None = None,
  267. plan: TrainingPlan | None = None,
  268. createdat: str | None = None,
  269. ):
  270. if question and not sql:
  271. raise ValidationError("Please provide a SQL query.")
  272. if documentation:
  273. logging.info(f"Adding documentation: {documentation}")
  274. return self.add_documentation(documentation)
  275. if sql and question:
  276. return self.add_question_sql(question=question, sql=sql, createdat=createdat)
  277. if ddl:
  278. logging.info(f"Adding ddl: {ddl}")
  279. return self.add_ddl(ddl)
  280. if plan:
  281. for item in plan._plan:
  282. if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
  283. self.add_ddl(item.item_value)
  284. elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
  285. self.add_documentation(item.item_value)
  286. elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
  287. self.add_question_sql(question=item.item_name, sql=item.item_value)
  288. def get_training_data(self, **kwargs) -> pd.DataFrame:
  289. # Establishing the connection
  290. engine = create_engine(self.connection_string)
  291. # Querying the 'langchain_pg_embedding' table
  292. query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
  293. df_embedding = pd.read_sql(query_embedding, engine)
  294. # List to accumulate the processed rows
  295. processed_rows = []
  296. # Process each row in the DataFrame
  297. for _, row in df_embedding.iterrows():
  298. custom_id = row["cmetadata"]["id"]
  299. document = row["document"]
  300. # 处理不同类型的ID后缀
  301. if custom_id.endswith("-doc"):
  302. training_data_type = "documentation"
  303. elif custom_id.endswith("-error_sql"):
  304. training_data_type = "error_sql"
  305. elif custom_id.endswith("-sql"):
  306. training_data_type = "sql"
  307. elif custom_id.endswith("-ddl"):
  308. training_data_type = "ddl"
  309. else:
  310. # If the suffix is not recognized, skip this row
  311. logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
  312. continue
  313. if training_data_type in ["sql", "error_sql"]:
  314. # Convert the document string to a dictionary
  315. try:
  316. doc_dict = ast.literal_eval(document)
  317. question = doc_dict.get("question")
  318. content = doc_dict.get("sql")
  319. except (ValueError, SyntaxError):
  320. logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
  321. continue
  322. elif training_data_type in ["documentation", "ddl"]:
  323. question = None # Default value for question
  324. content = document
  325. # Append the processed data to the list
  326. processed_rows.append(
  327. {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
  328. )
  329. # Create a DataFrame from the list of processed rows
  330. df_processed = pd.DataFrame(processed_rows)
  331. return df_processed
  332. def remove_training_data(self, id: str, **kwargs) -> bool:
  333. # Create the database engine
  334. engine = create_engine(self.connection_string)
  335. # SQL DELETE statement
  336. delete_statement = text(
  337. """
  338. DELETE FROM langchain_pg_embedding
  339. WHERE cmetadata ->> 'id' = :id
  340. """
  341. )
  342. # Connect to the database and execute the delete statement
  343. with engine.connect() as connection:
  344. # Start a transaction
  345. with connection.begin() as transaction:
  346. try:
  347. result = connection.execute(delete_statement, {"id": id})
  348. # Commit the transaction if the delete was successful
  349. transaction.commit()
  350. # Check if any row was deleted and return True or False accordingly
  351. return result.rowcount > 0
  352. except Exception as e:
  353. # Rollback the transaction in case of error
  354. logging.error(f"An error occurred: {e}")
  355. transaction.rollback()
  356. return False
  357. def remove_collection(self, collection_name: str) -> bool:
  358. engine = create_engine(self.connection_string)
  359. # Determine the suffix to look for based on the collection name
  360. suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc", "error_sql": "error_sql"}
  361. suffix = suffix_map.get(collection_name)
  362. if not suffix:
  363. logging.info("Invalid collection name. Choose from 'ddl', 'sql', 'documentation', or 'error_sql'.")
  364. return False
  365. # SQL query to delete rows based on the condition
  366. query = text(
  367. f"""
  368. DELETE FROM langchain_pg_embedding
  369. WHERE cmetadata->>'id' LIKE '%{suffix}'
  370. """
  371. )
  372. # Execute the deletion within a transaction block
  373. with engine.connect() as connection:
  374. with connection.begin() as transaction:
  375. try:
  376. result = connection.execute(query)
  377. transaction.commit() # Explicitly commit the transaction
  378. if result.rowcount > 0:
  379. logging.info(
  380. f"Deleted {result.rowcount} rows from "
  381. f"langchain_pg_embedding where collection is {collection_name}."
  382. )
  383. return True
  384. else:
  385. logging.info(f"No rows deleted for collection {collection_name}.")
  386. return False
  387. except Exception as e:
  388. logging.error(f"An error occurred: {e}")
  389. transaction.rollback() # Rollback in case of error
  390. return False
  391. def generate_embedding(self, *args, **kwargs):
  392. pass
  393. # 增加错误SQL的训练和查询功能
  394. # 1. 确保error_sql集合存在
  395. def _ensure_error_sql_collection(self):
  396. """确保error_sql集合存在"""
  397. # 集合已在 __init__ 中初始化,这里只是为了保持方法的一致性
  398. pass
  399. # 2. 将错误的question-sql对存储到error_sql集合中
  400. def train_error_sql(self, question: str, sql: str, **kwargs) -> str:
  401. """
  402. 将错误的question-sql对存储到error_sql集合中
  403. """
  404. # 确保集合存在
  405. self._ensure_error_sql_collection()
  406. # 创建文档内容,格式与现有SQL训练数据一致
  407. question_sql_json = json.dumps(
  408. {
  409. "question": question,
  410. "sql": sql,
  411. "type": "error_sql"
  412. },
  413. ensure_ascii=False,
  414. )
  415. # 生成ID,使用与现有方法一致的格式
  416. id = str(uuid.uuid4()) + "-error_sql"
  417. createdat = kwargs.get("createdat")
  418. # 创建Document对象
  419. doc = Document(
  420. page_content=question_sql_json,
  421. metadata={"id": id, "createdat": createdat},
  422. )
  423. # 添加到error_sql集合
  424. self.error_sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
  425. return id
  426. # 3. 获取相关的错误SQL示例
  427. def get_related_error_sql(self, question: str, **kwargs) -> list:
  428. """
  429. 获取相关的错误SQL示例
  430. """
  431. # 确保集合存在
  432. self._ensure_error_sql_collection()
  433. try:
  434. docs_with_scores = self.error_sql_collection.similarity_search_with_score(
  435. query=question,
  436. k=self.n_results
  437. )
  438. results = []
  439. for doc, score in docs_with_scores:
  440. try:
  441. # 将文档内容转换为 dict,与现有方法保持一致
  442. base = ast.literal_eval(doc.page_content)
  443. # 计算相似度
  444. similarity = round(1 - score, 4)
  445. # 每条记录单独打印
  446. print(f"[DEBUG] Error SQL Match: {base.get('question', '')} | similarity: {similarity}")
  447. # 添加 similarity 字段
  448. base["similarity"] = similarity
  449. results.append(base)
  450. except (ValueError, SyntaxError) as e:
  451. print(f"Error parsing error SQL document: {e}")
  452. continue
  453. # 应用错误SQL特有的阈值过滤逻辑
  454. filtered_results = self._apply_error_sql_threshold_filter(results)
  455. return filtered_results
  456. except Exception as e:
  457. print(f"Error retrieving error SQL examples: {e}")
  458. return []