pgvector.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import ast
  2. import json
  3. import logging
  4. import uuid
  5. import pandas as pd
  6. from langchain_core.documents import Document
  7. from langchain_postgres.vectorstores import PGVector
  8. from sqlalchemy import create_engine, text
  9. from vanna.exceptions import ValidationError
  10. from vanna.base import VannaBase
  11. from vanna.types import TrainingPlan, TrainingPlanItem
  12. class PG_VectorStore(VannaBase):
  13. def __init__(self, config=None):
  14. if not config or "connection_string" not in config:
  15. raise ValueError(
  16. "A valid 'config' dictionary with a 'connection_string' is required.")
  17. VannaBase.__init__(self, config=config)
  18. if config and "connection_string" in config:
  19. self.connection_string = config.get("connection_string")
  20. self.n_results = config.get("n_results", 10)
  21. if config and "embedding_function" in config:
  22. self.embedding_function = config.get("embedding_function")
  23. else:
  24. raise ValueError("No embedding_function was found.")
  25. # from langchain_huggingface import HuggingFaceEmbeddings
  26. # self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
  27. self.sql_collection = PGVector(
  28. embeddings=self.embedding_function,
  29. collection_name="sql",
  30. connection=self.connection_string,
  31. )
  32. self.ddl_collection = PGVector(
  33. embeddings=self.embedding_function,
  34. collection_name="ddl",
  35. connection=self.connection_string,
  36. )
  37. self.documentation_collection = PGVector(
  38. embeddings=self.embedding_function,
  39. collection_name="documentation",
  40. connection=self.connection_string,
  41. )
  42. def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
  43. question_sql_json = json.dumps(
  44. {
  45. "question": question,
  46. "sql": sql,
  47. },
  48. ensure_ascii=False,
  49. )
  50. id = str(uuid.uuid4()) + "-sql"
  51. createdat = kwargs.get("createdat")
  52. doc = Document(
  53. page_content=question_sql_json,
  54. metadata={"id": id, "createdat": createdat},
  55. )
  56. self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
  57. return id
  58. def add_ddl(self, ddl: str, **kwargs) -> str:
  59. _id = str(uuid.uuid4()) + "-ddl"
  60. doc = Document(
  61. page_content=ddl,
  62. metadata={"id": _id},
  63. )
  64. self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
  65. return _id
  66. def add_documentation(self, documentation: str, **kwargs) -> str:
  67. _id = str(uuid.uuid4()) + "-doc"
  68. doc = Document(
  69. page_content=documentation,
  70. metadata={"id": _id},
  71. )
  72. self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
  73. return _id
  74. def get_collection(self, collection_name):
  75. match collection_name:
  76. case "sql":
  77. return self.sql_collection
  78. case "ddl":
  79. return self.ddl_collection
  80. case "documentation":
  81. return self.documentation_collection
  82. case _:
  83. raise ValueError("Specified collection does not exist.")
  84. def get_similar_question_sql(self, question: str) -> list:
  85. documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
  86. return [ast.literal_eval(document.page_content) for document in documents]
  87. def get_related_ddl(self, question: str, **kwargs) -> list:
  88. documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
  89. return [document.page_content for document in documents]
  90. def get_related_documentation(self, question: str, **kwargs) -> list:
  91. documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
  92. return [document.page_content for document in documents]
  93. def train(
  94. self,
  95. question: str | None = None,
  96. sql: str | None = None,
  97. ddl: str | None = None,
  98. documentation: str | None = None,
  99. plan: TrainingPlan | None = None,
  100. createdat: str | None = None,
  101. ):
  102. if question and not sql:
  103. raise ValidationError("Please provide a SQL query.")
  104. if documentation:
  105. logging.info(f"Adding documentation: {documentation}")
  106. return self.add_documentation(documentation)
  107. if sql and question:
  108. return self.add_question_sql(question=question, sql=sql, createdat=createdat)
  109. if ddl:
  110. logging.info(f"Adding ddl: {ddl}")
  111. return self.add_ddl(ddl)
  112. if plan:
  113. for item in plan._plan:
  114. if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
  115. self.add_ddl(item.item_value)
  116. elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
  117. self.add_documentation(item.item_value)
  118. elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
  119. self.add_question_sql(question=item.item_name, sql=item.item_value)
  120. def get_training_data(self, **kwargs) -> pd.DataFrame:
  121. # Establishing the connection
  122. engine = create_engine(self.connection_string)
  123. # Querying the 'langchain_pg_embedding' table
  124. query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
  125. df_embedding = pd.read_sql(query_embedding, engine)
  126. # List to accumulate the processed rows
  127. processed_rows = []
  128. # Process each row in the DataFrame
  129. for _, row in df_embedding.iterrows():
  130. custom_id = row["cmetadata"]["id"]
  131. document = row["document"]
  132. training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
  133. if training_data_type == "sql":
  134. # Convert the document string to a dictionary
  135. try:
  136. doc_dict = ast.literal_eval(document)
  137. question = doc_dict.get("question")
  138. content = doc_dict.get("sql")
  139. except (ValueError, SyntaxError):
  140. logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
  141. continue
  142. elif training_data_type in ["documentation", "ddl"]:
  143. question = None # Default value for question
  144. content = document
  145. else:
  146. # If the suffix is not recognized, skip this row
  147. logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
  148. continue
  149. # Append the processed data to the list
  150. processed_rows.append(
  151. {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
  152. )
  153. # Create a DataFrame from the list of processed rows
  154. df_processed = pd.DataFrame(processed_rows)
  155. return df_processed
  156. def remove_training_data(self, id: str, **kwargs) -> bool:
  157. # Create the database engine
  158. engine = create_engine(self.connection_string)
  159. # SQL DELETE statement
  160. delete_statement = text(
  161. """
  162. DELETE FROM langchain_pg_embedding
  163. WHERE cmetadata ->> 'id' = :id
  164. """
  165. )
  166. # Connect to the database and execute the delete statement
  167. with engine.connect() as connection:
  168. # Start a transaction
  169. with connection.begin() as transaction:
  170. try:
  171. result = connection.execute(delete_statement, {"id": id})
  172. # Commit the transaction if the delete was successful
  173. transaction.commit()
  174. # Check if any row was deleted and return True or False accordingly
  175. return result.rowcount > 0
  176. except Exception as e:
  177. # Rollback the transaction in case of error
  178. logging.error(f"An error occurred: {e}")
  179. transaction.rollback()
  180. return False
  181. def remove_collection(self, collection_name: str) -> bool:
  182. engine = create_engine(self.connection_string)
  183. # Determine the suffix to look for based on the collection name
  184. suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
  185. suffix = suffix_map.get(collection_name)
  186. if not suffix:
  187. logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
  188. return False
  189. # SQL query to delete rows based on the condition
  190. query = text(
  191. f"""
  192. DELETE FROM langchain_pg_embedding
  193. WHERE cmetadata->>'id' LIKE '%{suffix}'
  194. """
  195. )
  196. # Execute the deletion within a transaction block
  197. with engine.connect() as connection:
  198. with connection.begin() as transaction:
  199. try:
  200. result = connection.execute(query)
  201. transaction.commit() # Explicitly commit the transaction
  202. if result.rowcount > 0:
  203. logging.info(
  204. f"Deleted {result.rowcount} rows from "
  205. f"langchain_pg_embedding where collection is {collection_name}."
  206. )
  207. return True
  208. else:
  209. logging.info(f"No rows deleted for collection {collection_name}.")
  210. return False
  211. except Exception as e:
  212. logging.error(f"An error occurred: {e}")
  213. transaction.rollback() # Rollback in case of error
  214. return False
  215. def generate_embedding(self, *args, **kwargs):
  216. pass