Просмотр исходного кода

增加对pgvector的支持,为训练数据添加数据目录的配置.

wangxq 4 недель назад
Родитель
Сommit
a7a36b0e9a

+ 31 - 5
app_config.py

@@ -5,7 +5,9 @@ import os
 load_dotenv()
 
 # 使用的模型类型("qwen" 或 "deepseek")
-MODEL_TYPE = "qwen"
+LLM_MODEL_NAME = "qwen"
+# 向量数据库类型, chromadb 或 pgvector
+VECTOR_DB_NAME = "pgvector"
 
 # DeepSeek模型配置
 DEEPSEEK_CONFIG = {
@@ -54,7 +56,31 @@ APP_DB_CONFIG = {
 # ChromaDB配置
 # CHROMADB_PATH = "."  
 
-# 批处理配置
-BATCH_PROCESSING_ENABLED = True
-BATCH_SIZE = 10
-MAX_WORKERS = 4
+# PgVector数据库连接配置 (向量数据库,独立于业务数据库)
+PGVECTOR_CONFIG = {
+    "host": "192.168.67.1",
+    "port": 5432,
+    "dbname": "pgvector_db",
+    "user": os.getenv("PGVECTOR_DB_USER"),
+    "password": os.getenv("PGVECTOR_DB_PASSWORD")
+}
+
+# 训练脚本批处理配置
+# 这些配置仅用于 training/run_training.py 训练脚本的批处理优化
+TRAINING_BATCH_PROCESSING_ENABLED = True    # 是否启用训练数据批处理
+TRAINING_BATCH_SIZE = 10                    # 每批处理的训练项目数量
+TRAINING_MAX_WORKERS = 4                    # 训练批处理的最大工作线程数
+
+# 训练数据路径配置
+# 支持以下格式:
+# 1. 相对路径(以 . 开头):
+#    "./training/data"     - 项目根目录下的training/data
+#    "../data"             - 项目根目录上级的data目录
+# 2. 绝对路径:
+#    "/home/user/data"     - Linux绝对路径
+#    "C:/data"             - Windows绝对路径
+#    "D:\\training\\data"  - Windows绝对路径(转义反斜杠)
+# 3. 相对路径(不以.开头):
+#    "training/data"       - 相对于项目根目录
+#    "my_data"             - 项目根目录下的my_data文件夹
+TRAINING_DATA_PATH = "./training/data"

+ 1 - 0
custompgvector/__init__.py

@@ -0,0 +1 @@
+from .pgvector import PG_VectorStore

+ 254 - 0
custompgvector/pgvector.py

@@ -0,0 +1,254 @@
+import ast
+import json
+import logging
+import uuid
+
+import pandas as pd
+from langchain_core.documents import Document
+from langchain_postgres.vectorstores import PGVector
+from sqlalchemy import create_engine, text
+
+from vanna.exceptions import ValidationError
+from vanna.base import VannaBase
+from vanna.types import TrainingPlan, TrainingPlanItem
+
+
+class PG_VectorStore(VannaBase):
+    def __init__(self, config=None):
+        if not config or "connection_string" not in config:
+            raise ValueError(
+                "A valid 'config' dictionary with a 'connection_string' is required.")
+
+        VannaBase.__init__(self, config=config)
+
+        if config and "connection_string" in config:
+            self.connection_string = config.get("connection_string")
+            self.n_results = config.get("n_results", 10)
+
+        if config and "embedding_function" in config:
+            self.embedding_function = config.get("embedding_function")
+        else:
+            raise ValueError("No embedding_function was found.")
+            # from langchain_huggingface import HuggingFaceEmbeddings
+            # self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
+
+        self.sql_collection = PGVector(
+            embeddings=self.embedding_function,
+            collection_name="sql",
+            connection=self.connection_string,
+        )
+        self.ddl_collection = PGVector(
+            embeddings=self.embedding_function,
+            collection_name="ddl",
+            connection=self.connection_string,
+        )
+        self.documentation_collection = PGVector(
+            embeddings=self.embedding_function,
+            collection_name="documentation",
+            connection=self.connection_string,
+        )
+
+    def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
+        question_sql_json = json.dumps(
+            {
+                "question": question,
+                "sql": sql,
+            },
+            ensure_ascii=False,
+        )
+        id = str(uuid.uuid4()) + "-sql"
+        createdat = kwargs.get("createdat")
+        doc = Document(
+            page_content=question_sql_json,
+            metadata={"id": id, "createdat": createdat},
+        )
+        self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
+
+        return id
+
+    def add_ddl(self, ddl: str, **kwargs) -> str:
+        _id = str(uuid.uuid4()) + "-ddl"
+        doc = Document(
+            page_content=ddl,
+            metadata={"id": _id},
+        )
+        self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
+        return _id
+
+    def add_documentation(self, documentation: str, **kwargs) -> str:
+        _id = str(uuid.uuid4()) + "-doc"
+        doc = Document(
+            page_content=documentation,
+            metadata={"id": _id},
+        )
+        self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
+        return _id
+
+    def get_collection(self, collection_name):
+        match collection_name:
+            case "sql":
+                return self.sql_collection
+            case "ddl":
+                return self.ddl_collection
+            case "documentation":
+                return self.documentation_collection
+            case _:
+                raise ValueError("Specified collection does not exist.")
+
+    def get_similar_question_sql(self, question: str) -> list:
+        documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
+        return [ast.literal_eval(document.page_content) for document in documents]
+
+    def get_related_ddl(self, question: str, **kwargs) -> list:
+        documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
+        return [document.page_content for document in documents]
+
+    def get_related_documentation(self, question: str, **kwargs) -> list:
+        documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
+        return [document.page_content for document in documents]
+
+    def train(
+        self,
+        question: str | None = None,
+        sql: str | None = None,
+        ddl: str | None = None,
+        documentation: str | None = None,
+        plan: TrainingPlan | None = None,
+        createdat: str | None = None,
+    ):
+        if question and not sql:
+            raise ValidationError("Please provide a SQL query.")
+
+        if documentation:
+            logging.info(f"Adding documentation: {documentation}")
+            return self.add_documentation(documentation)
+
+        if sql and question:
+            return self.add_question_sql(question=question, sql=sql, createdat=createdat)
+
+        if ddl:
+            logging.info(f"Adding ddl: {ddl}")
+            return self.add_ddl(ddl)
+
+        if plan:
+            for item in plan._plan:
+                if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
+                    self.add_ddl(item.item_value)
+                elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
+                    self.add_documentation(item.item_value)
+                elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
+                    self.add_question_sql(question=item.item_name, sql=item.item_value)
+
+    def get_training_data(self, **kwargs) -> pd.DataFrame:
+        # Establishing the connection
+        engine = create_engine(self.connection_string)
+
+        # Querying the 'langchain_pg_embedding' table
+        query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
+        df_embedding = pd.read_sql(query_embedding, engine)
+
+        # List to accumulate the processed rows
+        processed_rows = []
+
+        # Process each row in the DataFrame
+        for _, row in df_embedding.iterrows():
+            custom_id = row["cmetadata"]["id"]
+            document = row["document"]
+            training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
+
+            if training_data_type == "sql":
+                # Convert the document string to a dictionary
+                try:
+                    doc_dict = ast.literal_eval(document)
+                    question = doc_dict.get("question")
+                    content = doc_dict.get("sql")
+                except (ValueError, SyntaxError):
+                    logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
+                    continue
+            elif training_data_type in ["documentation", "ddl"]:
+                question = None  # Default value for question
+                content = document
+            else:
+                # If the suffix is not recognized, skip this row
+                logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
+                continue
+
+            # Append the processed data to the list
+            processed_rows.append(
+                {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
+            )
+
+        # Create a DataFrame from the list of processed rows
+        df_processed = pd.DataFrame(processed_rows)
+
+        return df_processed
+
+    def remove_training_data(self, id: str, **kwargs) -> bool:
+        # Create the database engine
+        engine = create_engine(self.connection_string)
+
+        # SQL DELETE statement
+        delete_statement = text(
+            """
+            DELETE FROM langchain_pg_embedding
+            WHERE cmetadata ->> 'id' = :id
+        """
+        )
+
+        # Connect to the database and execute the delete statement
+        with engine.connect() as connection:
+            # Start a transaction
+            with connection.begin() as transaction:
+                try:
+                    result = connection.execute(delete_statement, {"id": id})
+                    # Commit the transaction if the delete was successful
+                    transaction.commit()
+                    # Check if any row was deleted and return True or False accordingly
+                    return result.rowcount > 0
+                except Exception as e:
+                    # Rollback the transaction in case of error
+                    logging.error(f"An error occurred: {e}")
+                    transaction.rollback()
+                    return False
+
+    def remove_collection(self, collection_name: str) -> bool:
+        engine = create_engine(self.connection_string)
+
+        # Determine the suffix to look for based on the collection name
+        suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
+        suffix = suffix_map.get(collection_name)
+
+        if not suffix:
+            logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
+            return False
+
+        # SQL query to delete rows based on the condition
+        query = text(
+            f"""
+            DELETE FROM langchain_pg_embedding
+            WHERE cmetadata->>'id' LIKE '%{suffix}'
+        """
+        )
+
+        # Execute the deletion within a transaction block
+        with engine.connect() as connection:
+            with connection.begin() as transaction:
+                try:
+                    result = connection.execute(query)
+                    transaction.commit()  # Explicitly commit the transaction
+                    if result.rowcount > 0:
+                        logging.info(
+                            f"Deleted {result.rowcount} rows from "
+                            f"langchain_pg_embedding where collection is {collection_name}."
+                        )
+                        return True
+                    else:
+                        logging.info(f"No rows deleted for collection {collection_name}.")
+                        return False
+                except Exception as e:
+                    logging.error(f"An error occurred: {e}")
+                    transaction.rollback()  # Rollback in case of error
+                    return False
+
+    def generate_embedding(self, *args, **kwargs):
+        pass

+ 53 - 0
docs/pgvector.md

@@ -0,0 +1,53 @@
+## 使用PgVector作为向量数据库
+
+### 1.下面是langchain自动创建的表结构,可以参考这个结构来创建自己的表结构。
+
+```sql
+-- 创建向量表
+create table public.langchain_pg_embedding
+(
+    id            varchar not null
+        primary key,
+    collection_id uuid
+        references public.langchain_pg_collection
+            on delete cascade,
+    embedding     vector,
+    document      varchar,
+    cmetadata     jsonb
+);
+
+alter table public.langchain_pg_embedding
+    owner to postgres;
+
+create index ix_cmetadata_gin
+    on public.langchain_pg_embedding using gin (cmetadata jsonb_path_ops);
+
+-- 创建集合表
+create table public.langchain_pg_collection
+(
+    uuid      uuid    not null
+        primary key,
+    name      varchar not null
+        unique,
+    cmetadata json
+);
+
+alter table public.langchain_pg_collection
+    owner to postgres;
+
+
+```
+
+### 2. 为了便于测试,我会删除向量表的外键。
+
+```sql  
+alter table public.langchain_pg_embedding
+    drop constraint langchain_pg_embedding_collection_id_fkey;
+```
+
+
+
+
+
+
+

+ 147 - 0
docs/run_training说明.md

@@ -0,0 +1,147 @@
+## 文件扩展名与处理函数对应关系
+
+### 文件处理优先级和判断逻辑
+代码中的文件类型判断按以下顺序进行:
+
+1. **`.ddl`** → DDL文件
+2. **`.md` 或 `.markdown`** → 文档文件  
+3. **`_pair.json` 或 `_pairs.json`** → JSON问答对文件
+4. **`_pair.sql` 或 `_pairs.sql`** → 格式化问答对文件
+5. **`.sql` (但不以 `_pair.sql` 或 `_pairs.sql` 结尾)** → SQL示例文件
+6. **其他** → 跳过处理
+
+### 1. **DDL文件** (`.ddl`)
+- **处理函数**: `train_ddl_statements()`
+- **调用的训练函数**: `train_ddl()`
+- **文件格式**: 
+  - 使用分号 (`;`) 作为分隔符
+  - 每个DDL语句之间用分号分隔
+  - 示例格式:
+    ```sql
+    CREATE TABLE users (
+        id INT PRIMARY KEY,
+        name VARCHAR(100)
+    );
+    CREATE TABLE orders (
+        id INT PRIMARY KEY,
+        user_id INT REFERENCES users(id)
+    );
+    ```
+
+### 2. **文档文件** (`.md`, `.markdown`)
+- **处理函数**: `train_documentation_blocks()`
+- **调用的训练函数**: `train_documentation()`
+- **文件格式**:
+  - **Markdown文件**: 按标题级别自动分割 (`#`, `##`, `###`)
+  - **非Markdown文件**: 使用 `---` 作为分隔符
+  - 示例格式:
+    ```markdown
+    # 用户表说明
+    用户表存储系统中所有用户的基本信息...
+    
+    ## 字段说明
+    - id: 用户唯一标识符
+    - name: 用户姓名
+    
+    ### 注意事项
+    用户名不能重复...
+    ```
+
+### 3. **SQL示例文件** (`.sql`, 但排除 `_pair.sql` 和 `_pairs.sql`)
+- **处理函数**: `train_sql_examples()`
+- **调用的训练函数**: `train_sql_example()`
+- **文件格式**:
+  - 使用分号 (`;`) 作为分隔符
+  - 每个SQL示例之间用分号分隔
+  - 示例格式:
+    ```sql
+    SELECT * FROM users WHERE age > 18;
+    SELECT COUNT(*) FROM orders WHERE status = 'completed';
+    SELECT u.name, COUNT(o.id) FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id;
+    ```
+
+### 4. **格式化问答对文件** (`_pair.sql`, `_pairs.sql`)
+- **处理函数**: `train_formatted_question_sql_pairs()`
+- **调用的训练函数**: `train_question_sql_pair()`
+- **文件格式**:
+  - 使用 `Question:` 和 `SQL:` 标记
+  - 问答对之间用双空行分隔
+  - 支持单行和多行SQL
+  - 示例格式:
+    ```
+    Question: 查询所有成年用户
+    SQL: SELECT * FROM users WHERE age >= 18;
+
+    Question: 统计每个用户的订单数量
+    SQL: 
+    SELECT u.name, COUNT(o.id) as order_count
+    FROM users u 
+    LEFT JOIN orders o ON u.id = o.user_id 
+    GROUP BY u.id, u.name;
+    ```
+
+### 5. **JSON格式问答对文件** (`_pair.json`, `_pairs.json`)
+- **处理函数**: `train_json_question_sql_pairs()`
+- **调用的训练函数**: `train_question_sql_pair()`
+- **文件格式**:
+  - 标准JSON数组格式
+  - 每个对象包含 `question` 和 `sql` 字段
+  - 示例格式:
+    ```json
+    [
+        {
+            "question": "查询所有成年用户",
+            "sql": "SELECT * FROM users WHERE age >= 18"
+        },
+        {
+            "question": "统计每个用户的订单数量",
+            "sql": "SELECT u.name, COUNT(o.id) as order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id, u.name"
+        }
+    ]
+    ```
+
+### 6. **传统问答对文件** (其他格式,通过 `train_question_sql_pairs()` 处理)
+- **处理函数**: `train_question_sql_pairs()`
+- **调用的训练函数**: `train_question_sql_pair()`
+- **文件格式**:
+  - 每行一个问答对
+  - 使用 `::` 分隔问题和SQL
+  - 示例格式:
+    ```
+    查询所有成年用户::SELECT * FROM users WHERE age >= 18
+    统计订单总数::SELECT COUNT(*) FROM orders
+    ```
+
+
+
+## 统计信息
+
+训练完成后会显示以下统计:
+- DDL文件数量
+- 文档文件数量  
+- SQL示例文件数量
+- 格式化问答对文件数量
+- JSON问答对文件数量
+
+这个设计使得训练系统能够灵活处理多种不同格式的训练数据,满足不同场景下的数据准备需求。
+
+
+# 训练脚本批处理配置
+# 这些配置仅用于 training/run_training.py 训练脚本的批处理优化
+# 批处理可以提高训练效率,但会增加内存使用和复杂度
+# 
+# TRAINING_BATCH_PROCESSING_ENABLED: 
+#   - True: 启用批处理,将多个训练项目打包一起处理
+#   - False: 逐项处理,每个训练项目单独处理(更稳定但较慢)
+# 
+# TRAINING_BATCH_SIZE: 每批处理的训练项目数量
+#   - 较大值: 处理更快但占用更多内存
+#   - 较小值: 内存占用少但处理较慢
+#   - 建议范围: 5-20
+# 
+# TRAINING_MAX_WORKERS: 训练批处理的最大工作线程数
+#   - 建议设置为CPU核心数的1-2倍
+#   - 过多线程可能导致资源竞争
+TRAINING_BATCH_PROCESSING_ENABLED = True    # 是否启用训练数据批处理
+TRAINING_BATCH_SIZE = 10                    # 每批处理的训练项目数量
+TRAINING_MAX_WORKERS = 4                    # 训练批处理的最大工作线程数

+ 99 - 0
docs/training_path_examples.md

@@ -0,0 +1,99 @@
+# 训练数据路径配置示例
+
+在 `app_config.py` 中,您可以通过修改 `TRAINING_DATA_PATH` 来配置训练数据的路径。
+
+## 配置方式
+
+### 1. 相对路径(以 . 开头)
+```python
+# 项目根目录下的training/data文件夹
+TRAINING_DATA_PATH = "./training/data"
+
+# 项目根目录下的my_training_data文件夹
+TRAINING_DATA_PATH = "./my_training_data"
+
+# 项目根目录上级的data文件夹
+TRAINING_DATA_PATH = "../data"
+
+# 项目根目录上级的training_files文件夹
+TRAINING_DATA_PATH = "../training_files"
+```
+
+### 2. 绝对路径
+
+#### Linux/Mac 系统
+```python
+# Linux绝对路径
+TRAINING_DATA_PATH = "/home/username/training_data"
+
+# Mac绝对路径
+TRAINING_DATA_PATH = "/Users/username/Documents/training_data"
+```
+
+#### Windows 系统
+```python
+# Windows绝对路径(使用正斜杠)
+TRAINING_DATA_PATH = "C:/training_data"
+TRAINING_DATA_PATH = "D:/Projects/my_training_data"
+
+# Windows绝对路径(使用反斜杠,需要转义)
+TRAINING_DATA_PATH = "C:\\training_data"
+TRAINING_DATA_PATH = "D:\\Projects\\my_training_data"
+```
+
+### 3. 相对路径(不以 . 开头)
+```python
+# 相对于项目根目录
+TRAINING_DATA_PATH = "training/data"      # 等同于 "./training/data"
+TRAINING_DATA_PATH = "my_data"            # 等同于 "./my_data"
+TRAINING_DATA_PATH = "data/training"      # 等同于 "./data/training"
+```
+
+## 使用示例
+
+### 默认配置
+```python
+# 使用项目默认的训练数据目录
+TRAINING_DATA_PATH = "./training/data"
+```
+
+### 自定义本地目录
+```python
+# 使用项目根目录下的自定义文件夹
+TRAINING_DATA_PATH = "./my_training_files"
+```
+
+### 外部目录
+```python
+# Linux/Mac
+TRAINING_DATA_PATH = "/home/user/Documents/sql_training_data"
+
+# Windows
+TRAINING_DATA_PATH = "D:/SQL_Training_Data"
+```
+
+## 命令行覆盖
+
+即使在配置文件中设置了路径,您仍然可以通过命令行参数临时覆盖:
+
+```bash
+# 使用配置文件中的路径
+python training/run_training.py
+
+# 临时使用其他路径
+python training/run_training.py --data_path "./custom_data"
+python training/run_training.py --data_path "/absolute/path/to/data"
+python training/run_training.py --data_path "C:/Windows/Path/To/Data"
+```
+
+## 路径验证
+
+运行训练脚本时,会显示路径解析结果:
+```
+===== 训练数据路径配置 =====
+配置文件中的路径: ./training/data
+解析后的绝对路径: /full/path/to/project/training/data
+==============================
+```
+
+这样您可以确认路径是否正确解析。 

+ 23 - 0
embedding_function.py

@@ -81,6 +81,29 @@ class EmbeddingFunction:
                 
         return embeddings
     
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """
+        为文档列表生成嵌入向量 (LangChain 接口)
+        
+        Args:
+            texts: 要嵌入的文档列表
+            
+        Returns:
+            List[List[float]]: 嵌入向量列表
+        """
+        return self.__call__(texts)
+    
+    def embed_query(self, text: str) -> List[float]:
+        """
+        为查询文本生成嵌入向量 (LangChain 接口)
+        
+        Args:
+            text: 要嵌入的查询文本
+            
+        Returns:
+            List[float]: 嵌入向量
+        """
+        return self.generate_embedding(text)
     
     def generate_embedding(self, text: str) -> List[float]:
         """

+ 2 - 0
requirements.txt

@@ -1,3 +1,5 @@
 vanna[chromadb,openai,postgres]==0.7.9
 flask==3.1.1
 plotly==5.22.0
+langchain-core==0.3.64
+langchain-postgres==0.0.14

+ 153 - 33
training/run_training.py

@@ -393,6 +393,67 @@ def process_training_files(data_path):
         
     return True
 
+def check_pgvector_connection():
+    """检查 PgVector 数据库连接是否可用
+    
+    Returns:
+        bool: 连接成功返回True,否则返回False
+    """
+    import app_config
+    from sqlalchemy import create_engine, text
+    
+    try:
+        # 构建连接字符串
+        pg_config = app_config.PGVECTOR_CONFIG
+        connection_string = f"postgresql://{pg_config['user']}:{pg_config['password']}@{pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}"
+        
+        print(f"正在测试 PgVector 数据库连接...")
+        print(f"连接地址: {pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}")
+        
+        # 创建数据库引擎并测试连接
+        engine = create_engine(connection_string)
+        
+        with engine.connect() as connection:
+            # 测试基本连接
+            result = connection.execute(text("SELECT 1"))
+            result.fetchone()
+            
+            # 检查是否安装了 pgvector 扩展
+            try:
+                result = connection.execute(text("SELECT extname FROM pg_extension WHERE extname = 'vector'"))
+                extension_exists = result.fetchone() is not None
+                
+                if extension_exists:
+                    print("✓ PgVector 扩展已安装")
+                else:
+                    print("⚠ 警告: PgVector 扩展未安装,请确保已安装 pgvector 扩展")
+                    
+            except Exception as ext_e:
+                print(f"⚠ 无法检查 pgvector 扩展状态: {ext_e}")
+            
+            # 检查训练数据表是否存在
+            try:
+                result = connection.execute(text("SELECT tablename FROM pg_tables WHERE tablename = 'langchain_pg_embedding'"))
+                table_exists = result.fetchone() is not None
+                
+                if table_exists:
+                    # 获取表中的记录数
+                    result = connection.execute(text("SELECT COUNT(*) FROM langchain_pg_embedding"))
+                    count = result.fetchone()[0]
+                    print(f"✓ 训练数据表存在,当前包含 {count} 条记录")
+                else:
+                    print("ℹ 训练数据表尚未创建(首次训练时会自动创建)")
+                    
+            except Exception as table_e:
+                print(f"⚠ 无法检查训练数据表状态: {table_e}")
+        
+        print("✓ PgVector 数据库连接测试成功")
+        return True
+        
+    except Exception as e:
+        print(f"✗ PgVector 数据库连接失败: {e}")
+        return False
+
 def main():
     """主函数:配置和运行训练流程"""
     
@@ -402,43 +463,90 @@ def main():
     
     # 解析命令行参数
     parser = argparse.ArgumentParser(description='训练Vanna NL2SQL模型')
-    parser.add_argument('--data_path', type=str, default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data'),
-                        help='训练数据目录路径 (默认: training/data)')
+    
+    # 获取默认路径并进行智能处理
+    def resolve_training_data_path():
+        """智能解析训练数据路径"""
+        config_path = getattr(app_config, 'TRAINING_DATA_PATH', './training/data')
+        
+        # 如果是绝对路径,直接返回
+        if os.path.isabs(config_path):
+            return config_path
+        
+        # 如果以 . 开头,相对于项目根目录解析
+        if config_path.startswith('./') or config_path.startswith('../'):
+            project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+            return os.path.join(project_root, config_path)
+        
+        # 其他情况,相对于项目根目录
+        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        return os.path.join(project_root, config_path)
+    
+    default_path = resolve_training_data_path()
+    
+    parser.add_argument('--data_path', type=str, default=default_path,
+                        help='训练数据目录路径 (默认: 从app_config.TRAINING_DATA_PATH)')
     args = parser.parse_args()
     
     # 使用Path对象处理路径以确保跨平台兼容性
     data_path = Path(args.data_path)
     
+    # 显示路径解析结果
+    print(f"\n===== 训练数据路径配置 =====")
+    print(f"配置文件中的路径: {getattr(app_config, 'TRAINING_DATA_PATH', '未配置')}")
+    print(f"解析后的绝对路径: {os.path.abspath(data_path)}")
+    print("==============================")
+    
     # 设置正确的项目根目录路径
     project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 
     # 检查嵌入模型连接
     check_embedding_model_connection()
     
-    # 打印ChromaDB相关信息
-    try:
+    # 根据配置的向量数据库类型显示相应信息
+    vector_db_type = app_config.VECTOR_DB_NAME.lower()
+    
+    if vector_db_type == "chromadb":
+        # 打印ChromaDB相关信息
         try:
-            import chromadb
-            chroma_version = chromadb.__version__
-        except ImportError:
-            chroma_version = "未知"
-        
-        # 尝试查看当前使用的ChromaDB文件
-        chroma_file = "chroma.sqlite3"  # 默认文件名
-        
-        # 使用项目根目录作为ChromaDB文件路径
-        db_file_path = os.path.join(project_root, chroma_file)
+            try:
+                import chromadb
+                chroma_version = chromadb.__version__
+            except ImportError:
+                chroma_version = "未知"
+            
+            # 尝试查看当前使用的ChromaDB文件
+            chroma_file = "chroma.sqlite3"  # 默认文件名
+            
+            # 使用项目根目录作为ChromaDB文件路径
+            db_file_path = os.path.join(project_root, chroma_file)
 
-        if os.path.exists(db_file_path):
-            file_size = os.path.getsize(db_file_path) / 1024  # KB
-            print(f"\n===== ChromaDB数据库: {os.path.abspath(db_file_path)} (大小: {file_size:.2f} KB) =====")
-        else:
-            print(f"\n===== 未找到ChromaDB数据库文件于: {os.path.abspath(db_file_path)} =====")
+            if os.path.exists(db_file_path):
+                file_size = os.path.getsize(db_file_path) / 1024  # KB
+                print(f"\n===== ChromaDB数据库: {os.path.abspath(db_file_path)} (大小: {file_size:.2f} KB) =====")
+            else:
+                print(f"\n===== 未找到ChromaDB数据库文件于: {os.path.abspath(db_file_path)} =====")
+                
+            # 打印ChromaDB版本
+            print(f"===== ChromaDB客户端库版本: {chroma_version} =====\n")
+        except Exception as e:
+            print(f"\n===== 无法获取ChromaDB信息: {e} =====\n")
             
-        # 打印ChromaDB版本
-        print(f"===== ChromaDB客户端库版本: {chroma_version} =====\n")
-    except Exception as e:
-        print(f"\n===== 无法获取ChromaDB信息: {e} =====\n")
+    elif vector_db_type == "pgvector":
+        # 打印PgVector相关信息并测试连接
+        print(f"\n===== PgVector数据库配置 =====")
+        pg_config = app_config.PGVECTOR_CONFIG
+        print(f"数据库地址: {pg_config['host']}:{pg_config['port']}")
+        print(f"数据库名称: {pg_config['dbname']}")
+        print(f"用户名: {pg_config['user']}")
+        print("==============================\n")
+        
+        # 测试PgVector连接
+        if not check_pgvector_connection():
+            print("PgVector 数据库连接失败,训练过程终止。")
+            sys.exit(1)
+    else:
+        print(f"\n===== 未知的向量数据库类型: {vector_db_type} =====\n")
     
     # 处理训练文件
     process_successful = process_training_files(data_path)
@@ -455,20 +563,26 @@ def main():
         vn = create_vanna_instance()
         
         # 根据向量数据库类型执行不同的验证逻辑
-        # 由于已确定只使用ChromaDB,简化这部分逻辑
         try:
             training_data = vn.get_training_data()
             if training_data is not None and not training_data.empty:
-                # get_training_data 内部通常会打印数量,这里可以补充一个总结
-                print(f"已从ChromaDB中检索到 {len(training_data)} 条训练数据进行验证。")
+                print(f"✓ 已从{vector_db_type.upper()}中检索到 {len(training_data)} 条训练数据进行验证。")
+                
+                # 显示训练数据类型统计
+                if 'training_data_type' in training_data.columns:
+                    type_counts = training_data['training_data_type'].value_counts()
+                    print("训练数据类型统计:")
+                    for data_type, count in type_counts.items():
+                        print(f"  {data_type}: {count} 条")
+                        
             elif training_data is not None and training_data.empty:
-                 print("在ChromaDB中未找到任何训练数据。")
+                print(f"⚠ 在{vector_db_type.upper()}中未找到任何训练数据。")
             else: # training_data is None
-                print("无法从Vanna获取训练数据 (可能返回了None)。请检查连接和Vanna实现。")
+                print(f"无法从Vanna获取训练数据 (可能返回了None)。请检查{vector_db_type.upper()}连接和Vanna实现。")
 
         except Exception as e:
-            print(f"验证训练数据失败: {e}")
-            print("请检查ChromaDB连接和表结构。")
+            print(f"验证训练数据失败: {e}")
+            print(f"请检查{vector_db_type.upper()}连接和表结构。")
     else:
         print("\n===== 未能找到或处理任何训练文件,训练过程终止 =====")
     
@@ -477,9 +591,15 @@ def main():
     print(f"模型名称: {app_config.EMBEDDING_CONFIG.get('model_name')}")
     print(f"向量维度: {app_config.EMBEDDING_CONFIG.get('embedding_dimension')}")
     print(f"API服务: {app_config.EMBEDDING_CONFIG.get('base_url')}")
-    # 打印ChromaDB路径信息
-    chroma_display_path = os.path.abspath(project_root)
-    print(f"向量数据库: ChromaDB ({chroma_display_path})")
+    
+    # 根据配置显示向量数据库信息
+    if vector_db_type == "chromadb":
+        chroma_display_path = os.path.abspath(project_root)
+        print(f"向量数据库: ChromaDB ({chroma_display_path})")
+    elif vector_db_type == "pgvector":
+        pg_config = app_config.PGVECTOR_CONFIG
+        print(f"向量数据库: PgVector ({pg_config['host']}:{pg_config['port']}/{pg_config['dbname']})")
+    
     print("===== 训练流程完成 =====\n")
 
 if __name__ == "__main__":

+ 11 - 9
training/vanna_trainer.py

@@ -31,13 +31,14 @@ if hasattr(app_config, 'EMBEDDING_CONFIG'):
         print(f"API服务: {app_config.EMBEDDING_CONFIG['base_url']}")
 print("==============================")
 
-# 从app_config获取其他配置
-BATCH_PROCESSING_ENABLED = app_config.BATCH_PROCESSING_ENABLED
-BATCH_SIZE = app_config.BATCH_SIZE
-MAX_WORKERS = app_config.MAX_WORKERS
+# 从app_config获取训练批处理配置
+BATCH_PROCESSING_ENABLED = app_config.TRAINING_BATCH_PROCESSING_ENABLED
+BATCH_SIZE = app_config.TRAINING_BATCH_SIZE
+MAX_WORKERS = app_config.TRAINING_MAX_WORKERS
 
 
-# 数据批处理器
+# 训练数据批处理器
+# 专门用于优化训练过程的批处理器,将多个训练项目打包处理以提高效率
 class BatchProcessor:
     def __init__(self, batch_size=BATCH_SIZE, max_workers=MAX_WORKERS):
         self.batch_size = batch_size
@@ -51,7 +52,7 @@ class BatchProcessor:
         # 是否启用批处理
         self.batch_enabled = BATCH_PROCESSING_ENABLED       
 
-        print(f"[DEBUG] 批处理器初始化: 启用={self.batch_enabled}, 批大小={self.batch_size}, 最大工作线程={self.max_workers}")
+        print(f"[DEBUG] 训练批处理器初始化: 启用={self.batch_enabled}, 批大小={self.batch_size}, 最大工作线程={self.max_workers}")
     
     def add_item(self, batch_type: str, item: Dict[str, Any]):
         """添加一个项目到批处理队列"""
@@ -152,15 +153,16 @@ class BatchProcessor:
             # 清空队列
             self.batches = defaultdict(list)
         
-        print("[INFO] 所有批处理项目已完成")
+        print("[INFO] 所有训练批处理项目已完成")
     
     def shutdown(self):
         """关闭处理器和线程池"""
         self.flush_all()
         self.executor.shutdown(wait=True)
-        print("[INFO] 批处理器已关闭")
+        print("[INFO] 训练批处理器已关闭")
 
-# 创建全局批处理器实例
+# 创建全局训练批处理器实例
+# 用于所有训练函数的批处理优化
 batch_processor = BatchProcessor()
 
 # 原始训练函数的批处理增强版本

+ 48 - 37
vanna_llm_factory.py

@@ -1,12 +1,12 @@
 """
 Vanna LLM 工厂文件,专注于 ChromaDB 并简化配置。
 """
+import app_config, os
 from vanna.chromadb import ChromaDB_VectorStore  # 从 Vanna 系统获取
 from customqianwen.Custom_QianwenAI_chat import QianWenAI_Chat
 from customdeepseek.custom_deepseek_chat import DeepSeekChat
-import app_config 
 from embedding_function import get_embedding_function
-import os
+from custompgvector import PG_VectorStore
 
 class Vanna_Qwen_ChromaDB(ChromaDB_VectorStore, QianWenAI_Chat):
     def __init__(self, config=None):
@@ -18,6 +18,24 @@ class Vanna_DeepSeek_ChromaDB(ChromaDB_VectorStore, DeepSeekChat):
         ChromaDB_VectorStore.__init__(self, config=config)
         DeepSeekChat.__init__(self, config=config)
 
+class Vanna_Qwen_PGVector(PG_VectorStore, QianWenAI_Chat):
+    def __init__(self, config=None):
+        PG_VectorStore.__init__(self, config=config)
+        QianWenAI_Chat.__init__(self, config=config)
+
+class Vanna_DeepSeek_PGVector(PG_VectorStore, DeepSeekChat):
+    def __init__(self, config=None):
+        PG_VectorStore.__init__(self, config=config)
+        DeepSeekChat.__init__(self, config=config)
+
+# 组合映射表
+LLM_VECTOR_DB_MAP = {
+    ('deepseek', 'chromadb'): Vanna_DeepSeek_ChromaDB,
+    ('deepseek', 'pgvector'): Vanna_DeepSeek_PGVector,
+    ('qwen', 'chromadb'): Vanna_Qwen_ChromaDB,
+    ('qwen', 'pgvector'): Vanna_Qwen_PGVector,
+}
+
 def create_vanna_instance(config_module=None):
     """
     工厂函数:创建并初始化一个Vanna实例 (LLM 和 ChromaDB 特定版本)
@@ -31,55 +49,48 @@ def create_vanna_instance(config_module=None):
     if config_module is None:
         config_module = app_config
 
-    model_type = config_module.MODEL_TYPE.lower()
+    llm_model_name  = config_module.LLM_MODEL_NAME.lower()
+    vector_db_name = config_module.VECTOR_DB_NAME.lower()   
+
+    if (llm_model_name, vector_db_name) not in LLM_VECTOR_DB_MAP:
+        raise ValueError(f"不支持的模型类型: {llm_model_name} 或 向量数据库类型: {vector_db_name}")
     
     config = {}
-    if model_type == "deepseek":
+    if llm_model_name == "deepseek":
         config = config_module.DEEPSEEK_CONFIG.copy()
-        print(f"创建DeepSeek模型实例,使用模型: {config['model']}")
-        # 检查API密钥
-        if not config.get("api_key"):
-            print(f"\n错误: DeepSeek API密钥未设置或为空")
-            print(f"请在.env文件中设置DEEPSEEK_API_KEY环境变量")
-            print(f"无法继续执行,程序退出\n")
-            import sys
-            sys.exit(1)
-    elif model_type == "qwen":
+        print(f"创建DeepSeek模型实例,使用模型: {config.get('model', 'deepseek-chat')}")
+    elif llm_model_name == "qwen":
         config = config_module.QWEN_CONFIG.copy()
-        print(f"创建Qwen模型实例,使用模型: {config['model']}")
-        # 检查API密钥
-        if not config.get("api_key"):
-            print(f"\n错误: Qwen API密钥未设置或为空")
-            print(f"请在.env文件中设置QWEN_API_KEY环境变量")
-            print(f"无法继续执行,程序退出\n")
-            import sys
-            sys.exit(1)
+        print(f"创建Qwen模型实例,使用模型: {config.get('model', 'qwen-plus-latest')}")
     else:
-        raise ValueError(f"不支持的模型类型: {model_type}") 
+        raise ValueError(f"不支持的模型类型: {llm_model_name}") 
+    
+    if vector_db_name == "chromadb":
+        config["path"] = os.path.dirname(os.path.abspath(__file__))
+        print(f"已配置使用ChromaDB作为向量数据库,路径:{config['path']}")
+    elif vector_db_name == "pgvector":
+        # 构建PostgreSQL连接字符串
+        pg_config = config_module.PGVECTOR_CONFIG
+        connection_string = f"postgresql://{pg_config['user']}:{pg_config['password']}@{pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}"
+        config["connection_string"] = connection_string
+        print(f"已配置使用PgVector作为向量数据库,连接字符串: {connection_string}")
+    else:
+        raise ValueError(f"不支持的向量数据库类型: {vector_db_name}")    
     
     embedding_function = get_embedding_function()
 
     config["embedding_function"] = embedding_function
     print(f"已配置使用 EMBEDDING_CONFIG 中的嵌入模型: {config_module.EMBEDDING_CONFIG['model_name']}, 维度: {config_module.EMBEDDING_CONFIG['embedding_dimension']}")
     
-    # 设置ChromaDB路径为项目根目录
-    project_root = os.path.dirname(os.path.abspath(__file__))
-    config["path"] = project_root
-    print(f"已配置使用ChromaDB作为向量数据库,路径:{project_root}")
-    
-    vn = None
-    if model_type == "deepseek":
-        vn = Vanna_DeepSeek_ChromaDB(config=config)
-        print("创建DeepSeek+ChromaDB实例")
-    elif model_type == "qwen":
-        vn = Vanna_Qwen_ChromaDB(config=config)
-        print("创建Qwen+ChromaDB实例")
+    key = (llm_model_name, vector_db_name)
+    cls = LLM_VECTOR_DB_MAP.get(key)
+    if cls is None:
+        raise ValueError(f"不支持的组合: 模型类型={llm_model_name}, 向量数据库类型={vector_db_name}")
     
-    if vn is None:
-        raise ValueError(f"未能成功创建Vanna实例,不支持的模型类型: {model_type}")
+    vn = cls(config=config)
 
     vn.connect_to_postgres(**config_module.APP_DB_CONFIG)           
-    print(f"连接到业务数据库: "
+    print(f"连接到PostgreSQL业务数据库: "
           f"{config_module.APP_DB_CONFIG['host']}:"
           f"{config_module.APP_DB_CONFIG['port']}/"
           f"{config_module.APP_DB_CONFIG['dbname']}")