Explorar o código

Initial commit.

wangxq hai 1 mes
achega
dd31177d44

+ 36 - 0
.gitignore

@@ -0,0 +1,36 @@
+# 忽略 Python 缓存和虚拟环境
+__pycache__/
+*.py[cod]
+venv/
+.venv/
+.env
+.venv
+
+# 忽略日志文件
+*.log
+
+# 忽略操作系统文件
+.DS_Store
+Thumbs.db
+
+# 忽略编译生成文件
+*.class
+*.o
+*.exe
+
+# 忽略数据库文件
+*.db
+*.sqlite
+*.sqlite3
+
+# 忽略 node_modules(前端项目)
+node_modules/
+
+/training/data/
+
+
+# 忽略所有一级UUID目录
+/[0-9a-fA-F]*-[0-9a-fA-F]*-[0-9a-fA-F]*-[0-9a-fA-F]*-[0-9a-fA-F]*/
+
+
+test/

+ 0 - 0
README.md


+ 62 - 0
app.py

@@ -0,0 +1,62 @@
+import chainlit as cl
+from chainlit.input_widget import Select
+import vanna as vn
+import os
+
+vn.set_api_key(os.environ['VANNA_API_KEY'])
+vn.set_model('chinook')
+vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
+
+@cl.step(root=True, language="sql", name="Vanna")
+async def gen_query(human_query: str):
+    sql_query = vn.generate_sql(human_query)
+    return sql_query
+
+@cl.step(root=True, name="Vanna")
+async def execute_query(query):
+    current_step = cl.context.current_step
+    df = vn.run_sql(query)
+    current_step.output = df.head().to_markdown(index=False)
+
+    return df
+
+@cl.step(name="Plot", language="python")
+async def plot(human_query, sql, df):
+    current_step = cl.context.current_step
+    plotly_code = vn.generate_plotly_code(question=human_query, sql=sql, df=df)
+    fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
+
+    current_step.output = plotly_code
+    return fig
+
+@cl.step(type="run", root=True, name="Vanna")
+async def chain(human_query: str):
+    sql_query = await gen_query(human_query)
+    df = await execute_query(sql_query)    
+    fig = await plot(human_query, sql_query, df)
+
+    elements = [cl.Plotly(name="chart", figure=fig, display="inline")]
+    await cl.Message(content=human_query, elements=elements, author="Vanna").send()
+
+@cl.on_message
+async def main(message: cl.Message):
+    await chain(message.content)
+
+@cl.on_chat_start
+async def setup():
+    await cl.Avatar(
+        name="Vanna",
+        url="https://app.vanna.ai/vanna.svg",
+    ).send()
+
+    settings = await cl.ChatSettings(
+        [
+            Select(
+                id="Model",
+                label="OpenAI - Model",
+                values=["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"],
+                initial_index=0,
+            )
+        ]
+    ).send()
+    value = settings["Model"]

+ 59 - 0
app_config.py

@@ -0,0 +1,59 @@
+from dotenv import load_dotenv
+import os
+
+# 加载.env文件中的环境变量
+load_dotenv()
+
+# 使用的模型类型("qwen" 或 "deepseek")
+MODEL_TYPE = "qwen"
+
+# DeepSeek模型配置
+DEEPSEEK_CONFIG = {
+    "api_key": os.getenv("DEEPSEEK_API_KEY"),  # 从环境变量读取API密钥
+    "model": "deepseek-reasoner",  # deepseek-chat, deepseek-reasoner
+    "allow_llm_to_see_data": True,
+    "temperature": 0.7,
+    "n_results": 6,
+    "language": "Chinese",
+    "enable_thinking": False  # 自定义,是否支持流模式
+}
+
+
+# Qwen模型配置
+QWEN_CONFIG = {
+    "model": "qwen-plus",
+    "allow_llm_to_see_data": True,
+    "temperature": 0.7,
+    "n_results": 6,
+    "language": "Chinese",
+    "enable_thinking": False #自定义,是否支持流模式,仅qwen3模型。
+}
+#qwen3-30b-a3b
+#qwen3-235b-a22b
+#qwen-plus-latest
+#qwen-plus
+
+EMBEDDING_CONFIG = {
+    "model_name": "BAAI/bge-m3",
+    "api_key": os.getenv("EMBEDDING_API_KEY"),
+    "base_url": os.getenv("EMBEDDING_BASE_URL"),
+    "embedding_dimension": 1024
+}
+
+
+# 应用数据库连接配置 (业务数据库)
+APP_DB_CONFIG = {
+    "host": "192.168.67.1",
+    "port": 5432,
+    "dbname": "bank_db",
+    "user": os.getenv("APP_DB_USER"),
+    "password": os.getenv("APP_DB_PASSWORD")
+}
+
+# ChromaDB配置
+# CHROMADB_PATH = "."  
+
+# 批处理配置
+BATCH_PROCESSING_ENABLED = True
+BATCH_SIZE = 10
+MAX_WORKERS = 4

+ 1 - 0
customdeepseek/__init__.py

@@ -0,0 +1 @@
+from .custom_deepseek_chat import DeepSeekChat

+ 114 - 0
customdeepseek/custom_deepseek_chat.py

@@ -0,0 +1,114 @@
+import os
+
+from openai import OpenAI
+from vanna.base import VannaBase
+#from base import VannaBase
+
+
+# from vanna.chromadb import ChromaDB_VectorStore
+
+# class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
+#     def __init__(self, config=None):
+#         ChromaDB_VectorStore.__init__(self, config=config)
+#         DeepSeekChat.__init__(self, config=config)
+
+# vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"})
+
+
+class DeepSeekChat(VannaBase):
+    def __init__(self, config=None):
+        VannaBase.__init__(self, config=config)
+        
+        print("...DeepSeekChat init...")
+        if config is None:
+            raise ValueError(
+                "For DeepSeek, config must be provided with an api_key and model"
+            )
+        if "api_key" not in config:
+            raise ValueError("config must contain a DeepSeek api_key")
+
+        if "model" not in config:
+            config["model"] = "deepseek-chat"  # 默认模型
+            print(f"未指定模型,使用默认模型: {config['model']}")
+        
+        # 设置默认值
+        self.temperature = config.get("temperature", 0.7)
+        self.model = config["model"]
+        
+        print("传入的 config 参数如下:")
+        for key, value in config.items():
+            if key != "api_key":  # 不打印API密钥
+                print(f"  {key}: {value}")
+        
+        # 使用标准的OpenAI客户端,但更改基础URL
+        self.client = OpenAI(
+            api_key=config["api_key"], 
+            base_url="https://api.deepseek.com/v1"
+        )
+    
+    def system_message(self, message: str) -> any:
+        print(f"system_content: {message}")
+        return {"role": "system", "content": message}
+
+    def user_message(self, message: str) -> any:
+        print(f"\nuser_content: {message}")
+        return {"role": "user", "content": message}
+
+    def assistant_message(self, message: str) -> any:
+        print(f"assistant_content: {message}")
+        return {"role": "assistant", "content": message}
+
+    def submit_prompt(self, prompt, **kwargs) -> str:
+        if prompt is None:
+            raise Exception("Prompt is None")
+
+        if len(prompt) == 0:
+            raise Exception("Prompt is empty")
+
+        # Count the number of tokens in the message log
+        # Use 4 as an approximation for the number of characters per token
+        num_tokens = 0
+        for message in prompt:
+            num_tokens += len(message["content"]) / 4
+        
+        # 从配置和参数中获取model设置,kwargs优先
+        model = kwargs.get("model", self.model)
+        
+        print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
+        
+        # 创建请求参数
+        chat_params = {
+            "model": model,
+            "messages": prompt,
+            "temperature": kwargs.get("temperature", self.temperature),
+        }
+        
+        try:
+            chat_response = self.client.chat.completions.create(**chat_params)
+            # 返回生成的文本
+            return chat_response.choices[0].message.content
+        except Exception as e:
+            print(f"DeepSeek API调用失败: {e}")
+            raise
+
+    def generate_sql(self, question: str, **kwargs) -> str:
+        # 使用父类的 generate_sql
+        sql = super().generate_sql(question, **kwargs)
+        
+        # 替换 "\_" 为 "_",解决特殊字符转义问题
+        sql = sql.replace("\\_", "_")
+        
+        return sql
+    
+    # 为了解决通过sql生成question时,question是英文的问题。
+    def generate_question(self, sql: str, **kwargs) -> str:
+        # 这里可以自定义提示词/逻辑
+        prompt = [
+            self.system_message(
+                "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,问题要使用中文,不要包含任何解释或SQL内容,也不要出现表名。"
+            ),
+            self.user_message(sql)
+        ]
+        response = self.submit_prompt(prompt, **kwargs)
+        # 你也可以在这里对response做后处理
+        return response

