pgvector.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. import ast
  2. import json
  3. import logging
  4. import uuid
  5. import pandas as pd
  6. from langchain_core.documents import Document
  7. from langchain_postgres.vectorstores import PGVector
  8. from sqlalchemy import create_engine, text
  9. from vanna.exceptions import ValidationError
  10. from vanna.base import VannaBase
  11. from vanna.types import TrainingPlan, TrainingPlanItem
  12. class PG_VectorStore(VannaBase):
  13. def __init__(self, config=None):
  14. if not config or "connection_string" not in config:
  15. raise ValueError(
  16. "A valid 'config' dictionary with a 'connection_string' is required.")
  17. VannaBase.__init__(self, config=config)
  18. if config and "connection_string" in config:
  19. self.connection_string = config.get("connection_string")
  20. self.n_results = config.get("n_results", 10)
  21. if config and "embedding_function" in config:
  22. self.embedding_function = config.get("embedding_function")
  23. else:
  24. raise ValueError("No embedding_function was found.")
  25. # from langchain_huggingface import HuggingFaceEmbeddings
  26. # self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
  27. self.sql_collection = PGVector(
  28. embeddings=self.embedding_function,
  29. collection_name="sql",
  30. connection=self.connection_string,
  31. )
  32. self.ddl_collection = PGVector(
  33. embeddings=self.embedding_function,
  34. collection_name="ddl",
  35. connection=self.connection_string,
  36. )
  37. self.documentation_collection = PGVector(
  38. embeddings=self.embedding_function,
  39. collection_name="documentation",
  40. connection=self.connection_string,
  41. )
  42. def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
  43. question_sql_json = json.dumps(
  44. {
  45. "question": question,
  46. "sql": sql,
  47. },
  48. ensure_ascii=False,
  49. )
  50. id = str(uuid.uuid4()) + "-sql"
  51. createdat = kwargs.get("createdat")
  52. doc = Document(
  53. page_content=question_sql_json,
  54. metadata={"id": id, "createdat": createdat},
  55. )
  56. self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
  57. return id
  58. def add_ddl(self, ddl: str, **kwargs) -> str:
  59. _id = str(uuid.uuid4()) + "-ddl"
  60. doc = Document(
  61. page_content=ddl,
  62. metadata={"id": _id},
  63. )
  64. self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
  65. return _id
  66. def add_documentation(self, documentation: str, **kwargs) -> str:
  67. _id = str(uuid.uuid4()) + "-doc"
  68. doc = Document(
  69. page_content=documentation,
  70. metadata={"id": _id},
  71. )
  72. self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
  73. return _id
  74. def get_collection(self, collection_name):
  75. match collection_name:
  76. case "sql":
  77. return self.sql_collection
  78. case "ddl":
  79. return self.ddl_collection
  80. case "documentation":
  81. return self.documentation_collection
  82. case _:
  83. raise ValueError("Specified collection does not exist.")
  84. # def get_similar_question_sql(self, question: str) -> list:
  85. # documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
  86. # return [ast.literal_eval(document.page_content) for document in documents]
  87. # 在原来的基础之上,增加相似度的值。
  88. def get_similar_question_sql(self, question: str) -> list:
  89. docs_with_scores = self.sql_collection.similarity_search_with_score(
  90. query=question,
  91. k=self.n_results
  92. )
  93. results = []
  94. for doc, score in docs_with_scores:
  95. # 将文档内容转换为 dict
  96. base = ast.literal_eval(doc.page_content)
  97. # 计算相似度
  98. similarity = round(1 - score, 4)
  99. # 每条记录单独打印
  100. print(f"[DEBUG] SQL Match: {base.get('question', '')} | similarity: {similarity}")
  101. # 添加 similarity 字段
  102. base["similarity"] = similarity
  103. results.append(base)
  104. # 应用阈值过滤
  105. filtered_results = self._apply_score_threshold_filter(
  106. results,
  107. "RESULT_VECTOR_SQL_SCORE_THRESHOLD",
  108. "SQL"
  109. )
  110. return filtered_results
  111. def get_related_ddl(self, question: str, **kwargs) -> list:
  112. docs_with_scores = self.ddl_collection.similarity_search_with_score(
  113. query=question,
  114. k=self.n_results
  115. )
  116. results = []
  117. for doc, score in docs_with_scores:
  118. # 计算相似度
  119. similarity = round(1 - score, 4)
  120. # 每条记录单独打印
  121. print(f"[DEBUG] DDL Match: {doc.page_content[:50]}... | similarity: {similarity}")
  122. # 添加 similarity 字段
  123. result = {
  124. "content": doc.page_content,
  125. "similarity": similarity
  126. }
  127. results.append(result)
  128. # 应用阈值过滤
  129. filtered_results = self._apply_score_threshold_filter(
  130. results,
  131. "RESULT_VECTOR_DDL_SCORE_THRESHOLD",
  132. "DDL"
  133. )
  134. return filtered_results
  135. def get_related_documentation(self, question: str, **kwargs) -> list:
  136. docs_with_scores = self.documentation_collection.similarity_search_with_score(
  137. query=question,
  138. k=self.n_results
  139. )
  140. results = []
  141. for doc, score in docs_with_scores:
  142. # 计算相似度
  143. similarity = round(1 - score, 4)
  144. # 每条记录单独打印
  145. print(f"[DEBUG] Doc Match: {doc.page_content[:50]}... | similarity: {similarity}")
  146. # 添加 similarity 字段
  147. result = {
  148. "content": doc.page_content,
  149. "similarity": similarity
  150. }
  151. results.append(result)
  152. # 应用阈值过滤
  153. filtered_results = self._apply_score_threshold_filter(
  154. results,
  155. "RESULT_VECTOR_DOC_SCORE_THRESHOLD",
  156. "DOC"
  157. )
  158. return filtered_results
  159. def _apply_score_threshold_filter(self, results: list, threshold_config_key: str, result_type: str) -> list:
  160. """
  161. 应用相似度阈值过滤逻辑
  162. Args:
  163. results: 原始结果列表,每个元素包含 similarity 字段
  164. threshold_config_key: 配置中的阈值参数名
  165. result_type: 结果类型(用于日志)
  166. Returns:
  167. 过滤后的结果列表
  168. """
  169. if not results:
  170. return results
  171. # 导入配置
  172. try:
  173. import app_config
  174. enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
  175. threshold = getattr(app_config, threshold_config_key, 0.65)
  176. except (ImportError, AttributeError) as e:
  177. print(f"[WARNING] 无法加载阈值配置: {e},使用默认值")
  178. enable_threshold = False
  179. threshold = 0.65
  180. # 如果未启用阈值过滤,直接返回原结果
  181. if not enable_threshold:
  182. print(f"[DEBUG] {result_type} 阈值过滤未启用,返回全部 {len(results)} 条结果")
  183. return results
  184. total_count = len(results)
  185. min_required = max((total_count + 1) // 2, 1)
  186. print(f"[DEBUG] {result_type} 阈值过滤: 总数={total_count}, 阈值={threshold}, 最少保留={min_required}")
  187. # 按相似度降序排序(确保最相似的在前面)
  188. sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
  189. # 找出满足阈值的结果
  190. above_threshold = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
  191. # 应用过滤逻辑
  192. if len(above_threshold) >= min_required:
  193. # 情况1: 满足阈值的结果数量 >= 最少保留数量,返回满足阈值的结果
  194. filtered_results = above_threshold
  195. filtered_count = len(above_threshold)
  196. print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (全部满足阈值)")
  197. else:
  198. # 情况2: 满足阈值的结果数量 < 最少保留数量,强制保留前 min_required 条
  199. filtered_results = sorted_results[:min_required]
  200. above_count = len(above_threshold)
  201. below_count = min_required - above_count
  202. filtered_count = min_required
  203. print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (满足阈值: {above_count}, 强制保留: {below_count})")
  204. # 打印过滤详情
  205. for i, result in enumerate(filtered_results):
  206. similarity = result.get('similarity', 0)
  207. status = "✓" if similarity >= threshold else "✗"
  208. print(f"[DEBUG] {result_type} 保留 {i+1}: similarity={similarity} {status}")
  209. return filtered_results
  210. def train(
  211. self,
  212. question: str | None = None,
  213. sql: str | None = None,
  214. ddl: str | None = None,
  215. documentation: str | None = None,
  216. plan: TrainingPlan | None = None,
  217. createdat: str | None = None,
  218. ):
  219. if question and not sql:
  220. raise ValidationError("Please provide a SQL query.")
  221. if documentation:
  222. logging.info(f"Adding documentation: {documentation}")
  223. return self.add_documentation(documentation)
  224. if sql and question:
  225. return self.add_question_sql(question=question, sql=sql, createdat=createdat)
  226. if ddl:
  227. logging.info(f"Adding ddl: {ddl}")
  228. return self.add_ddl(ddl)
  229. if plan:
  230. for item in plan._plan:
  231. if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
  232. self.add_ddl(item.item_value)
  233. elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
  234. self.add_documentation(item.item_value)
  235. elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
  236. self.add_question_sql(question=item.item_name, sql=item.item_value)
  237. def get_training_data(self, **kwargs) -> pd.DataFrame:
  238. # Establishing the connection
  239. engine = create_engine(self.connection_string)
  240. # Querying the 'langchain_pg_embedding' table
  241. query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
  242. df_embedding = pd.read_sql(query_embedding, engine)
  243. # List to accumulate the processed rows
  244. processed_rows = []
  245. # Process each row in the DataFrame
  246. for _, row in df_embedding.iterrows():
  247. custom_id = row["cmetadata"]["id"]
  248. document = row["document"]
  249. training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
  250. if training_data_type == "sql":
  251. # Convert the document string to a dictionary
  252. try:
  253. doc_dict = ast.literal_eval(document)
  254. question = doc_dict.get("question")
  255. content = doc_dict.get("sql")
  256. except (ValueError, SyntaxError):
  257. logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
  258. continue
  259. elif training_data_type in ["documentation", "ddl"]:
  260. question = None # Default value for question
  261. content = document
  262. else:
  263. # If the suffix is not recognized, skip this row
  264. logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
  265. continue
  266. # Append the processed data to the list
  267. processed_rows.append(
  268. {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
  269. )
  270. # Create a DataFrame from the list of processed rows
  271. df_processed = pd.DataFrame(processed_rows)
  272. return df_processed
  273. def remove_training_data(self, id: str, **kwargs) -> bool:
  274. # Create the database engine
  275. engine = create_engine(self.connection_string)
  276. # SQL DELETE statement
  277. delete_statement = text(
  278. """
  279. DELETE FROM langchain_pg_embedding
  280. WHERE cmetadata ->> 'id' = :id
  281. """
  282. )
  283. # Connect to the database and execute the delete statement
  284. with engine.connect() as connection:
  285. # Start a transaction
  286. with connection.begin() as transaction:
  287. try:
  288. result = connection.execute(delete_statement, {"id": id})
  289. # Commit the transaction if the delete was successful
  290. transaction.commit()
  291. # Check if any row was deleted and return True or False accordingly
  292. return result.rowcount > 0
  293. except Exception as e:
  294. # Rollback the transaction in case of error
  295. logging.error(f"An error occurred: {e}")
  296. transaction.rollback()
  297. return False
  298. def remove_collection(self, collection_name: str) -> bool:
  299. engine = create_engine(self.connection_string)
  300. # Determine the suffix to look for based on the collection name
  301. suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
  302. suffix = suffix_map.get(collection_name)
  303. if not suffix:
  304. logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
  305. return False
  306. # SQL query to delete rows based on the condition
  307. query = text(
  308. f"""
  309. DELETE FROM langchain_pg_embedding
  310. WHERE cmetadata->>'id' LIKE '%{suffix}'
  311. """
  312. )
  313. # Execute the deletion within a transaction block
  314. with engine.connect() as connection:
  315. with connection.begin() as transaction:
  316. try:
  317. result = connection.execute(query)
  318. transaction.commit() # Explicitly commit the transaction
  319. if result.rowcount > 0:
  320. logging.info(
  321. f"Deleted {result.rowcount} rows from "
  322. f"langchain_pg_embedding where collection is {collection_name}."
  323. )
  324. return True
  325. else:
  326. logging.info(f"No rows deleted for collection {collection_name}.")
  327. return False
  328. except Exception as e:
  329. logging.error(f"An error occurred: {e}")
  330. transaction.rollback() # Rollback in case of error
  331. return False
  332. def generate_embedding(self, *args, **kwargs):
  333. pass