+ 169 - 0
customqianwen/Custom_QianwenAI_chat.py

@@ -0,0 +1,169 @@
+import os
+from openai import OpenAI
+from vanna.base import VannaBase
+
+
+class QianWenAI_Chat(VannaBase):
+  def __init__(self, client=None, config=None):
+    print("...QianWenAI_Chat init...")
+    VannaBase.__init__(self, config=config)
+
+    print("传入的 config 参数如下:")
+    for key, value in self.config.items():
+        print(f"  {key}: {value}")
+
+    # default parameters - can be overrided using config
+    self.temperature = 0.7
+
+    if "temperature" in config:
+      print(f"temperature is changed to: {config['temperature']}")
+      self.temperature = config["temperature"]
+
+    if "api_type" in config:
+      raise Exception(
+        "Passing api_type is now deprecated. Please pass an OpenAI client instead."
+      )
+
+    if "api_base" in config:
+      raise Exception(
+        "Passing api_base is now deprecated. Please pass an OpenAI client instead."
+      )
+
+    if "api_version" in config:
+      raise Exception(
+        "Passing api_version is now deprecated. Please pass an OpenAI client instead."
+      )
+
+    if client is not None:
+      self.client = client
+      return
+
+    if config is None and client is None:
+      self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
+      return
+
+    if "api_key" in config:
+      if "base_url" not in config:
+        self.client = OpenAI(api_key=config["api_key"],
+                             base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
+      else:
+        self.client = OpenAI(api_key=config["api_key"],
+                             base_url=config["base_url"])
+   
+  def system_message(self, message: str) -> any:
+    print(f"system_content: {message}")
+    return {"role": "system", "content": message}
+
+  def user_message(self, message: str) -> any:
+    print(f"\nuser_content: {message}")
+    return {"role": "user", "content": message}
+
+  def assistant_message(self, message: str) -> any:
+    print(f"assistant_content: {message}")
+    return {"role": "assistant", "content": message}
+
+  def submit_prompt(self, prompt, **kwargs) -> str:
+    if prompt is None:
+      raise Exception("Prompt is None")
+
+    if len(prompt) == 0:
+      raise Exception("Prompt is empty")
+
+    # Count the number of tokens in the message log
+    # Use 4 as an approximation for the number of characters per token
+    num_tokens = 0
+    for message in prompt:
+      num_tokens += len(message["content"]) / 4
+
+    # 从配置和参数中获取enable_thinking设置
+    # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
+    enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
+    
+    # 公共参数
+    common_params = {
+      "messages": prompt,
+      "stop": None,
+      "temperature": self.temperature,
+    }
+    
+    # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
+    if enable_thinking:
+      common_params["stream"] = True
+      # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
+      # 也可能它只是默认启用stream=True时的thinking功能
+    
+    model = None
+    # 确定使用的模型
+    if kwargs.get("model", None) is not None:
+      model = kwargs.get("model", None)
+      common_params["model"] = model
+    elif kwargs.get("engine", None) is not None:
+      engine = kwargs.get("engine", None)
+      common_params["engine"] = engine
+      model = engine
+    elif self.config is not None and "engine" in self.config:
+      common_params["engine"] = self.config["engine"]
+      model = self.config["engine"]
+    elif self.config is not None and "model" in self.config:
+      common_params["model"] = self.config["model"]
+      model = self.config["model"]
+    else:
+      if num_tokens > 3500:
+        model = "qwen-long"
+      else:
+        model = "qwen-plus"
+      common_params["model"] = model
+    
+    print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
+    
+    if enable_thinking:
+      # 流式处理模式
+      print("使用流式处理模式,启用thinking功能")
+      
+      # 检查是否需要通过headers传递enable_thinking参数
+      response_stream = self.client.chat.completions.create(**common_params)
+      
+      # 收集流式响应
+      collected_thinking = []
+      collected_content = []
+      
+      for chunk in response_stream:
+        # 处理thinking部分
+        if hasattr(chunk, 'thinking') and chunk.thinking:
+          collected_thinking.append(chunk.thinking)
+        
+        # 处理content部分
+        if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
+          collected_content.append(chunk.choices[0].delta.content)
+      
+      # 可以在这里处理thinking的展示逻辑,如保存到日志等
+      if collected_thinking:
+        print("Model thinking process:", "".join(collected_thinking))
+      
+      # 返回完整的内容
+      return "".join(collected_content)
+    else:
+      # 非流式处理模式
+      print("使用非流式处理模式")
+      response = self.client.chat.completions.create(**common_params)
+      
+      # Find the first response from the chatbot that has text in it (some responses may not have text)
+      for choice in response.choices:
+        if "text" in choice:
+          return choice.text
+
+      # If no response with text is found, return the first response's content (which may be empty)
+      return response.choices[0].message.content
+
+# 为了解决通过sql生成question时,question是英文的问题。
+  def generate_question(self, sql: str, **kwargs) -> str:
+      # 这里可以自定义提示词/逻辑
+      prompt = [
+          self.system_message(
+              "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
+          ),
+          self.user_message(sql)
+      ]
+      response = self.submit_prompt(prompt, **kwargs)
+      # 你也可以在这里对response做后处理
+      return response

+ 400 - 0
customqianwen/Custom_QiawenAI_chat_cn.py

@@ -0,0 +1,400 @@
+"""
+中文千问AI实现
+基于对源码的正确理解,实现正确的方法
+"""
+import os
+from openai import OpenAI
+from vanna.base import VannaBase
+from typing import List, Dict, Any, Optional
+
+
+class QianWenAI_Chat_CN(VannaBase):
+    """
+    中文千问AI聊天类,直接继承VannaBase
+    实现正确的方法名(get_sql_prompt而不是generate_sql_prompt)
+    """
+    def __init__(self, client=None, config=None):
+        """
+        初始化中文千问AI实例
+        
+        Args:
+            client: 可选,OpenAI兼容的客户端
+            config: 配置字典,包含API密钥等配置
+        """
+        print("初始化QianWenAI_Chat_CN...")
+        VannaBase.__init__(self, config=config)
+
+        print("传入的 config 参数如下:")
+        for key, value in self.config.items():
+            print(f"  {key}: {value}")
+
+        # 设置语言为中文
+        self.language = "Chinese"
+        
+        # 默认参数 - 可通过config覆盖
+        self.temperature = 0.7
+
+        if "temperature" in config:
+            print(f"temperature is changed to: {config['temperature']}")
+            self.temperature = config["temperature"]
+
+        if "api_type" in config:
+            raise Exception(
+                "Passing api_type is now deprecated. Please pass an OpenAI client instead."
+            )
+
+        if "api_base" in config:
+            raise Exception(
+                "Passing api_base is now deprecated. Please pass an OpenAI client instead."
+            )
+
+        if "api_version" in config:
+            raise Exception(
+                "Passing api_version is now deprecated. Please pass an OpenAI client instead."
+            )
+
+        if client is not None:
+            self.client = client
+            return
+
+        if config is None and client is None:
+            self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
+            return
+
+        if "api_key" in config:
+            if "base_url" not in config:
+                self.client = OpenAI(api_key=config["api_key"],
+                                    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
+            else:
+                self.client = OpenAI(api_key=config["api_key"],
+                                    base_url=config["base_url"])
+        
+        print("中文千问AI初始化完成")
+    
+    def _response_language(self) -> str:
+        """
+        返回响应语言指示
+        """
+        return "请用中文回答。"
+    
+    def system_message(self, message: str) -> any:
+        """
+        创建系统消息
+        """
+        print(f"[DEBUG] 系统消息: {message}")
+        return {"role": "system", "content": message}
+
+    def user_message(self, message: str) -> any:
+        """
+        创建用户消息
+        """
+        print(f"[DEBUG] 用户消息: {message}")
+        return {"role": "user", "content": message}
+
+    def assistant_message(self, message: str) -> any:
+        """
+        创建助手消息
+        """
+        print(f"[DEBUG] 助手消息: {message}")
+        return {"role": "assistant", "content": message}
+
+    def submit_prompt(self, prompt, **kwargs) -> str:
+        """
+        提交提示词到LLM
+        """
+        if prompt is None:
+            raise Exception("Prompt is None")
+
+        if len(prompt) == 0:
+            raise Exception("Prompt is empty")
+
+        # Count the number of tokens in the message log
+        # Use 4 as an approximation for the number of characters per token
+        num_tokens = 0
+        for message in prompt:
+            num_tokens += len(message["content"]) / 4
+
+        # 从配置和参数中获取enable_thinking设置
+        # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
+        enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
+        
+        # 公共参数
+        common_params = {
+            "messages": prompt,
+            "stop": None,
+            "temperature": self.temperature,
+        }
+        
+        # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
+        if enable_thinking:
+            common_params["stream"] = True
+            # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
+            # 也可能它只是默认启用stream=True时的thinking功能
+        
+        model = None
+        # 确定使用的模型
+        if kwargs.get("model", None) is not None:
+            model = kwargs.get("model", None)
+            common_params["model"] = model
+        elif kwargs.get("engine", None) is not None:
+            engine = kwargs.get("engine", None)
+            common_params["engine"] = engine
+            model = engine
+        elif self.config is not None and "engine" in self.config:
+            common_params["engine"] = self.config["engine"]
+            model = self.config["engine"]
+        elif self.config is not None and "model" in self.config:
+            common_params["model"] = self.config["model"]
+            model = self.config["model"]
+        else:
+            if num_tokens > 3500:
+                model = "qwen-long"
+            else:
+                model = "qwen-plus"
+            common_params["model"] = model
+        
+        print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
+        
+        if enable_thinking:
+            # 流式处理模式
+            print("使用流式处理模式,启用thinking功能")
+            
+            # 检查是否需要通过headers传递enable_thinking参数
+            response_stream = self.client.chat.completions.create(**common_params)
+            
+            # 收集流式响应
+            collected_thinking = []
+            collected_content = []
+            
+            for chunk in response_stream:
+                # 处理thinking部分
+                if hasattr(chunk, 'thinking') and chunk.thinking:
+                    collected_thinking.append(chunk.thinking)
+                
+                # 处理content部分
+                if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
+                    collected_content.append(chunk.choices[0].delta.content)
+            
+            # 可以在这里处理thinking的展示逻辑,如保存到日志等
+            if collected_thinking:
+                print("Model thinking process:", "".join(collected_thinking))
+            
+            # 返回完整的内容
+            return "".join(collected_content)
+        else:
+            # 非流式处理模式
+            print("使用非流式处理模式")
+            response = self.client.chat.completions.create(**common_params)
+            
+            # Find the first response from the chatbot that has text in it (some responses may not have text)
+            for choice in response.choices:
+                if "text" in choice:
+                    return choice.text
+
+            # If no response with text is found, return the first response's content (which may be empty)
+            return response.choices[0].message.content
+
+    # 核心方法:get_sql_prompt
+    def get_sql_prompt(self, question: str, 
+                      question_sql_list: list, 
+                      ddl_list: list, 
+                      doc_list: list, 
+                      **kwargs) -> List[Dict[str, str]]:
+        """
+        生成SQL查询的中文提示词
+        """
+        print("[DEBUG] 正在生成中文SQL提示词...")
+        print(f"[DEBUG] 问题: {question}")
+        print(f"[DEBUG] 相关SQL数量: {len(question_sql_list) if question_sql_list else 0}")
+        print(f"[DEBUG] 相关DDL数量: {len(ddl_list) if ddl_list else 0}")
+        print(f"[DEBUG] 相关文档数量: {len(doc_list) if doc_list else 0}")
+        
+        # 获取dialect
+        dialect = getattr(self, 'dialect', 'SQL')
+        
+        # 创建基础提示词
+        messages = [
+            self.system_message(
+                f"""你是一个专业的SQL助手,根据用户的问题生成正确的{dialect}查询语句。
+                你只需生成SQL语句,不需要任何解释或评论。
+                用户问题: {question}
+                """
+            )
+        ]
+
+        # 添加相关的DDL(如果有)
+        if ddl_list and len(ddl_list) > 0:
+            ddl_text = "\n\n".join([f"-- DDL项 {i+1}:\n{ddl}" for i, ddl in enumerate(ddl_list)])
+            messages.append(
+                self.user_message(
+                    f"""
+                    以下是可能相关的数据库表结构定义,请基于这些信息生成SQL:
+                    
+                    {ddl_text}
+                    
+                    记住,这些只是参考信息,可能并不包含所有需要的表和字段。
+                    """
+                )
+            )
+
+        # 添加相关的文档(如果有)
+        if doc_list and len(doc_list) > 0:
+            doc_text = "\n\n".join([f"-- 文档项 {i+1}:\n{doc}" for i, doc in enumerate(doc_list)])
+            messages.append(
+                self.user_message(
+                    f"""
+                    以下是可能有用的业务逻辑文档:
+                    
+                    {doc_text}
+                    """
+                )
+            )
+
+        # 添加相关的问题和SQL(如果有)
+        if question_sql_list and len(question_sql_list) > 0:
+            qs_text = ""
+            for i, qs_item in enumerate(question_sql_list):
+                qs_text += f"问题 {i+1}: {qs_item.get('question', '')}\n"
+                qs_text += f"SQL:\n```sql\n{qs_item.get('sql', '')}\n```\n\n"
+                
+            messages.append(
+                self.user_message(
+                    f"""
+                    以下是与当前问题相似的问题及其对应的SQL查询:
+                    
+                    {qs_text}
+                    
+                    请参考这些样例来生成当前问题的SQL查询。
+                    """
+                )
+            )
+
+        # 添加最终的用户请求和限制
+        messages.append(
+            self.user_message(
+                f"""
+                根据以上信息,为以下问题生成一个{dialect}查询语句:
+                
+                问题: {question}
+                
+                要求:
+                1. 仅输出SQL语句,不要有任何解释或说明
+                2. 确保语法正确,符合{dialect}标准
+                3. 不要使用不存在的表或字段
+                4. 查询应尽可能高效
+                """
+            )
+        )
+
+        return messages
+        
+    def get_followup_questions_prompt(self, 
+                                     question: str, 
+                                     sql: str, 
+                                     df_metadata: str, 
+                                     **kwargs) -> List[Dict[str, str]]:
+        """
+        生成后续问题的中文提示词
+        """
+        print("[DEBUG] 正在生成中文后续问题提示词...")
+        
+        messages = [
+            self.system_message(
+                f"""你是一个专业的数据分析师,能够根据已有问题提出相关的后续问题。
+                {self._response_language()}
+                """
+            ),
+            self.user_message(
+                f"""
+                原始问题: {question}
+                
+                已执行的SQL查询:
+                ```sql
+                {sql}
+                ```
+                
+                数据结构:
+                {df_metadata}
+                
+                请基于上述信息,生成3-5个相关的后续问题,这些问题应该:
+                1. 与原始问题和数据相关,是自然的延续
+                2. 提供更深入的分析视角或维度拓展
+                3. 探索可能的业务洞见和价值发现
+                4. 简洁明了,便于用户理解
+                5. 确保问题可以通过SQL查询解答,与现有数据结构相关
+                
+                只需列出问题,不要提供任何解释或SQL。每个问题应该是完整的句子,以问号结尾。
+                """
+            )
+        ]
+        
+        return messages
+        
+    def get_summary_prompt(self, question: str, df_markdown: str, **kwargs) -> List[Dict[str, str]]:
+        """
+        生成摘要的中文提示词
+        """
+        print("[DEBUG] 正在生成中文摘要提示词...")
+        
+        messages = [
+            self.system_message(
+                f"""你是一个专业的数据分析师,能够清晰解释SQL查询的含义和结果。
+                {self._response_language()}
+                """
+            ),
+            self.user_message(
+                f"""
+                你是一个有帮助的数据助手。用户问了这个问题: '{question}'
+
+                以下是一个pandas DataFrame,包含查询的结果: 
+                {df_markdown}
+                
+                请用中文简明扼要地总结这些数据,回答用户的问题。不要提供任何额外的解释,只需提供摘要。
+                """
+            )
+        ]
+        
+        return messages
+        
+    def get_plotly_prompt(self, question: str, sql: str, df_metadata: str, 
+                        chart_instructions: Optional[str] = None, **kwargs) -> List[Dict[str, str]]:
+        """
+        生成Python可视化代码的中文提示词
+        """
+        print("[DEBUG] 正在生成中文Python可视化提示词...")
+        
+        instructions = chart_instructions if chart_instructions else "生成一个适合展示数据的图表"
+        
+        messages = [
+            self.system_message(
+                f"""你是一个专业的Python数据可视化专家,擅长使用Plotly创建数据可视化图表。
+                {self._response_language()}
+                """
+            ),
+            self.user_message(
+                f"""
+                问题: {question}
+                
+                SQL查询:
+                ```sql
+                {sql}
+                ```
+                
+                数据结构:
+                {df_metadata}
+                
+                请生成一个Python函数,使用Plotly库为上述数据创建一个可视化图表。要求:
+                1. {instructions}
+                2. 确保代码语法正确,可直接运行
+                3. 图表应直观展示数据中的关键信息和关系
+                4. 只需提供Python代码,不要有任何解释
+                5. 使用中文作为图表标题、轴标签和图例
+                6. 添加合适的颜色方案,保证图表美观
+                7. 针对数据类型选择最合适的图表类型
+                
+                输出格式必须是可以直接运行的Python代码。
+                """
+            )
+        ]
+        
+        return messages 

+ 2 - 0
customqianwen/__init__.py

@@ -0,0 +1,2 @@
+from .Custom_QianwenAI_chat import QianWenAI_Chat
+from .Custom_QiawenAI_chat_cn import QianWenAI_Chat_CN

+ 324 - 0
embedding_function.py

@@ -0,0 +1,324 @@
+import requests
+import time
+import numpy as np
+from typing import List, Callable
+
+class EmbeddingFunction:
+    def __init__(self, model_name: str, api_key: str, base_url: str, embedding_dimension: int):
+        self.model_name = model_name
+        self.api_key = api_key
+        self.base_url = base_url
+        self.embedding_dimension = embedding_dimension
+        self.headers = {
+            "Authorization": f"Bearer {api_key}",
+            "Content-Type": "application/json"
+        }
+        self.max_retries = 2  # 设置默认的最大重试次数
+        self.retry_interval = 2  # 设置默认的重试间隔秒数
+        self.normalize_embeddings = True # 设置默认是否归一化
+
+    def _normalize_vector(self, vector: List[float]) -> List[float]:
+        """
+        对向量进行L2归一化
+        Args:
+            vector: 输入向量   
+        Returns:
+            List[float]: 归一化后的向量
+        """
+
+        if not vector:
+            return []
+        norm = np.linalg.norm(vector)
+        if norm == 0:
+            return vector
+        return (np.array(vector) / norm).tolist()
+    
+
+    def __call__(self, input) -> List[List[float]]:
+        """
+        为文本列表生成嵌入向量
+        
+        Args:
+            input: 要嵌入的文本或文本列表
+            
+        Returns:
+            List[List[float]]: 嵌入向量列表
+        """
+        if not isinstance(input, list):
+            input = [input]
+            
+        embeddings = []
+        for text in input:
+            payload = {
+                "model": self.model_name,
+                "input": text,
+                "encoding_format": "float"
+            }
+            
+            try:
+                # 修复URL拼接问题
+                url = self.base_url
+                if not url.endswith("/embeddings"):
+                    url = url.rstrip("/")  # 移除尾部斜杠,避免双斜杠
+                    if not url.endswith("/v1/embeddings"):
+                        url = f"{url}/embeddings"
+                
+                response = requests.post(url, json=payload, headers=self.headers)
+                response.raise_for_status()
+                
+                result = response.json()
+                
+                if "data" in result and len(result["data"]) > 0:
+                    vector = result["data"][0]["embedding"]
+                    embeddings.append(vector)
+                else:
+                    raise ValueError(f"API返回无效: {result}")
+                    
+            except Exception as e:
+                print(f"获取embedding时出错: {e}")
+                # 使用实例的 embedding_dimension 来创建零向量
+                embeddings.append([0.0] * self.embedding_dimension)
+                
+        return embeddings
+    
+    
+    def generate_embedding(self, text: str) -> List[float]:
+        """
+        为单个文本生成嵌入向量
+        
+        Args:
+            text (str): 要嵌入的文本
+            
+        Returns:
+            List[float]: 嵌入向量
+        """
+        print(f"生成嵌入向量,文本长度: {len(text)} 字符")
+        
+        # 处理空文本
+        if not text or len(text.strip()) == 0:
+            print("输入文本为空,返回零向量")
+            # self.embedding_dimension 在初始化时已被强制要求
+            # 因此不应该为 None 或需要默认值
+            if self.embedding_dimension is None:
+                # 这个分支理论上不应该被执行,因为工厂函数会确保 embedding_dimension 已设置
+                # 但为了健壮性,如果它意外地是 None,则抛出错误
+                raise ValueError("Embedding dimension (self.embedding_dimension) 未被正确初始化。")
+            return [0.0] * self.embedding_dimension
+        
+        # 准备请求体
+        payload = {
+            "model": self.model_name,
+            "input": text,
+            "encoding_format": "float"
+        }
+        
+        # 添加重试机制
+        retries = 0
+        while retries <= self.max_retries:
+            try:
+                # 发送API请求
+                url = self.base_url
+                if not url.endswith("/embeddings"):
+                    url = url.rstrip("/")  # 移除尾部斜杠,避免双斜杠
+                    if not url.endswith("/v1/embeddings"):
+                        url = f"{url}/embeddings"
+                print(f"请求URL: {url}")
+                
+                response = requests.post(
+                    url, 
+                    json=payload, 
+                    headers=self.headers,
+                    timeout=30  # 设置超时时间
+                )
+                
+                # 检查响应状态
+                if response.status_code != 200:
+                    error_msg = f"API请求错误: {response.status_code}, {response.text}"
+                    print(error_msg)
+                    
+                    # 根据错误码判断是否需要重试
+                    if response.status_code in (429, 500, 502, 503, 504):
+                        retries += 1
+                        if retries <= self.max_retries:
+                            wait_time = self.retry_interval * (2 ** (retries - 1))  # 指数退避
+                            print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                            time.sleep(wait_time)
+                            continue
+                    
+                    raise ValueError(error_msg)
+                
+                # 解析响应
+                result = response.json()
+                
+                # 提取embedding向量
+                if "data" in result and len(result["data"]) > 0 and "embedding" in result["data"][0]:
+                    vector = result["data"][0]["embedding"]
+                    
+                    # 如果是首次调用且未提供维度,则自动设置
+                    if self.embedding_dimension is None:
+                        self.embedding_dimension = len(vector)
+                        print(f"自动设置embedding维度为: {self.embedding_dimension}")
+                    else:
+                        # 验证向量维度
+                        actual_dim = len(vector)
+                        if actual_dim != self.embedding_dimension:
+                            print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
+                    
+                    # 如果需要归一化
+                    if self.normalize_embeddings:
+                        vector = self._normalize_vector(vector)
+                    
+                    print(f"成功生成embedding向量,维度: {len(vector)}")
+                    return vector
+                else:
+                    error_msg = f"API返回格式异常: {result}"
+                    print(error_msg)
+                    raise ValueError(error_msg)
+                
+            except Exception as e:
+                print(f"生成embedding时出错: {str(e)}")
+                retries += 1
+                
+                if retries <= self.max_retries:
+                    wait_time = self.retry_interval * (2 ** (retries - 1))  # 指数退避
+                    print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                    time.sleep(wait_time)
+                else:
+                    print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
+                    # 决定是返回零向量还是重新抛出异常
+                    if self.embedding_dimension:
+                        print(f"返回零向量 (维度: {self.embedding_dimension})")
+                        return [0.0] * self.embedding_dimension
+                    raise
+        
+        # 这里不应该到达,但为了完整性添加
+        raise RuntimeError("生成embedding失败")
+
+    # def test_connection(self, test_text="测试文本") -> dict:
+    #     """
+    #     测试嵌入模型的连接和功能
+        
+    #     Args:
+    #         test_text (str): 用于测试的文本
+            
+    #     Returns:
+    #         dict: 包含测试结果的字典,包括是否成功、维度信息等
+    #     """
+    #     result = {
+    #         "success": False,
+    #         "model": self.model_name,
+    #         "base_url": self.base_url,
+    #         "message": "",
+    #         "actual_dimension": None,
+    #         "expected_dimension": self.embedding_dimension
+    #     }
+        
+    #     try:
+    #         print(f"测试嵌入模型连接 - 模型: {self.model_name}")
+    #         print(f"API服务地址: {self.base_url}")
+            
+    #         # 验证配置
+    #         if not self.api_key:
+    #             result["message"] = "API密钥未设置或为空"
+    #             return result
+                
+    #         if not self.base_url:
+    #             result["message"] = "API服务地址未设置或为空"
+    #             return result
+                
+    #         # 测试生成向量
+    #         vector = self.generate_embedding(test_text)
+    #         actual_dimension = len(vector)
+            
+    #         result["success"] = True
+    #         result["actual_dimension"] = actual_dimension
+            
+    #         # 检查维度是否一致
+    #         if actual_dimension != self.embedding_dimension:
+    #             result["message"] = f"警告: 模型实际生成的向量维度({actual_dimension})与配置维度({self.embedding_dimension})不一致"
+    #         else:
+    #             result["message"] = f"连接测试成功,向量维度: {actual_dimension}"
+                
+    #         return result
+            
+    #     except Exception as e:
+    #         result["message"] = f"连接测试失败: {str(e)}"
+    #         return result
+
+
+def get_embedding_function() -> EmbeddingFunction:
+    """
+    从 app_config.py 的 EMBEDDING_CONFIG 字典加载配置并创建 EmbeddingFunction 实例。
+    如果任何必需的配置未找到,则抛出异常。
+
+    Returns:
+        EmbeddingFunction: EmbeddingFunction 的实例。
+
+    Raises:
+        ImportError: 如果 app_config.py 无法导入。
+        AttributeError: 如果 app_config.py 中缺少 EMBEDDING_CONFIG。
+        KeyError: 如果 EMBEDDING_CONFIG 字典中缺少任何必要的键。
+    """
+    try:
+        import app_config
+    except ImportError:
+        raise ImportError("无法导入 app_config.py。请确保该文件存在且在PYTHONPATH中。")
+
+    try:
+        embedding_config_dict = app_config.EMBEDDING_CONFIG
+    except AttributeError:
+        raise AttributeError("app_config.py 中缺少 EMBEDDING_CONFIG 配置字典。")
+
+    try:
+        api_key = embedding_config_dict["api_key"]
+        model_name = embedding_config_dict["model_name"]
+        base_url = embedding_config_dict["base_url"]
+        embedding_dimension = embedding_config_dict["embedding_dimension"]
+        
+        if api_key is None:
+            # 明确指出 api_key (可能来自环境变量) 未设置的问题
+            raise KeyError("EMBEDDING_CONFIG 中的 'api_key' 未设置 (可能环境变量 EMBEDDING_API_KEY 未定义)。")
+            
+    except KeyError as e:
+        # 将原始的KeyError e 作为原因传递,可以提供更详细的上下文,比如哪个键确实缺失了
+        raise KeyError(f"app_config.py 的 EMBEDDING_CONFIG 字典中缺少必要的键或值无效:{e}")
+
+    return EmbeddingFunction(
+        model_name=model_name,
+        api_key=api_key,
+        base_url=base_url,
+        embedding_dimension=embedding_dimension
+    )
+
+def test_embedding_connection() -> dict:
+    """
+    测试嵌入模型连接和配置是否正确
+    
+    Returns:
+        dict: 测试结果,包括成功/失败状态、错误消息等
+    """
+    try:
+        # 获取嵌入函数实例
+        embedding_function = get_embedding_function()
+        
+        # 测试连接
+        test_result = embedding_function.test_connection()
+        
+        if test_result["success"]:
+            print(f"嵌入模型连接测试成功!")
+            if "警告" in test_result["message"]:
+                print(test_result["message"])
+                print(f"建议将app_config.py中的EMBEDDING_CONFIG['embedding_dimension']修改为{test_result['actual_dimension']}")
+        else:
+            print(f"嵌入模型连接测试失败: {test_result['message']}")
+            
+        return test_result
+        
+    except Exception as e:
+        error_message = f"无法测试嵌入模型连接: {str(e)}"
+        print(error_message)
+        return {
+            "success": False,
+            "message": error_message
+        }
+

+ 2 - 0
requirements.txt

@@ -0,0 +1,2 @@
+vanna[chromadb,openai,postgres]==0.7.9
+flask==3.1.1

+ 14 - 0
training/__init__.py

@@ -0,0 +1,14 @@
+# training_tools 模块
+# 包含用于训练Vanna模型的工具和实用程序
+
+__version__ = '0.1.0'
+
+# 导出关键的训练函数
+from .vanna_trainer import (
+    train_ddl,
+    train_documentation,
+    train_sql_example,
+    train_question_sql_pair,
+    flush_training,
+    shutdown_trainer
+) 

+ 473 - 0
training/run_training.py

@@ -0,0 +1,473 @@
+# run_training.py
+import os
+import time
+import re
+import json
+import sys
+import requests
+import pandas as pd
+import argparse
+from pathlib import Path
+from sqlalchemy import create_engine
+
+
+from vanna_trainer import (
+    train_ddl,
+    train_documentation,
+    train_sql_example,
+    train_question_sql_pair,
+    flush_training,
+    shutdown_trainer
+)
+
+def check_embedding_model_connection():
+    """检查嵌入模型连接是否可用    
+    如果无法连接到嵌入模型,则终止程序执行    
+    Returns:
+        bool: 连接成功返回True,否则终止程序
+    """
+    from embedding_function import test_embedding_connection
+
+    print("正在检查嵌入模型连接...")
+    
+    # 使用专门的测试函数进行连接测试
+    test_result = test_embedding_connection()
+    
+    if test_result["success"]:
+        print(f"可以继续训练过程。")
+        return True
+    else:
+        print(f"\n错误: 无法连接到嵌入模型: {test_result['message']}")
+        print("训练过程终止。请检查配置和API服务可用性。")
+        sys.exit(1)
+
+def read_file_by_delimiter(filepath, delimiter="---"):
+    """通用读取:将文件按分隔符切片为多个段落"""
+    with open(filepath, "r", encoding="utf-8") as f:
+        content = f.read()
+    blocks = [block.strip() for block in content.split(delimiter) if block.strip()]
+    return blocks
+
+def read_markdown_file_by_sections(filepath):
+    """专门用于Markdown文件:按标题(#、##、###)分割文档
+    
+    Args:
+        filepath (str): Markdown文件路径
+        
+    Returns:
+        list: 分割后的Markdown章节列表
+    """
+    with open(filepath, "r", encoding="utf-8") as f:
+        content = f.read()
+    
+    # 确定文件是否为Markdown
+    is_markdown = filepath.lower().endswith('.md') or filepath.lower().endswith('.markdown')
+    
+    if not is_markdown:
+        # 非Markdown文件使用默认的---分隔
+        return read_file_by_delimiter(filepath, "---")
+    
+    # 直接按照标题级别分割内容,处理#、##和###
+    sections = []
+    
+    # 匹配所有级别的标题(#、##或###开头)
+    header_pattern = r'(?:^|\n)((?:#|##|###)[^#].*?)(?=\n(?:#|##|###)[^#]|\Z)'
+    all_sections = re.findall(header_pattern, content, re.DOTALL)
+    
+    for section in all_sections:
+        section = section.strip()
+        if section:
+            sections.append(section)
+    
+    # 处理没有匹配到标题的情况
+    if not sections and content.strip():
+        sections = [content.strip()]
+        
+    return sections
+
+def train_ddl_statements(ddl_file):
+    """训练DDL语句
+    Args:
+        ddl_file (str): DDL文件路径
+    """
+    print(f"开始训练 DDL: {ddl_file}")
+    if not os.path.exists(ddl_file):
+        print(f"DDL 文件不存在: {ddl_file}")
+        return
+    for idx, ddl in enumerate(read_file_by_delimiter(ddl_file, ";"), start=1):
+        try:
+            print(f"\n DDL 训练 {idx}")
+            train_ddl(ddl)
+        except Exception as e:
+            print(f"错误:DDL #{idx} - {e}")
+
+def train_documentation_blocks(doc_file):
+    """训练文档块
+    Args:
+        doc_file (str): 文档文件路径
+    """
+    print(f"开始训练 文档: {doc_file}")
+    if not os.path.exists(doc_file):
+        print(f"文档文件不存在: {doc_file}")
+        return
+    
+    # 检查是否为Markdown文件
+    is_markdown = doc_file.lower().endswith('.md') or doc_file.lower().endswith('.markdown')
+    
+    if is_markdown:
+        # 使用Markdown专用分割器
+        sections = read_markdown_file_by_sections(doc_file)
+        print(f" Markdown文档已分割为 {len(sections)} 个章节")
+        
+        for idx, section in enumerate(sections, start=1):
+            try:
+                section_title = section.split('\n', 1)[0].strip()
+                print(f"\n Markdown章节训练 {idx}: {section_title}")
+                
+                # 检查部分长度并提供警告
+                if len(section) > 2000:
+                    print(f" 章节 {idx} 长度为 {len(section)} 字符,接近API限制(2048)")
+                
+                train_documentation(section)
+            except Exception as e:
+                print(f" 错误:章节 #{idx} - {e}")
+    else:
+        # 非Markdown文件使用传统的---分隔
+        for idx, doc in enumerate(read_file_by_delimiter(doc_file, "---"), start=1):
+            try:
+                print(f"\n 文档训练 {idx}")
+                train_documentation(doc)
+            except Exception as e:
+                print(f" 错误:文档 #{idx} - {e}")
+
+def train_sql_examples(sql_file):
+    """训练SQL示例
+    Args:
+        sql_file (str): SQL示例文件路径
+    """
+    print(f" 开始训练 SQL 示例: {sql_file}")
+    if not os.path.exists(sql_file):
+        print(f" SQL 示例文件不存在: {sql_file}")
+        return
+    for idx, sql in enumerate(read_file_by_delimiter(sql_file, ";"), start=1):
+        try:
+            print(f"\n SQL 示例训练 {idx}")
+            train_sql_example(sql)
+        except Exception as e:
+            print(f" 错误:SQL #{idx} - {e}")
+
+def train_question_sql_pairs(qs_file):
+    """训练问答对
+    Args:
+        qs_file (str): 问答对文件路径
+    """
+    print(f" 开始训练 问答对: {qs_file}")
+    if not os.path.exists(qs_file):
+        print(f" 问答文件不存在: {qs_file}")
+        return
+    try:
+        with open(qs_file, "r", encoding="utf-8") as f:
+            lines = f.readlines()
+        for idx, line in enumerate(lines, start=1):
+            if "::" not in line:
+                continue
+            question, sql = line.strip().split("::", 1)
+            print(f"\n 问答训练 {idx}")
+            train_question_sql_pair(question.strip(), sql.strip())
+    except Exception as e:
+        print(f" 错误:问答训练 - {e}")
+
+def train_formatted_question_sql_pairs(formatted_file):
+    """训练格式化的问答对文件
+    支持两种格式:
+    1. Question: xxx\nSQL: xxx (单行SQL)
+    2. Question: xxx\nSQL:\nxxx\nxxx (多行SQL)
+    
+    Args:
+        formatted_file (str): 格式化问答对文件路径
+    """
+    print(f" 开始训练 格式化问答对: {formatted_file}")
+    if not os.path.exists(formatted_file):
+        print(f" 格式化问答文件不存在: {formatted_file}")
+        return
+    
+    # 读取整个文件内容
+    with open(formatted_file, "r", encoding="utf-8") as f:
+        content = f.read()
+    
+    # 按双空行分割不同的问答对
+    # 使用更精确的分隔符,避免误识别
+    pairs = []
+    blocks = content.split("\n\nQuestion:")
+    
+    # 处理第一块(可能没有前导的"\n\nQuestion:")
+    first_block = blocks[0]
+    if first_block.strip().startswith("Question:"):
+        pairs.append(first_block.strip())
+    elif "Question:" in first_block:
+        # 处理文件开头没有Question:的情况
+        question_start = first_block.find("Question:")
+        pairs.append(first_block[question_start:].strip())
+    
+    # 处理其余块
+    for block in blocks[1:]:
+        pairs.append("Question:" + block.strip())
+    
+    # 处理每个问答对
+    successfully_processed = 0
+    for idx, pair in enumerate(pairs, start=1):
+        try:
+            if "Question:" not in pair or "SQL:" not in pair:
+                print(f" 跳过不符合格式的对 #{idx}")
+                continue
+                
+            # 提取问题部分
+            question_start = pair.find("Question:") + len("Question:")
+            sql_start = pair.find("SQL:", question_start)
+            
+            if sql_start == -1:
+                print(f" SQL部分未找到,跳过对 #{idx}")
+                continue
+                
+            question = pair[question_start:sql_start].strip()
+            
+            # 提取SQL部分(支持多行)
+            sql_part = pair[sql_start + len("SQL:"):].strip()
+            
+            # 检查是否存在下一个Question标记(防止解析错误)
+            next_question = pair.find("Question:", sql_start)
+            if next_question != -1:
+                sql_part = pair[sql_start + len("SQL:"):next_question].strip()
+            
+            if not question or not sql_part:
+                print(f" 问题或SQL为空,跳过对 #{idx}")
+                continue
+            
+            # 训练问答对
+            print(f"\n格式化问答训练 {idx}")
+            print(f"问题: {question}")
+            print(f"SQL: {sql_part}")
+            train_question_sql_pair(question, sql_part)
+            successfully_processed += 1
+            
+        except Exception as e:
+            print(f" 错误:格式化问答训练对 #{idx} - {e}")
+    
+    print(f"格式化问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(pairs)} 对)")
+
+def train_json_question_sql_pairs(json_file):
+    """训练JSON格式的问答对
+    
+    Args:
+        json_file (str): JSON格式问答对文件路径
+    """
+    print(f" 开始训练 JSON格式问答对: {json_file}")
+    if not os.path.exists(json_file):
+        print(f" JSON问答文件不存在: {json_file}")
+        return
+    
+    try:
+        # 读取JSON文件
+        with open(json_file, "r", encoding="utf-8") as f:
+            data = json.load(f)
+        
+        # 确保数据是列表格式
+        if not isinstance(data, list):
+            print(f" 错误: JSON文件格式不正确,应为问答对列表")
+            return
+            
+        successfully_processed = 0
+        for idx, pair in enumerate(data, start=1):
+            try:
+                # 检查问答对格式
+                if not isinstance(pair, dict) or "question" not in pair or "sql" not in pair:
+                    print(f" 跳过不符合格式的对 #{idx}")
+                    continue
+                
+                question = pair["question"].strip()
+                sql = pair["sql"].strip()
+                
+                if not question or not sql:
+                    print(f" 问题或SQL为空,跳过对 #{idx}")
+                    continue
+                
+                # 训练问答对
+                print(f"\n JSON格式问答训练 {idx}")
+                print(f"问题: {question}")
+                print(f"SQL: {sql}")
+                train_question_sql_pair(question, sql)
+                successfully_processed += 1
+                
+            except Exception as e:
+                print(f" 错误:JSON问答训练对 #{idx} - {e}")
+        
+        print(f"JSON格式问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(data)} 对)")
+        
+    except json.JSONDecodeError as e:
+        print(f" 错误:JSON解析失败 - {e}")
+    except Exception as e:
+        print(f" 错误:处理JSON问答训练 - {e}")
+
+def process_training_files(data_path):
+    """处理指定路径下的所有训练文件
+    
+    Args:
+        data_path (str): 训练数据目录路径
+    """
+    print(f"\n===== 扫描训练数据目录: {os.path.abspath(data_path)} =====")
+    
+    # 检查目录是否存在
+    if not os.path.exists(data_path):
+        print(f"错误: 训练数据目录不存在: {data_path}")
+        return False
+    
+    # 初始化统计计数器
+    stats = {
+        "ddl": 0,
+        "documentation": 0,
+        "sql_example": 0,
+        "question_sql_formatted": 0,
+        "question_sql_json": 0
+    }
+    
+    # 递归遍历目录中的所有文件
+    for root, _, files in os.walk(data_path):
+        for file in files:
+            file_path = os.path.join(root, file)
+            file_lower = file.lower()
+            
+            # 根据文件类型调用相应的处理函数
+            try:
+                if file_lower.endswith(".ddl"):
+                    print(f"\n处理DDL文件: {file_path}")
+                    train_ddl_statements(file_path)
+                    stats["ddl"] += 1
+                    
+                elif file_lower.endswith(".md") or file_lower.endswith(".markdown"):
+                    print(f"\n处理文档文件: {file_path}")
+                    train_documentation_blocks(file_path)
+                    stats["documentation"] += 1
+                    
+                elif file_lower.endswith("_pair.json") or file_lower.endswith("_pairs.json"):
+                    print(f"\n处理JSON问答对文件: {file_path}")
+                    train_json_question_sql_pairs(file_path)
+                    stats["question_sql_json"] += 1
+                    
+                elif file_lower.endswith("_sql_pair.sql") or file_lower.endswith("_sql_pairs.sql"):
+                    print(f"\n处理格式化问答对文件: {file_path}")
+                    train_formatted_question_sql_pairs(file_path)
+                    stats["question_sql_formatted"] += 1
+                    
+                elif file_lower.endswith(".sql") and not (file_lower.endswith("_sql_pair.sql") or file_lower.endswith("_sql_pairs.sql")):
+                    print(f"\n处理SQL示例文件: {file_path}")
+                    train_sql_examples(file_path)
+                    stats["sql_example"] += 1
+            except Exception as e:
+                print(f"处理文件 {file_path} 时出错: {e}")
+    
+    # 打印处理统计
+    print("\n===== 训练文件处理统计 =====")
+    print(f"DDL文件: {stats['ddl']}个")
+    print(f"文档文件: {stats['documentation']}个")
+    print(f"SQL示例文件: {stats['sql_example']}个")
+    print(f"格式化问答对文件: {stats['question_sql_formatted']}个")
+    print(f"JSON问答对文件: {stats['question_sql_json']}个")
+    
+    total_files = sum(stats.values())
+    if total_files == 0:
+        print(f"警告: 在目录 {data_path} 中未找到任何可训练的文件")
+        return False
+        
+    return True
+
+def main():
+    """主函数:配置和运行训练流程"""
+    
+    # 先导入所需模块
+    import os
+    import app_config
+    
+    # 解析命令行参数
+    parser = argparse.ArgumentParser(description='训练Vanna NL2SQL模型')
+    parser.add_argument('--data_path', type=str, default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data'),
+                        help='训练数据目录路径 (默认: training/data)')
+    args = parser.parse_args()
+    
+    # 使用Path对象处理路径以确保跨平台兼容性
+    data_path = Path(args.data_path)
+    
+    # 设置正确的项目根目录路径
+    project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+    # 检查嵌入模型连接
+    check_embedding_model_connection()
+    
+    # 打印ChromaDB相关信息
+    try:
+        try:
+            import chromadb
+            chroma_version = chromadb.__version__
+        except ImportError:
+            chroma_version = "未知"
+        
+        # 尝试查看当前使用的ChromaDB文件
+        chroma_file = "chroma.sqlite3"  # 默认文件名
+        
+        # 使用项目根目录作为ChromaDB文件路径
+        db_file_path = os.path.join(project_root, chroma_file)
+
+        if os.path.exists(db_file_path):
+            file_size = os.path.getsize(db_file_path) / 1024  # KB
+            print(f"\n===== ChromaDB数据库: {os.path.abspath(db_file_path)} (大小: {file_size:.2f} KB) =====")
+        else:
+            print(f"\n===== 未找到ChromaDB数据库文件于: {os.path.abspath(db_file_path)} =====")
+            
+        # 打印ChromaDB版本
+        print(f"===== ChromaDB客户端库版本: {chroma_version} =====\n")
+    except Exception as e:
+        print(f"\n===== 无法获取ChromaDB信息: {e} =====\n")
+    
+    # 处理训练文件
+    process_successful = process_training_files(data_path)
+    
+    if process_successful:
+        # 训练结束,刷新和关闭批处理器
+        print("\n===== 训练完成,处理剩余批次 =====")
+        flush_training()
+        shutdown_trainer()
+        
+        # 验证数据是否成功写入
+        print("\n===== 验证训练数据 =====")
+        from vanna_llm_factory import create_vanna_instance
+        vn = create_vanna_instance()
+        
+        # 根据向量数据库类型执行不同的验证逻辑
+        # 由于已确定只使用ChromaDB,简化这部分逻辑
+        try:
+            training_data = vn.get_training_data()
+            if training_data is not None and not training_data.empty:
+                # get_training_data 内部通常会打印数量,这里可以补充一个总结
+                print(f"已从ChromaDB中检索到 {len(training_data)} 条训练数据进行验证。")
+            elif training_data is not None and training_data.empty:
+                 print("在ChromaDB中未找到任何训练数据。")
+            else: # training_data is None
+                print("无法从Vanna获取训练数据 (可能返回了None)。请检查连接和Vanna实现。")
+
+        except Exception as e:
+            print(f"验证训练数据失败: {e}")
+            print("请检查ChromaDB连接和表结构。")
+    else:
+        print("\n===== 未能找到或处理任何训练文件,训练过程终止 =====")
+    
+    # 输出embedding模型信息
+    print("\n===== Embedding模型信息 =====")
+    print(f"模型名称: {app_config.EMBEDDING_CONFIG.get('model_name')}")
+    print(f"向量维度: {app_config.EMBEDDING_CONFIG.get('embedding_dimension')}")
+    print(f"API服务: {app_config.EMBEDDING_CONFIG.get('base_url')}")
+    # 打印ChromaDB路径信息
+    chroma_display_path = os.path.abspath(project_root)
+    print(f"向量数据库: ChromaDB ({chroma_display_path})")
+    print("===== 训练流程完成 =====\n")
+
+if __name__ == "__main__":
+    main() 

+ 207 - 0
training/vanna_trainer.py

@@ -0,0 +1,207 @@
+# vanna_trainer.py
+import os
+import time
+import threading
+import queue
+import concurrent.futures
+from functools import lru_cache
+from collections import defaultdict
+from typing import List, Dict, Any, Tuple, Optional, Union, Callable
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import app_config
+
+# 设置正确的项目根目录路径
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+# 创建vanna实例
+from vanna_llm_factory import create_vanna_instance
+
+vn = create_vanna_instance()
+
+# 直接从配置文件获取模型名称
+embedding_model = app_config.EMBEDDING_CONFIG.get('model_name')
+print(f"\n===== Embedding模型信息 =====")
+print(f"模型名称: {embedding_model}")
+if hasattr(app_config, 'EMBEDDING_CONFIG'):
+    if 'embedding_dimension' in app_config.EMBEDDING_CONFIG:
+        print(f"向量维度: {app_config.EMBEDDING_CONFIG['embedding_dimension']}")
+    if 'base_url' in app_config.EMBEDDING_CONFIG:
+        print(f"API服务: {app_config.EMBEDDING_CONFIG['base_url']}")
+print("==============================")
+
+# 从app_config获取其他配置
+BATCH_PROCESSING_ENABLED = app_config.BATCH_PROCESSING_ENABLED
+BATCH_SIZE = app_config.BATCH_SIZE
+MAX_WORKERS = app_config.MAX_WORKERS
+
+
+# 数据批处理器
+class BatchProcessor:
+    def __init__(self, batch_size=BATCH_SIZE, max_workers=MAX_WORKERS):
+        self.batch_size = batch_size
+        self.max_workers = max_workers
+        self.batches = defaultdict(list)
+        self.lock = threading.Lock()  # 线程安全锁
+        
+        # 初始化工作线程池
+        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
+        
+        # 是否启用批处理
+        self.batch_enabled = BATCH_PROCESSING_ENABLED       
+
+        print(f"[DEBUG] 批处理器初始化: 启用={self.batch_enabled}, 批大小={self.batch_size}, 最大工作线程={self.max_workers}")
+    
+    def add_item(self, batch_type: str, item: Dict[str, Any]):
+        """添加一个项目到批处理队列"""
+        if not self.batch_enabled:
+            # 如果未启用批处理,直接处理
+            self._process_single_item(batch_type, item)
+            return
+        
+        with self.lock:
+            self.batches[batch_type].append(item)
+            
+            if len(self.batches[batch_type]) >= self.batch_size:
+                batch_items = self.batches[batch_type]
+                self.batches[batch_type] = []
+                # 提交批处理任务到线程池
+                self.executor.submit(self._process_batch, batch_type, batch_items)
+    
+    def _process_single_item(self, batch_type: str, item: Dict[str, Any]):
+        """处理单个项目"""
+        try:
+            if batch_type == 'ddl':
+                vn.train(ddl=item['ddl'])
+            elif batch_type == 'documentation':
+                vn.train(documentation=item['documentation'])
+            elif batch_type == 'question_sql':
+                vn.train(question=item['question'], sql=item['sql'])
+            
+            print(f"[DEBUG] 单项处理成功: {batch_type}")
+                
+        except Exception as e:
+            print(f"[ERROR] 处理 {batch_type} 项目失败: {e}")
+    
+    def _process_batch(self, batch_type: str, items: List[Dict[str, Any]]):
+        """处理一批项目"""
+        print(f"[INFO] 开始批量处理 {len(items)} 个 {batch_type} 项")
+        start_time = time.time()
+        
+        try:
+            # 准备批处理数据
+            batch_data = []
+            
+            if batch_type == 'ddl':
+                for item in items:
+                    batch_data.append({
+                        'type': 'ddl',
+                        'content': item['ddl']
+                    })
+            
+            elif batch_type == 'documentation':
+                for item in items:
+                    batch_data.append({
+                        'type': 'documentation',
+                        'content': item['documentation']
+                    })
+            
+            elif batch_type == 'question_sql':
+                for item in items:
+                    batch_data.append({
+                        'type': 'question_sql',
+                        'question': item['question'],
+                        'sql': item['sql']
+                    })
+            
+            # 使用批量添加方法
+            if hasattr(vn, 'add_batch') and callable(getattr(vn, 'add_batch')):
+                success = vn.add_batch(batch_data)
+                if success:
+                    print(f"[INFO] 批量处理成功: {len(items)} 个 {batch_type} 项")
+                else:
+                    print(f"[WARNING] 批量处理部分失败: {batch_type}")
+            else:
+                # 如果没有批处理方法,退回到逐条处理
+                print(f"[WARNING] 批处理不可用,使用逐条处理: {batch_type}")
+                for item in items:
+                    self._process_single_item(batch_type, item)
+                
+        except Exception as e:
+            print(f"[ERROR] 批处理 {batch_type} 失败: {e}")
+            # 如果批处理失败,尝试逐条处理
+            print(f"[INFO] 尝试逐条处理...")
+            for item in items:
+                try:
+                    self._process_single_item(batch_type, item)
+                except Exception as item_e:
+                    print(f"[ERROR] 处理项目失败: {item_e}")
+        
+        elapsed = time.time() - start_time
+        print(f"[INFO] 批处理完成 {len(items)} 个 {batch_type} 项,耗时 {elapsed:.2f} 秒")
+    
+    def flush_all(self):
+        """强制处理所有剩余项目"""
+        with self.lock:
+            for batch_type, items in self.batches.items():
+                if items:
+                    print(f"[INFO] 正在处理剩余的 {len(items)} 个 {batch_type} 项")
+                    self._process_batch(batch_type, items)
+            
+            # 清空队列
+            self.batches = defaultdict(list)
+        
+        print("[INFO] 所有批处理项目已完成")
+    
+    def shutdown(self):
+        """关闭处理器和线程池"""
+        self.flush_all()
+        self.executor.shutdown(wait=True)
+        print("[INFO] 批处理器已关闭")
+
+# 创建全局批处理器实例
+batch_processor = BatchProcessor()
+
+# 原始训练函数的批处理增强版本
+def train_ddl(ddl_sql: str):
+    print(f"[DDL] Training on DDL:\n{ddl_sql}")
+    batch_processor.add_item('ddl', {'ddl': ddl_sql})
+
+def train_documentation(doc: str):
+    print(f"[DOC] Training on documentation:\n{doc}")
+    batch_processor.add_item('documentation', {'documentation': doc})
+
+def train_sql_example(sql: str):
+    """训练单个SQL示例,通过SQL生成相应的问题"""
+    print(f"[SQL] Training on SQL:\n{sql}")
+    
+    try:
+        # 直接调用generate_question方法
+        question = vn.generate_question(sql=sql)
+        
+        question = question.strip()
+        if not question.endswith("?") and not question.endswith("?"):
+            question += "?"
+            
+    except Exception as e:
+        print(f"[ERROR] 生成问题时出错: {e}")
+        raise Exception(f"无法为SQL生成问题: {e}")
+        
+    print(f"[SQL] 生成问题: {question}")
+    # 使用标准方式存储问题-SQL对
+    batch_processor.add_item('question_sql', {'question': question, 'sql': sql})
+
+def train_question_sql_pair(question: str, sql: str):
+    print(f"[Q-S] Training on:\nquestion: {question}\nsql: {sql}")
+    batch_processor.add_item('question_sql', {'question': question, 'sql': sql})
+
+# 完成训练后刷新所有待处理项
+def flush_training():
+    """强制处理所有待处理的训练项目"""
+    batch_processor.flush_all()
+
+# 关闭训练器
+def shutdown_trainer():
+    """关闭训练器和相关资源"""
+    batch_processor.shutdown() 

+ 86 - 0
vanna_llm_factory.py

@@ -0,0 +1,86 @@
+"""
+Vanna LLM 工厂文件,专注于 ChromaDB 并简化配置。
+"""
+from vanna.chromadb import ChromaDB_VectorStore  # 从 Vanna 系统获取
+from customqianwen.Custom_QianwenAI_chat import QianWenAI_Chat
+from customdeepseek.custom_deepseek_chat import DeepSeekChat
+import app_config 
+from embedding_function import get_embedding_function
+import os
+
+class Vanna_Qwen_ChromaDB(ChromaDB_VectorStore, QianWenAI_Chat):
+    def __init__(self, config=None):
+        ChromaDB_VectorStore.__init__(self, config=config)
+        QianWenAI_Chat.__init__(self, config=config)
+
+class Vanna_DeepSeek_ChromaDB(ChromaDB_VectorStore, DeepSeekChat):
+    def __init__(self, config=None):
+        ChromaDB_VectorStore.__init__(self, config=config)
+        DeepSeekChat.__init__(self, config=config)
+
+def create_vanna_instance(config_module=None):
+    """
+    工厂函数:创建并初始化一个Vanna实例 (LLM 和 ChromaDB 特定版本)
+    
+    Args:
+        config_module: 配置模块,默认为None时使用 app_config
+        
+    Returns:
+        初始化后的Vanna实例
+    """
+    if config_module is None:
+        config_module = app_config
+
+    model_type = config_module.MODEL_TYPE.lower()
+    
+    config = {}
+    if model_type == "deepseek":
+        config = config_module.DEEPSEEK_CONFIG.copy()
+        print(f"创建DeepSeek模型实例,使用模型: {config['model']}")
+        # 检查API密钥
+        if not config.get("api_key"):
+            print(f"\n错误: DeepSeek API密钥未设置或为空")
+            print(f"请在.env文件中设置DEEPSEEK_API_KEY环境变量")
+            print(f"无法继续执行,程序退出\n")
+            import sys
+            sys.exit(1)
+    elif model_type == "qwen":
+        config = config_module.QWEN_CONFIG.copy()
+        print(f"创建Qwen模型实例,使用模型: {config['model']}")
+        # 检查API密钥
+        if not config.get("api_key"):
+            print(f"\n错误: Qwen API密钥未设置或为空")
+            print(f"请在.env文件中设置QWEN_API_KEY环境变量")
+            print(f"无法继续执行,程序退出\n")
+            import sys
+            sys.exit(1)
+    else:
+        raise ValueError(f"不支持的模型类型: {model_type}") 
+    
+    embedding_function = get_embedding_function()
+
+    config["embedding_function"] = embedding_function
+    print(f"已配置使用 EMBEDDING_CONFIG 中的嵌入模型: {config_module.EMBEDDING_CONFIG['model_name']}, 维度: {config_module.EMBEDDING_CONFIG['embedding_dimension']}")
+    
+    # 设置ChromaDB路径为项目根目录
+    project_root = os.path.dirname(os.path.abspath(__file__))
+    config["path"] = project_root
+    print(f"已配置使用ChromaDB作为向量数据库,路径:{project_root}")
+    
+    vn = None
+    if model_type == "deepseek":
+        vn = Vanna_DeepSeek_ChromaDB(config=config)
+        print("创建DeepSeek+ChromaDB实例")
+    elif model_type == "qwen":
+        vn = Vanna_Qwen_ChromaDB(config=config)
+        print("创建Qwen+ChromaDB实例")
+    
+    if vn is None:
+        raise ValueError(f"未能成功创建Vanna实例,不支持的模型类型: {model_type}")
+
+    vn.connect_to_postgres(**config_module.APP_DB_CONFIG)           
+    print(f"已连接到业务数据库: "
+          f"{config_module.APP_DB_CONFIG['host']}:"
+          f"{config_module.APP_DB_CONFIG['port']}/"
+          f"{config_module.APP_DB_CONFIG['dbname']}")
+    return vn