浏览代码

项目统一的日志管理重构完成.

wangxq 1 周之前
父节点
当前提交
04e48d41f1
共有 43 个文件被更改,包括 740 次插入1605 次删除
  1. 2 1
      .claude/settings.local.json
  2. 90 86
      agent/citu_agent.py
  3. 24 18
      agent/classifier.py
  4. 7 3
      agent/tools/general_chat.py
  5. 10 6
      agent/tools/sql_execution.py
  6. 7 3
      agent/tools/sql_generation.py
  7. 8 4
      agent/tools/summary_generation.py
  8. 12 8
      agent/tools/utils.py
  9. 1 1
      app_config.py
  10. 76 69
      citu_app.py
  11. 15 13
      common/embedding_cache_manager.py
  12. 22 18
      common/qa_feedback_manager.py
  13. 37 35
      common/redis_conversation_manager.py
  14. 12 8
      common/utils.py
  15. 11 7
      common/vanna_combinations.py
  16. 8 4
      common/vanna_instance.py
  17. 17 12
      core/embedding_function.py
  18. 11 7
      core/vanna_llm_factory.py
  19. 17 13
      customembedding/ollama_embedding.py
  20. 72 65
      customllm/base_llm_chat.py
  21. 15 15
      customllm/deepseek_chat.py
  22. 31 31
      customllm/ollama_chat.py
  23. 8 8
      customllm/qianwen_chat.py
  24. 58 22
      custompgvector/pgvector.py
  25. 2 2
      data_pipeline/analyzers/md_analyzer.py
  26. 2 2
      data_pipeline/analyzers/theme_extractor.py
  27. 3 1
      data_pipeline/config.py
  28. 2 3
      data_pipeline/ddl_generation/training_data_agent.py
  29. 49 42
      data_pipeline/metadata_only_generator.py
  30. 2 2
      data_pipeline/qa_generation/qs_agent.py
  31. 20 18
      data_pipeline/schema_workflow.py
  32. 7 5
      data_pipeline/tools/base.py
  33. 35 31
      data_pipeline/trainer/vanna_trainer.py
  34. 2 2
      data_pipeline/utils/file_manager.py
  35. 2 2
      data_pipeline/utils/large_table_handler.py
  36. 31 132
      data_pipeline/utils/logger.py
  37. 2 2
      data_pipeline/utils/permission_checker.py
  38. 2 2
      data_pipeline/utils/system_filter.py
  39. 2 2
      data_pipeline/utils/table_parser.py
  40. 2 2
      data_pipeline/validators/file_count_validator.py
  41. 2 2
      data_pipeline/validators/sql_validation_agent.py
  42. 2 2
      data_pipeline/validators/sql_validator.py
  43. 0 894
      docs/全局log服务改造方案.md

+ 2 - 1
.claude/settings.local.json

@@ -16,7 +16,8 @@
       "Bash(.venv/Scripts/python.exe:*)",
       "Bash(mv:*)",
       "Bash(rm:*)",
-      "Bash(.venv/bin/python:*)"
+      "Bash(.venv/bin/python:*)",
+      "Bash(./.venv/Scripts/python.exe:*)"
     ],
     "deny": []
   }

+ 90 - 86
agent/citu_agent.py

@@ -4,6 +4,7 @@ from langgraph.graph import StateGraph, END
 from langchain.agents import AgentExecutor, create_openai_tools_agent
 from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
 from langchain_core.messages import SystemMessage, HumanMessage
+from core.logging import get_agent_logger
 
 from agent.state import AgentState
 from agent.classifier import QuestionClassifier
@@ -15,39 +16,42 @@ class CituLangGraphAgent:
     """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""
     
     def __init__(self):
+        # 初始化日志
+        self.logger = get_agent_logger("CituAgent")
+        
         # 加载配置
         try:
             from agent.config import get_current_config, get_nested_config
             self.config = get_current_config()
-            print("[CITU_AGENT] 加载Agent配置完成")
+            self.logger.info("加载Agent配置完成")
         except ImportError:
             self.config = {}
-            print("[CITU_AGENT] 配置文件不可用,使用默认配置")
+            self.logger.warning("配置文件不可用,使用默认配置")
         
         self.classifier = QuestionClassifier()
         self.tools = TOOLS
         self.llm = get_compatible_llm()
         
         # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
-        print("[CITU_AGENT] 使用直接工具调用模式")
+        self.logger.info("使用直接工具调用模式")
         
         # 不在构造时创建workflow,改为动态创建以支持路由模式参数
         # self.workflow = self._create_workflow()
-        print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
+        self.logger.info("LangGraph Agent with Direct Tools初始化完成")
     
     def _create_workflow(self, routing_mode: str = None) -> StateGraph:
         """根据路由模式创建不同的工作流"""
         # 确定使用的路由模式
         if routing_mode:
             QUESTION_ROUTING_MODE = routing_mode
-            print(f"[CITU_AGENT] 创建工作流,使用传入的路由模式: {QUESTION_ROUTING_MODE}")
+            self.logger.info(f"创建工作流,使用传入的路由模式: {QUESTION_ROUTING_MODE}")
         else:
             try:
                 from app_config import QUESTION_ROUTING_MODE
-                print(f"[CITU_AGENT] 创建工作流,使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
+                self.logger.info(f"创建工作流,使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
             except ImportError:
                 QUESTION_ROUTING_MODE = "hybrid"
-                print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
+                self.logger.warning(f"配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
         
         workflow = StateGraph(AgentState)
         
@@ -137,12 +141,12 @@ class CituLangGraphAgent:
             state["current_step"] = "direct_database_init"
             state["execution_path"].append("init_direct_database")
             
-            print(f"[DIRECT_DATABASE] 直接数据库模式初始化完成")
+            self.logger.info("直接数据库模式初始化完成")
             
             return state
             
         except Exception as e:
-            print(f"[ERROR] 直接数据库模式初始化异常: {str(e)}")
+            self.logger.error(f"直接数据库模式初始化异常: {str(e)}")
             state["error"] = f"直接数据库模式初始化失败: {str(e)}"
             state["error_code"] = 500
             state["execution_path"].append("init_direct_database_error")
@@ -163,12 +167,12 @@ class CituLangGraphAgent:
             state["current_step"] = "direct_chat_init"
             state["execution_path"].append("init_direct_chat")
             
-            print(f"[DIRECT_CHAT] 直接聊天模式初始化完成")
+            self.logger.info("直接聊天模式初始化完成")
             
             return state
             
         except Exception as e:
-            print(f"[ERROR] 直接聊天模式初始化异常: {str(e)}")
+            self.logger.error(f"直接聊天模式初始化异常: {str(e)}")
             state["error"] = f"直接聊天模式初始化失败: {str(e)}"
             state["error_code"] = 500
             state["execution_path"].append("init_direct_chat_error")
@@ -180,12 +184,12 @@ class CituLangGraphAgent:
             # 从state中获取路由模式,而不是从配置文件读取
             routing_mode = state.get("routing_mode", "hybrid")
             
-            print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
+            self.logger.info(f"开始分类问题: {state['question']}")
             
             # 获取上下文类型(如果有的话)
             context_type = state.get("context_type")
             if context_type:
-                print(f"[CLASSIFY_NODE] 检测到上下文类型: {context_type}")
+                self.logger.info(f"检测到上下文类型: {context_type}")
             
             # 使用渐进式分类策略,传递路由模式
             classification_result = self.classifier.classify(state["question"], context_type, routing_mode)
@@ -199,13 +203,13 @@ class CituLangGraphAgent:
             state["current_step"] = "classified"
             state["execution_path"].append("classify")
             
-            print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
-            print(f"[CLASSIFY_NODE] 路由模式: {routing_mode}, 分类方法: {classification_result.method}")
+            self.logger.info(f"分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
+            self.logger.info(f"路由模式: {routing_mode}, 分类方法: {classification_result.method}")
             
             return state
             
         except Exception as e:
-            print(f"[ERROR] 问题分类异常: {str(e)}")
+            self.logger.error(f"问题分类异常: {str(e)}")
             state["error"] = f"问题分类失败: {str(e)}"
             state["error_code"] = 500
             state["execution_path"].append("classify_error")
@@ -214,12 +218,12 @@ class CituLangGraphAgent:
     async def _agent_sql_generation_node(self, state: AgentState) -> AgentState:
         """SQL生成验证节点 - 负责生成SQL、验证SQL和决定路由"""
         try:
-            print(f"[SQL_GENERATION] 开始处理SQL生成和验证: {state['question']}")
+            self.logger.info(f"开始处理SQL生成和验证: {state['question']}")
             
             question = state["question"]
             
             # 步骤1:生成SQL
-            print(f"[SQL_GENERATION] 步骤1:生成SQL")
+            self.logger.info("步骤1:生成SQL")
             sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
             
             if not sql_result.get("success"):
@@ -228,7 +232,7 @@ class CituLangGraphAgent:
                 error_type = sql_result.get("error_type", "")
                 
                 #print(f"[SQL_GENERATION] SQL生成失败: {error_message}")
-                print(f"[DEBUG] error_type = '{error_type}'")
+                self.logger.debug(f"error_type = '{error_type}'")
                 
                 # 根据错误类型生成用户提示
                 if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():
@@ -244,7 +248,7 @@ class CituLangGraphAgent:
                     state["validation_error_type"] = "llm_explanation"
                     state["current_step"] = "sql_generation_completed"
                     state["execution_path"].append("agent_sql_generation")
-                    print(f"[SQL_GENERATION] 返回LLM解释性答案: {error_message}")
+                    self.logger.info(f"返回LLM解释性答案: {error_message}")
                     return state
                 else:
                     user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"
@@ -257,7 +261,7 @@ class CituLangGraphAgent:
                 state["current_step"] = "sql_generation_failed"
                 state["execution_path"].append("agent_sql_generation_failed")
                 
-                print(f"[SQL_GENERATION] 生成失败: {failure_reason} - {user_prompt}")
+                self.logger.warning(f"生成失败: {failure_reason} - {user_prompt}")
                 return state
             
             sql = sql_result.get("sql")
@@ -273,13 +277,13 @@ class CituLangGraphAgent:
                 state["validation_error_type"] = "llm_explanation"
                 state["current_step"] = "sql_generation_completed"
                 state["execution_path"].append("agent_sql_generation")
-                print(f"[SQL_GENERATION] 返回LLM解释性答案: {explanation}")
+                self.logger.info(f"返回LLM解释性答案: {explanation}")
                 return state
             
             if sql:
-                print(f"[SQL_GENERATION] SQL生成成功: {sql}")
+                self.logger.info(f"SQL生成成功: {sql}")
             else:
-                print(f"[SQL_GENERATION] SQL为空,但不是解释性响应")
+                self.logger.warning("SQL为空,但不是解释性响应")
                 # 这种情况应该很少见,但为了安全起见保留原有的错误处理
                 return state
             
@@ -292,12 +296,12 @@ class CituLangGraphAgent:
                 state["validation_error_type"] = "invalid_sql_format"
                 state["current_step"] = "sql_generation_completed"  
                 state["execution_path"].append("agent_sql_generation")
-                print(f"[SQL_GENERATION] 内容不是有效SQL,当作解释返回: {sql}")
+                self.logger.info(f"内容不是有效SQL,当作解释返回: {sql}")
                 return state
             
             # 步骤2:SQL验证(如果启用)
             if self._is_sql_validation_enabled():
-                print(f"[SQL_GENERATION] 步骤2:验证SQL")
+                self.logger.info("步骤2:验证SQL")
                 validation_result = await self._validate_sql_with_custom_priority(sql)
                 
                 if not validation_result.get("valid"):
@@ -306,7 +310,7 @@ class CituLangGraphAgent:
                     error_message = validation_result.get("error_message")
                     can_repair = validation_result.get("can_repair", False)
                     
-                    print(f"[SQL_GENERATION] SQL验证失败: {error_type} - {error_message}")
+                    self.logger.warning(f"SQL验证失败: {error_type} - {error_message}")
                     
                     if error_type == "forbidden_keywords":
                         # 禁止词错误,直接失败,不尝试修复
@@ -316,12 +320,12 @@ class CituLangGraphAgent:
                         state["validation_error_type"] = "forbidden_keywords"
                         state["current_step"] = "sql_validation_failed"
                         state["execution_path"].append("forbidden_keywords_failed")
-                        print(f"[SQL_GENERATION] 禁止词验证失败,直接结束")
+                        self.logger.warning("禁止词验证失败,直接结束")
                         return state
                     
                     elif error_type == "syntax_error" and can_repair and self._is_auto_repair_enabled():
                         # 语法错误,尝试修复(仅一次)
-                        print(f"[SQL_GENERATION] 尝试修复SQL语法错误(仅一次): {error_message}")
+                        self.logger.info(f"尝试修复SQL语法错误(仅一次): {error_message}")
                         state["sql_repair_attempted"] = True
                         
                         repair_result = await self._attempt_sql_repair_once(sql, error_message)
@@ -335,12 +339,12 @@ class CituLangGraphAgent:
                             state["sql_repair_success"] = True
                             state["current_step"] = "sql_generation_completed"
                             state["execution_path"].append("sql_repair_success")
-                            print(f"[SQL_GENERATION] SQL修复成功: {repaired_sql}")
+                            self.logger.info(f"SQL修复成功: {repaired_sql}")
                             return state
                         else:
                             # 修复失败,直接结束
                             repair_error = repair_result.get("error", "修复失败")
-                            print(f"[SQL_GENERATION] SQL修复失败: {repair_error}")
+                            self.logger.warning(f"SQL修复失败: {repair_error}")
                             state["sql_generation_success"] = False
                             state["sql_validation_success"] = False
                             state["sql_repair_success"] = False
@@ -357,13 +361,13 @@ class CituLangGraphAgent:
                         state["validation_error_type"] = error_type
                         state["current_step"] = "sql_validation_failed"
                         state["execution_path"].append("sql_validation_failed")
-                        print(f"[SQL_GENERATION] SQL验证失败,不尝试修复")
+                        self.logger.warning("SQL验证失败,不尝试修复")
                         return state
                 else:
-                    print(f"[SQL_GENERATION] SQL验证通过")
+                    self.logger.info("SQL验证通过")
                     state["sql_validation_success"] = True
             else:
-                print(f"[SQL_GENERATION] 跳过SQL验证(未启用)")
+                self.logger.info("跳过SQL验证(未启用)")
                 state["sql_validation_success"] = True
             
             # 生成和验证都成功
@@ -371,13 +375,13 @@ class CituLangGraphAgent:
             state["current_step"] = "sql_generation_completed"
             state["execution_path"].append("agent_sql_generation")
             
-            print(f"[SQL_GENERATION] SQL生成验证完成,准备执行")
+            self.logger.info("SQL生成验证完成,准备执行")
             return state
             
         except Exception as e:
-            print(f"[ERROR] SQL生成验证节点异常: {str(e)}")
+            self.logger.error(f"SQL生成验证节点异常: {str(e)}")
             import traceback
-            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            self.logger.error(f"详细错误信息: {traceback.format_exc()}")
             state["sql_generation_success"] = False
             state["sql_validation_success"] = False
             state["user_prompt"] = f"SQL生成验证异常: {str(e)}"
@@ -389,13 +393,13 @@ class CituLangGraphAgent:
     def _agent_sql_execution_node(self, state: AgentState) -> AgentState:
         """SQL执行节点 - 负责执行已验证的SQL和生成摘要"""
         try:
-            print(f"[SQL_EXECUTION] 开始执行SQL: {state.get('sql', 'N/A')}")
+            self.logger.info(f"开始执行SQL: {state.get('sql', 'N/A')}")
             
             sql = state.get("sql")
             question = state["question"]
             
             if not sql:
-                print(f"[SQL_EXECUTION] 没有可执行的SQL")
+                self.logger.warning("没有可执行的SQL")
                 state["error"] = "没有可执行的SQL语句"
                 state["error_code"] = 500
                 state["current_step"] = "sql_execution_error"
@@ -403,11 +407,11 @@ class CituLangGraphAgent:
                 return state
             
             # 步骤1:执行SQL
-            print(f"[SQL_EXECUTION] 步骤1:执行SQL")
+            self.logger.info("步骤1:执行SQL")
             execute_result = execute_sql.invoke({"sql": sql})
             
             if not execute_result.get("success"):
-                print(f"[SQL_EXECUTION] SQL执行失败: {execute_result.get('error')}")
+                self.logger.error(f"SQL执行失败: {execute_result.get('error')}")
                 state["error"] = execute_result.get("error", "SQL执行失败")
                 state["error_code"] = 500
                 state["current_step"] = "sql_execution_error"
@@ -416,15 +420,15 @@ class CituLangGraphAgent:
             
             query_result = execute_result.get("data_result")
             state["query_result"] = query_result
-            print(f"[SQL_EXECUTION] SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
+            self.logger.info(f"SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
             
             # 步骤2:生成摘要(根据配置和数据情况)
             if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
-                print(f"[SQL_EXECUTION] 步骤2:生成摘要")
+                self.logger.info("步骤2:生成摘要")
                 
                 # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套
                 original_question = self._extract_original_question(question)
-                print(f"[SQL_EXECUTION] 原始问题: {original_question}")
+                self.logger.debug(f"原始问题: {original_question}")
                 
                 summary_result = generate_summary.invoke({
                     "question": original_question,  # 使用原始问题而不是enhanced_question
@@ -433,26 +437,26 @@ class CituLangGraphAgent:
                 })
                 
                 if not summary_result.get("success"):
-                    print(f"[SQL_EXECUTION] 摘要生成失败: {summary_result.get('message')}")
+                    self.logger.warning(f"摘要生成失败: {summary_result.get('message')}")
                     # 摘要生成失败不是致命错误,使用默认摘要
                     state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
                 else:
                     state["summary"] = summary_result.get("summary")
-                    print(f"[SQL_EXECUTION] 摘要生成成功")
+                    self.logger.info("摘要生成成功")
             else:
-                print(f"[SQL_EXECUTION] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
+                self.logger.info(f"跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
                 # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
             
             state["current_step"] = "sql_execution_completed"
             state["execution_path"].append("agent_sql_execution")
             
-            print(f"[SQL_EXECUTION] SQL执行完成")
+            self.logger.info("SQL执行完成")
             return state
             
         except Exception as e:
-            print(f"[ERROR] SQL执行节点异常: {str(e)}")
+            self.logger.error(f"SQL执行节点异常: {str(e)}")
             import traceback
-            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            self.logger.error(f"详细错误信息: {traceback.format_exc()}")
             state["error"] = f"SQL执行失败: {str(e)}"
             state["error_code"] = 500
             state["current_step"] = "sql_execution_error"
@@ -467,17 +471,17 @@ class CituLangGraphAgent:
         保留此方法仅为向后兼容,新的工作流使用拆分后的节点
         """
         try:
-            print(f"[DATABASE_AGENT] ⚠️  使用已废弃的database节点,建议使用新的拆分节点")
-            print(f"[DATABASE_AGENT] 开始处理数据库查询: {state['question']}")
+            self.logger.warning("使用已废弃的database节点,建议使用新的拆分节点")
+            self.logger.info(f"开始处理数据库查询: {state['question']}")
             
             question = state["question"]
             
             # 步骤1:生成SQL
-            print(f"[DATABASE_AGENT] 步骤1:生成SQL")
+            self.logger.info("步骤1:生成SQL")
             sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
             
             if not sql_result.get("success"):
-                print(f"[DATABASE_AGENT] SQL生成失败: {sql_result.get('error')}")
+                self.logger.error(f"SQL生成失败: {sql_result.get('error')}")
                 state["error"] = sql_result.get("error", "SQL生成失败")
                 state["error_code"] = 500
                 state["current_step"] = "database_error"
@@ -486,7 +490,7 @@ class CituLangGraphAgent:
             
             sql = sql_result.get("sql")
             state["sql"] = sql
-            print(f"[DATABASE_AGENT] SQL生成成功: {sql}")
+            self.logger.info(f"SQL生成成功: {sql}")
             
             # 步骤1.5:检查是否为解释性响应而非SQL
             error_type = sql_result.get("error_type")
@@ -496,7 +500,7 @@ class CituLangGraphAgent:
                 state["chat_response"] = explanation + " 请尝试提问其它问题。"
                 state["current_step"] = "database_completed"
                 state["execution_path"].append("agent_database")
-                print(f"[DATABASE_AGENT] 返回LLM解释性答案: {explanation}")
+                self.logger.info(f"返回LLM解释性答案: {explanation}")
                 return state
             
             # 额外验证:检查SQL格式(防止工具误判)
@@ -506,15 +510,15 @@ class CituLangGraphAgent:
                 state["chat_response"] = sql + " 请尝试提问其它问题。"
                 state["current_step"] = "database_completed"  
                 state["execution_path"].append("agent_database")
-                print(f"[DATABASE_AGENT] 内容不是有效SQL,当作解释返回: {sql}")
+                self.logger.info(f"内容不是有效SQL,当作解释返回: {sql}")
                 return state
             
             # 步骤2:执行SQL
-            print(f"[DATABASE_AGENT] 步骤2:执行SQL")
+            self.logger.info("步骤2:执行SQL")
             execute_result = execute_sql.invoke({"sql": sql})
             
             if not execute_result.get("success"):
-                print(f"[DATABASE_AGENT] SQL执行失败: {execute_result.get('error')}")
+                self.logger.error(f"SQL执行失败: {execute_result.get('error')}")
                 state["error"] = execute_result.get("error", "SQL执行失败")
                 state["error_code"] = 500
                 state["current_step"] = "database_error"
@@ -523,15 +527,15 @@ class CituLangGraphAgent:
             
             query_result = execute_result.get("data_result")
             state["query_result"] = query_result
-            print(f"[DATABASE_AGENT] SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
+            self.logger.info(f"SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
             
             # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)
             if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
-                print(f"[DATABASE_AGENT] 步骤3:生成摘要")
+                self.logger.info("步骤3:生成摘要")
                 
                 # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套
                 original_question = self._extract_original_question(question)
-                print(f"[DATABASE_AGENT] 原始问题: {original_question}")
+                self.logger.debug(f"原始问题: {original_question}")
                 
                 summary_result = generate_summary.invoke({
                     "question": original_question,  # 使用原始问题而不是enhanced_question
@@ -540,26 +544,26 @@ class CituLangGraphAgent:
                 })
                 
                 if not summary_result.get("success"):
-                    print(f"[DATABASE_AGENT] 摘要生成失败: {summary_result.get('message')}")
+                    self.logger.warning(f"摘要生成失败: {summary_result.get('message')}")
                     # 摘要生成失败不是致命错误,使用默认摘要
                     state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
                 else:
                     state["summary"] = summary_result.get("summary")
-                    print(f"[DATABASE_AGENT] 摘要生成成功")
+                    self.logger.info("摘要生成成功")
             else:
-                print(f"[DATABASE_AGENT] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
+                self.logger.info(f"跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
                 # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
             
             state["current_step"] = "database_completed"
             state["execution_path"].append("agent_database")
             
-            print(f"[DATABASE_AGENT] 数据库查询完成")
+            self.logger.info("数据库查询完成")
             return state
             
         except Exception as e:
-            print(f"[ERROR] 数据库Agent异常: {str(e)}")
+            self.logger.error(f"数据库Agent异常: {str(e)}")
             import traceback
-            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            self.logger.error(f"详细错误信息: {traceback.format_exc()}")
             state["error"] = f"数据库查询失败: {str(e)}"
             state["error_code"] = 500
             state["current_step"] = "database_error"
@@ -569,7 +573,7 @@ class CituLangGraphAgent:
     def _agent_chat_node(self, state: AgentState) -> AgentState:
         """聊天Agent节点 - 直接工具调用模式"""
         try:
-            print(f"[CHAT_AGENT] 开始处理聊天: {state['question']}")
+            self.logger.info(f"开始处理聊天: {state['question']}")
             
             question = state["question"]
             
@@ -584,7 +588,7 @@ class CituLangGraphAgent:
                 pass
             
             # 直接调用general_chat工具
-            print(f"[CHAT_AGENT] 调用general_chat工具")
+            self.logger.info("调用general_chat工具")
             chat_result = general_chat.invoke({
                 "question": question,
                 "context": context
@@ -592,22 +596,22 @@ class CituLangGraphAgent:
             
             if chat_result.get("success"):
                 state["chat_response"] = chat_result.get("response", "")
-                print(f"[CHAT_AGENT] 聊天处理成功")
+                self.logger.info("聊天处理成功")
             else:
                 # 处理失败,使用备用响应
                 state["chat_response"] = chat_result.get("response", "抱歉,我暂时无法处理您的问题。请稍后再试。")
-                print(f"[CHAT_AGENT] 聊天处理失败,使用备用响应: {chat_result.get('error')}")
+                self.logger.warning(f"聊天处理失败,使用备用响应: {chat_result.get('error')}")
             
             state["current_step"] = "chat_completed"
             state["execution_path"].append("agent_chat")
             
-            print(f"[CHAT_AGENT] 聊天处理完成")
+            self.logger.info("聊天处理完成")
             return state
             
         except Exception as e:
-            print(f"[ERROR] 聊天Agent异常: {str(e)}")
+            self.logger.error(f"聊天Agent异常: {str(e)}")
             import traceback
-            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            self.logger.error(f"详细错误信息: {traceback.format_exc()}")
             state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"
             state["current_step"] = "chat_error"
             state["execution_path"].append("agent_chat_error")
@@ -616,7 +620,7 @@ class CituLangGraphAgent:
     def _format_response_node(self, state: AgentState) -> AgentState:
         """格式化最终响应节点"""
         try:
-            print(f"[FORMAT_NODE] 开始格式化响应,问题类型: {state['question_type']}")
+            self.logger.info(f"开始格式化响应,问题类型: {state['question_type']}")
             
             state["current_step"] = "completed"
             state["execution_path"].append("format_response")
@@ -737,11 +741,11 @@ class CituLangGraphAgent:
                     }
                 }
             
-            print(f"[FORMAT_NODE] 响应格式化完成")
+            self.logger.info("响应格式化完成")
             return state
             
         except Exception as e:
-            print(f"[ERROR] 响应格式化异常: {str(e)}")
+            self.logger.error(f"响应格式化异常: {str(e)}")
             state["final_response"] = {
                 "success": False,
                 "error": f"响应格式化异常: {str(e)}",
@@ -760,7 +764,7 @@ class CituLangGraphAgent:
         """
         sql_generation_success = state.get("sql_generation_success", False)
         
-        print(f"[ROUTE] SQL生成路由: success={sql_generation_success}")
+        self.logger.debug(f"SQL生成路由: success={sql_generation_success}")
         
         if sql_generation_success:
             return "continue_execution"  # 路由到SQL执行节点
@@ -780,7 +784,7 @@ class CituLangGraphAgent:
         question_type = state["question_type"]
         confidence = state["classification_confidence"]
         
-        print(f"[ROUTE] 分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")
+        self.logger.debug(f"分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")
         
         if question_type == "DATABASE":
             return "DATABASE"
@@ -803,11 +807,11 @@ class CituLangGraphAgent:
             Dict包含完整的处理结果
         """
         try:
-            print(f"[CITU_AGENT] 开始处理问题: {question}")
+            self.logger.info(f"开始处理问题: {question}")
             if context_type:
-                print(f"[CITU_AGENT] 上下文类型: {context_type}")
+                self.logger.info(f"上下文类型: {context_type}")
             if routing_mode:
-                print(f"[CITU_AGENT] 使用指定路由模式: {routing_mode}")
+                self.logger.info(f"使用指定路由模式: {routing_mode}")
             
             # 动态创建workflow(基于路由模式)
             workflow = self._create_workflow(routing_mode)
@@ -826,12 +830,12 @@ class CituLangGraphAgent:
             # 提取最终结果
             result = final_state["final_response"]
             
-            print(f"[CITU_AGENT] 问题处理完成: {result.get('success', False)}")
+            self.logger.info(f"问题处理完成: {result.get('success', False)}")
             
             return result
             
         except Exception as e:
-            print(f"[ERROR] Agent执行异常: {str(e)}")
+            self.logger.error(f"Agent执行异常: {str(e)}")
             return {
                 "success": False,
                 "error": f"Agent系统异常: {str(e)}",
@@ -1127,7 +1131,7 @@ class CituLangGraphAgent:
             return question.strip()
             
         except Exception as e:
-            print(f"[WARNING] 提取原始问题失败: {str(e)}")
+            self.logger.warning(f"提取原始问题失败: {str(e)}")
             return question.strip()
 
     async def health_check(self) -> Dict[str, Any]:

+ 24 - 18
agent/classifier.py

@@ -2,6 +2,7 @@
 import re
 from typing import Dict, Any, List, Optional
 from dataclasses import dataclass
+from core.logging import get_agent_logger
 
 @dataclass
 class ClassificationResult:
@@ -16,6 +17,9 @@ class QuestionClassifier:
     """
     
     def __init__(self):
+        # 初始化日志
+        self.logger = get_agent_logger("Classifier")
+        
         # 从配置文件加载阈值参数
         try:
             from agent.config import get_current_config, get_nested_config
@@ -27,7 +31,8 @@ class QuestionClassifier:
             self.confidence_increment = get_nested_config(config, "classification.confidence_increment", 0.08)
             self.llm_fallback_confidence = get_nested_config(config, "classification.llm_fallback_confidence", 0.5)
             self.uncertain_confidence = get_nested_config(config, "classification.uncertain_confidence", 0.2)
-            print("[CLASSIFIER] 从配置文件加载分类器参数完成")
+            self.medium_confidence_threshold = get_nested_config(config, "classification.medium_confidence_threshold", 0.6)
+            self.logger.info("从配置文件加载分类器参数完成")
         except ImportError:
             self.high_confidence_threshold = 0.7
             self.low_confidence_threshold = 0.4
@@ -36,7 +41,8 @@ class QuestionClassifier:
             self.confidence_increment = 0.08
             self.llm_fallback_confidence = 0.5
             self.uncertain_confidence = 0.2
-            print("[CLASSIFIER] 配置文件不可用,使用默认分类器参数")
+            self.medium_confidence_threshold = 0.6
+            self.logger.warning("配置文件不可用,使用默认分类器参数")
         
         # 基于高速公路服务区业务的精准关键词
         self.strong_business_keywords = {
@@ -159,14 +165,14 @@ class QuestionClassifier:
         # 确定使用的路由模式
         if routing_mode:
             QUESTION_ROUTING_MODE = routing_mode
-            print(f"[CLASSIFIER] 使用传入的路由模式: {QUESTION_ROUTING_MODE}")
+            self.logger.info(f"使用传入的路由模式: {QUESTION_ROUTING_MODE}")
         else:
             try:
                 from app_config import QUESTION_ROUTING_MODE
-                print(f"[CLASSIFIER] 使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
+                self.logger.info(f"使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
             except ImportError:
                 QUESTION_ROUTING_MODE = "hybrid"
-                print(f"[CLASSIFIER] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
+                self.logger.info(f"配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
         
         # 根据路由模式选择分类策略
         if QUESTION_ROUTING_MODE == "database_direct":
@@ -196,36 +202,36 @@ class QuestionClassifier:
         2. 如果置信度不够且有上下文,考虑上下文辅助
         3. 检测话题切换,避免错误继承
         """
-        print(f"[CLASSIFIER] 渐进式分类 - 问题: {question}")
+        self.logger.info(f"渐进式分类 - 问题: {question}")
         if context_type:
-            print(f"[CLASSIFIER] 上下文类型: {context_type}")
+            self.logger.info(f"上下文类型: {context_type}")
         
         # 第一步:只基于问题本身分类
         primary_result = self._hybrid_classify(question)
-        print(f"[CLASSIFIER] 主分类结果: {primary_result.question_type}, 置信度: {primary_result.confidence}")
+        self.logger.info(f"主分类结果: {primary_result.question_type}, 置信度: {primary_result.confidence}")
         
         # 如果没有上下文,直接返回主分类结果
         if not context_type:
-            print(f"[CLASSIFIER] 无上下文,使用主分类结果")
+            self.logger.debug("无上下文,使用主分类结果")
             return primary_result
         
         # 如果置信度足够高,直接使用主分类结果
         if primary_result.confidence >= self.high_confidence_threshold:
-            print(f"[CLASSIFIER] 高置信度({primary_result.confidence}≥{self.high_confidence_threshold}),使用主分类结果")
+            self.logger.info(f"高置信度({primary_result.confidence}≥{self.high_confidence_threshold}),使用主分类结果")
             return primary_result
         
         # 检测明显的话题切换
         if self._is_topic_switch(question):
-            print(f"[CLASSIFIER] 检测到话题切换,忽略上下文")
+            self.logger.info("检测到话题切换,忽略上下文")
             return primary_result
         
         # 如果置信度较低,考虑上下文辅助
         if primary_result.confidence < self.medium_confidence_threshold:
-            print(f"[CLASSIFIER] 低置信度({primary_result.confidence}<{self.medium_confidence_threshold}),考虑上下文辅助")
+            self.logger.info(f"低置信度({primary_result.confidence}<{self.medium_confidence_threshold}),考虑上下文辅助")
             
             # 检测是否为追问型问题
             if self._is_follow_up_question(question):
-                print(f"[CLASSIFIER] 检测到追问型问题,继承上下文类型: {context_type}")
+                self.logger.info(f"检测到追问型问题,继承上下文类型: {context_type}")
                 return ClassificationResult(
                     question_type=context_type,
                     confidence=0.75,  # 给予中等置信度
@@ -234,7 +240,7 @@ class QuestionClassifier:
                 )
         
         # 中等置信度或其他情况,保持主分类结果
-        print(f"[CLASSIFIER] 保持主分类结果")
+        self.logger.debug("保持主分类结果")
         return primary_result
 
     def _is_follow_up_question(self, question: str) -> bool:
@@ -426,11 +432,11 @@ class QuestionClassifier:
             
         except FileNotFoundError:
             error_msg = f"无法找到业务上下文文件: {prompt_file}"
-            print(f"[ERROR] {error_msg}")
+            self.logger.error(error_msg)
             raise FileNotFoundError(error_msg)
         except Exception as e:
             error_msg = f"读取业务上下文文件失败: {str(e)}"
-            print(f"[ERROR] {error_msg}")
+            self.logger.error(error_msg)
             raise RuntimeError(error_msg)
 
     def _enhanced_llm_classify(self, question: str) -> ClassificationResult:
@@ -506,7 +512,7 @@ class QuestionClassifier:
             
         except (FileNotFoundError, RuntimeError) as e:
             # 业务上下文加载失败,返回错误状态
-            print(f"[ERROR] LLM分类失败,业务上下文不可用: {str(e)}")
+            self.logger.error(f"LLM分类失败,业务上下文不可用: {str(e)}")
             return ClassificationResult(
                 question_type="CHAT",  # 失败时默认为CHAT,更安全
                 confidence=0.1,  # 很低的置信度表示分类不可靠
@@ -514,7 +520,7 @@ class QuestionClassifier:
                 method="llm_context_error"
             )
         except Exception as e:
-            print(f"[WARNING] 增强LLM分类失败: {str(e)}")
+            self.logger.warning(f"增强LLM分类失败: {str(e)}")
             return ClassificationResult(
                 question_type="CHAT",  # 失败时默认为CHAT,更安全
                 confidence=self.llm_fallback_confidence,

+ 7 - 3
agent/tools/general_chat.py

@@ -2,6 +2,10 @@
 from langchain.tools import tool
 from typing import Dict, Any, Optional
 from common.vanna_instance import get_vanna_instance
+from core.logging import get_agent_logger
+
+# Initialize logger
+logger = get_agent_logger("GeneralChat")
 
 @tool
 def general_chat(question: str, context: Optional[str] = None) -> Dict[str, Any]:
@@ -21,7 +25,7 @@ def general_chat(question: str, context: Optional[str] = None) -> Dict[str, Any]
         }
     """
     try:
-        print(f"[TOOL:general_chat] 处理聊天问题: {question}")
+        logger.info(f"处理聊天问题: {question}")
         
         system_prompt = """
 你是Citu智能数据问答平台的AI助手,为用户提供全面的帮助和支持。
@@ -58,7 +62,7 @@ def general_chat(question: str, context: Optional[str] = None) -> Dict[str, Any]
         )
         
         if response:
-            print(f"[TOOL:general_chat] 聊天响应生成成功: {response[:100]}...")
+            logger.info(f"聊天响应生成成功: {response[:100]}...")
             return {
                 "success": True,
                 "response": response.strip(),
@@ -72,7 +76,7 @@ def general_chat(question: str, context: Optional[str] = None) -> Dict[str, Any]
             }
             
     except Exception as e:
-        print(f"[ERROR] 通用聊天异常: {str(e)}")
+        logger.error(f"通用聊天异常: {str(e)}")
         return {
             "success": False,
             "response": _get_fallback_response(question),

+ 10 - 6
agent/tools/sql_execution.py

@@ -6,6 +6,10 @@ import time
 import functools
 from common.vanna_instance import get_vanna_instance
 from app_config import API_MAX_RETURN_ROWS
+from core.logging import get_agent_logger
+
+# Initialize logger
+logger = get_agent_logger("SQLExecution")
 
 def retry_on_failure(max_retries: int = 2, delay: float = 1.0, backoff_factor: float = 2.0):
     """
@@ -29,7 +33,7 @@ def retry_on_failure(max_retries: int = 2, delay: float = 1.0, backoff_factor: f
                         if retries < max_retries:
                             retries += 1
                             wait_time = delay * (backoff_factor ** (retries - 1))
-                            print(f"[RETRY] {func.__name__} 执行失败,等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
+                            logger.warning(f"{func.__name__} 执行失败,等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
                             time.sleep(wait_time)
                             continue
                     
@@ -39,10 +43,10 @@ def retry_on_failure(max_retries: int = 2, delay: float = 1.0, backoff_factor: f
                     retries += 1
                     if retries <= max_retries:
                         wait_time = delay * (backoff_factor ** (retries - 1))
-                        print(f"[RETRY] {func.__name__} 异常: {str(e)}, 等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
+                        logger.warning(f"{func.__name__} 异常: {str(e)}, 等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
                         time.sleep(wait_time)
                     else:
-                        print(f"[RETRY] {func.__name__} 达到最大重试次数 ({max_retries}),抛出异常")
+                        logger.error(f"{func.__name__} 达到最大重试次数 ({max_retries}),抛出异常")
                         raise
             
             # 不应该到达这里,但为了安全性
@@ -75,7 +79,7 @@ def execute_sql(sql: str, max_rows: int = None) -> Dict[str, Any]:
     if max_rows is None:
         max_rows = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
     try:
-        print(f"[TOOL:execute_sql] 开始执行SQL: {sql[:100]}...")
+        logger.info(f"开始执行SQL: {sql[:100]}...")
         
         vn = get_vanna_instance()
         df = vn.run_sql(sql)
@@ -118,7 +122,7 @@ def execute_sql(sql: str, max_rows: int = None) -> Dict[str, Any]:
         rows = _process_dataframe_rows(limited_df.to_dict(orient="records"))
         columns = list(df.columns)
         
-        print(f"[TOOL:execute_sql] 查询成功,返回 {len(rows)} 行数据")
+        logger.info(f"查询成功,返回 {len(rows)} 行数据")
         
         result = {
             "success": True,
@@ -139,7 +143,7 @@ def execute_sql(sql: str, max_rows: int = None) -> Dict[str, Any]:
         
     except Exception as e:
         error_msg = str(e)
-        print(f"[ERROR] SQL执行异常: {error_msg}")
+        logger.error(f"SQL执行异常: {error_msg}")
         
         return {
             "success": False,

+ 7 - 3
agent/tools/sql_generation.py

@@ -2,6 +2,10 @@
 from langchain.tools import tool
 from typing import Dict, Any
 from common.vanna_instance import get_vanna_instance
+from core.logging import get_agent_logger
+
+# Initialize logger
+logger = get_agent_logger("SQLGeneration")
 
 @tool
 def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str, Any]:
@@ -22,7 +26,7 @@ def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str,
         }
     """
     try:
-        print(f"[TOOL:generate_sql] 开始生成SQL: {question}")
+        logger.info(f"开始生成SQL: {question}")
         
         vn = get_vanna_instance()
         sql = vn.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
@@ -58,7 +62,7 @@ def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str,
                 "can_retry": True
             }
         
-        print(f"[TOOL:generate_sql] 成功生成SQL: {sql}")
+        logger.info(f"成功生成SQL: {sql}")
         return {
             "success": True,
             "sql": sql,
@@ -67,7 +71,7 @@ def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str,
         }
         
     except Exception as e:
-        print(f"[ERROR] SQL生成异常: {str(e)}")
+        logger.error(f"SQL生成异常: {str(e)}")
         return {
             "success": False,
             "sql": None,

+ 8 - 4
agent/tools/summary_generation.py

@@ -3,6 +3,10 @@ from langchain.tools import tool
 from typing import Dict, Any
 import pandas as pd
 from common.vanna_instance import get_vanna_instance
+from core.logging import get_agent_logger
+
+# Initialize logger
+logger = get_agent_logger("SummaryGeneration")
 
 @tool
 def generate_summary(question: str, query_result: Dict[str, Any], sql: str) -> Dict[str, Any]:
@@ -23,7 +27,7 @@ def generate_summary(question: str, query_result: Dict[str, Any], sql: str) -> D
         }
     """
     try:
-        print(f"[TOOL:generate_summary] 开始生成摘要,问题: {question}")
+        logger.info(f"开始生成摘要,问题: {question}")
         
         if not query_result or not query_result.get("rows"):
             return {
@@ -50,7 +54,7 @@ def generate_summary(question: str, query_result: Dict[str, Any], sql: str) -> D
             # 生成默认摘要
             summary = _generate_default_summary(question, query_result, sql)
         
-        print(f"[TOOL:generate_summary] 摘要生成成功: {summary[:100]}...")
+        logger.info(f"摘要生成成功: {summary[:100]}...")
         
         return {
             "success": True,
@@ -59,7 +63,7 @@ def generate_summary(question: str, query_result: Dict[str, Any], sql: str) -> D
         }
         
     except Exception as e:
-        print(f"[ERROR] 摘要生成异常: {str(e)}")
+        logger.error(f"摘要生成异常: {str(e)}")
         
         # 生成备用摘要
         fallback_summary = _generate_fallback_summary(question, query_result, sql)
@@ -82,7 +86,7 @@ def _reconstruct_dataframe(query_result: Dict[str, Any]) -> pd.DataFrame:
         return pd.DataFrame(rows, columns=columns)
         
     except Exception as e:
-        print(f"[WARNING] DataFrame重构失败: {str(e)}")
+        logger.warning(f"DataFrame重构失败: {str(e)}")
         return pd.DataFrame()
 
 def _generate_default_summary(question: str, query_result: Dict[str, Any], sql: str) -> str:

+ 12 - 8
agent/tools/utils.py

@@ -7,6 +7,10 @@ import json
 from typing import Dict, Any, Callable, List, Optional
 from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage
 from langchain_core.tools import BaseTool
+from core.logging import get_agent_logger
+
+# Initialize logger
+logger = get_agent_logger("AgentUtils")
 
 def handle_tool_errors(func: Callable) -> Callable:
     """
@@ -17,7 +21,7 @@ def handle_tool_errors(func: Callable) -> Callable:
         try:
             return func(*args, **kwargs)
         except Exception as e:
-            print(f"[ERROR] 工具 {func.__name__} 执行失败: {str(e)}")
+            logger.error(f"工具 {func.__name__} 执行失败: {str(e)}")
             return {
                 "success": False,
                 "error": f"工具执行异常: {str(e)}",
@@ -50,7 +54,7 @@ class LLMWrapper:
                 return self._invoke_without_tools(messages, **kwargs)
                 
         except Exception as e:
-            print(f"[ERROR] LLM包装器调用失败: {str(e)}")
+            logger.error(f"LLM包装器调用失败: {str(e)}")
             return AIMessage(content=f"LLM调用失败: {str(e)}")
     
     def _should_use_tools(self, messages: List[BaseMessage]) -> bool:
@@ -88,7 +92,7 @@ class LLMWrapper:
                 return AIMessage(content=response)
                 
         except Exception as e:
-            print(f"[ERROR] 工具调用失败: {str(e)}")
+            logger.error(f"工具调用失败: {str(e)}")
             return self._invoke_without_tools(messages, **kwargs)
     
     def _invoke_without_tools(self, messages: List[BaseMessage], **kwargs):
@@ -206,26 +210,26 @@ def get_compatible_llm():
                     model=llm_config.get("model"),
                     temperature=llm_config.get("temperature", 0.7)
                 )
-                print("[INFO] 使用标准OpenAI兼容API")
+                logger.info("使用标准OpenAI兼容API")
                 return llm
             except ImportError:
-                print("[WARNING] langchain_openai 未安装,使用 Vanna 实例包装器")
+                logger.warning("langchain_openai 未安装,使用 Vanna 实例包装器")
         
         # 优先使用统一的 Vanna 实例
         from common.vanna_instance import get_vanna_instance
         vn = get_vanna_instance()
-        print("[INFO] 使用Vanna实例包装器")
+        logger.info("使用Vanna实例包装器")
         return LLMWrapper(vn)
         
     except Exception as e:
-        print(f"[ERROR] 获取 Vanna 实例失败: {str(e)}")
+        logger.error(f"获取 Vanna 实例失败: {str(e)}")
         # 回退到原有逻辑
         from common.utils import get_current_llm_config
         from customllm.qianwen_chat import QianWenChat
         
         llm_config = get_current_llm_config()
         custom_llm = QianWenChat(config=llm_config)
-        print("[INFO] 使用QianWen包装器")
+        logger.info("使用QianWen包装器")
         return LLMWrapper(custom_llm)
 
 def _is_valid_sql_format(sql_text: str) -> bool:

+ 1 - 1
app_config.py

@@ -169,7 +169,7 @@ REDIS_PASSWORD = None
 
 # 缓存开关配置
 ENABLE_CONVERSATION_CONTEXT = True      # 是否启用对话上下文
-ENABLE_QUESTION_ANSWER_CACHE = True     # 是否启用问答结果缓存
+ENABLE_QUESTION_ANSWER_CACHE = False     # 是否启用问答结果缓存
 ENABLE_EMBEDDING_CACHE = True           # 是否启用embedding向量缓存
 
 # TTL配置(单位:秒)

+ 76 - 69
citu_app.py

@@ -1,4 +1,8 @@
 # 给dataops 对话助手返回结果
+# 初始化日志系统 - 必须在最前面
+from core.logging import initialize_logging, get_app_logger, set_log_context, clear_log_context
+initialize_logging()
+
 from vanna.flask import VannaFlaskApp
 from core.vanna_llm_factory import create_vanna_instance
 from flask import request, jsonify
@@ -31,6 +35,9 @@ from app_config import (  # 添加Redis相关配置导入
     ENABLE_QUESTION_ANSWER_CACHE
 )
 
+# 创建app logger
+logger = get_app_logger("CituApp")
+
 # 设置默认的最大返回行数
 DEFAULT_MAX_RETURN_ROWS = 200
 MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
@@ -131,9 +138,9 @@ def ask_full():
                 if ENABLE_RESULT_SUMMARY:
                     try:
                         summary = vn.generate_summary(question=question, df=df)
-                        print(f"[INFO] 成功生成摘要: {summary}")
+                        logger.info(f"成功生成摘要: {summary}")
                     except Exception as e:
-                        print(f"[WARNING] 生成摘要失败: {str(e)}")
+                        logger.warning(f"生成摘要失败: {str(e)}")
                         summary = None
 
         # 构建返回数据
@@ -156,7 +163,7 @@ def ask_full():
         ))
         
     except Exception as e:
-        print(f"[ERROR] ask_full执行失败: {str(e)}")
+        logger.error(f"ask_full执行失败: {str(e)}")
         
         # 即使发生异常,也检查是否有业务层面的解释
         if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
@@ -219,7 +226,7 @@ def citu_run_sql():
         ))
         
     except Exception as e:
-        print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
+        logger.error(f"citu_run_sql执行失败: {str(e)}")
         from common.result import internal_error_response
         return jsonify(internal_error_response(
             response_text=f"SQL执行失败,请检查SQL语句是否正确"
@@ -245,27 +252,27 @@ def ask_cached():
     try:
         # 生成conversation_id
         # 调试:查看generate_id的实际行为
-        print(f"[DEBUG] 输入问题: '{question}'")
+        logger.debug(f"输入问题: '{question}'")
         conversation_id = app.cache.generate_id(question=question)
-        print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
+        logger.debug(f"生成的conversation_id: {conversation_id}")
         
         # 再次用相同问题测试
         conversation_id2 = app.cache.generate_id(question=question)
-        print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
-        print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
+        logger.debug(f"再次生成的conversation_id: {conversation_id2}")
+        logger.debug(f"两次ID是否相同: {conversation_id == conversation_id2}")
         
         # 检查缓存
         cached_sql = app.cache.get(id=conversation_id, field="sql")
         
         if cached_sql is not None:
             # 缓存命中
-            print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
+            logger.info(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
             sql = cached_sql
             df = app.cache.get(id=conversation_id, field="df")
             summary = app.cache.get(id=conversation_id, field="summary")
         else:
             # 缓存未命中,执行新查询
-            print(f"[CACHE MISS] 执行新查询: {conversation_id}")
+            logger.info(f"[CACHE MISS] 执行新查询: {conversation_id}")
             
             sql, df, _ = vn.ask(
                 question=question,
@@ -301,9 +308,9 @@ def ask_cached():
             if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
                 try:
                     summary = vn.generate_summary(question=question, df=df)
-                    print(f"[INFO] 成功生成摘要: {summary}")
+                    logger.info(f"成功生成摘要: {summary}")
                 except Exception as e:
-                    print(f"[WARNING] 生成摘要失败: {str(e)}")
+                    logger.warning(f"生成摘要失败: {str(e)}")
                     summary = None
             
             app.cache.set(id=conversation_id, field="summary", value=summary)
@@ -348,7 +355,7 @@ def ask_cached():
         ))
         
     except Exception as e:
-        print(f"[ERROR] ask_cached执行失败: {str(e)}")
+        logger.error(f"ask_cached执行失败: {str(e)}")
         from common.result import internal_error_response
         return jsonify(internal_error_response(
             response_text="查询处理失败,请稍后重试"
@@ -386,10 +393,10 @@ def citu_train_question_sql():
         # 正确的调用方式:同时传递question和sql
         if question:
             training_id = vn.train(question=question, sql=sql)
-            print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
+            logger.info(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
         else:
             training_id = vn.train(sql=sql)
-            print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
+            logger.info(f"训练成功,训练ID为:{training_id},SQL:{sql}")
 
         from common.result import success_response
         return jsonify(success_response(
@@ -418,23 +425,23 @@ def get_citu_langraph_agent():
     if citu_langraph_agent is None:
         try:
             from agent.citu_agent import CituLangGraphAgent
-            print("[CITU_APP] 开始创建LangGraph Agent实例...")
+            logger.info("开始创建LangGraph Agent实例...")
             citu_langraph_agent = CituLangGraphAgent()
-            print("[CITU_APP] LangGraph Agent实例创建成功")
+            logger.info("LangGraph Agent实例创建成功")
         except ImportError as e:
-            print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
-            print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
+            logger.critical(f"Agent模块导入失败: {str(e)}")
+            logger.critical("请检查agent模块是否存在以及依赖是否正确安装")
             raise Exception(f"Agent模块导入失败: {str(e)}")
         except Exception as e:
-            print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
-            print(f"[CRITICAL] 错误类型: {type(e).__name__}")
+            logger.critical(f"LangGraph Agent实例创建失败: {str(e)}")
+            logger.critical(f"错误类型: {type(e).__name__}")
             # 提供更有用的错误信息
             if "config" in str(e).lower():
-                print("[CRITICAL] 可能是配置文件问题,请检查配置")
+                logger.critical("可能是配置文件问题,请检查配置")
             elif "llm" in str(e).lower():
-                print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
+                logger.critical("可能是LLM连接问题,请检查LLM配置")
             elif "tool" in str(e).lower():
-                print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
+                logger.critical("可能是工具加载问题,请检查工具模块")
             raise Exception(f"Agent初始化失败: {str(e)}")
     return citu_langraph_agent
 
@@ -495,15 +502,15 @@ def ask_agent():
                         metadata = message.get("metadata", {})
                         context_type = metadata.get("type")
                         if context_type:
-                            print(f"[AGENT_API] 检测到上下文类型: {context_type}")
+                            logger.info(f"[AGENT_API] 检测到上下文类型: {context_type}")
                             break
             except Exception as e:
-                print(f"[WARNING] 获取上下文类型失败: {str(e)}")
+                logger.warning(f"获取上下文类型失败: {str(e)}")
         
         # 4. 检查缓存(新逻辑:放宽使用条件,严控存储条件)
         cached_answer = redis_conversation_manager.get_cached_answer(question, context)
         if cached_answer:
-            print(f"[AGENT_API] 使用缓存答案")
+            logger.info(f"[AGENT_API] 使用缓存答案")
             
             # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
             cached_response_type = cached_answer.get("type", "UNKNOWN")
@@ -567,31 +574,31 @@ def ask_agent():
         # 6. 构建带上下文的问题
         if context:
             enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
-            print(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
+            logger.info(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
         else:
             enhanced_question = question
-            print(f"[AGENT_API] 新对话,无上下文")
+            logger.info(f"[AGENT_API] 新对话,无上下文")
         
         # 7. 确定最终使用的路由模式(优先级逻辑)
         if api_routing_mode:
             # API传了参数,优先使用
             effective_routing_mode = api_routing_mode
-            print(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
+            logger.info(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
         else:
             # API没传参数,使用配置文件
             try:
                 from app_config import QUESTION_ROUTING_MODE
                 effective_routing_mode = QUESTION_ROUTING_MODE
-                print(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
+                logger.info(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
             except ImportError:
                 effective_routing_mode = "hybrid"
-                print(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
+                logger.info(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
         
         # 8. 现有Agent处理逻辑(修改为传递路由模式)
         try:
             agent = get_citu_langraph_agent()
         except Exception as e:
-            print(f"[CRITICAL] Agent初始化失败: {str(e)}")
+            logger.critical(f"Agent初始化失败: {str(e)}")
             return jsonify(service_unavailable_response(
                 response_text="AI服务暂时不可用,请稍后重试",
                 can_retry=True
@@ -687,7 +694,7 @@ def ask_agent():
             )), error_code
         
     except Exception as e:
-        print(f"[ERROR] ask_agent执行失败: {str(e)}")
+        logger.error(f"ask_agent执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="查询处理失败,请稍后重试"
         )), 500
@@ -784,9 +791,9 @@ def agent_health():
                 health_data["status"] = "degraded"
                 health_data["message"] = "部分组件异常"
         except Exception as e:
-            print(f"[ERROR] 健康检查异常: {str(e)}")
+            logger.error(f"健康检查异常: {str(e)}")
             import traceback
-            print(f"[ERROR] 详细健康检查错误: {traceback.format_exc()}")
+            logger.error(f"详细健康检查错误: {traceback.format_exc()}")
             health_data["status"] = "degraded"
             health_data["message"] = f"完整测试失败: {str(e)}"
         
@@ -803,9 +810,9 @@ def agent_health():
             return jsonify(health_error_response(**health_data)), 503
             
     except Exception as e:
-        print(f"[ERROR] 顶层健康检查异常: {str(e)}")
+        logger.error(f"顶层健康检查异常: {str(e)}")
         import traceback
-        print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+        logger.error(f"详细错误信息: {traceback.format_exc()}")
         from common.result import internal_error_response
         return jsonify(internal_error_response(
             response_text="健康检查失败,请稍后重试"
@@ -1517,7 +1524,7 @@ def training_error_question_sql():
         question = data.get('question')
         sql = data.get('sql')
         
-        print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
+        logger.debug(f"接收到错误SQL训练请求: question={question}, sql={sql}")
         
         if not question or not sql:
             from common.result import bad_request_response
@@ -1535,7 +1542,7 @@ def training_error_question_sql():
         # 使用vn实例的train_error_sql方法存储错误SQL
         id = vn.train_error_sql(question=question, sql=sql)
         
-        print(f"[INFO] 成功存储错误SQL,ID: {id}")
+        logger.info(f"成功存储错误SQL,ID: {id}")
         
         from common.result import success_response
         return jsonify(success_response(
@@ -1547,7 +1554,7 @@ def training_error_question_sql():
         ))
         
     except Exception as e:
-        print(f"[ERROR] 存储错误SQL失败: {str(e)}")
+        logger.error(f"存储错误SQL失败: {str(e)}")
         from common.result import internal_error_response
         return jsonify(internal_error_response(
             response_text="存储错误SQL失败,请稍后重试"
@@ -1593,7 +1600,7 @@ def get_user_conversations(user_id: str):
                     conversation['conversation_title'] = "空对话"
                     
             except Exception as e:
-                print(f"[WARNING] 获取对话标题失败 {conversation_id}: {str(e)}")
+                logger.warning(f"获取对话标题失败 {conversation_id}: {str(e)}")
                 conversation['conversation_title'] = "对话"
         
         return jsonify(success_response(
@@ -1747,7 +1754,7 @@ def get_user_conversations_with_messages(user_id: str):
         ))
         
     except Exception as e:
-        print(f"[ERROR] 获取用户完整对话数据失败: {str(e)}")
+        logger.error(f"获取用户完整对话数据失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="获取用户对话数据失败,请稍后重试"
         )), 500
@@ -1770,7 +1777,7 @@ def embedding_cache_stats():
         ))
         
     except Exception as e:
-        print(f"[ERROR] 获取embedding缓存统计失败: {str(e)}")
+        logger.error(f"获取embedding缓存统计失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="获取embedding缓存统计失败,请稍后重试"
         )), 500
@@ -1801,7 +1808,7 @@ def embedding_cache_cleanup():
             )), 500
         
     except Exception as e:
-        print(f"[ERROR] 清空embedding缓存失败: {str(e)}")
+        logger.error(f"清空embedding缓存失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="清空embedding缓存失败,请稍后重试"
         )), 500
@@ -1827,15 +1834,15 @@ def get_qa_feedback_manager():
                 elif 'vn' in globals():
                     vanna_instance = vn
                 else:
-                    print("[INFO] 未找到可用的vanna实例,将创建新的数据库连接")
+                    logger.info("未找到可用的vanna实例,将创建新的数据库连接")
             except Exception as e:
-                print(f"[INFO] 获取vanna实例失败: {e},将创建新的数据库连接")
+                logger.info(f"获取vanna实例失败: {e},将创建新的数据库连接")
                 vanna_instance = None
             
             qa_feedback_manager = QAFeedbackManager(vanna_instance=vanna_instance)
-            print("[CITU_APP] QA反馈管理器实例创建成功")
+            logger.info("QA反馈管理器实例创建成功")
         except Exception as e:
-            print(f"[CRITICAL] QA反馈管理器创建失败: {str(e)}")
+            logger.critical(f"QA反馈管理器创建失败: {str(e)}")
             raise Exception(f"QA反馈管理器初始化失败: {str(e)}")
     return qa_feedback_manager
 
@@ -1904,7 +1911,7 @@ def qa_feedback_query():
         ))
         
     except Exception as e:
-        print(f"[ERROR] qa_feedback_query执行失败: {str(e)}")
+        logger.error(f"qa_feedback_query执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="查询反馈记录失败,请稍后重试"
         )), 500
@@ -1929,7 +1936,7 @@ def qa_feedback_delete(feedback_id):
             )), 404
             
     except Exception as e:
-        print(f"[ERROR] qa_feedback_delete执行失败: {str(e)}")
+        logger.error(f"qa_feedback_delete执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="删除反馈记录失败,请稍后重试"
         )), 500
@@ -1973,7 +1980,7 @@ def qa_feedback_update(feedback_id):
             )), 404
             
     except Exception as e:
-        print(f"[ERROR] qa_feedback_update执行失败: {str(e)}")
+        logger.error(f"qa_feedback_update执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="更新反馈记录失败,请稍后重试"
         )), 500
@@ -2026,7 +2033,7 @@ def qa_feedback_add_to_training():
                         sql=record['sql']
                     )
                     positive_count += 1
-                    print(f"[TRAINING] 正向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
+                    logger.info(f"正向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
                 else:
                     # 负向反馈 - 加入错误SQL训练集
                     training_id = vn.train_error_sql(
@@ -2034,18 +2041,18 @@ def qa_feedback_add_to_training():
                         sql=record['sql']
                     )
                     negative_count += 1
-                    print(f"[TRAINING] 负向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
+                    logger.info(f"负向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
                 
                 successfully_trained_ids.append(record['id'])
                 
             except Exception as e:
-                print(f"[ERROR] 训练失败 - 反馈ID: {record['id']}, 错误: {e}")
+                logger.error(f"训练失败 - 反馈ID: {record['id']}, 错误: {e}")
                 error_count += 1
         
         # 更新训练状态
         if successfully_trained_ids:
             updated_count = manager.mark_training_status(successfully_trained_ids, True)
-            print(f"[TRAINING] 批量更新训练状态完成,影响 {updated_count} 条记录")
+            logger.info(f"批量更新训练状态完成,影响 {updated_count} 条记录")
         
         # 构建响应
         total_processed = positive_count + negative_count + already_trained_count + error_count
@@ -2070,7 +2077,7 @@ def qa_feedback_add_to_training():
         ))
         
     except Exception as e:
-        print(f"[ERROR] qa_feedback_add_to_training执行失败: {str(e)}")
+        logger.error(f"qa_feedback_add_to_training执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="添加训练数据失败,请稍后重试"
         )), 500
@@ -2123,7 +2130,7 @@ def qa_feedback_add():
         ))
         
     except Exception as e:
-        print(f"[ERROR] qa_feedback_add执行失败: {str(e)}")
+        logger.error(f"qa_feedback_add执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="创建反馈记录失败,请稍后重试"
         )), 500
@@ -2158,7 +2165,7 @@ def qa_feedback_stats():
         ))
         
     except Exception as e:
-        print(f"[ERROR] qa_feedback_stats执行失败: {str(e)}")
+        logger.error(f"qa_feedback_stats执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="获取统计信息失败,请稍后重试"
         )), 500
@@ -2178,7 +2185,7 @@ def qa_cache_stats():
         ))
         
     except Exception as e:
-        print(f"[ERROR] 获取问答缓存统计失败: {str(e)}")
+        logger.error(f"获取问答缓存统计失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="获取问答缓存统计失败,请稍后重试"
         )), 500
@@ -2209,7 +2216,7 @@ def qa_cache_list():
         ))
         
     except Exception as e:
-        print(f"[ERROR] 获取问答缓存列表失败: {str(e)}")
+        logger.error(f"获取问答缓存列表失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="获取问答缓存列表失败,请稍后重试"
         )), 500
@@ -2235,7 +2242,7 @@ def qa_cache_cleanup():
         ))
         
     except Exception as e:
-        print(f"[ERROR] 清空问答缓存失败: {str(e)}")
+        logger.error(f"清空问答缓存失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="清空问答缓存失败,请稍后重试"
         )), 500
@@ -2367,7 +2374,7 @@ def get_total_training_count():
             return len(training_data)
         return 0
     except Exception as e:
-        print(f"[WARNING] 获取训练数据总数失败: {e}")
+        logger.warning(f"获取训练数据总数失败: {e}")
         return 0
 
 @app.flask_app.route('/api/v0/training_data/query', methods=['POST'])
@@ -2460,7 +2467,7 @@ def training_data_query():
         ))
         
     except Exception as e:
-        print(f"[ERROR] training_data_query执行失败: {str(e)}")
+        logger.error(f"training_data_query执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="查询训练数据失败,请稍后重试"
         )), 500
@@ -2533,7 +2540,7 @@ def training_data_create():
         ))
         
     except Exception as e:
-        print(f"[ERROR] training_data_create执行失败: {str(e)}")
+        logger.error(f"training_data_create执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="创建训练数据失败,请稍后重试"
         )), 500
@@ -2605,7 +2612,7 @@ def training_data_delete():
         ))
         
     except Exception as e:
-        print(f"[ERROR] training_data_delete执行失败: {str(e)}")
+        logger.error(f"training_data_delete执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="删除训练数据失败,请稍后重试"
         )), 500
@@ -2666,7 +2673,7 @@ def training_data_stats():
         ))
         
     except Exception as e:
-        print(f"[ERROR] training_data_stats执行失败: {str(e)}")
+        logger.error(f"training_data_stats执行失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="获取统计信息失败,请稍后重试"
         )), 500
@@ -2702,7 +2709,7 @@ def cache_overview_full():
         ))
         
     except Exception as e:
-        print(f"[ERROR] 获取综合缓存概览失败: {str(e)}")
+        logger.error(f"获取综合缓存概览失败: {str(e)}")
         return jsonify(internal_error_response(
             response_text="获取缓存概览失败,请稍后重试"
         )), 500
@@ -2748,5 +2755,5 @@ const chatSession = new ChatSession();
 chatSession.askQuestion("各年龄段客户的流失率如何?");
 """
 
-print("正在启动Flask应用: http://localhost:8084")
+logger.info("正在启动Flask应用: http://localhost:8084")
 app.run(host="0.0.0.0", port=8084, debug=True)

+ 15 - 13
common/embedding_cache_manager.py

@@ -5,6 +5,7 @@ import time
 from typing import List, Optional, Dict, Any
 from datetime import datetime
 import app_config
+from core.logging import get_app_logger
 
 
 class EmbeddingCacheManager:
@@ -12,6 +13,7 @@ class EmbeddingCacheManager:
     
     def __init__(self):
         """初始化缓存管理器"""
+        self.logger = get_app_logger("EmbeddingCacheManager")
         self.redis_client = None
         self.cache_enabled = app_config.ENABLE_EMBEDDING_CACHE
         
@@ -28,9 +30,9 @@ class EmbeddingCacheManager:
                 )
                 # 测试连接
                 self.redis_client.ping()
-                print(f"[DEBUG] Embedding缓存管理器初始化成功")
+                self.logger.debug("Embedding缓存管理器初始化成功")
             except Exception as e:
-                print(f"[WARNING] Redis连接失败,embedding缓存将被禁用: {e}")
+                self.logger.warning(f"Redis连接失败,embedding缓存将被禁用: {e}")
                 self.cache_enabled = False
                 self.redis_client = None
     
@@ -72,7 +74,7 @@ class EmbeddingCacheManager:
                 'embedding_dimension': str(embedding_config.get('embedding_dimension', 'unknown'))
             }
         except Exception as e:
-            print(f"[WARNING] 获取模型信息失败: {e}")
+            self.logger.warning(f"获取模型信息失败: {e}")
             return {'model_name': 'unknown', 'embedding_dimension': 'unknown'}
     
     def get_cached_embedding(self, question: str) -> Optional[List[float]]:
@@ -97,13 +99,13 @@ class EmbeddingCacheManager:
                 data = json.loads(cached_data)
                 vector = data.get('vector')
                 if vector:
-                    print(f"[DEBUG] ✓ Embedding缓存命中: {question[:50]}...")
+                    self.logger.debug(f"✓ Embedding缓存命中: {question[:50]}...")
                     return vector
             
             return None
             
         except Exception as e:
-            print(f"[WARNING] 获取embedding缓存失败: {e}")
+            self.logger.warning(f"获取embedding缓存失败: {e}")
             return None
     
     def cache_embedding(self, question: str, vector: List[float]) -> bool:
@@ -141,7 +143,7 @@ class EmbeddingCacheManager:
                 json.dumps(cache_data, ensure_ascii=False)
             )
             
-            print(f"[DEBUG] ✓ Embedding向量已缓存: {question[:50]}... (维度: {len(vector)})")
+            self.logger.debug(f"✓ Embedding向量已缓存: {question[:50]}... (维度: {len(vector)})")
             
             # 检查缓存大小并清理
             self._cleanup_if_needed()
@@ -149,7 +151,7 @@ class EmbeddingCacheManager:
             return True
             
         except Exception as e:
-            print(f"[WARNING] 缓存embedding失败: {e}")
+            self.logger.warning(f"缓存embedding失败: {e}")
             return False
     
     def _cleanup_if_needed(self):
@@ -180,10 +182,10 @@ class EmbeddingCacheManager:
                 
                 if keys_to_delete:
                     self.redis_client.delete(*keys_to_delete)
-                    print(f"[DEBUG] 清理了 {len(keys_to_delete)} 个旧的embedding缓存")
+                    self.logger.debug(f"清理了 {len(keys_to_delete)} 个旧的embedding缓存")
                     
         except Exception as e:
-            print(f"[WARNING] 清理embedding缓存失败: {e}")
+            self.logger.warning(f"清理embedding缓存失败: {e}")
     
     def get_cache_stats(self) -> Dict[str, Any]:
         """
@@ -217,7 +219,7 @@ class EmbeddingCacheManager:
                     stats["memory_usage_mb"] = round(total_size_bytes / (1024 * 1024), 2)
             
         except Exception as e:
-            print(f"[WARNING] 获取缓存统计失败: {e}")
+            self.logger.warning(f"获取缓存统计失败: {e}")
         
         return stats
     
@@ -237,14 +239,14 @@ class EmbeddingCacheManager:
             
             if keys:
                 self.redis_client.delete(*keys)
-                print(f"[DEBUG] 已清空所有embedding缓存 ({len(keys)} 条)")
+                self.logger.debug(f"已清空所有embedding缓存 ({len(keys)} 条)")
                 return True
             else:
-                print(f"[DEBUG] 没有embedding缓存需要清空")
+                self.logger.debug("没有embedding缓存需要清空")
                 return True
                 
         except Exception as e:
-            print(f"[WARNING] 清空embedding缓存失败: {e}")
+            self.logger.warning(f"清空embedding缓存失败: {e}")
             return False
 
 

+ 22 - 18
common/qa_feedback_manager.py

@@ -8,6 +8,7 @@ from sqlalchemy.exc import OperationalError, ProgrammingError
 from datetime import datetime
 from typing import List, Dict, Any, Optional, Tuple
 import logging
+from core.logging import get_app_logger
 
 class QAFeedbackManager:
     """QA反馈数据管理器 - 复用Vanna连接版本"""
@@ -18,6 +19,9 @@ class QAFeedbackManager:
         Args:
             vanna_instance: 可选的vanna实例,用于复用其数据库连接
         """
+        # 初始化日志
+        self.logger = get_app_logger("QAFeedbackManager")
+        
         self.engine = None
         self.vanna_instance = vanna_instance
         self._init_database_connection()
@@ -29,7 +33,7 @@ class QAFeedbackManager:
             # 方案1: 优先尝试复用vanna连接
             if self.vanna_instance and hasattr(self.vanna_instance, 'engine'):
                 self.engine = self.vanna_instance.engine
-                print(f"[QAFeedbackManager] 复用Vanna数据库连接")
+                self.logger.info("复用Vanna数据库连接")
                 return
             
             # 方案2: 创建新的连接(原有方式)
@@ -52,10 +56,10 @@ class QAFeedbackManager:
             with self.engine.connect() as conn:
                 conn.execute(text("SELECT 1"))
             
-            print(f"[QAFeedbackManager] 数据库连接成功: {db_config['host']}:{db_config['port']}/{db_config['dbname']}")
+            self.logger.info(f"数据库连接成功: {db_config['host']}:{db_config['port']}/{db_config['dbname']}")
             
         except Exception as e:
-            print(f"[ERROR] QAFeedbackManager数据库连接失败: {e}")
+            self.logger.error(f"QAFeedbackManager数据库连接失败: {e}")
             raise
     
     def _ensure_table_exists(self):
@@ -91,10 +95,10 @@ class QAFeedbackManager:
                     for index_sql in create_indexes_sql:
                         conn.execute(text(index_sql))
                     
-            print("[QAFeedbackManager] qa_feedback表检查/创建成功")
+            self.logger.info("qa_feedback表检查/创建成功")
             
         except Exception as e:
-            print(f"[ERROR] qa_feedback表创建失败: {e}")
+            self.logger.error(f"qa_feedback表创建失败: {e}")
             raise
     
     def add_feedback(self, question: str, sql: str, is_thumb_up: bool, user_id: str = "guest") -> int:
@@ -127,11 +131,11 @@ class QAFeedbackManager:
                     })
                     feedback_id = result.fetchone()[0]
                 
-            print(f"[QAFeedbackManager] 反馈记录创建成功, ID: {feedback_id}")
+            self.logger.info(f"反馈记录创建成功, ID: {feedback_id}")
             return feedback_id
             
         except Exception as e:
-            print(f"[ERROR] 添加反馈记录失败: {e}")
+            self.logger.error(f"添加反馈记录失败: {e}")
             raise
     
     def query_feedback(self, page: int = 1, page_size: int = 20, 
@@ -232,7 +236,7 @@ class QAFeedbackManager:
             return records, total
             
         except Exception as e:
-            print(f"[ERROR] 查询反馈记录失败: {e}")
+            self.logger.error(f"查询反馈记录失败: {e}")
             raise
     
     def delete_feedback(self, feedback_id: int) -> bool:
@@ -252,14 +256,14 @@ class QAFeedbackManager:
                     result = conn.execute(text(delete_sql), {'id': feedback_id})
                 
                 if result.rowcount > 0:
-                    print(f"[QAFeedbackManager] 反馈记录删除成功, ID: {feedback_id}")
+                    self.logger.info(f"反馈记录删除成功, ID: {feedback_id}")
                     return True
                 else:
-                    print(f"[WARNING] 反馈记录不存在, ID: {feedback_id}")
+                    self.logger.warning(f"反馈记录不存在, ID: {feedback_id}")
                     return False
                     
         except Exception as e:
-            print(f"[ERROR] 删除反馈记录失败: {e}")
+            self.logger.error(f"删除反馈记录失败: {e}")
             raise
     
     def update_feedback(self, feedback_id: int, **kwargs) -> bool:
@@ -284,7 +288,7 @@ class QAFeedbackManager:
                 params[field] = value
         
         if not update_fields:
-            print("[WARNING] 没有有效的更新字段")
+            self.logger.warning("没有有效的更新字段")
             return False
         
         update_fields.append("update_time = :update_time")
@@ -301,14 +305,14 @@ class QAFeedbackManager:
                     result = conn.execute(text(update_sql), params)
                 
                 if result.rowcount > 0:
-                    print(f"[QAFeedbackManager] 反馈记录更新成功, ID: {feedback_id}")
+                    self.logger.info(f"反馈记录更新成功, ID: {feedback_id}")
                     return True
                 else:
-                    print(f"[WARNING] 反馈记录不存在或无变化, ID: {feedback_id}")
+                    self.logger.warning(f"反馈记录不存在或无变化, ID: {feedback_id}")
                     return False
                     
         except Exception as e:
-            print(f"[ERROR] 更新反馈记录失败: {e}")
+            self.logger.error(f"更新反馈记录失败: {e}")
             raise
     
     def get_feedback_by_ids(self, feedback_ids: List[int]) -> List[Dict]:
@@ -354,7 +358,7 @@ class QAFeedbackManager:
                 return records
                 
         except Exception as e:
-            print(f"[ERROR] 根据ID查询反馈记录失败: {e}")
+            self.logger.error(f"根据ID查询反馈记录失败: {e}")
             raise
     
     def mark_training_status(self, feedback_ids: List[int], status: bool = True) -> int:
@@ -386,9 +390,9 @@ class QAFeedbackManager:
                 with conn.begin():
                     result = conn.execute(text(update_sql), params)
                 
-                print(f"[QAFeedbackManager] 批量更新训练状态成功, 影响行数: {result.rowcount}")
+                self.logger.info(f"批量更新训练状态成功, 影响行数: {result.rowcount}")
                 return result.rowcount
                 
         except Exception as e:
-            print(f"[ERROR] 批量更新训练状态失败: {e}")
+            self.logger.error(f"批量更新训练状态失败: {e}")
             raise

+ 37 - 35
common/redis_conversation_manager.py

@@ -12,12 +12,14 @@ from app_config import (
     ENABLE_CONVERSATION_CONTEXT, ENABLE_QUESTION_ANSWER_CACHE,
     DEFAULT_ANONYMOUS_USER
 )
+from core.logging import get_app_logger
 
 class RedisConversationManager:
     """Redis对话管理器 - 修正版"""
     
     def __init__(self):
         """初始化Redis连接"""
+        self.logger = get_app_logger("RedisConversationManager")
         try:
             self.redis_client = redis.Redis(
                 host=REDIS_HOST,
@@ -30,9 +32,9 @@ class RedisConversationManager:
             )
             # 测试连接
             self.redis_client.ping()
-            print(f"[REDIS_CONV] Redis连接成功: {REDIS_HOST}:{REDIS_PORT}")
+            self.logger.info(f"Redis连接成功: {REDIS_HOST}:{REDIS_PORT}")
         except Exception as e:
-            print(f"[ERROR] Redis连接失败: {str(e)}")
+            self.logger.error(f"Redis连接失败: {str(e)}")
             self.redis_client = None
     
     def is_available(self) -> bool:
@@ -59,16 +61,16 @@ class RedisConversationManager:
         
         # 1. 优先使用登录用户ID
         if login_user_id:
-            print(f"[REDIS_CONV] 使用登录用户ID: {login_user_id}")
+            self.logger.debug(f"使用登录用户ID: {login_user_id}")
             return login_user_id
         
         # 2. 如果没有登录,尝试从请求参数获取user_id
         if user_id_from_request:
-            print(f"[REDIS_CONV] 使用请求参数user_id: {user_id_from_request}")
+            self.logger.debug(f"使用请求参数user_id: {user_id_from_request}")
             return user_id_from_request
         
         # 3. 都没有则为匿名用户(统一为guest)
-        print(f"[REDIS_CONV] 使用匿名用户: {DEFAULT_ANONYMOUS_USER}")
+        self.logger.debug(f"使用匿名用户: {DEFAULT_ANONYMOUS_USER}")
         return DEFAULT_ANONYMOUS_USER
     
     def resolve_conversation_id(self, user_id: str, conversation_id_input: Optional[str], 
@@ -87,13 +89,13 @@ class RedisConversationManager:
         # 1. 如果指定了conversation_id,验证后使用
         if conversation_id_input:
             if self._is_valid_conversation(conversation_id_input, user_id):
-                print(f"[REDIS_CONV] 使用指定对话: {conversation_id_input}")
+                self.logger.debug(f"使用指定对话: {conversation_id_input}")
                 return conversation_id_input, {
                     "status": "existing",
                     "message": "继续已有对话"
                 }
             else:
-                print(f"[WARN] 无效的conversation_id: {conversation_id_input},创建新对话")
+                self.logger.warning(f"无效的conversation_id: {conversation_id_input},创建新对话")
                 new_conversation_id = self.create_conversation(user_id)
                 return new_conversation_id, {
                     "status": "invalid_id_new",
@@ -105,7 +107,7 @@ class RedisConversationManager:
         if continue_conversation:
             recent_conversation = self._get_recent_conversation(user_id)
             if recent_conversation:
-                print(f"[REDIS_CONV] 继续最近对话: {recent_conversation}")
+                self.logger.debug(f"继续最近对话: {recent_conversation}")
                 return recent_conversation, {
                     "status": "existing",
                     "message": "继续最近对话"
@@ -113,7 +115,7 @@ class RedisConversationManager:
         
         # 3. 创建新对话
         new_conversation_id = self.create_conversation(user_id)
-        print(f"[REDIS_CONV] 创建新对话: {new_conversation_id}")
+        self.logger.debug(f"创建新对话: {new_conversation_id}")
         return new_conversation_id, {
             "status": "new",
             "message": "创建新对话"
@@ -180,11 +182,11 @@ class RedisConversationManager:
             # 添加到用户的对话列表
             self._add_conversation_to_user(user_id, conversation_id)
             
-            print(f"[REDIS_CONV] 创建对话成功: {conversation_id}")
+            self.logger.info(f"创建对话成功: {conversation_id}")
             return conversation_id
             
         except Exception as e:
-            print(f"[ERROR] 创建对话失败: {str(e)}")
+            self.logger.error(f"创建对话失败: {str(e)}")
             return conversation_id  # 返回ID但可能未存储
     
     def save_message(self, conversation_id: str, role: str, content: str, 
@@ -223,7 +225,7 @@ class RedisConversationManager:
             return True
             
         except Exception as e:
-            print(f"[ERROR] 保存消息失败: {str(e)}")
+            self.logger.error(f"保存消息失败: {str(e)}")
             return False
     
     def get_context(self, conversation_id: str, count: Optional[int] = None) -> str:
@@ -262,11 +264,11 @@ class RedisConversationManager:
                     continue
             
             context = "\n".join(context_parts)
-            print(f"[REDIS_CONV] 获取上下文成功: {len(context_parts)}条消息")
+            self.logger.debug(f"获取上下文成功: {len(context_parts)}条消息")
             return context
             
         except Exception as e:
-            print(f"[ERROR] 获取上下文失败: {str(e)}")
+            self.logger.error(f"获取上下文失败: {str(e)}")
             return ""
         
     def get_context_for_display(self, conversation_id: str, count: Optional[int] = None) -> str:
@@ -307,11 +309,11 @@ class RedisConversationManager:
                     continue
             
             context = "\n".join(context_parts)
-            print(f"[REDIS_CONV] 获取显示上下文成功: {len(context_parts)}条消息")
+            self.logger.debug(f"获取显示上下文成功: {len(context_parts)}条消息")
             return context
             
         except Exception as e:
-            print(f"[ERROR] 获取显示上下文失败: {str(e)}")
+            self.logger.error(f"获取显示上下文失败: {str(e)}")
             return ""
     
     
@@ -341,7 +343,7 @@ class RedisConversationManager:
             return parsed_messages
             
         except Exception as e:
-            print(f"[ERROR] 获取对话消息失败: {str(e)}")
+            self.logger.error(f"获取对话消息失败: {str(e)}")
             return []
     
     def get_conversation_meta(self, conversation_id: str) -> Dict:
@@ -353,7 +355,7 @@ class RedisConversationManager:
             meta_data = self.redis_client.hgetall(f"conversation:{conversation_id}:meta")
             return meta_data if meta_data else {}
         except Exception as e:
-            print(f"[ERROR] 获取对话元信息失败: {str(e)}")
+            self.logger.error(f"获取对话元信息失败: {str(e)}")
             return {}
     
     def get_conversations(self, user_id: str, limit: int = None) -> List[Dict]:
@@ -379,7 +381,7 @@ class RedisConversationManager:
             return conversations
             
         except Exception as e:
-            print(f"[ERROR] 获取用户对话列表失败: {str(e)}")
+            self.logger.error(f"获取用户对话列表失败: {str(e)}")
             return []
     
     # ==================== 智能缓存(修正版)====================
@@ -396,13 +398,13 @@ class RedisConversationManager:
             
             if cached_answer:
                 context_info = "有上下文" if context else "无上下文"
-                print(f"[REDIS_CONV] 缓存命中: {cache_key} ({context_info})")
+                self.logger.debug(f"缓存命中: {cache_key} ({context_info})")
                 return json.loads(cached_answer)
             
             return None
             
         except Exception as e:
-            print(f"[ERROR] 获取缓存答案失败: {str(e)}")
+            self.logger.error(f"获取缓存答案失败: {str(e)}")
             return None
     
     def cache_answer(self, question: str, answer: Dict, context: str = ""):
@@ -412,7 +414,7 @@ class RedisConversationManager:
         
         # 新增:如果有上下文,不缓存
         if context:
-            print(f"[REDIS_CONV] 跳过缓存存储:存在上下文")
+            self.logger.debug("跳过缓存存储:存在上下文")
             return
         
         try:
@@ -432,10 +434,10 @@ class RedisConversationManager:
                 json.dumps(answer_with_meta)
             )
             
-            print(f"[REDIS_CONV] 缓存答案成功: {cache_key}")
+            self.logger.debug(f"缓存答案成功: {cache_key}")
             
         except Exception as e:
-            print(f"[ERROR] 缓存答案失败: {str(e)}")
+            self.logger.error(f"缓存答案失败: {str(e)}")
     
     def _get_cache_key(self, question: str) -> str:
         """生成缓存键 - 简化版,只基于问题本身"""
@@ -464,7 +466,7 @@ class RedisConversationManager:
             )
             
         except Exception as e:
-            print(f"[ERROR] 添加对话到用户列表失败: {str(e)}")
+            self.logger.error(f"添加对话到用户列表失败: {str(e)}")
     
     def _update_conversation_meta(self, conversation_id: str):
         """更新对话元信息"""
@@ -482,7 +484,7 @@ class RedisConversationManager:
             )
             
         except Exception as e:
-            print(f"[ERROR] 更新对话元信息失败: {str(e)}")
+            self.logger.error(f"更新对话元信息失败: {str(e)}")
     
     # ==================== 管理方法 ====================
     
@@ -510,7 +512,7 @@ class RedisConversationManager:
             return stats
             
         except Exception as e:
-            print(f"[ERROR] 获取统计信息失败: {str(e)}")
+            self.logger.error(f"获取统计信息失败: {str(e)}")
             return {"available": False, "error": str(e)}
     
     def cleanup_expired_conversations(self):
@@ -542,10 +544,10 @@ class RedisConversationManager:
                         # 重新设置TTL
                         self.redis_client.expire(user_key, USER_CONVERSATIONS_TTL)
             
-            print(f"[REDIS_CONV] 清理完成,移除了 {cleaned_count} 个无效对话引用")
+            self.logger.info(f"清理完成,移除了 {cleaned_count} 个无效对话引用")
             
         except Exception as e:
-            print(f"[ERROR] 清理失败: {str(e)}")
+            self.logger.error(f"清理失败: {str(e)}")
     
     # ==================== 问答缓存管理方法 ====================
     
@@ -579,7 +581,7 @@ class RedisConversationManager:
             return stats
             
         except Exception as e:
-            print(f"[ERROR] 获取问答缓存统计失败: {str(e)}")
+            self.logger.error(f"获取问答缓存统计失败: {str(e)}")
             return {"available": False, "error": str(e)}
     
     def get_qa_cache_list(self, limit: int = 50) -> List[Dict]:
@@ -621,7 +623,7 @@ class RedisConversationManager:
                     # 跳过无效的JSON数据
                     continue
                 except Exception as e:
-                    print(f"[WARNING] 处理缓存项 {key} 失败: {e}")
+                    self.logger.warning(f"处理缓存项 {key} 失败: {e}")
                     continue
             
             # 按缓存时间倒序排列
@@ -630,7 +632,7 @@ class RedisConversationManager:
             return cache_list
             
         except Exception as e:
-            print(f"[ERROR] 获取问答缓存列表失败: {str(e)}")
+            self.logger.error(f"获取问答缓存列表失败: {str(e)}")
             return []
     
     def clear_all_qa_cache(self) -> int:
@@ -644,12 +646,12 @@ class RedisConversationManager:
             
             if keys:
                 deleted_count = self.redis_client.delete(*keys)
-                print(f"[REDIS_CONV] 清空问答缓存成功,删除了 {deleted_count} 个缓存项")
+                self.logger.info(f"清空问答缓存成功,删除了 {deleted_count} 个缓存项")
                 return deleted_count
             else:
-                print(f"[REDIS_CONV] 没有找到问答缓存项")
+                self.logger.info("没有找到问答缓存项")
                 return 0
                 
         except Exception as e:
-            print(f"[ERROR] 清空问答缓存失败: {str(e)}")
+            self.logger.error(f"清空问答缓存失败: {str(e)}")
             return 0 

+ 12 - 8
common/utils.py

@@ -2,6 +2,10 @@
 配置相关的工具函数
 用于处理不同模型类型的配置选择逻辑
 """
+from core.logging import get_app_logger
+
+# 初始化logger
+_logger = get_app_logger("ConfigUtils")
 
 def get_current_embedding_config():
     """
@@ -180,12 +184,12 @@ def print_current_config():
     """
     try:
         model_info = get_current_model_info()
-        print("=== 当前模型配置 ===")
-        print(f"LLM提供商: {model_info['llm_type']}")
-        print(f"LLM模型: {model_info['llm_model']}")
-        print(f"Embedding提供商: {model_info['embedding_type']}")
-        print(f"Embedding模型: {model_info['embedding_model']}")
-        print(f"向量数据库: {model_info['vector_db']}")
-        print("==================")
+        _logger.info("=== 当前模型配置 ===")
+        _logger.info(f"LLM提供商: {model_info['llm_type']}")
+        _logger.info(f"LLM模型: {model_info['llm_model']}")
+        _logger.info(f"Embedding提供商: {model_info['embedding_type']}")
+        _logger.info(f"Embedding模型: {model_info['embedding_model']}")
+        _logger.info(f"向量数据库: {model_info['vector_db']}")
+        _logger.info("==================")
     except Exception as e:
-        print(f"无法获取配置信息: {e}") 
+        _logger.error(f"无法获取配置信息: {e}") 

+ 11 - 7
common/vanna_combinations.py

@@ -2,13 +2,17 @@
 Vanna LLM与向量数据库的组合类
 统一管理所有LLM提供商与向量数据库的组合
 """
+from core.logging import get_app_logger
+
+# 初始化logger
+_logger = get_app_logger("VannaCombinations")
 
 # 向量数据库导入
 from vanna.chromadb import ChromaDB_VectorStore
 try:
     from custompgvector import PG_VectorStore
 except ImportError:
-    print("警告: 无法导入 PG_VectorStore,PGVector相关组合类将不可用")
+    _logger.warning("无法导入 PG_VectorStore,PGVector相关组合类将不可用")
     PG_VectorStore = None
 
 # LLM提供商导入 - 使用新的重构后的实现
@@ -17,7 +21,7 @@ from customllm.deepseek_chat import DeepSeekChat
 try:
     from customllm.ollama_chat import OllamaChat
 except ImportError:
-    print("警告: 无法导入 OllamaChat,Ollama相关组合类将不可用")
+    _logger.warning("无法导入 OllamaChat,Ollama相关组合类将不可用")
     OllamaChat = None
 
 
@@ -168,19 +172,19 @@ def list_available_combinations():
 
 def print_available_combinations():
     """打印所有可用的组合"""
-    print("可用的LLM与向量数据库组合:")
-    print("=" * 40)
+    _logger.info("可用的LLM与向量数据库组合:")
+    _logger.info("=" * 40)
     
     combinations = list_available_combinations()
     
     for llm_type, vector_dbs in combinations.items():
-        print(f"\n{llm_type.upper()} LLM:")
+        _logger.info(f"\n{llm_type.upper()} LLM:")
         for vector_db in vector_dbs:
             class_name = LLM_CLASS_MAP[llm_type][vector_db].__name__
-            print(f"  + {vector_db} -> {class_name}")
+            _logger.info(f"  + {vector_db} -> {class_name}")
     
     if not any(combinations.values()):
-        print("没有可用的组合,请检查依赖是否正确安装")
+        _logger.warning("没有可用的组合,请检查依赖是否正确安装")
 
 
 # ===== 向后兼容性支持 =====

+ 8 - 4
common/vanna_instance.py

@@ -4,6 +4,10 @@ Vanna实例单例管理器
 """
 import threading
 from typing import Optional
+from core.logging import get_app_logger
+
+# 初始化logger
+_logger = get_app_logger("VannaSingleton")
 
 # 全局变量
 _vanna_instance: Optional[object] = None
@@ -22,14 +26,14 @@ def get_vanna_instance():
     if _vanna_instance is None:
         with _instance_lock:
             if _vanna_instance is None:
-                print("[VANNA_SINGLETON] 创建 Vanna 实例...")
+                _logger.info("创建 Vanna 实例...")
                 try:
                     # 延迟导入,避免循环导入
                     from core.vanna_llm_factory import create_vanna_instance
                     _vanna_instance = create_vanna_instance()
-                    print("[VANNA_SINGLETON] Vanna 实例创建成功")
+                    _logger.info("Vanna 实例创建成功")
                 except Exception as e:
-                    print(f"[ERROR] Vanna 实例创建失败: {str(e)}")
+                    _logger.error(f"Vanna 实例创建失败: {str(e)}")
                     raise
     
     return _vanna_instance
@@ -41,7 +45,7 @@ def reset_vanna_instance():
     global _vanna_instance
     with _instance_lock:
         if _vanna_instance is not None:
-            print("[VANNA_SINGLETON] 重置 Vanna 实例")
+            _logger.info("重置 Vanna 实例")
             _vanna_instance = None
 
 def get_instance_status() -> dict:

+ 17 - 12
core/embedding_function.py

@@ -2,6 +2,7 @@ import requests
 import time
 import numpy as np
 from typing import List, Callable
+from core.logging import get_vanna_logger
 
 class EmbeddingFunction:
     def __init__(self, model_name: str, api_key: str, base_url: str, embedding_dimension: int):
@@ -16,6 +17,9 @@ class EmbeddingFunction:
         self.max_retries = 3  # 设置默认的最大重试次数
         self.retry_interval = 2  # 设置默认的重试间隔秒数
         self.normalize_embeddings = True # 设置默认是否归一化
+        
+        # 初始化日志
+        self.logger = get_vanna_logger("EmbeddingFunction")
 
     def _normalize_vector(self, vector: List[float]) -> List[float]:
         """
@@ -54,7 +58,7 @@ class EmbeddingFunction:
                 vector = self.generate_embedding(text)
                 embeddings.append(vector)
             except Exception as e:
-                print(f"为文本 '{text}' 生成embedding失败: {e}")
+                self.logger.error(f"为文本 '{text}' 生成embedding失败: {e}")
                 # 重新抛出异常,不返回零向量
                 raise e
                 
@@ -135,7 +139,7 @@ class EmbeddingFunction:
                         retries += 1
                         if retries <= self.max_retries:
                             wait_time = self.retry_interval * (2 ** (retries - 1))  # 指数退避
-                            print(f"API请求失败,等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                            self.logger.warning(f"API请求失败,等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
                             time.sleep(wait_time)
                             continue
                     
@@ -155,14 +159,14 @@ class EmbeddingFunction:
                         # 验证向量维度
                         actual_dim = len(vector)
                         if actual_dim != self.embedding_dimension:
-                            print(f"警告: 向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
+                            self.logger.warning(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
                     
                     # 如果需要归一化
                     if self.normalize_embeddings:
                         vector = self._normalize_vector(vector)
                     
                     # 添加成功生成embedding的debug日志
-                    print(f"[DEBUG] ✓ 成功生成embedding向量,维度: {len(vector)}")
+                    self.logger.debug(f"成功生成embedding向量,维度: {len(vector)}")
                     
                     return vector
                 else:
@@ -174,7 +178,7 @@ class EmbeddingFunction:
                 
                 if retries <= self.max_retries:
                     wait_time = self.retry_interval * (2 ** (retries - 1))  # 指数退避
-                    print(f"生成embedding时出错: {str(e)}, 等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                    self.logger.warning(f"生成embedding时出错: {str(e)}, 等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
                     time.sleep(wait_time)
                 else:
                     # 抛出异常而不是返回零向量,确保问题不被掩盖
@@ -203,8 +207,8 @@ class EmbeddingFunction:
         }
         
         try:
-            print(f"测试嵌入模型连接 - 模型: {self.model_name}")
-            print(f"API服务地址: {self.base_url}")
+            self.logger.info(f"测试嵌入模型连接 - 模型: {self.model_name}")
+            self.logger.info(f"API服务地址: {self.base_url}")
             
             # 验证配置
             if not self.api_key:
@@ -241,6 +245,7 @@ def test_embedding_connection() -> dict:
     Returns:
         dict: 测试结果,包括成功/失败状态、错误消息等
     """
+    logger = get_vanna_logger("EmbeddingTest")
     try:
         # 获取嵌入函数实例
         embedding_function = get_embedding_function()
@@ -249,18 +254,18 @@ def test_embedding_connection() -> dict:
         test_result = embedding_function.test_connection()
         
         if test_result["success"]:
-            print(f"嵌入模型连接测试成功!")
+            logger.info(f"嵌入模型连接测试成功!")
             if "警告" in test_result["message"]:
-                print(test_result["message"])
-                print(f"建议将app_config.py中的EMBEDDING_CONFIG['embedding_dimension']修改为{test_result['actual_dimension']}")
+                logger.warning(test_result["message"])
+                logger.warning(f"建议将app_config.py中的EMBEDDING_CONFIG['embedding_dimension']修改为{test_result['actual_dimension']}")
         else:
-            print(f"嵌入模型连接测试失败: {test_result['message']}")
+            logger.error(f"嵌入模型连接测试失败: {test_result['message']}")
             
         return test_result
         
     except Exception as e:
         error_message = f"无法测试嵌入模型连接: {str(e)}"
-        print(error_message)
+        logger.error(error_message)
         return {
             "success": False,
             "message": error_message

+ 11 - 7
core/vanna_llm_factory.py

@@ -4,6 +4,10 @@ Vanna LLM 工厂文件,支持多种LLM提供商和向量数据库
 import app_config, os
 from core.embedding_function import get_embedding_function
 from common.vanna_combinations import get_vanna_class, print_available_combinations
+from core.logging import get_vanna_logger
+
+# 初始化日志
+logger = get_vanna_logger("VannaFactory")
 
 def create_vanna_instance(config_module=None):
     """
@@ -48,11 +52,11 @@ def create_vanna_instance(config_module=None):
         vector_db_type = model_info["vector_db"].lower()
         
         cls = get_vanna_class(llm_type, vector_db_type)
-        print(f"创建{llm_type.upper()}+{vector_db_type.upper()}实例")
+        logger.info(f"创建{llm_type.upper()}+{vector_db_type.upper()}实例")
         
     except ValueError as e:
-        print(f"错误: {e}")
-        print("\n可用的组合:")
+        logger.error(f"{e}")
+        logger.info("可用的组合:")
         print_available_combinations()
         raise
     
@@ -62,24 +66,24 @@ def create_vanna_instance(config_module=None):
     # 配置向量数据库
     if model_info["vector_db"] == "chromadb":
         config["path"] = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # 返回项目根目录
-        print(f"已配置使用ChromaDB,路径:{config['path']}")
+        logger.info(f"已配置使用ChromaDB,路径:{config['path']}")
     elif model_info["vector_db"] == "pgvector":
         # 构建PostgreSQL连接字符串
         connection_string = f"postgresql://{vector_db_config['user']}:{vector_db_config['password']}@{vector_db_config['host']}:{vector_db_config['port']}/{vector_db_config['dbname']}"
         config["connection_string"] = connection_string
-        print(f"已配置使用PgVector,连接字符串: {connection_string}")
+        logger.info(f"已配置使用PgVector,连接字符串: {connection_string}")
     
     # 配置embedding函数
     embedding_function = get_embedding_function()
     config["embedding_function"] = embedding_function
-    print(f"已配置使用{model_info['embedding_type'].upper()}嵌入模型: {model_info['embedding_model']}")
+    logger.info(f"已配置使用{model_info['embedding_type'].upper()}嵌入模型: {model_info['embedding_model']}")
     
     # 创建实例
     vn = cls(config=config)
 
     # 连接到业务数据库
     vn.connect_to_postgres(**config_module.APP_DB_CONFIG)           
-    print(f"已连接到业务数据库: "
+    logger.info(f"已连接到业务数据库: "
           f"{config_module.APP_DB_CONFIG['host']}:"
           f"{config_module.APP_DB_CONFIG['port']}/"
           f"{config_module.APP_DB_CONFIG['dbname']}")

+ 17 - 13
customembedding/ollama_embedding.py

@@ -2,6 +2,7 @@ import requests
 import time
 import numpy as np
 from typing import List, Callable
+from core.logging import get_vanna_logger
 
 class OllamaEmbeddingFunction:
     def __init__(self, model_name: str, base_url: str, embedding_dimension: int):
@@ -10,6 +11,9 @@ class OllamaEmbeddingFunction:
         self.embedding_dimension = embedding_dimension
         self.max_retries = 3
         self.retry_interval = 2
+        
+        # 初始化日志
+        self.logger = get_vanna_logger("OllamaEmbedding")
 
     def __call__(self, input) -> List[List[float]]:
         """为文本列表生成嵌入向量"""
@@ -22,7 +26,7 @@ class OllamaEmbeddingFunction:
                 embedding = self.generate_embedding(text)
                 embeddings.append(embedding)
             except Exception as e:
-                print(f"获取embedding时出错: {e}")
+                self.logger.error(f"获取embedding时出错: {e}")
                 embeddings.append([0.0] * self.embedding_dimension)
                 
         return embeddings
@@ -37,10 +41,10 @@ class OllamaEmbeddingFunction:
     
     def generate_embedding(self, text: str) -> List[float]:
         """为单个文本生成嵌入向量"""
-        print(f"生成Ollama嵌入向量,文本长度: {len(text)} 字符")
+        self.logger.debug(f"生成Ollama嵌入向量,文本长度: {len(text)} 字符")
         
         if not text or len(text.strip()) == 0:
-            print("输入文本为空,返回零向量")
+            self.logger.debug("输入文本为空,返回零向量")
             return [0.0] * self.embedding_dimension
 
         url = f"{self.base_url}/api/embeddings"
@@ -60,13 +64,13 @@ class OllamaEmbeddingFunction:
                 
                 if response.status_code != 200:
                     error_msg = f"Ollama API请求错误: {response.status_code}, {response.text}"
-                    print(error_msg)
+                    self.logger.error(error_msg)
                     
                     if response.status_code in (429, 500, 502, 503, 504):
                         retries += 1
                         if retries <= self.max_retries:
                             wait_time = self.retry_interval * (2 ** (retries - 1))
-                            print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                            self.logger.info(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
                             time.sleep(wait_time)
                             continue
                     
@@ -80,7 +84,7 @@ class OllamaEmbeddingFunction:
                     # 验证向量维度
                     actual_dim = len(vector)
                     if actual_dim != self.embedding_dimension:
-                        print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
+                        self.logger.debug(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
                         # 如果维度不匹配,可以选择截断或填充
                         if actual_dim > self.embedding_dimension:
                             vector = vector[:self.embedding_dimension]
@@ -88,23 +92,23 @@ class OllamaEmbeddingFunction:
                             vector.extend([0.0] * (self.embedding_dimension - actual_dim))
                     
                     # 添加成功生成embedding的debug日志
-                    print(f"[DEBUG] ✓ 成功生成Ollama embedding向量,维度: {len(vector)}")
+                    self.logger.debug(f"✓ 成功生成Ollama embedding向量,维度: {len(vector)}")
                     return vector
                 else:
                     error_msg = f"Ollama API返回格式异常: {result}"
-                    print(error_msg)
+                    self.logger.error(error_msg)
                     raise ValueError(error_msg)
                 
             except Exception as e:
-                print(f"生成Ollama embedding时出错: {str(e)}")
+                self.logger.error(f"生成Ollama embedding时出错: {str(e)}")
                 retries += 1
                 
                 if retries <= self.max_retries:
                     wait_time = self.retry_interval * (2 ** (retries - 1))
-                    print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                    self.logger.info(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
                     time.sleep(wait_time)
                 else:
-                    print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
+                    self.logger.error(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
                     return [0.0] * self.embedding_dimension
         
         raise RuntimeError("生成Ollama embedding失败")
@@ -121,8 +125,8 @@ class OllamaEmbeddingFunction:
         }
         
         try:
-            print(f"测试Ollama嵌入模型连接 - 模型: {self.model_name}")
-            print(f"Ollama服务地址: {self.base_url}")
+            self.logger.info(f"测试Ollama嵌入模型连接 - 模型: {self.model_name}")
+            self.logger.info(f"Ollama服务地址: {self.base_url}")
             
             vector = self.generate_embedding(test_text)
             actual_dimension = len(vector)

+ 72 - 65
customllm/base_llm_chat.py

@@ -4,6 +4,7 @@ from typing import List, Dict, Any, Optional, Union, Tuple
 import pandas as pd
 import plotly.graph_objs
 from vanna.base import VannaBase
+from core.logging import get_vanna_logger
 # 导入配置参数
 from app_config import REWRITE_QUESTION_ENABLED, DISPLAY_RESULT_THINKING
 
@@ -14,18 +15,21 @@ class BaseLLMChat(VannaBase, ABC):
     def __init__(self, config=None):
         VannaBase.__init__(self, config=config)
 
+        # 初始化日志
+        self.logger = get_vanna_logger("BaseLLMChat")
+
         # 存储LLM解释性文本
         self.last_llm_explanation = None
         
-        print("传入的 config 参数如下:")
+        self.logger.info("传入的 config 参数如下:")
         for key, value in self.config.items():
-            print(f"  {key}: {value}")
+            self.logger.info(f"  {key}: {value}")
         
         # 默认参数
         self.temperature = 0.7
         
         if "temperature" in config:
-            print(f"temperature is changed to: {config['temperature']}")
+            self.logger.info(f"temperature is changed to: {config['temperature']}")
             self.temperature = config["temperature"]
         
         # 加载错误SQL提示配置
@@ -36,32 +40,32 @@ class BaseLLMChat(VannaBase, ABC):
         try:
             import app_config
             enable_error_sql = getattr(app_config, 'ENABLE_ERROR_SQL_PROMPT', False)
-            print(f"[DEBUG] 错误SQL提示配置: ENABLE_ERROR_SQL_PROMPT = {enable_error_sql}")
+            self.logger.debug(f"错误SQL提示配置: ENABLE_ERROR_SQL_PROMPT = {enable_error_sql}")
             return enable_error_sql
         except (ImportError, AttributeError) as e:
-            print(f"[WARNING] 无法加载错误SQL提示配置: {e},使用默认值 False")
+            self.logger.warning(f"无法加载错误SQL提示配置: {e},使用默认值 False")
             return False
 
     def system_message(self, message: str) -> dict:
         """创建系统消息格式"""
-        print(f"system_content: {message}")
+        self.logger.debug(f"system_content: {message}")
         return {"role": "system", "content": message}
 
     def user_message(self, message: str) -> dict:
         """创建用户消息格式"""
-        print(f"\nuser_content: {message}")
+        self.logger.debug(f"\nuser_content: {message}")
         return {"role": "user", "content": message}
 
     def assistant_message(self, message: str) -> dict:
         """创建助手消息格式"""
-        print(f"assistant_content: {message}")
+        self.logger.debug(f"assistant_content: {message}")
         return {"role": "assistant", "content": message}
 
     def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
         """
         基于VannaBase源码实现,在第7点添加中文别名指令
         """
-        print(f"[DEBUG] 开始生成SQL提示词,问题: {question}")
+        self.logger.debug(f"开始生成SQL提示词,问题: {question}")
         
         if initial_prompt is None:
             initial_prompt = f"You are a {self.dialect} expert. " + \
@@ -101,7 +105,7 @@ class BaseLLMChat(VannaBase, ABC):
             try:
                 error_sql_list = self.get_related_error_sql(question, **kwargs)
                 if error_sql_list:
-                    print(f"[DEBUG] 找到 {len(error_sql_list)} 个相关的错误SQL示例")
+                    self.logger.debug(f"找到 {len(error_sql_list)} 个相关的错误SQL示例")
                     
                     # 构建格式化的负面提示内容
                     negative_prompt_content = "===Negative Examples\n"
@@ -110,33 +114,36 @@ class BaseLLMChat(VannaBase, ABC):
                     for i, error_example in enumerate(error_sql_list, 1):
                         if "question" in error_example and "sql" in error_example:
                             similarity = error_example.get('similarity', 'N/A')
-                            print(f"[DEBUG] 错误SQL示例 {i}: 相似度={similarity}")
+                            self.logger.debug(f"错误SQL示例 {i}: 相似度={similarity}")
                             negative_prompt_content += f"问题: {error_example['question']}\n"
                             negative_prompt_content += f"错误的SQL: {error_example['sql']}\n\n"
                     
                     # 将负面提示添加到初始提示中
                     initial_prompt += negative_prompt_content
                 else:
-                    print("[DEBUG] 未找到相关的错误SQL示例")
+                    self.logger.debug("未找到相关的错误SQL示例")
             except Exception as e:
-                print(f"[WARNING] 获取错误SQL示例失败: {e}")
+                self.logger.warning(f"获取错误SQL示例失败: {e}")
 
         initial_prompt += (
             "===Response Guidelines \n"
             "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
             "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
             "3. If the provided context is insufficient, please explain why it can't be generated. \n"
-            "4. Please use the most relevant table(s). \n"
-            "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
-            f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
-            "7. 在生成 SQL 查询时,如果出现 ORDER BY 子句,请遵循以下规则:\n"
+            "4. **Context Understanding**: If the question follows [CONTEXT]...[CURRENT] format, replace pronouns in [CURRENT] with specific entities from [CONTEXT].\n"
+            "   - Example: If context mentions 'Nancheng Service Area has the most stalls', and current question is 'How many dining stalls does this service area have?', \n"
+            "     interpret it as 'How many dining stalls does Nancheng Service Area have?'\n"
+            "5. Please use the most relevant table(s). \n"
+            "6. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
+            f"7. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
+            "8. 在生成 SQL 查询时,如果出现 ORDER BY 子句,请遵循以下规则:\n"
             "   - 对所有的排序字段(如聚合字段 SUM()、普通列等),请在 ORDER BY 中显式添加 NULLS LAST。\n"
             "   - 不论是否使用 LIMIT,只要排序字段存在,都必须添加 NULLS LAST,以防止 NULL 排在结果顶部。\n"
             "   - 示例参考:\n"
             "     - ORDER BY total DESC NULLS LAST\n"
             "     - ORDER BY zf_order DESC NULLS LAST\n"
             "     - ORDER BY SUM(c.customer_count) DESC NULLS LAST \n"
-            "8. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
+            "9. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
             "   - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
             "   - 包括原始字段名也要添加中文别名,例如:SELECT gender AS 性别, card_category AS 卡片类型\n"
             "   - 计算字段也要有中文别名,例如:SELECT COUNT(*) AS 持卡人数\n"
@@ -147,7 +154,7 @@ class BaseLLMChat(VannaBase, ABC):
 
         for example in question_sql_list:
             if example is None:
-                print("example is None")
+                self.logger.warning("example is None")
             else:
                 if example is not None and "question" in example and "sql" in example:
                     message_log.append(self.user_message(example["question"]))
@@ -225,7 +232,7 @@ class BaseLLMChat(VannaBase, ABC):
         if not DISPLAY_RESULT_THINKING:
             original_code = plotly_code
             plotly_code = self._remove_thinking_content(plotly_code)
-            print(f"[DEBUG] generate_plotly_code隐藏thinking内容 - 原始长度: {len(original_code)}, 处理后长度: {len(plotly_code)}")
+            self.logger.debug(f"generate_plotly_code隐藏thinking内容 - 原始长度: {len(original_code)}, 处理后长度: {len(plotly_code)}")
 
         return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
 
@@ -270,12 +277,12 @@ class BaseLLMChat(VannaBase, ABC):
         对于Flask应用,这个方法决定了前端是否显示图表生成按钮
         """
         if df is None or df.empty:
-            print(f"[DEBUG] should_generate_chart: df为空,返回False")
+            self.logger.debug("should_generate_chart: df为空,返回False")
             return False
         
         # 如果数据有多行或多列,通常适合生成图表
         result = len(df) > 1 or len(df.columns) > 1
-        print(f"[DEBUG] should_generate_chart: df.shape={df.shape}, 返回{result}")
+        self.logger.debug(f"should_generate_chart: df.shape={df.shape}, 返回{result}")
         
         if result:
             return True
@@ -290,12 +297,12 @@ class BaseLLMChat(VannaBase, ABC):
             # 清空上次的解释性文本
             self.last_llm_explanation = None
             
-            print(f"[DEBUG] 尝试为问题生成SQL: {question}")
+            self.logger.debug(f"尝试为问题生成SQL: {question}")
             # 调用父类的 generate_sql
             sql = super().generate_sql(question, **kwargs)
             
             if not sql or sql.strip() == "":
-                print(f"[WARNING] 生成的SQL为空")
+                self.logger.warning("生成的SQL为空")
                 explanation = "无法生成SQL查询,可能是问题描述不够清晰或缺少必要的数据表信息。"
                 # 根据 DISPLAY_RESULT_THINKING 参数处理thinking内容
                 if not DISPLAY_RESULT_THINKING:
@@ -319,38 +326,38 @@ class BaseLLMChat(VannaBase, ABC):
             
             for indicator in error_indicators:
                 if indicator in sql_lower:
-                    print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
+                    self.logger.warning(f"LLM返回错误信息而非SQL: {sql}")
                     # 保存LLM的解释性文本,并根据配置处理thinking内容
                     explanation = sql
                     if not DISPLAY_RESULT_THINKING:
                         explanation = self._remove_thinking_content(explanation)
-                        print(f"[DEBUG] 隐藏thinking内容 - SQL生成解释性文本")
+                        self.logger.debug("隐藏thinking内容 - SQL生成解释性文本")
                     self.last_llm_explanation = explanation
                     return None
             
             # 简单检查是否像SQL语句(至少包含一些SQL关键词)
             sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
             if not any(keyword in sql_lower for keyword in sql_keywords):
-                print(f"[WARNING] 返回内容不像有效SQL: {sql}")
+                self.logger.warning(f"返回内容不像有效SQL: {sql}")
                 # 保存LLM的解释性文本,并根据配置处理thinking内容
                 explanation = sql
                 if not DISPLAY_RESULT_THINKING:
                     explanation = self._remove_thinking_content(explanation)
-                    print(f"[DEBUG] 隐藏thinking内容 - SQL生成非有效SQL内容")
+                    self.logger.debug("隐藏thinking内容 - SQL生成非有效SQL内容")
                 self.last_llm_explanation = explanation
                 return None
                 
-            print(f"[SUCCESS] 成功生成SQL:\n {sql}")
+            self.logger.info(f"成功生成SQL:\n {sql}")
             # 清空解释性文本
             self.last_llm_explanation = None
             return sql
             
         except Exception as e:
-            print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
-            print(f"[ERROR] 异常类型: {type(e).__name__}")
+            self.logger.error(f"SQL生成过程中出现异常: {str(e)}")
+            self.logger.error(f"异常类型: {type(e).__name__}")
             # 导入traceback以获取详细错误信息
             import traceback
-            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            self.logger.error(f"详细错误信息: {traceback.format_exc()}")
             explanation = f"SQL生成过程中出现异常: {str(e)}"
             # 根据 DISPLAY_RESULT_THINKING 参数处理thinking内容
             if not DISPLAY_RESULT_THINKING:
@@ -372,7 +379,7 @@ class BaseLLMChat(VannaBase, ABC):
         if not DISPLAY_RESULT_THINKING:
             original_response = response
             response = self._remove_thinking_content(response)
-            print(f"[DEBUG] generate_question隐藏thinking内容 - 原始长度: {len(original_response)}, 处理后长度: {len(response)}")
+            self.logger.debug(f"generate_question隐藏thinking内容 - 原始长度: {len(original_response)}, 处理后长度: {len(response)}")
         
         return response
 
@@ -390,7 +397,7 @@ class BaseLLMChat(VannaBase, ABC):
     #         response = self.submit_prompt(prompt, **kwargs)
     #         return response
     #     except Exception as e:
-    #         print(f"[ERROR] LLM对话失败: {str(e)}")
+    #         self.logger.error(f"LLM对话失败: {str(e)}")
     #         return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
 
     def chat_with_llm(self, question: str, system_prompt: str = None, **kwargs) -> str:
@@ -421,12 +428,12 @@ class BaseLLMChat(VannaBase, ABC):
             if not DISPLAY_RESULT_THINKING:
                 original_response = response
                 response = self._remove_thinking_content(response)
-                print(f"[DEBUG] chat_with_llm隐藏thinking内容 - 原始长度: {len(original_response)}, 处理后长度: {len(response)}")
+                self.logger.debug(f"chat_with_llm隐藏thinking内容 - 原始长度: {len(original_response)}, 处理后长度: {len(response)}")
             
             return response
             
         except Exception as e:
-            print(f"[ERROR] LLM对话失败: {str(e)}")
+            self.logger.error(f"LLM对话失败: {str(e)}")
             return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
 
     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
@@ -443,12 +450,12 @@ class BaseLLMChat(VannaBase, ABC):
         """
         # 如果未启用合并功能或没有上一个问题,直接返回新问题
         if not REWRITE_QUESTION_ENABLED or last_question is None:
-            print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
+            self.logger.debug(f"问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
             return new_question
         
-        print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
-        print(f"[DEBUG] 上一个问题: {last_question}")
-        print(f"[DEBUG] 新问题: {new_question}")
+        self.logger.debug("启用问题合并功能,尝试合并问题")
+        self.logger.debug(f"上一个问题: {last_question}")
+        self.logger.debug(f"新问题: {new_question}")
         
         try:
             prompt = [
@@ -466,13 +473,13 @@ class BaseLLMChat(VannaBase, ABC):
             if not DISPLAY_RESULT_THINKING:
                 original_question = rewritten_question
                 rewritten_question = self._remove_thinking_content(rewritten_question)
-                print(f"[DEBUG] generate_rewritten_question隐藏thinking内容 - 原始长度: {len(original_question)}, 处理后长度: {len(rewritten_question)}")
+                self.logger.debug(f"generate_rewritten_question隐藏thinking内容 - 原始长度: {len(original_question)}, 处理后长度: {len(rewritten_question)}")
             
-            print(f"[DEBUG] 合并后的问题: {rewritten_question}")
+            self.logger.debug(f"合并后的问题: {rewritten_question}")
             return rewritten_question
             
         except Exception as e:
-            print(f"[ERROR] 问题合并失败: {str(e)}")
+            self.logger.error(f"问题合并失败: {str(e)}")
             # 如果合并失败,返回新问题
             return new_question
 
@@ -494,14 +501,14 @@ class BaseLLMChat(VannaBase, ABC):
             
             # 确保 df 是 pandas DataFrame
             if not isinstance(df, pd.DataFrame):
-                print(f"[WARNING] df 不是 pandas DataFrame,类型: {type(df)}")
+                self.logger.warning(f"df 不是 pandas DataFrame,类型: {type(df)}")
                 return "无法生成摘要:数据格式不正确"
             
             if df.empty:
                 return "查询结果为空,无数据可供摘要。"
             
-            print(f"[DEBUG] 生成摘要 - 问题: {question}")
-            print(f"[DEBUG] DataFrame 形状: {df.shape}")
+            self.logger.debug(f"生成摘要 - 问题: {question}")
+            self.logger.debug(f"DataFrame 形状: {df.shape}")
             
             # 构建包含中文指令的系统消息
             system_content = (
@@ -531,15 +538,15 @@ class BaseLLMChat(VannaBase, ABC):
                 # 移除 <think></think> 标签及其内容
                 original_summary = summary
                 summary = self._remove_thinking_content(summary)
-                print(f"[DEBUG] 隐藏thinking内容 - 原始长度: {len(original_summary)}, 处理后长度: {len(summary)}")
+                self.logger.debug(f"隐藏thinking内容 - 原始长度: {len(original_summary)}, 处理后长度: {len(summary)}")
             
-            print(f"[DEBUG] 生成的摘要: {summary[:100]}...")
+            self.logger.debug(f"生成的摘要: {summary[:100]}...")
             return summary
             
         except Exception as e:
-            print(f"[ERROR] 生成摘要失败: {str(e)}")
+            self.logger.error(f"生成摘要失败: {str(e)}")
             import traceback
-            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            self.logger.error(f"详细错误信息: {traceback.format_exc()}")
             return f"生成摘要时出现错误:{str(e)}"
 
     def _remove_thinking_content(self, text: str) -> str:
@@ -598,7 +605,7 @@ class BaseLLMChat(VannaBase, ABC):
         try:
             sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
         except Exception as e:
-            print(e)
+            self.logger.error(f"SQL generation error: {e}")
             self.last_llm_explanation = str(e)
             if print_results:
                 return None
@@ -608,7 +615,7 @@ class BaseLLMChat(VannaBase, ABC):
         # 如果SQL为空,说明有解释性文本,按照正常流程返回None
         # API层会检查 last_llm_explanation 来获取解释
         if sql is None:
-            print(f"[INFO] 无法生成SQL,解释: {self.last_llm_explanation}")
+            self.logger.info(f"无法生成SQL,解释: {self.last_llm_explanation}")
             if print_results:
                 return None
             else:
@@ -616,10 +623,10 @@ class BaseLLMChat(VannaBase, ABC):
 
         # 以下是正常的SQL执行流程(保持VannaBase原有逻辑)
         if print_results:
-            print(sql)
+            self.logger.info(f"Generated SQL: {sql}")
 
         if self.run_sql_is_set is False:
-            print("If you want to run the SQL query, connect to a database first.")
+            self.logger.info("If you want to run the SQL query, connect to a database first.")
             if print_results:
                 return None
             else:
@@ -629,7 +636,7 @@ class BaseLLMChat(VannaBase, ABC):
             df = self.run_sql(sql)
             
             if df is None:
-                print("The SQL query returned no results.")
+                self.logger.info("The SQL query returned no results.")
                 if print_results:
                     return None
                 else:
@@ -638,17 +645,17 @@ class BaseLLMChat(VannaBase, ABC):
             if print_results:
                 # 显示结果表格
                 if len(df) > 10:
-                    print(df.head(10).to_string())
-                    print(f"... ({len(df)} rows)")
+                    self.logger.info(f"Query results (first 10 rows):\n{df.head(10).to_string()}")
+                    self.logger.info(f"... ({len(df)} rows)")
                 else:
-                    print(df.to_string())
+                    self.logger.info(f"Query results:\n{df.to_string()}")
 
             # 如果启用了自动训练,添加问题-SQL对到训练集
             if auto_train:
                 try:
                     self.add_question_sql(question=question, sql=sql)
                 except Exception as e:
-                    print(f"Could not add question and sql to training data: {e}")
+                    self.logger.warning(f"Could not add question and sql to training data: {e}")
 
             if visualize:
                 try:
@@ -668,25 +675,25 @@ class BaseLLMChat(VannaBase, ABC):
                             )
                             if fig is not None:
                                 if print_results:
-                                    print("Chart generated (use fig.show() to display)")
+                                    self.logger.info("Chart generated (use fig.show() to display)")
                                 return sql, df, fig
                             else:
-                                print("Could not generate chart")
+                                self.logger.warning("Could not generate chart")
                                 return sql, df, None
                         else:
-                            print("No chart generated")
+                            self.logger.info("No chart generated")
                             return sql, df, None
                     else:
-                        print("Not generating chart for this data")
+                        self.logger.info("Not generating chart for this data")
                         return sql, df, None
                 except Exception as e:
-                    print(f"Couldn't generate chart: {e}")
+                    self.logger.error(f"Couldn't generate chart: {e}")
                     return sql, df, None
             else:
                 return sql, df, None
 
         except Exception as e:
-            print("Couldn't run sql: ", e)
+            self.logger.error(f"Couldn't run sql: {e}")
             if print_results:
                 return None
             else:

+ 15 - 15
customllm/deepseek_chat.py

@@ -7,8 +7,8 @@ class DeepSeekChat(BaseLLMChat):
     """DeepSeek AI聊天实现"""
     
     def __init__(self, config=None):
-        print("...DeepSeekChat init...")
         super().__init__(config=config)
+        self.logger.info("DeepSeekChat init")
 
         if config is None:
             self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
@@ -43,7 +43,7 @@ class DeepSeekChat(BaseLLMChat):
         # DeepSeek API约束:enable_thinking=True时建议使用stream=True
         # 如果stream=False但enable_thinking=True,则忽略enable_thinking
         if enable_thinking and not stream_mode:
-            print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
+            self.logger.warning("enable_thinking=True 不生效,因为它需要 stream=True")
             enable_thinking = False
 
         # 确定使用的模型
@@ -68,18 +68,18 @@ class DeepSeekChat(BaseLLMChat):
 
         # 模型兼容性提示(但不强制切换)
         if enable_thinking and model not in ["deepseek-reasoner"]:
-            print(f"提示:模型 {model} 可能不支持推理功能,推理相关参数将被忽略")
+            self.logger.warning(f"提示:模型 {model} 可能不支持推理功能,推理相关参数将被忽略")
 
-        print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
-        print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
+        self.logger.info(f"\nUsing model {model} for {num_tokens} tokens (approx)")
+        self.logger.info(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
 
         # 方案1:通过 system prompt 控制中文输出(DeepSeek 不支持 language 参数)
         # 检查配置中的语言设置,并在 system prompt 中添加中文指令
         # language_setting = self.config.get("language", "").lower() if self.config else ""
-        # print(f"DEBUG: language_setting='{language_setting}', model='{model}', enable_thinking={enable_thinking}")
+        # self.logger.debug(f"language_setting='{language_setting}', model='{model}', enable_thinking={enable_thinking}")
         
         # if language_setting == "chinese" and enable_thinking:
-        #     print("DEBUG: ✅ 触发中文指令添加")
+        #     self.logger.debug("触发中文指令添加")
         #     # 为推理模型添加中文思考指令
         #     chinese_instruction = {"role": "system", "content": "请用中文进行思考和回答。在推理过程中,请使用中文进行分析和思考。<think></think>之间也请使用中文"}
         #     # 如果第一条消息不是 system 消息,则添加中文指令
@@ -90,7 +90,7 @@ class DeepSeekChat(BaseLLMChat):
         #         existing_content = prompt[0]["content"]
         #         prompt[0]["content"] = f"{existing_content}\n\n请用中文进行思考和回答。在推理过程中,请使用中文进行分析和思考。<think></think>之间也请使用中文"
         # else:
-        #     print(f"DEBUG: ❌ 未触发中文指令 - language_setting==chinese: {language_setting == 'chinese'}, model==deepseek-reasoner: {model == 'deepseek-reasoner'}, enable_thinking: {enable_thinking}")
+        #     self.logger.debug(f"未触发中文指令 - language_setting==chinese: {language_setting == 'chinese'}, model==deepseek-reasoner: {model == 'deepseek-reasoner'}, enable_thinking: {enable_thinking}")
 
         # 构建 API 调用参数
         api_params = {
@@ -112,7 +112,7 @@ class DeepSeekChat(BaseLLMChat):
             unsupported_params = ['top_p', 'presence_penalty', 'frequency_penalty', 'logprobs', 'top_logprobs']
             for param in unsupported_params:
                 if param in filtered_kwargs:
-                    print(f"警告:deepseek-reasoner 不支持参数 {param},已忽略")
+                    self.logger.warning(f"deepseek-reasoner 不支持参数 {param},已忽略")
                     filtered_kwargs.pop(param, None)
         else:
             # deepseek-chat 等其他模型,只过滤明确会导致错误的参数
@@ -125,9 +125,9 @@ class DeepSeekChat(BaseLLMChat):
         if stream_mode:
             # 流式处理模式
             if model == "deepseek-reasoner" and enable_thinking:
-                print("使用流式处理模式,启用推理功能")
+                self.logger.info("使用流式处理模式,启用推理功能")
             else:
-                print("使用流式处理模式,常规聊天")
+                self.logger.info("使用流式处理模式,常规聊天")
             
             response_stream = self.client.chat.completions.create(**api_params)
             
@@ -151,7 +151,7 @@ class DeepSeekChat(BaseLLMChat):
                 # 可选:打印推理过程
                 if collected_reasoning:
                     reasoning_text = "".join(collected_reasoning)
-                    print("Model reasoning process:\n", reasoning_text)
+                    self.logger.debug("Model reasoning process:\n" + reasoning_text)
                 
                 # 方案2:返回包含 <think></think> 标签的完整内容,与 QianWen 保持一致
                 final_content = "".join(collected_content)
@@ -173,9 +173,9 @@ class DeepSeekChat(BaseLLMChat):
         else:
             # 非流式处理模式
             if model == "deepseek-reasoner" and enable_thinking:
-                print("使用非流式处理模式,启用推理功能")
+                self.logger.info("使用非流式处理模式,启用推理功能")
             else:
-                print("使用非流式处理模式,常规聊天")
+                self.logger.info("使用非流式处理模式,常规聊天")
             
             response = self.client.chat.completions.create(**api_params)
             
@@ -187,7 +187,7 @@ class DeepSeekChat(BaseLLMChat):
                 reasoning_content = ""
                 if hasattr(message, 'reasoning_content') and message.reasoning_content:
                     reasoning_content = message.reasoning_content
-                    print("Model reasoning process:\n", reasoning_content)
+                    self.logger.debug("Model reasoning process:\n" + reasoning_content)
                 
                 # 方案2:返回包含 <think></think> 标签的完整内容,与 QianWen 保持一致
                 final_content = message.content

+ 31 - 31
customllm/ollama_chat.py

@@ -9,8 +9,8 @@ class OllamaChat(BaseLLMChat):
     """Ollama AI聊天实现"""
     
     def __init__(self, config=None):
-        print("...OllamaChat init...")
         super().__init__(config=config)
+        self.logger.info("OllamaChat init")
 
         # Ollama特定的配置参数
         self.base_url = config.get("base_url", "http://localhost:11434") if config else "http://localhost:11434"
@@ -31,13 +31,13 @@ class OllamaChat(BaseLLMChat):
         try:
             response = requests.get(f"{self.base_url}/api/tags", timeout=5)
             if response.status_code == 200:
-                print(f"✅ Ollama 服务连接正常: {self.base_url}")
+                self.logger.info(f"Ollama 服务连接正常: {self.base_url}")
                 return True
             else:
-                print(f"⚠️ Ollama 服务响应异常: {response.status_code}")
+                self.logger.warning(f"Ollama 服务响应异常: {response.status_code}")
                 return False
         except requests.exceptions.RequestException as e:
-            print(f"❌ Ollama 服务连接失败: {e}")
+            self.logger.error(f"Ollama 服务连接失败: {e}")
             return False
 
     def submit_prompt(self, prompt, **kwargs) -> str:
@@ -61,7 +61,7 @@ class OllamaChat(BaseLLMChat):
         # Ollama 约束:enable_thinking=True时建议使用stream=True
         # 如果stream=False但enable_thinking=True,则忽略enable_thinking
         if enable_thinking and not stream_mode:
-            print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
+            self.logger.warning("enable_thinking=True 不生效,因为它需要 stream=True")
             enable_thinking = False
 
         # 智能模型选择
@@ -72,10 +72,10 @@ class OllamaChat(BaseLLMChat):
         
         # 模型兼容性提示(但不强制切换)
         if enable_thinking and not is_reasoning_model:
-            print(f"提示:模型 {model} 不是专门的推理模型,但仍会尝试启用推理功能")
+            self.logger.warning(f"提示:模型 {model} 不是专门的推理模型,但仍会尝试启用推理功能")
 
-        print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
-        print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
+        self.logger.info(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
+        self.logger.info(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
 
         # 准备Ollama API请求
         url = f"{self.base_url}/api/chat"
@@ -91,22 +91,22 @@ class OllamaChat(BaseLLMChat):
             if stream_mode:
                 # 流式处理模式
                 if enable_thinking:
-                    print("使用流式处理模式,启用推理功能")
+                    self.logger.info("使用流式处理模式,启用推理功能")
                 else:
-                    print("使用流式处理模式,常规聊天")
+                    self.logger.info("使用流式处理模式,常规聊天")
                 
                 return self._handle_stream_response(url, payload, enable_thinking)
             else:
                 # 非流式处理模式
                 if enable_thinking:
-                    print("使用非流式处理模式,启用推理功能")
+                    self.logger.info("使用非流式处理模式,启用推理功能")
                 else:
-                    print("使用非流式处理模式,常规聊天")
+                    self.logger.info("使用非流式处理模式,常规聊天")
                 
                 return self._handle_non_stream_response(url, payload, enable_thinking)
                 
         except requests.exceptions.RequestException as e:
-            print(f"Ollama API请求失败: {e}")
+            self.logger.error(f"Ollama API请求失败: {e}")
             raise Exception(f"Ollama API调用失败: {str(e)}")
 
     def _handle_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
@@ -146,7 +146,7 @@ class OllamaChat(BaseLLMChat):
             reasoning_content, final_content = self._extract_reasoning(full_content)
             
             if reasoning_content:
-                print("Model reasoning process:\n", reasoning_content)
+                self.logger.debug("Model reasoning process:\n" + reasoning_content)
                 return final_content
         
         return full_content
@@ -169,7 +169,7 @@ class OllamaChat(BaseLLMChat):
             reasoning_content, final_content = self._extract_reasoning(content)
             
             if reasoning_content:
-                print("Model reasoning process:\n", reasoning_content)
+                self.logger.debug("Model reasoning process:\n" + reasoning_content)
                 return final_content
         
         return content
@@ -197,17 +197,17 @@ class OllamaChat(BaseLLMChat):
                 
                 # 检查目标模型是否存在
                 if self.model not in result["available_models"]:
-                    print(f"警告:模型 {self.model} 不存在,尝试拉取...")
+                    self.logger.warning(f"模型 {self.model} 不存在,尝试拉取...")
                     if not self.pull_model(self.model):
                         result["message"] = f"模型 {self.model} 不存在且拉取失败"
                         return result
             except Exception as e:
-                print(f"获取模型列表失败: {e}")
+                self.logger.error(f"获取模型列表失败: {e}")
                 result["available_models"] = [self.model]
             
-            print(f"测试Ollama连接 - 模型: {self.model}")
-            print(f"Ollama服务地址: {self.base_url}")
-            print(f"可用模型: {', '.join(result['available_models'])}")
+            self.logger.info(f"测试Ollama连接 - 模型: {self.model}")
+            self.logger.info(f"Ollama服务地址: {self.base_url}")
+            self.logger.info(f"可用模型: {', '.join(result['available_models'])}")
             
             # 测试简单对话
             prompt = [self.user_message(test_prompt)]
@@ -243,10 +243,10 @@ class OllamaChat(BaseLLMChat):
                     if reasoning_models:
                         return reasoning_models[0]  # 选择第一个推理模型
                     else:
-                        print("警告:未找到推理模型,使用默认模型")
+                        self.logger.warning("未找到推理模型,使用默认模型")
                         return self.model
                 except Exception as e:
-                    print(f"获取模型列表时出错: {e},使用默认模型")
+                    self.logger.error(f"获取模型列表时出错: {e},使用默认模型")
                     return self.model
             else:
                 # 根据 token 数量选择模型
@@ -258,7 +258,7 @@ class OllamaChat(BaseLLMChat):
                         if long_context_models:
                             return long_context_models[0]
                     except Exception as e:
-                        print(f"获取模型列表时出错: {e},使用默认模型")
+                        self.logger.error(f"获取模型列表时出错: {e},使用默认模型")
                 
                 return self.model
 
@@ -357,26 +357,26 @@ class OllamaChat(BaseLLMChat):
             models = [model["name"] for model in data.get("models", [])]
             return models if models else [self.model]  # 如果没有模型,返回默认模型
         except requests.exceptions.RequestException as e:
-            print(f"获取模型列表失败: {e}")
+            self.logger.error(f"获取模型列表失败: {e}")
             return [self.model]  # 返回默认模型
         except Exception as e:
-            print(f"解析模型列表失败: {e}")
+            self.logger.error(f"解析模型列表失败: {e}")
             return [self.model]  # 返回默认模型
 
     def pull_model(self, model_name: str) -> bool:
         """拉取模型"""
         try:
-            print(f"正在拉取模型: {model_name}")
+            self.logger.info(f"正在拉取模型: {model_name}")
             response = requests.post(
                 f"{self.base_url}/api/pull",
                 json={"name": model_name},
                 timeout=300  # 拉取模型可能需要较长时间
             )
             response.raise_for_status()
-            print(f"✅ 模型 {model_name} 拉取成功")
+            self.logger.info(f"模型 {model_name} 拉取成功")
             return True
         except requests.exceptions.RequestException as e:
-            print(f"❌ 模型 {model_name} 拉取失败: {e}")
+            self.logger.error(f"模型 {model_name} 拉取失败: {e}")
             return False
 
     def delete_model(self, model_name: str) -> bool:
@@ -388,10 +388,10 @@ class OllamaChat(BaseLLMChat):
                 timeout=self.timeout
             )
             response.raise_for_status()
-            print(f"✅ 模型 {model_name} 删除成功")
+            self.logger.info(f"模型 {model_name} 删除成功")
             return True
         except requests.exceptions.RequestException as e:
-            print(f"❌ 模型 {model_name} 删除失败: {e}")
+            self.logger.error(f"模型 {model_name} 删除失败: {e}")
             return False
 
     def get_model_info(self, model_name: str) -> Optional[Dict]:
@@ -405,7 +405,7 @@ class OllamaChat(BaseLLMChat):
             response.raise_for_status()
             return response.json()
         except requests.exceptions.RequestException as e:
-            print(f"获取模型信息失败: {e}")
+            self.logger.error(f"获取模型信息失败: {e}")
             return None
 
     def get_system_info(self) -> Dict:

+ 8 - 8
customllm/qianwen_chat.py

@@ -7,8 +7,8 @@ class QianWenChat(BaseLLMChat):
     """千问AI聊天实现"""
     
     def __init__(self, client=None, config=None):
-        print("...QianWenChat init...")
         super().__init__(config=config)
+        self.logger.info("QianWenChat init")
 
         if "api_type" in config:
             raise Exception(
@@ -65,7 +65,7 @@ class QianWenChat(BaseLLMChat):
         # 千问API约束:enable_thinking=True时必须stream=True
         # 如果stream=False但enable_thinking=True,则忽略enable_thinking
         if enable_thinking and not stream_mode:
-            print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
+            self.logger.warning("enable_thinking=True 不生效,因为它需要 stream=True")
             enable_thinking = False
         
         # 创建一个干净的kwargs副本,移除可能导致API错误的自定义参数
@@ -112,15 +112,15 @@ class QianWenChat(BaseLLMChat):
                 model = "qwen-plus"
             common_params["model"] = model
         
-        print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
-        print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
+        self.logger.info(f"\nUsing model {model} for {num_tokens} tokens (approx)")
+        self.logger.info(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
         
         if stream_mode:
             # 流式处理模式
             if enable_thinking:
-                print("使用流式处理模式,启用thinking功能")
+                self.logger.info("使用流式处理模式,启用thinking功能")
             else:
-                print("使用流式处理模式,不启用thinking功能")
+                self.logger.info("使用流式处理模式,不启用thinking功能")
             
             response_stream = self.client.chat.completions.create(**common_params)
             
@@ -144,7 +144,7 @@ class QianWenChat(BaseLLMChat):
             # 可以在这里处理thinking的展示逻辑,如保存到日志等
             if enable_thinking and collected_thinking:
                 thinking_text = "".join(collected_thinking)
-                print("Model thinking process:\n", thinking_text)
+                self.logger.debug("Model thinking process:\n" + thinking_text)
             
             # 返回包含 <think></think> 标签的完整内容,与界面显示需求保持一致
             final_content = "".join(collected_content)
@@ -155,7 +155,7 @@ class QianWenChat(BaseLLMChat):
                 return final_content
         else:
             # 非流式处理模式
-            print("使用非流式处理模式")
+            self.logger.info("使用非流式处理模式")
             response = self.client.chat.completions.create(**common_params)
             
             # Find the first response from the chatbot that has text in it (some responses may not have text)

+ 58 - 22
custompgvector/pgvector.py

@@ -7,6 +7,7 @@ import pandas as pd
 from langchain_core.documents import Document
 from langchain_postgres.vectorstores import PGVector
 from sqlalchemy import create_engine, text
+from core.logging import get_vanna_logger
 
 from vanna.exceptions import ValidationError
 from vanna.base import VannaBase
@@ -23,6 +24,9 @@ class PG_VectorStore(VannaBase):
                 "A valid 'config' dictionary with a 'connection_string' is required.")
 
         VannaBase.__init__(self, config=config)
+        
+        # 初始化日志
+        self.logger = get_vanna_logger("PGVector")
 
         if config and "connection_string" in config:
             self.connection_string = config.get("connection_string")
@@ -135,7 +139,7 @@ class PG_VectorStore(VannaBase):
                 if generated_embedding:
                     embedding_cache.cache_embedding(question, generated_embedding)
             except Exception as e:
-                print(f"[WARNING] 缓存embedding失败: {e}")
+                self.logger.warning(f"缓存embedding失败: {e}")
 
         results = []
         for doc, score in docs_with_scores:
@@ -146,12 +150,16 @@ class PG_VectorStore(VannaBase):
             similarity = round(1 - score, 4)
 
             # 每条记录单独打印
-            print(f"[DEBUG] SQL Match: {base.get('question', '')} | similarity: {similarity}")
+            self.logger.debug(f"SQL Match: {base.get('question', '')} | similarity: {similarity}")
 
             # 添加 similarity 字段
             base["similarity"] = similarity
             results.append(base)
 
+        # 检查原始查询结果是否为空
+        if not results:
+            self.logger.warning(f"向量查询未找到任何相似的SQL问答对,问题: {question}")
+
         # 应用阈值过滤
         filtered_results = self._apply_score_threshold_filter(
             results, 
@@ -159,6 +167,10 @@ class PG_VectorStore(VannaBase):
             "SQL"
         )
 
+        # 检查过滤后结果是否为空
+        if results and not filtered_results:
+            self.logger.warning(f"向量查询找到了 {len(results)} 条SQL问答对,但全部被阈值过滤掉,问题: {question}")
+
         return filtered_results
 
     def get_related_ddl(self, question: str, **kwargs) -> list:
@@ -186,7 +198,7 @@ class PG_VectorStore(VannaBase):
                 if generated_embedding:
                     embedding_cache.cache_embedding(question, generated_embedding)
             except Exception as e:
-                print(f"[WARNING] 缓存embedding失败: {e}")
+                self.logger.warning(f"缓存embedding失败: {e}")
 
         results = []
         for doc, score in docs_with_scores:
@@ -194,7 +206,7 @@ class PG_VectorStore(VannaBase):
             similarity = round(1 - score, 4)
 
             # 每条记录单独打印
-            print(f"[DEBUG] DDL Match: {doc.page_content[:50]}... | similarity: {similarity}")
+            self.logger.debug(f"DDL Match: {doc.page_content[:50]}... | similarity: {similarity}")
 
             # 添加 similarity 字段
             result = {
@@ -203,6 +215,10 @@ class PG_VectorStore(VannaBase):
             }
             results.append(result)
 
+        # 检查原始查询结果是否为空
+        if not results:
+            self.logger.warning(f"向量查询未找到任何相关的DDL表结构,问题: {question}")
+
         # 应用阈值过滤
         filtered_results = self._apply_score_threshold_filter(
             results, 
@@ -210,6 +226,10 @@ class PG_VectorStore(VannaBase):
             "DDL"
         )
 
+        # 检查过滤后结果是否为空
+        if results and not filtered_results:
+            self.logger.warning(f"向量查询找到了 {len(results)} 条DDL表结构,但全部被阈值过滤掉,问题: {question}")
+
         return filtered_results
 
     def get_related_documentation(self, question: str, **kwargs) -> list:
@@ -237,7 +257,7 @@ class PG_VectorStore(VannaBase):
                 if generated_embedding:
                     embedding_cache.cache_embedding(question, generated_embedding)
             except Exception as e:
-                print(f"[WARNING] 缓存embedding失败: {e}")
+                self.logger.warning(f"缓存embedding失败: {e}")
 
         results = []
         for doc, score in docs_with_scores:
@@ -245,7 +265,7 @@ class PG_VectorStore(VannaBase):
             similarity = round(1 - score, 4)
 
             # 每条记录单独打印
-            print(f"[DEBUG] Doc Match: {doc.page_content[:50]}... | similarity: {similarity}")
+            self.logger.debug(f"Doc Match: {doc.page_content[:50]}... | similarity: {similarity}")
 
             # 添加 similarity 字段
             result = {
@@ -254,6 +274,10 @@ class PG_VectorStore(VannaBase):
             }
             results.append(result)
 
+        # 检查原始查询结果是否为空
+        if not results:
+            self.logger.warning(f"向量查询未找到任何相关的文档,问题: {question}")
+
         # 应用阈值过滤
         filtered_results = self._apply_score_threshold_filter(
             results, 
@@ -261,6 +285,10 @@ class PG_VectorStore(VannaBase):
             "DOC"
         )
 
+        # 检查过滤后结果是否为空
+        if results and not filtered_results:
+            self.logger.warning(f"向量查询找到了 {len(results)} 条文档,但全部被阈值过滤掉,问题: {question}")
+
         return filtered_results
 
     def _apply_score_threshold_filter(self, results: list, threshold_config_key: str, result_type: str) -> list:
@@ -284,19 +312,19 @@ class PG_VectorStore(VannaBase):
             enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
             threshold = getattr(app_config, threshold_config_key, 0.65)
         except (ImportError, AttributeError) as e:
-            print(f"[WARNING] 无法加载阈值配置: {e},使用默认值")
+            self.logger.warning(f"无法加载阈值配置: {e},使用默认值")
             enable_threshold = False
             threshold = 0.65
         
         # 如果未启用阈值过滤,直接返回原结果
         if not enable_threshold:
-            print(f"[DEBUG] {result_type} 阈值过滤未启用,返回全部 {len(results)} 条结果")
+            self.logger.debug(f"{result_type} 阈值过滤未启用,返回全部 {len(results)} 条结果")
             return results
         
         total_count = len(results)
         min_required = max((total_count + 1) // 2, 1)
         
-        print(f"[DEBUG] {result_type} 阈值过滤: 总数={total_count}, 阈值={threshold}, 最少保留={min_required}")
+        self.logger.debug(f"{result_type} 阈值过滤: 总数={total_count}, 阈值={threshold}, 最少保留={min_required}")
         
         # 按相似度降序排序(确保最相似的在前面)
         sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
@@ -309,20 +337,20 @@ class PG_VectorStore(VannaBase):
             # 情况1: 满足阈值的结果数量 >= 最少保留数量,返回满足阈值的结果
             filtered_results = above_threshold
             filtered_count = len(above_threshold)
-            print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (全部满足阈值)")
+            self.logger.debug(f"{result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (全部满足阈值)")
         else:
             # 情况2: 满足阈值的结果数量 < 最少保留数量,强制保留前 min_required 条
             filtered_results = sorted_results[:min_required]
             above_count = len(above_threshold)
             below_count = min_required - above_count
             filtered_count = min_required
-            print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (满足阈值: {above_count}, 强制保留: {below_count})")
+            self.logger.debug(f"{result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (满足阈值: {above_count}, 强制保留: {below_count})")
         
         # 打印过滤详情
         for i, result in enumerate(filtered_results):
             similarity = result.get('similarity', 0)
             status = "✓" if similarity >= threshold else "✗"
-            print(f"[DEBUG] {result_type} 保留 {i+1}: similarity={similarity} {status}")
+            self.logger.debug(f"{result_type} 保留 {i+1}: similarity={similarity} {status}")
         
         return filtered_results
 
@@ -350,17 +378,17 @@ class PG_VectorStore(VannaBase):
             enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
             threshold = getattr(app_config, 'RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD', 0.5)
         except (ImportError, AttributeError) as e:
-            print(f"[WARNING] 无法加载错误SQL阈值配置: {e},使用默认值")
+            self.logger.warning(f"无法加载错误SQL阈值配置: {e},使用默认值")
             enable_threshold = False
             threshold = 0.5
         
         # 如果未启用阈值过滤,直接返回原结果
         if not enable_threshold:
-            print(f"[DEBUG] Error SQL 阈值过滤未启用,返回全部 {len(results)} 条结果")
+            self.logger.debug(f"Error SQL 阈值过滤未启用,返回全部 {len(results)} 条结果")
             return results
         
         total_count = len(results)
-        print(f"[DEBUG] Error SQL 阈值过滤: 总数={total_count}, 阈值={threshold}")
+        self.logger.debug(f"Error SQL 阈值过滤: 总数={total_count}, 阈值={threshold}")
         
         # 按相似度降序排序(确保最相似的在前面)
         sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
@@ -372,13 +400,13 @@ class PG_VectorStore(VannaBase):
         filtered_out_count = total_count - filtered_count
         
         if filtered_count > 0:
-            print(f"[DEBUG] Error SQL 过滤结果: 保留 {filtered_count} 条, 过滤掉 {filtered_out_count} 条")
+            self.logger.debug(f"Error SQL 过滤结果: 保留 {filtered_count} 条, 过滤掉 {filtered_out_count} 条")
             # 打印保留的结果详情
             for i, result in enumerate(filtered_results):
                 similarity = result.get('similarity', 0)
-                print(f"[DEBUG] Error SQL 保留 {i+1}: similarity={similarity} ✓")
+                self.logger.debug(f"Error SQL 保留 {i+1}: similarity={similarity} ✓")
         else:
-            print(f"[DEBUG] Error SQL 过滤结果: 所有 {total_count} 条结果都低于阈值 {threshold},返回空列表")
+            self.logger.debug(f"Error SQL 过滤结果: 所有 {total_count} 条结果都低于阈值 {threshold},返回空列表")
         
         return filtered_results
 
@@ -610,7 +638,7 @@ class PG_VectorStore(VannaBase):
                     if generated_embedding:
                         embedding_cache.cache_embedding(question, generated_embedding)
                 except Exception as e:
-                    print(f"[WARNING] 缓存embedding失败: {e}")
+                    self.logger.warning(f"缓存embedding失败: {e}")
             
             results = []
             for doc, score in docs_with_scores:
@@ -622,21 +650,29 @@ class PG_VectorStore(VannaBase):
                     similarity = round(1 - score, 4)
                     
                     # 每条记录单独打印
-                    print(f"[DEBUG] Error SQL Match: {base.get('question', '')} | similarity: {similarity}")
+                    self.logger.debug(f"Error SQL Match: {base.get('question', '')} | similarity: {similarity}")
                     
                     # 添加 similarity 字段
                     base["similarity"] = similarity
                     results.append(base)
                     
                 except (ValueError, SyntaxError) as e:
-                    print(f"Error parsing error SQL document: {e}")
+                    self.logger.error(f"Error parsing error SQL document: {e}")
                     continue
             
+            # 检查原始查询结果是否为空
+            if not results:
+                self.logger.warning(f"向量查询未找到任何相关的错误SQL示例,问题: {question}")
+
             # 应用错误SQL特有的阈值过滤逻辑
             filtered_results = self._apply_error_sql_threshold_filter(results)
             
+            # 检查过滤后结果是否为空
+            if results and not filtered_results:
+                self.logger.warning(f"向量查询找到了 {len(results)} 条错误SQL示例,但全部被阈值过滤掉,问题: {question}")
+
             return filtered_results
             
         except Exception as e:
-            print(f"Error retrieving error SQL examples: {e}")
+            self.logger.error(f"Error retrieving error SQL examples: {e}")
             return []

+ 2 - 2
data_pipeline/analyzers/md_analyzer.py

@@ -1,6 +1,6 @@
-import logging
 from pathlib import Path
 from typing import List, Dict, Any
+from core.logging import get_data_pipeline_logger
 
 
 class MDFileAnalyzer:
@@ -8,7 +8,7 @@ class MDFileAnalyzer:
     
     def __init__(self, output_dir: str):
         self.output_dir = Path(output_dir)
-        self.logger = logging.getLogger("schema_tools.MDFileAnalyzer")
+        self.logger = get_data_pipeline_logger("MDFileAnalyzer")
         
     async def read_all_md_files(self) -> str:
         """

+ 2 - 2
data_pipeline/analyzers/theme_extractor.py

@@ -1,9 +1,9 @@
 import asyncio
 import json
-import logging
 from typing import List, Dict, Any
 
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
+from core.logging import get_data_pipeline_logger
 
 
 class ThemeExtractor:
@@ -19,7 +19,7 @@ class ThemeExtractor:
         """
         self.vn = vn
         self.business_context = business_context
-        self.logger = logging.getLogger("schema_tools.ThemeExtractor")
+        self.logger = get_data_pipeline_logger("ThemeExtractor")
         self.config = SCHEMA_TOOLS_CONFIG
         
     async def extract_themes(self, md_contents: str) -> List[Dict[str, Any]]:

+ 3 - 1
data_pipeline/config.py

@@ -169,4 +169,6 @@ def validate_config():
 try:
     validate_config()
 except ValueError as e:
-    print(f"警告: {e}")
+    # 在配置文件中使用stderr输出警告,避免依赖logging
+    import sys
+    print(f"警告: {e}", file=sys.stderr)

+ 2 - 3
data_pipeline/ddl_generation/training_data_agent.py

@@ -1,6 +1,5 @@
 import asyncio
 import time
-import logging
 import os
 from typing import List, Dict, Any, Optional
 from pathlib import Path
@@ -11,8 +10,8 @@ from data_pipeline.utils.file_manager import FileNameManager
 from data_pipeline.utils.system_filter import SystemTableFilter
 from data_pipeline.utils.permission_checker import DatabasePermissionChecker
 from data_pipeline.utils.table_parser import TableListParser
-from data_pipeline.utils.logger import setup_logging
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
+from core.logging import get_data_pipeline_logger
 
 class SchemaTrainingDataAgent:
     """Schema训练数据生成AI Agent"""
@@ -50,7 +49,7 @@ class SchemaTrainingDataAgent:
         }
         
         self.failed_tables = []
-        self.logger = logging.getLogger("schema_tools.Agent")
+        self.logger = get_data_pipeline_logger("SchemaTrainingDataAgent")
     
     async def generate_training_data(self) -> Dict[str, Any]:
         """主入口:生成训练数据"""

+ 49 - 42
data_pipeline/metadata_only_generator.py

@@ -15,6 +15,7 @@ from data_pipeline.analyzers import MDFileAnalyzer, ThemeExtractor
 from data_pipeline.validators import FileCountValidator
 from data_pipeline.utils.logger import setup_logging
 from core.vanna_llm_factory import create_vanna_instance
+from core.logging import get_data_pipeline_logger
 
 
 class MetadataOnlyGenerator:
@@ -45,10 +46,13 @@ class MetadataOnlyGenerator:
         self.vn = None
         self.theme_extractor = None
         
-        print(f"🎯 元数据生成器初始化完成")
-        print(f"📁 输出目录: {output_dir}")
-        print(f"🏢 业务背景: {business_context}")
-        print(f"💾 数据库: {self.db_name}")
+        # 初始化logger
+        self.logger = get_data_pipeline_logger("MetadataOnlyGenerator")
+        
+        self.logger.info(f"🎯 元数据生成器初始化完成")
+        self.logger.info(f"📁 输出目录: {output_dir}")
+        self.logger.info(f"🏢 业务背景: {business_context}")
+        self.logger.info(f"💾 数据库: {self.db_name}")
     
     async def generate_metadata_only(self) -> Dict[str, Any]:
         """
@@ -58,50 +62,50 @@ class MetadataOnlyGenerator:
             生成结果报告
         """
         try:
-            print("🚀 开始生成元数据文件...")
+            self.logger.info("🚀 开始生成元数据文件...")
             
             # 1. 验证文件数量
-            print("📋 验证文件数量...")
+            self.logger.info("📋 验证文件数量...")
             validation_result = self.validator.validate(self.table_list_file, str(self.output_dir))
             
             if not validation_result.is_valid:
-                print(f"❌ 文件验证失败: {validation_result.error}")
+                self.logger.error(f"❌ 文件验证失败: {validation_result.error}")
                 if validation_result.missing_ddl:
-                    print(f"缺失DDL文件: {validation_result.missing_ddl}")
+                    self.logger.error(f"缺失DDL文件: {validation_result.missing_ddl}")
                 if validation_result.missing_md:
-                    print(f"缺失MD文件: {validation_result.missing_md}")
+                    self.logger.error(f"缺失MD文件: {validation_result.missing_md}")
                 raise ValueError(f"文件验证失败: {validation_result.error}")
             
-            print(f"✅ 文件验证通过: {validation_result.table_count}个表")
+            self.logger.info(f"✅ 文件验证通过: {validation_result.table_count}个表")
             
             # 2. 读取所有MD文件内容
-            print("📖 读取MD文件...")
+            self.logger.info("📖 读取MD文件...")
             md_contents = await self.md_analyzer.read_all_md_files()
             
             # 3. 初始化LLM相关组件
             self._initialize_llm_components()
             
             # 4. 提取分析主题
-            print("🎯 提取分析主题...")
+            self.logger.info("🎯 提取分析主题...")
             themes = await self.theme_extractor.extract_themes(md_contents)
-            print(f"✅ 成功提取 {len(themes)} 个分析主题")
+            self.logger.info(f"✅ 成功提取 {len(themes)} 个分析主题")
             
 
             for i, theme in enumerate(themes):
                 topic_name = theme.get('topic_name', theme.get('name', ''))
                 description = theme.get('description', '')
-                print(f"  {i+1}. {topic_name}: {description}")
+                self.logger.info(f"  {i+1}. {topic_name}: {description}")
             
             # 5. 生成metadata.txt文件
-            print("📝 生成metadata.txt...")
+            self.logger.info("📝 生成metadata.txt...")
             metadata_file = await self._generate_metadata_file(themes)
             
             # 6. 生成metadata_detail.md文件
-            print("📝 生成metadata_detail.md...")
+            self.logger.info("📝 生成metadata_detail.md...")
             metadata_md_file = await self._generate_metadata_md_file(themes)
             
             # 7. 生成db_query_decision_prompt.txt文件
-            print("📝 生成db_query_decision_prompt.txt...")
+            self.logger.info("📝 生成db_query_decision_prompt.txt...")
             decision_prompt_file = await self._generate_decision_prompt_file(themes, md_contents)
             
             # 8. 生成报告
@@ -119,13 +123,13 @@ class MetadataOnlyGenerator:
             return report
             
         except Exception as e:
-            print(f"❌ 元数据生成失败: {e}")
+            self.logger.error(f"❌ 元数据生成失败: {e}")
             raise
     
     def _initialize_llm_components(self):
         """初始化LLM相关组件"""
         if not self.vn:
-            print("🤖 初始化LLM组件...")
+            self.logger.info("🤖 初始化LLM组件...")
             self.vn = create_vanna_instance()
             self.theme_extractor = ThemeExtractor(self.vn, self.business_context)
     
@@ -188,11 +192,11 @@ class MetadataOnlyGenerator:
                     f.write(f"  '{metrics_str}'\n")
                     f.write(");\n\n")
             
-            print(f"✅ metadata.txt文件已生成: {metadata_file}")
+            self.logger.info(f"✅ metadata.txt文件已生成: {metadata_file}")
             return metadata_file
             
         except Exception as e:
-            print(f"❌ 生成metadata.txt文件失败: {e}")
+            self.logger.error(f"❌ 生成metadata.txt文件失败: {e}")
             return None
     
     async def _generate_metadata_md_file(self, themes: List[Dict]):
@@ -240,11 +244,11 @@ class MetadataOnlyGenerator:
                 f.write("- `biz_entities` 表示主题关注的核心对象,例如服务区、车辆、公司;\n")
                 f.write("- `biz_metrics` 表示该主题关注的业务分析指标,例如营收对比、趋势变化、占比结构等。\n")
             
-            print(f"✅ metadata_detail.md文件已生成: {metadata_md_file}")
+            self.logger.info(f"✅ metadata_detail.md文件已生成: {metadata_md_file}")
             return metadata_md_file
             
         except Exception as e:
-            print(f"❌ 生成metadata_detail.md文件失败: {e}")
+            self.logger.error(f"❌ 生成metadata_detail.md文件失败: {e}")
             return None
     
     async def _generate_decision_prompt_file(self, themes: List[Dict], md_contents: str):
@@ -259,20 +263,20 @@ class MetadataOnlyGenerator:
             with open(decision_prompt_file, 'w', encoding='utf-8') as f:
                 f.write(decision_content)
             
-            print(f"✅ db_query_decision_prompt.txt文件已生成: {decision_prompt_file}")
+            self.logger.info(f"✅ db_query_decision_prompt.txt文件已生成: {decision_prompt_file}")
             return decision_prompt_file
             
         except Exception as e:
-            print(f"❌ 生成db_query_decision_prompt.txt文件失败: {e}")
+            self.logger.error(f"❌ 生成db_query_decision_prompt.txt文件失败: {e}")
             # 如果LLM调用失败,使用回退方案
             try:
                 fallback_content = await self._generate_fallback_decision_content(themes)
                 with open(decision_prompt_file, 'w', encoding='utf-8') as f:
                     f.write(fallback_content)
-                print(f"⚠️ 使用回退方案生成了 {decision_prompt_file}")
+                self.logger.warning(f"⚠️ 使用回退方案生成了 {decision_prompt_file}")
                 return decision_prompt_file
             except Exception as fallback_error:
-                print(f"❌ 回退方案也失败: {fallback_error}")
+                self.logger.error(f"❌ 回退方案也失败: {fallback_error}")
                 return None
     
     async def _generate_decision_prompt_with_llm(self, themes: List[Dict], md_contents: str) -> str:
@@ -326,7 +330,7 @@ class MetadataOnlyGenerator:
             return response.strip()
             
         except Exception as e:
-            print(f"❌ LLM生成决策提示内容失败: {e}")
+            self.logger.error(f"❌ LLM生成决策提示内容失败: {e}")
             # 回退方案:生成基础内容
             return await self._generate_fallback_decision_content(themes)
     
@@ -370,7 +374,7 @@ class MetadataOnlyGenerator:
                 raise Exception("LLM返回内容不合理")
                 
         except Exception as e:
-            print(f"⚠️ 简化LLM调用也失败,使用完全兜底方案: {e}")
+            self.logger.warning(f"⚠️ 简化LLM调用也失败,使用完全兜底方案: {e}")
             # 真正的最后兜底
             content += f"当前数据库存储的是{self.business_context}的相关数据,主要涉及相关业务数据,包含以下业务数据:\n"
         
@@ -409,13 +413,13 @@ class MetadataOnlyGenerator:
     
     def _print_summary(self, report: Dict):
         """打印总结信息"""
-        print("=" * 60)
-        print("📊 元数据生成总结")
-        print(f"  ✅ 分析主题数: {report['total_themes']}")
-        print(f"  📄 metadata.txt: {'✅ 已生成' if report['metadata_file'] else '❌ 生成失败'}")
-        print(f"  📄 metadata_detail.md: {'✅ 已生成' if report['metadata_md_file'] else '❌ 生成失败'}")
-        print(f"  📄 db_query_decision_prompt.txt: {'✅ 已生成' if report['decision_prompt_file'] else '❌ 生成失败'}")
-        print("=" * 60)
+        self.logger.info("=" * 60)
+        self.logger.info("📊 元数据生成总结")
+        self.logger.info(f"  ✅ 分析主题数: {report['total_themes']}")
+        self.logger.info(f"  📄 metadata.txt: {'✅ 已生成' if report['metadata_file'] else '❌ 生成失败'}")
+        self.logger.info(f"  📄 metadata_detail.md: {'✅ 已生成' if report['metadata_md_file'] else '❌ 生成失败'}")
+        self.logger.info(f"  📄 db_query_decision_prompt.txt: {'✅ 已生成' if report['decision_prompt_file'] else '❌ 生成失败'}")
+        self.logger.info("=" * 60)
 
 
 def setup_argument_parser():
@@ -488,12 +492,15 @@ async def main():
     
     # 验证参数
     output_path = Path(args.output_dir)
+    # 初始化logger用于参数验证
+    logger = get_data_pipeline_logger("MetadataGeneratorMain")
+    
     if not output_path.exists():
-        print(f"错误: 输出目录不存在: {args.output_dir}")
+        logger.error(f"错误: 输出目录不存在: {args.output_dir}")
         sys.exit(1)
     
     if not os.path.exists(args.table_list):
-        print(f"错误: 表清单文件不存在: {args.table_list}")
+        logger.error(f"错误: 表清单文件不存在: {args.table_list}")
         sys.exit(1)
     
     try:
@@ -510,19 +517,19 @@ async def main():
         
         # 输出结果
         if report['success']:
-            print("\n🎉 元数据文件生成成功!")
+            logger.info("\n🎉 元数据文件生成成功!")
             exit_code = 0
         else:
-            print("\n❌ 元数据文件生成失败")
+            logger.error("\n❌ 元数据文件生成失败")
             exit_code = 1
         
         sys.exit(exit_code)
         
     except KeyboardInterrupt:
-        print("\n\n⏹️  用户中断,程序退出")
+        logger.info("\n\n⏹️  用户中断,程序退出")
         sys.exit(130)
     except Exception as e:
-        print(f"\n❌ 程序执行失败: {e}")
+        logger.error(f"\n❌ 程序执行失败: {e}")
         if args.verbose:
             import traceback
             traceback.print_exc()

+ 2 - 2
data_pipeline/qa_generation/qs_agent.py

@@ -9,7 +9,7 @@ from typing import List, Dict, Any, Optional
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
 from data_pipeline.validators import FileCountValidator
 from data_pipeline.analyzers import MDFileAnalyzer, ThemeExtractor
-from data_pipeline.utils.logger import setup_logging
+from core.logging import get_data_pipeline_logger
 from core.vanna_llm_factory import create_vanna_instance
 
 
@@ -36,7 +36,7 @@ class QuestionSQLGenerationAgent:
         self.db_name = db_name or "db"
         
         self.config = SCHEMA_TOOLS_CONFIG
-        self.logger = logging.getLogger("schema_tools.QSAgent")
+        self.logger = get_data_pipeline_logger("QSAgent")
         
         # 初始化组件
         self.validator = FileCountValidator()

+ 20 - 18
data_pipeline/schema_workflow.py

@@ -14,7 +14,7 @@ from data_pipeline.ddl_generation.training_data_agent import SchemaTrainingDataA
 from data_pipeline.qa_generation.qs_agent import QuestionSQLGenerationAgent
 from data_pipeline.validators.sql_validation_agent import SQLValidationAgent
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
-from data_pipeline.utils.logger import setup_logging
+from core.logging import get_data_pipeline_logger
 
 
 class SchemaWorkflowOrchestrator:
@@ -56,7 +56,7 @@ class SchemaWorkflowOrchestrator:
         self.output_dir.mkdir(parents=True, exist_ok=True)
         
         # 初始化日志
-        self.logger = logging.getLogger("schema_tools.SchemaWorkflowOrchestrator")
+        self.logger = get_data_pipeline_logger("SchemaWorkflow")
         
         # 工作流程状态
         self.workflow_state = {
@@ -645,7 +645,8 @@ async def main():
     
     # 验证输入文件
     if not os.path.exists(args.table_list):
-        print(f"错误: 表清单文件不存在: {args.table_list}")
+        logger = get_data_pipeline_logger("SchemaWorkflow")
+        logger.error(f"错误: 表清单文件不存在: {args.table_list}")
         sys.exit(1)
     
     try:
@@ -661,15 +662,16 @@ async def main():
             enable_training_data_load=not args.skip_training_load
         )
         
-        # 显示启动信息
-        print(f"🚀 开始执行Schema工作流编排...")
-        print(f"📁 输出目录: {args.output_dir}")
-        print(f"📋 表清单: {args.table_list}")
-        print(f"🏢 业务背景: {args.business_context}")
-        print(f"💾 数据库: {orchestrator.db_name}")
-        print(f"🔍 SQL验证: {'启用' if not args.skip_validation else '禁用'}")
-        print(f"🔧 LLM修复: {'启用' if not args.disable_llm_repair else '禁用'}")
-        print(f"🎯 训练数据加载: {'启用' if not args.skip_training_load else '禁用'}")
+        # 获取logger用于启动信息
+        logger = get_data_pipeline_logger("SchemaWorkflow")
+        logger.info(f"🚀 开始执行Schema工作流编排...")
+        logger.info(f"📁 输出目录: {args.output_dir}")
+        logger.info(f"📋 表清单: {args.table_list}")
+        logger.info(f"🏢 业务背景: {args.business_context}")
+        logger.info(f"💾 数据库: {orchestrator.db_name}")
+        logger.info(f"🔍 SQL验证: {'启用' if not args.skip_validation else '禁用'}")
+        logger.info(f"🔧 LLM修复: {'启用' if not args.disable_llm_repair else '禁用'}")
+        logger.info(f"🎯 训练数据加载: {'启用' if not args.skip_training_load else '禁用'}")
         
         # 执行完整工作流程
         report = await orchestrator.execute_complete_workflow()
@@ -680,23 +682,23 @@ async def main():
         # 输出结果并设置退出码
         if report["success"]:
             if report["processing_results"].get("sql_validation", {}).get("success_rate", 1.0) >= 0.8:
-                print(f"\n🎉 工作流程执行成功!")
+                logger.info(f"\n🎉 工作流程执行成功!")
                 exit_code = 0  # 完全成功
             else:
-                print(f"\n⚠️  工作流程执行完成,但SQL验证成功率较低")
+                logger.warning(f"\n⚠️  工作流程执行完成,但SQL验证成功率较低")
                 exit_code = 1  # 部分成功
         else:
-            print(f"\n❌ 工作流程执行失败")
+            logger.error(f"\n❌ 工作流程执行失败")
             exit_code = 2  # 失败
         
-        print(f"📄 主要输出文件: {report['final_outputs']['primary_output_file']}")
+        logger.info(f"📄 主要输出文件: {report['final_outputs']['primary_output_file']}")
         sys.exit(exit_code)
         
     except KeyboardInterrupt:
-        print("\n\n⏹️  用户中断,程序退出")
+        logger.info("\n\n⏹️  用户中断,程序退出")
         sys.exit(130)
     except Exception as e:
-        print(f"\n❌ 程序执行失败: {e}")
+        logger.error(f"\n❌ 程序执行失败: {e}")
         if args.verbose:
             import traceback
             traceback.print_exc()

+ 7 - 5
data_pipeline/tools/base.py

@@ -1,7 +1,7 @@
 import asyncio
 import time
-import logging
 from abc import ABC, abstractmethod
+from core.logging import get_data_pipeline_logger
 from typing import Dict, Any, Optional, Type, List
 from data_pipeline.utils.data_structures import ProcessingResult, TableProcessingContext
 
@@ -15,7 +15,8 @@ class ToolRegistry:
         """装饰器:注册工具"""
         def decorator(tool_class: Type['BaseTool']):
             cls._tools[name] = tool_class
-            logging.debug(f"注册工具: {name} -> {tool_class.__name__}")
+            logger = get_data_pipeline_logger("ToolRegistry")
+            logger.debug(f"注册工具: {name} -> {tool_class.__name__}")
             return tool_class
         return decorator
     
@@ -32,7 +33,8 @@ class ToolRegistry:
             if hasattr(tool_class, 'needs_llm') and tool_class.needs_llm:
                 from core.vanna_llm_factory import create_vanna_instance
                 kwargs['vn'] = create_vanna_instance()
-                logging.debug(f"为工具 {name} 注入LLM实例")
+                logger = get_data_pipeline_logger("ToolRegistry")
+                logger.debug(f"为工具 {name} 注入LLM实例")
             
             cls._instances[name] = tool_class(**kwargs)
         
@@ -55,7 +57,7 @@ class BaseTool(ABC):
     tool_name: str = ""      # 工具名称
     
     def __init__(self, **kwargs):
-        self.logger = logging.getLogger(f"schema_tools.{self.__class__.__name__}")
+        self.logger = get_data_pipeline_logger(f"tools.{self.__class__.__name__}")
         
         # 如果工具需要LLM,检查是否已注入
         if self.needs_llm and 'vn' not in kwargs:
@@ -113,7 +115,7 @@ class PipelineExecutor:
     
     def __init__(self, pipeline_config: Dict[str, List[str]]):
         self.pipeline_config = pipeline_config
-        self.logger = logging.getLogger("schema_tools.PipelineExecutor")
+        self.logger = get_data_pipeline_logger("tools.PipelineExecutor")
     
     async def execute_pipeline(self, pipeline_name: str, context: TableProcessingContext) -> Dict[str, ProcessingResult]:
         """执行指定的处理链"""

+ 35 - 31
data_pipeline/trainer/vanna_trainer.py

@@ -11,6 +11,10 @@ import sys
 import os
 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
 import app_config
+from core.logging import get_data_pipeline_logger
+
+# 初始化日志
+logger = get_data_pipeline_logger("VannaTrainer")
 
 # 设置正确的项目根目录路径
 project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -27,20 +31,20 @@ try:
     embedding_config = get_current_embedding_config()
     model_info = get_current_model_info()
     
-    print(f"\n===== Embedding模型信息 =====")
-    print(f"模型类型: {model_info['embedding_type']}")
-    print(f"模型名称: {model_info['embedding_model']}")
-    print(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
+    logger.info("===== Embedding模型信息 =====")
+    logger.info(f"模型类型: {model_info['embedding_type']}")
+    logger.info(f"模型名称: {model_info['embedding_model']}")
+    logger.info(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
     if 'base_url' in embedding_config:
-        print(f"API服务: {embedding_config['base_url']}")
-    print("==============================")
+        logger.info(f"API服务: {embedding_config['base_url']}")
+    logger.info("==============================")
 except ImportError as e:
-    print(f"警告: 无法导入配置工具函数: {e}")
-    print("使用默认配置...")
+    logger.warning(f"无法导入配置工具函数: {e}")
+    logger.info("使用默认配置...")
     embedding_config = getattr(app_config, 'API_EMBEDDING_CONFIG', {})
-    print(f"\n===== Embedding模型信息 (默认) =====")
-    print(f"模型名称: {embedding_config.get('model_name', '未知')}")
-    print("==============================")
+    logger.info("===== Embedding模型信息 (默认) =====")
+    logger.info(f"模型名称: {embedding_config.get('model_name', '未知')}")
+    logger.info("==============================")
 
 # 从app_config获取训练批处理配置
 BATCH_PROCESSING_ENABLED = app_config.TRAINING_BATCH_PROCESSING_ENABLED
@@ -63,7 +67,7 @@ class BatchProcessor:
         # 是否启用批处理
         self.batch_enabled = BATCH_PROCESSING_ENABLED       
 
-        print(f"[DEBUG] 训练批处理器初始化: 启用={self.batch_enabled}, 批大小={self.batch_size}, 最大工作线程={self.max_workers}")
+        logger.debug(f"训练批处理器初始化: 启用={self.batch_enabled}, 批大小={self.batch_size}, 最大工作线程={self.max_workers}")
     
     def add_item(self, batch_type: str, item: Dict[str, Any]):
         """添加一个项目到批处理队列"""
@@ -91,14 +95,14 @@ class BatchProcessor:
             elif batch_type == 'question_sql':
                 vn.train(question=item['question'], sql=item['sql'])
             
-            print(f"[DEBUG] 单项处理成功: {batch_type}")
+            logger.debug(f"单项处理成功: {batch_type}")
                 
         except Exception as e:
-            print(f"[ERROR] 处理 {batch_type} 项目失败: {e}")
+            logger.error(f"处理 {batch_type} 项目失败: {e}")
     
     def _process_batch(self, batch_type: str, items: List[Dict[str, Any]]):
         """处理一批项目"""
-        print(f"[INFO] 开始批量处理 {len(items)} 个 {batch_type} 项")
+        logger.info(f"开始批量处理 {len(items)} 个 {batch_type} 项")
         start_time = time.time()
         
         try:
@@ -131,46 +135,46 @@ class BatchProcessor:
             if hasattr(vn, 'add_batch') and callable(getattr(vn, 'add_batch')):
                 success = vn.add_batch(batch_data)
                 if success:
-                    print(f"[INFO] 批量处理成功: {len(items)} 个 {batch_type} 项")
+                    logger.info(f"批量处理成功: {len(items)} 个 {batch_type} 项")
                 else:
-                    print(f"[WARNING] 批量处理部分失败: {batch_type}")
+                    logger.warning(f"批量处理部分失败: {batch_type}")
             else:
                 # 如果没有批处理方法,退回到逐条处理
-                print(f"[WARNING] 批处理不可用,使用逐条处理: {batch_type}")
+                logger.warning(f"批处理不可用,使用逐条处理: {batch_type}")
                 for item in items:
                     self._process_single_item(batch_type, item)
                 
         except Exception as e:
-            print(f"[ERROR] 批处理 {batch_type} 失败: {e}")
+            logger.error(f"批处理 {batch_type} 失败: {e}")
             # 如果批处理失败,尝试逐条处理
-            print(f"[INFO] 尝试逐条处理...")
+            logger.info(f"尝试逐条处理...")
             for item in items:
                 try:
                     self._process_single_item(batch_type, item)
                 except Exception as item_e:
-                    print(f"[ERROR] 处理项目失败: {item_e}")
+                    logger.error(f"处理项目失败: {item_e}")
         
         elapsed = time.time() - start_time
-        print(f"[INFO] 批处理完成 {len(items)} 个 {batch_type} 项,耗时 {elapsed:.2f} 秒")
+        logger.info(f"批处理完成 {len(items)} 个 {batch_type} 项,耗时 {elapsed:.2f} 秒")
     
     def flush_all(self):
         """强制处理所有剩余项目"""
         with self.lock:
             for batch_type, items in self.batches.items():
                 if items:
-                    print(f"[INFO] 正在处理剩余的 {len(items)} 个 {batch_type} 项")
+                    logger.info(f"正在处理剩余的 {len(items)} 个 {batch_type} 项")
                     self._process_batch(batch_type, items)
             
             # 清空队列
             self.batches = defaultdict(list)
         
-        print("[INFO] 所有训练批处理项目已完成")
+        logger.info("所有训练批处理项目已完成")
     
     def shutdown(self):
         """关闭处理器和线程池"""
         self.flush_all()
         self.executor.shutdown(wait=True)
-        print("[INFO] 训练批处理器已关闭")
+        logger.info("训练批处理器已关闭")
 
 # 创建全局训练批处理器实例
 # 用于所有训练函数的批处理优化
@@ -178,16 +182,16 @@ batch_processor = BatchProcessor()
 
 # 原始训练函数的批处理增强版本
 def train_ddl(ddl_sql: str):
-    print(f"[DDL] Training on DDL:\n{ddl_sql}")
+    logger.debug(f"Training on DDL:\n{ddl_sql}")
     batch_processor.add_item('ddl', {'ddl': ddl_sql})
 
 def train_documentation(doc: str):
-    print(f"[DOC] Training on documentation:\n{doc}")
+    logger.debug(f"Training on documentation:\n{doc}")
     batch_processor.add_item('documentation', {'documentation': doc})
 
 def train_sql_example(sql: str):
     """训练单个SQL示例,通过SQL生成相应的问题"""
-    print(f"[SQL] Training on SQL:\n{sql}")
+    logger.debug(f"Training on SQL:\n{sql}")
     
     try:
         # 直接调用generate_question方法
@@ -198,15 +202,15 @@ def train_sql_example(sql: str):
             question += "?"
             
     except Exception as e:
-        print(f"[ERROR] 生成问题时出错: {e}")
+        logger.error(f"生成问题时出错: {e}")
         raise Exception(f"无法为SQL生成问题: {e}")
         
-    print(f"[SQL] 生成问题: {question}")
+    logger.debug(f"生成问题: {question}")
     # 使用标准方式存储问题-SQL对
     batch_processor.add_item('question_sql', {'question': question, 'sql': sql})
 
 def train_question_sql_pair(question: str, sql: str):
-    print(f"[Q-S] Training on:\nquestion: {question}\nsql: {sql}")
+    logger.debug(f"Training on question-sql pair:\nquestion: {question}\nsql: {sql}")
     batch_processor.add_item('question_sql', {'question': question, 'sql': sql})
 
 # 完成训练后刷新所有待处理项

+ 2 - 2
data_pipeline/utils/file_manager.py

@@ -1,7 +1,7 @@
 import os
-import logging
 from typing import Dict, Set, Optional
 from pathlib import Path
+from core.logging import get_data_pipeline_logger
 
 class FileNameManager:
     """文件名管理器,处理文件命名和冲突"""
@@ -10,7 +10,7 @@ class FileNameManager:
         self.output_dir = output_dir
         self.used_names: Set[str] = set()
         self.name_mapping: Dict[str, str] = {}  # 原始名 -> 实际文件名
-        self.logger = logging.getLogger("schema_tools.FileNameManager")
+        self.logger = get_data_pipeline_logger("FileNameManager")
         
         # 扫描已存在的文件
         self._scan_existing_files()

+ 2 - 2
data_pipeline/utils/large_table_handler.py

@@ -1,13 +1,13 @@
-import logging
 import random
 from typing import List, Dict, Any, Optional
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
+from core.logging import get_data_pipeline_logger
 
 class LargeTableHandler:
     """大表处理策略"""
     
     def __init__(self):
-        self.logger = logging.getLogger("schema_tools.LargeTableHandler")
+        self.logger = get_data_pipeline_logger("LargeTableHandler")
         self.large_table_threshold = SCHEMA_TOOLS_CONFIG.get("large_table_threshold", 1000000)
         self.skip_large_tables = SCHEMA_TOOLS_CONFIG.get("skip_large_tables", False)
         self.max_table_size = SCHEMA_TOOLS_CONFIG.get("max_table_size", 10000000)

+ 31 - 132
data_pipeline/utils/logger.py

@@ -1,160 +1,50 @@
-import logging
-import os
-import sys
-from datetime import datetime
+"""
+原有日志系统已被新的统一日志系统替代
+保留此文件仅为避免导入错误
+"""
+from core.logging import get_data_pipeline_logger
 from typing import Optional
+import logging
 
 def setup_logging(verbose: bool = False, log_file: Optional[str] = None, log_dir: Optional[str] = None):
     """
-    设置日志系统
-    
-    Args:
-        verbose: 是否启用详细日志
-        log_file: 日志文件名
-        log_dir: 日志目录
+    函数保留以避免调用错误,但不做任何事
+    原有日志系统已被统一日志系统替代
     """
-    # 确定日志级别
-    log_level = logging.DEBUG if verbose else logging.INFO
-    
-    # 创建根logger
-    root_logger = logging.getLogger()
-    root_logger.setLevel(log_level)
-    
-    # 清除已有的处理器
-    root_logger.handlers.clear()
-    
-    # 设置日志格式
-    console_format = "%(asctime)s [%(levelname)s] %(message)s"
-    file_format = "%(asctime)s [%(levelname)s] [%(name)s] %(message)s"
-    date_format = "%Y-%m-%d %H:%M:%S"
-    
-    # 控制台处理器
-    console_handler = logging.StreamHandler(sys.stdout)
-    console_handler.setLevel(log_level)
-    console_formatter = logging.Formatter(console_format, datefmt=date_format)
-    console_handler.setFormatter(console_formatter)
-    root_logger.addHandler(console_handler)
-    
-    # 文件处理器(如果指定)
-    if log_file:
-        # 确定日志文件路径
-        if log_dir:
-            os.makedirs(log_dir, exist_ok=True)
-            log_path = os.path.join(log_dir, log_file)
-        else:
-            log_path = log_file
-        
-        # 添加时间戳到日志文件名
-        base_name, ext = os.path.splitext(log_path)
-        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
-        log_path = f"{base_name}_{timestamp}{ext}"
-        
-        file_handler = logging.FileHandler(log_path, encoding='utf-8')
-        file_handler.setLevel(log_level)
-        file_formatter = logging.Formatter(file_format, datefmt=date_format)
-        file_handler.setFormatter(file_formatter)
-        root_logger.addHandler(file_handler)
-        
-        # 记录日志文件位置
-        root_logger.info(f"日志文件: {os.path.abspath(log_path)}")
-    
-    # 设置schema_tools模块的日志级别
-    schema_tools_logger = logging.getLogger("schema_tools")
-    schema_tools_logger.setLevel(log_level)
-    
-    # 设置第三方库的日志级别(避免过多输出)
-    logging.getLogger("asyncio").setLevel(logging.WARNING)
-    logging.getLogger("asyncpg").setLevel(logging.WARNING)
-    logging.getLogger("openai").setLevel(logging.WARNING)
-    logging.getLogger("httpx").setLevel(logging.WARNING)
-    logging.getLogger("urllib3").setLevel(logging.WARNING)
-    
-    # 返回schema_tools的logger
-    return schema_tools_logger
+    pass
 
-class ColoredFormatter(logging.Formatter):
-    """带颜色的日志格式化器(用于控制台)"""
-    
-    # ANSI颜色代码
-    COLORS = {
-        'DEBUG': '\033[36m',     # 青色
-        'INFO': '\033[32m',      # 绿色
-        'WARNING': '\033[33m',   # 黄色
-        'ERROR': '\033[31m',     # 红色
-        'CRITICAL': '\033[35m',  # 紫色
-    }
-    RESET = '\033[0m'
-    
-    def format(self, record):
-        # 保存原始级别名
-        levelname = record.levelname
-        
-        # 添加颜色
-        if levelname in self.COLORS:
-            record.levelname = f"{self.COLORS[levelname]}{levelname}{self.RESET}"
-        
-        # 格式化消息
-        formatted = super().format(record)
-        
-        # 恢复原始级别名
-        record.levelname = levelname
-        
-        return formatted
+def get_logger(name: str = "DataPipeline"):
+    """直接返回新的logger"""
+    return get_data_pipeline_logger(name)
 
 def get_colored_console_handler(level=logging.INFO):
-    """获取带颜色的控制台处理器"""
-    handler = logging.StreamHandler(sys.stdout)
-    handler.setLevel(level)
-    
-    # 检查是否支持颜色(Windows需要特殊处理)
-    if sys.platform == "win32":
-        try:
-            import colorama
-            colorama.init()
-            use_color = True
-        except ImportError:
-            use_color = False
-    else:
-        # Unix/Linux/Mac通常支持ANSI颜色
-        use_color = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
-    
-    if use_color:
-        formatter = ColoredFormatter(
-            "%(asctime)s [%(levelname)s] %(message)s",
-            datefmt="%Y-%m-%d %H:%M:%S"
-        )
-    else:
-        formatter = logging.Formatter(
-            "%(asctime)s [%(levelname)s] %(message)s",
-            datefmt="%Y-%m-%d %H:%M:%S"
-        )
-    
-    handler.setFormatter(formatter)
-    return handler
+    """兼容性函数,返回None"""
+    return None
 
 class TableProcessingLogger:
-    """表处理专用日志器"""
+    """兼容性类,实际使用新的日志系统"""
     
     def __init__(self, logger_name: str = "schema_tools.TableProcessor"):
-        self.logger = logging.getLogger(logger_name)
+        self.logger = get_data_pipeline_logger("TableProcessor")
         self.current_table = None
         self.start_time = None
     
     def start_table(self, table_name: str):
         """开始处理表"""
+        import time
         self.current_table = table_name
-        self.start_time = datetime.now()
+        self.start_time = time.time()
         self.logger.info(f"{'='*60}")
         self.logger.info(f"开始处理表: {table_name}")
-        self.logger.info(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
     
     def end_table(self, success: bool = True):
         """结束处理表"""
         if self.start_time:
-            duration = (datetime.now() - self.start_time).total_seconds()
+            import time
+            duration = time.time() - self.start_time
             status = "成功" if success else "失败"
             self.logger.info(f"处理{status},耗时: {duration:.2f}秒")
-        self.logger.info(f"{'='*60}\n")
+        self.logger.info(f"{'='*60}")
         self.current_table = None
         self.start_time = None
     
@@ -171,4 +61,13 @@ class TableProcessingLogger:
     
     def log_error(self, message: str):
         """记录错误"""
-        self.logger.error(f"  ✗ {message}")
+        self.logger.error(f"  ✗ {message}")
+
+# 兼容性类
+class ColoredFormatter:
+    """兼容性类,不再使用"""
+    def __init__(self, *args, **kwargs):
+        pass
+    
+    def format(self, record):
+        return str(record)

+ 2 - 2
data_pipeline/utils/permission_checker.py

@@ -1,13 +1,13 @@
-import logging
 from typing import Dict, Optional
 import asyncio
+from core.logging import get_data_pipeline_logger
 
 class DatabasePermissionChecker:
     """数据库权限检查器"""
     
     def __init__(self, db_inspector):
         self.db_inspector = db_inspector
-        self.logger = logging.getLogger("schema_tools.DatabasePermissionChecker")
+        self.logger = get_data_pipeline_logger("DatabasePermissionChecker")
         self._permission_cache: Optional[Dict[str, bool]] = None
     
     async def check_permissions(self) -> Dict[str, bool]:

+ 2 - 2
data_pipeline/utils/system_filter.py

@@ -1,6 +1,6 @@
-import logging
 from typing import List, Set
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
+from core.logging import get_data_pipeline_logger
 
 class SystemTableFilter:
     """系统表过滤器"""
@@ -18,7 +18,7 @@ class SystemTableFilter:
     ]
     
     def __init__(self):
-        self.logger = logging.getLogger("schema_tools.SystemTableFilter")
+        self.logger = get_data_pipeline_logger("SystemTableFilter")
         
         # 加载自定义配置
         self.custom_prefixes = SCHEMA_TOOLS_CONFIG.get("custom_system_prefixes", [])

+ 2 - 2
data_pipeline/utils/table_parser.py

@@ -1,12 +1,12 @@
 import os
-import logging
 from typing import List, Tuple
+from core.logging import get_data_pipeline_logger
 
 class TableListParser:
     """表清单解析器"""
     
     def __init__(self):
-        self.logger = logging.getLogger("schema_tools.TableListParser")
+        self.logger = get_data_pipeline_logger("TableListParser")
     
     def parse_file(self, file_path: str) -> List[str]:
         """

+ 2 - 2
data_pipeline/validators/file_count_validator.py

@@ -1,10 +1,10 @@
-import logging
 from pathlib import Path
 from typing import Dict, List, Tuple, Set
 from dataclasses import dataclass, field
 
 from data_pipeline.utils.table_parser import TableListParser
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
+from core.logging import get_data_pipeline_logger
 
 
 @dataclass
@@ -24,7 +24,7 @@ class FileCountValidator:
     """文件数量验证器"""
     
     def __init__(self):
-        self.logger = logging.getLogger("schema_tools.FileCountValidator")
+        self.logger = get_data_pipeline_logger("FileCountValidator")
         self.config = SCHEMA_TOOLS_CONFIG
         
     def validate(self, table_list_file: str, output_dir: str) -> ValidationResult:

+ 2 - 2
data_pipeline/validators/sql_validation_agent.py

@@ -8,7 +8,7 @@ from typing import List, Dict, Any, Optional
 
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
 from data_pipeline.validators import SQLValidator, SQLValidationResult, ValidationStats
-from data_pipeline.utils.logger import setup_logging
+from core.logging import get_data_pipeline_logger
 
 
 class SQLValidationAgent:
@@ -40,7 +40,7 @@ class SQLValidationAgent:
             self.config['enable_sql_repair'] = enable_sql_repair
         if modify_original_file is not None:
             self.config['modify_original_file'] = modify_original_file
-        self.logger = logging.getLogger("schema_tools.SQLValidationAgent")
+        self.logger = get_data_pipeline_logger("SQLValidationAgent")
         
         # 初始化验证器
         self.validator = SQLValidator(db_connection)

+ 2 - 2
data_pipeline/validators/sql_validator.py

@@ -1,10 +1,10 @@
 import asyncio
-import logging
 import time
 from typing import Dict, Any, List, Optional
 from dataclasses import dataclass, field
 
 from data_pipeline.config import SCHEMA_TOOLS_CONFIG
+from core.logging import get_data_pipeline_logger
 
 
 @dataclass
@@ -52,7 +52,7 @@ class SQLValidator:
         self.db_connection = db_connection
         self.connection_pool = None
         self.config = SCHEMA_TOOLS_CONFIG['sql_validation']
-        self.logger = logging.getLogger("schema_tools.SQLValidator")
+        self.logger = get_data_pipeline_logger("SQLValidator")
         
     async def _get_connection_pool(self):
         """获取或复用现有连接池"""

+ 0 - 894
docs/全局log服务改造方案.md

@@ -1,894 +0,0 @@
-# 项目日志系统改造设计方案(精简实用版)
-
-## 1. 整体设计理念
-
-基于您的需求,设计一套类似Log4j的统一日志服务,专注核心功能:
-- 统一的日志级别管理(info/error/debug/warning)
-- 可配置的日志输出路径
-- 支持控制台和文件输出
-- 不同模块独立日志文件(data_pipeline、agent、vanna等)
-- 自动日志轮转和清理
-- 与现有vanna/langchain/langgraph技术栈兼容
-
-## 2. 核心架构设计
-
-### 2.1 精简的日志服务层次结构
-
-```
-项目根目录/
-├── core/
-│   └── logging/
-│       ├── __init__.py           # 日志服务入口
-│       └── log_manager.py        # 核心日志管理器
-├── logs/                         # 日志文件目录
-│   ├── data_pipeline.log        # data_pipeline模块日志
-│   ├── agent.log                # agent模块日志
-│   ├── vanna.log                # vanna模块日志
-│   ├── langchain.log            # langchain模块日志
-│   ├── langgraph.log            # langgraph模块日志
-│   └── app.log                  # 主应用日志
-└── config/
-    └── logging_config.yaml       # 日志配置文件
-```
-
-### 2.2 核心日志管理器设计(增强版)
-
-基于用户反馈,增强版包含以下特性:
-- **异步日志支持**
-- **灵活的上下文管理**(user_id可选)
-- **错误降级策略**
-- **重点支持citu_app.py**
-
-```python
-# core/logging/log_manager.py
-import logging
-import logging.handlers
-import os
-from typing import Dict, Optional
-from pathlib import Path
-import yaml
-import asyncio
-from concurrent.futures import ThreadPoolExecutor
-import contextvars
-
-# 上下文变量,存储可选的上下文信息
-log_context = contextvars.ContextVar('log_context', default={})
-
-class ContextFilter(logging.Filter):
-    """添加上下文信息到日志记录"""
-    def filter(self, record):
-        ctx = log_context.get()
-        # 设置默认值,避免格式化错误
-        record.session_id = ctx.get('session_id', 'N/A')
-        record.user_id = ctx.get('user_id', 'anonymous')
-        record.request_id = ctx.get('request_id', 'N/A')
-        return True
-
-class LogManager:
-    """统一日志管理器 - 类似Log4j的功能"""
-    
-    _instance = None
-    _loggers: Dict[str, logging.Logger] = {}
-    _initialized = False
-    _executor = None
-    _fallback_to_console = False  # 标记是否降级到控制台
-    
-    def __new__(cls):
-        if cls._instance is None:
-            cls._instance = super().__new__(cls)
-        return cls._instance
-    
-    def __init__(self):
-        if not self._initialized:
-            self.config = None
-            self.base_log_dir = Path("logs")
-            self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="log")
-            self._setup_base_directory()
-            LogManager._initialized = True
-    
-    def initialize(self, config_path: str = "config/logging_config.yaml"):
-        """初始化日志系统"""
-        self.config = self._load_config(config_path)
-        self._setup_base_directory()
-        self._configure_root_logger()
-    
-    def get_logger(self, name: str, module: str = "default") -> logging.Logger:
-        """获取指定模块的logger"""
-        logger_key = f"{module}.{name}"
-        
-        if logger_key not in self._loggers:
-            logger = logging.getLogger(logger_key)
-            self._configure_logger(logger, module)
-            self._loggers[logger_key] = logger
-        
-        return self._loggers[logger_key]
-    
-    async def alog(self, logger: logging.Logger, level: str, message: str, **kwargs):
-        """异步日志方法"""
-        loop = asyncio.get_event_loop()
-        await loop.run_in_executor(
-            self._executor,
-            lambda: getattr(logger, level)(message, **kwargs)
-        )
-    
-    def set_context(self, **kwargs):
-        """设置日志上下文(可选)"""
-        ctx = log_context.get()
-        ctx.update(kwargs)
-        log_context.set(ctx)
-    
-    def clear_context(self):
-        """清除日志上下文"""
-        log_context.set({})
-    
-    def _load_config(self, config_path: str) -> dict:
-        """加载配置文件(带错误处理)"""
-        try:
-            with open(config_path, 'r', encoding='utf-8') as f:
-                return yaml.safe_load(f)
-        except FileNotFoundError:
-            print(f"[WARNING] 配置文件 {config_path} 未找到,使用默认配置")
-            return self._get_default_config()
-        except Exception as e:
-            print(f"[ERROR] 加载配置文件失败: {e},使用默认配置")
-            return self._get_default_config()
-    
-    def _setup_base_directory(self):
-        """创建日志目录(带降级策略)"""
-        try:
-            self.base_log_dir.mkdir(parents=True, exist_ok=True)
-            self._fallback_to_console = False
-        except Exception as e:
-            print(f"[WARNING] 无法创建日志目录 {self.base_log_dir},将只使用控制台输出: {e}")
-            self._fallback_to_console = True
-    
-    def _configure_logger(self, logger: logging.Logger, module: str):
-        """配置具体的logger(支持降级)"""
-        module_config = self.config.get('modules', {}).get(module, self.config['default'])
-        
-        # 设置日志级别
-        level = getattr(logging, module_config['level'].upper())
-        logger.setLevel(level)
-        
-        # 清除已有处理器
-        logger.handlers.clear()
-        logger.propagate = False
-        
-        # 添加控制台处理器
-        if module_config.get('console', {}).get('enabled', True):
-            console_handler = self._create_console_handler(module_config['console'])
-            console_handler.addFilter(ContextFilter())
-            logger.addHandler(console_handler)
-        
-        # 添加文件处理器(如果没有降级到控制台)
-        if not self._fallback_to_console and module_config.get('file', {}).get('enabled', True):
-            try:
-                file_handler = self._create_file_handler(module_config['file'], module)
-                file_handler.addFilter(ContextFilter())
-                logger.addHandler(file_handler)
-            except Exception as e:
-                print(f"[WARNING] 无法创建文件处理器: {e}")
-    
-    def _get_default_config(self) -> dict:
-        """获取默认配置"""
-        return {
-            'global': {'base_level': 'INFO'},
-            'default': {
-                'level': 'INFO',
-                'console': {
-                    'enabled': True,
-                    'level': 'INFO',
-                    'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
-                },
-                'file': {
-                    'enabled': True,
-                    'level': 'DEBUG',
-                    'filename': 'app.log',
-                    'format': '%(asctime)s [%(levelname)s] [%(name)s] [user:%(user_id)s] [session:%(session_id)s] %(filename)s:%(lineno)d - %(message)s',
-                    'rotation': {
-                        'enabled': True,
-                        'max_size': '50MB',
-                        'backup_count': 10
-                    }
-                }
-            },
-            'modules': {}
-        }
-    
-    def _create_console_handler(self, console_config: dict) -> logging.StreamHandler:
-        """创建控制台处理器"""
-        handler = logging.StreamHandler()
-        handler.setLevel(getattr(logging, console_config.get('level', 'INFO').upper()))
-        
-        formatter = logging.Formatter(
-            console_config.get('format', '%(asctime)s [%(levelname)s] %(name)s: %(message)s'),
-            datefmt='%Y-%m-%d %H:%M:%S'
-        )
-        handler.setFormatter(formatter)
-        return handler
-    
-    def _create_file_handler(self, file_config: dict, module: str) -> logging.Handler:
-        """创建文件处理器(支持自动轮转)"""
-        log_file = self.base_log_dir / file_config.get('filename', f'{module}.log')
-        
-        # 使用RotatingFileHandler实现自动轮转和清理
-        rotation_config = file_config.get('rotation', {})
-        if rotation_config.get('enabled', False):
-            handler = logging.handlers.RotatingFileHandler(
-                log_file,
-                maxBytes=self._parse_size(rotation_config.get('max_size', '50MB')),
-                backupCount=rotation_config.get('backup_count', 10),
-                encoding='utf-8'
-            )
-        else:
-            handler = logging.FileHandler(log_file, encoding='utf-8')
-        
-        handler.setLevel(getattr(logging, file_config.get('level', 'DEBUG').upper()))
-        
-        formatter = logging.Formatter(
-            file_config.get('format', '%(asctime)s [%(levelname)s] [%(name)s] %(filename)s:%(lineno)d - %(message)s'),
-            datefmt='%Y-%m-%d %H:%M:%S'
-        )
-        handler.setFormatter(formatter)
-        return handler
-    
-    def _parse_size(self, size_str: str) -> int:
-        """解析大小字符串,如 '50MB' -> 字节数"""
-        size_str = size_str.upper()
-        if size_str.endswith('KB'):
-            return int(size_str[:-2]) * 1024
-        elif size_str.endswith('MB'):
-            return int(size_str[:-2]) * 1024 * 1024
-        elif size_str.endswith('GB'):
-            return int(size_str[:-2]) * 1024 * 1024 * 1024
-        else:
-            return int(size_str)
-    
-    def __del__(self):
-        """清理资源"""
-        if self._executor:
-            self._executor.shutdown(wait=False)
-```
-
-### 2.3 统一日志接口(增强版)
-
-```python
-# core/logging/__init__.py
-from .log_manager import LogManager
-import logging
-
-# 全局日志管理器实例
-_log_manager = LogManager()
-
-def initialize_logging(config_path: str = "config/logging_config.yaml"):
-    """初始化项目日志系统"""
-    _log_manager.initialize(config_path)
-
-def get_logger(name: str, module: str = "default") -> logging.Logger:
-    """获取logger实例 - 主要API"""
-    return _log_manager.get_logger(name, module)
-
-# 便捷方法
-def get_data_pipeline_logger(name: str) -> logging.Logger:
-    """获取data_pipeline模块logger"""
-    return get_logger(name, "data_pipeline")
-
-def get_agent_logger(name: str) -> logging.Logger:
-    """获取agent模块logger"""
-    return get_logger(name, "agent")
-
-def get_vanna_logger(name: str) -> logging.Logger:
-    """获取vanna模块logger"""
-    return get_logger(name, "vanna")
-
-# 上下文管理便捷方法
-def set_log_context(**kwargs):
-    """设置日志上下文(可选)
-    示例: set_log_context(user_id='user123', session_id='sess456')
-    """
-    _log_manager.set_context(**kwargs)
-
-def clear_log_context():
-    """清除日志上下文"""
-    _log_manager.clear_context()
-
-# 异步日志便捷方法
-async def alog_info(logger: logging.Logger, message: str, **kwargs):
-    """异步记录INFO日志"""
-    await _log_manager.alog(logger, 'info', message, **kwargs)
-
-async def alog_error(logger: logging.Logger, message: str, **kwargs):
-    """异步记录ERROR日志"""
-    await _log_manager.alog(logger, 'error', message, **kwargs)
-
-async def alog_debug(logger: logging.Logger, message: str, **kwargs):
-    """异步记录DEBUG日志"""
-    await _log_manager.alog(logger, 'debug', message, **kwargs)
-
-async def alog_warning(logger: logging.Logger, message: str, **kwargs):
-    """异步记录WARNING日志"""
-    await _log_manager.alog(logger, 'warning', message, **kwargs)
-```
-
-### 2.4 日志配置文件(支持上下文信息)
-
-```yaml
-# config/logging_config.yaml
-version: 1
-
-# 全局配置
-global:
-  base_level: INFO
-  
-# 默认配置
-default:
-  level: INFO
-  console:
-    enabled: true
-    level: INFO
-    format: "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
-  file:
-    enabled: true
-    level: DEBUG
-    filename: "app.log"
-    # 支持上下文信息,但有默认值避免错误
-    format: "%(asctime)s [%(levelname)s] [%(name)s] [user:%(user_id)s] [session:%(session_id)s] %(filename)s:%(lineno)d - %(message)s"
-    rotation:
-      enabled: true
-      max_size: "50MB"
-      backup_count: 10
-
-# 模块特定配置
-modules:
-  data_pipeline:
-    level: DEBUG
-    console:
-      enabled: true
-      level: INFO
-      format: "🔄 %(asctime)s [%(levelname)s] Pipeline: %(message)s"
-    file:
-      enabled: true
-      level: DEBUG
-      filename: "data_pipeline.log"
-      format: "%(asctime)s [%(levelname)s] [%(name)s] %(filename)s:%(lineno)d - %(message)s"
-      rotation:
-        enabled: true
-        max_size: "30MB"
-        backup_count: 8
-  
-  agent:
-    level: DEBUG
-    console:
-      enabled: true
-      level: INFO
-      format: "🤖 %(asctime)s [%(levelname)s] Agent: %(message)s"
-    file:
-      enabled: true
-      level: DEBUG
-      filename: "agent.log"
-      # Agent模块支持user_id和session_id
-      format: "%(asctime)s [%(levelname)s] [%(name)s] [user:%(user_id)s] [session:%(session_id)s] %(filename)s:%(lineno)d - %(message)s"
-      rotation:
-        enabled: true
-        max_size: "30MB"
-        backup_count: 8
-  
-  vanna:
-    level: INFO
-    console:
-      enabled: true
-      level: INFO
-      format: "🧠 %(asctime)s [%(levelname)s] Vanna: %(message)s"
-    file:
-      enabled: true
-      level: DEBUG
-      filename: "vanna.log"
-      format: "%(asctime)s [%(levelname)s] [%(name)s] %(filename)s:%(lineno)d - %(message)s"
-      rotation:
-        enabled: true
-        max_size: "20MB"
-        backup_count: 5
-```
-
-## 3. 改造实施步骤
-
-### 3.1 第一阶段:基础架构搭建
-
-1. **创建日志服务目录结构**
-   ```bash
-   mkdir -p core/logging
-   mkdir -p config
-   mkdir -p logs
-   ```
-
-2. **实现核心组件**
-   - 创建 `core/logging/log_manager.py`
-   - 创建 `core/logging/__init__.py`
-   - 创建 `config/logging_config.yaml`
-
-3. **集成到citu_app.py(主要应用)**
-   ```python
-   # 在citu_app.py的开头添加
-   from core.logging import initialize_logging, get_logger, set_log_context, clear_log_context
-   import uuid
-   
-   # 初始化日志系统
-   initialize_logging("config/logging_config.yaml")
-   app_logger = get_logger("CituApp", "default")
-   
-   # 在Flask应用配置后集成请求级别的日志上下文
-   @app.flask_app.before_request
-   def before_request():
-       # 为每个请求设置上下文(如果有的话)
-       request_id = str(uuid.uuid4())[:8]
-       user_id = request.headers.get('X-User-ID', 'anonymous')
-       set_log_context(request_id=request_id, user_id=user_id)
-   
-   @app.flask_app.after_request
-   def after_request(response):
-       # 清理上下文
-       clear_log_context()
-       return response
-   ```
-
-### 3.2 第二阶段:模块改造
-
-#### 3.2.1 改造data_pipeline模块
-
-```python
-# 替换 data_pipeline/utils/logger.py 中的使用方式
-from core.logging import get_data_pipeline_logger
-
-def setup_logging(verbose: bool = False, log_file: str = None, log_dir: str = None):
-    """
-    保持原有接口,内部使用新的日志系统
-    """
-    # 不再需要复杂的设置,直接使用统一日志系统
-    pass
-
-# 在各个文件中使用
-# data_pipeline/qa_generation/qs_agent.py
-class QuestionSQLGenerationAgent:
-    def __init__(self, ...):
-        # 替换原有的 logging.getLogger("schema_tools.QSAgent")
-        self.logger = get_data_pipeline_logger("QSAgent")
-        
-    async def generate(self):
-        self.logger.info("🚀 开始生成Question-SQL训练数据")
-        # ... 其他代码
-        
-        # 手动记录关键节点的时间
-        start_time = time.time()
-        self.logger.info("开始初始化LLM组件")
-        
-        self._initialize_llm_components()
-        
-        init_time = time.time() - start_time
-        self.logger.info(f"LLM组件初始化完成,耗时: {init_time:.2f}秒")
-```
-
-#### 3.2.2 改造Agent模块(支持可选的用户上下文)
-
-```python
-# 在ask_agent接口中使用
-@app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
-def ask_agent():
-    logger = get_agent_logger("AskAgent")
-    
-    try:
-        data = request.json
-        question = data.get('question', '')
-        user_id = data.get('user_id')  # 可选
-        session_id = data.get('session_id')  # 可选
-        
-        # 设置上下文(如果有的话)
-        if user_id or session_id:
-            set_log_context(user_id=user_id or 'anonymous', session_id=session_id or 'N/A')
-        
-        logger.info(f"收到问题: {question[:50]}...")
-        
-        # 异步记录示例(在async函数中)
-        # await alog_info(logger, f"开始处理问题: {question}")
-        
-        # ... 其他处理逻辑
-        
-    except Exception as e:
-        logger.error(f"处理失败: {str(e)}", exc_info=True)
-        # ...
-```
-
-#### 3.2.3 改造vanna相关代码
-
-由于vanna使用print方式,创建简单的适配器:
-
-```python
-# core/logging/vanna_adapter.py
-from core.logging import get_vanna_logger
-
-class VannaLogAdapter:
-    """Vanna日志适配器 - 将print转换为logger调用"""
-    
-    def __init__(self, logger_name: str = "VannaBase"):
-        self.logger = get_vanna_logger(logger_name)
-    
-    def log(self, message: str):
-        """替换vanna的log方法"""
-        # 根据内容判断日志级别
-        message_lower = message.lower()
-        if any(keyword in message_lower for keyword in ['error', 'exception', 'fail']):
-            self.logger.error(message)
-        elif any(keyword in message_lower for keyword in ['warning', 'warn']):
-            self.logger.warning(message)
-        else:
-            self.logger.info(message)
-
-# 使用装饰器改造vanna实例
-def enhance_vanna_logging(vanna_instance):
-    """增强vanna实例的日志功能"""
-    adapter = VannaLogAdapter(vanna_instance.__class__.__name__)
-    
-    # 替换log方法
-    vanna_instance.log = adapter.log
-    return vanna_instance
-
-# 在vanna实例创建时使用
-# core/vanna_llm_factory.py
-from core.logging.vanna_adapter import enhance_vanna_logging
-
-def create_vanna_instance():
-    # 原有创建逻辑
-    vn = VannaDefault(...)
-    
-    # 增强日志功能
-    vn = enhance_vanna_logging(vn)
-    
-    return vn
-```
-
-### 3.3 第三阶段:workflow级别的时间统计
-
-对于跨多个函数的执行时间统计,在关键业务节点手动记录:
-
-```python
-# data_pipeline/schema_workflow.py
-import time
-from core.logging import get_data_pipeline_logger
-
-class SchemaWorkflowOrchestrator:
-    def __init__(self, ...):
-        self.logger = get_data_pipeline_logger("SchemaWorkflow")
-    
-    async def run_full_workflow(self):
-        """执行完整工作流"""
-        workflow_start = time.time()
-        self.logger.info("🚀 开始执行完整的Schema工作流")
-        
-        try:
-            # 步骤1:生成DDL和MD文件
-            step1_start = time.time()
-            self.logger.info("📝 步骤1: 开始生成DDL和MD文件")
-            
-            result1 = await self.generate_ddl_md()
-            
-            step1_time = time.time() - step1_start
-            self.logger.info(f"✅ 步骤1完成,生成了{result1['ddl_count']}个DDL文件和{result1['md_count']}个MD文件,耗时: {step1_time:.2f}秒")
-            
-            # 步骤2:生成Question-SQL对
-            step2_start = time.time()
-            self.logger.info("❓ 步骤2: 开始生成Question-SQL对")
-            
-            result2 = await self.generate_qa_pairs()
-            
-            step2_time = time.time() - step2_start
-            self.logger.info(f"✅ 步骤2完成,生成了{result2['qa_count']}个问答对,耗时: {step2_time:.2f}秒")
-            
-            # 步骤3:验证SQL
-            step3_start = time.time()
-            self.logger.info("🔍 步骤3: 开始验证SQL")
-            
-            result3 = await self.validate_sql()
-            
-            step3_time = time.time() - step3_start
-            self.logger.info(f"✅ 步骤3完成,验证了{result3['validated_count']}个SQL,修复了{result3['fixed_count']}个,耗时: {step3_time:.2f}秒")
-            
-            # 步骤4:加载训练数据
-            step4_start = time.time()
-            self.logger.info("📚 步骤4: 开始加载训练数据")
-            
-            result4 = await self.load_training_data()
-            
-            step4_time = time.time() - step4_start
-            self.logger.info(f"✅ 步骤4完成,加载了{result4['loaded_count']}个训练文件,耗时: {step4_time:.2f}秒")
-            
-            # 总结
-            total_time = time.time() - workflow_start
-            self.logger.info(f"🎉 完整工作流执行成功!总耗时: {total_time:.2f}秒")
-            self.logger.info(f"   - DDL/MD生成: {step1_time:.2f}秒")
-            self.logger.info(f"   - QA生成: {step2_time:.2f}秒")  
-            self.logger.info(f"   - SQL验证: {step3_time:.2f}秒")
-            self.logger.info(f"   - 数据加载: {step4_time:.2f}秒")
-            
-            return {
-                "success": True,
-                "total_time": total_time,
-                "steps": {
-                    "ddl_md": {"time": step1_time, "result": result1},
-                    "qa_generation": {"time": step2_time, "result": result2},
-                    "sql_validation": {"time": step3_time, "result": result3},
-                    "data_loading": {"time": step4_time, "result": result4}
-                }
-            }
-            
-        except Exception as e:
-            total_time = time.time() - workflow_start
-            self.logger.error(f"❌ 工作流执行失败,耗时: {total_time:.2f}秒,错误: {str(e)}")
-            raise
-```
-
-## 4. 实际使用示例
-
-### 4.1 在citu_app.py中的使用(主要应用)
-
-```python
-# citu_app.py
-from core.logging import initialize_logging, get_logger, set_log_context, clear_log_context
-import uuid
-
-# 应用启动时初始化
-initialize_logging("config/logging_config.yaml")
-app_logger = get_logger("CituApp", "default")
-
-# API端点示例
-@app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
-def ask_agent():
-    logger = get_agent_logger("AskAgent")
-    request_id = str(uuid.uuid4())[:8]
-    
-    try:
-        data = request.json
-        user_id = data.get('user_id')
-        
-        # 设置上下文(安全的,即使没有user_id)
-        set_log_context(
-            request_id=request_id,
-            user_id=user_id or 'anonymous'
-        )
-        
-        logger.info(f"开始处理请求")
-        # ... 业务逻辑
-        
-        logger.info(f"请求处理成功")
-        return success_response(...)
-        
-    except Exception as e:
-        logger.error(f"请求处理失败: {str(e)}", exc_info=True)
-        return error_response(...)
-    finally:
-        clear_log_context()
-```
-
-### 4.2 在data_pipeline中的使用
-
-```python
-# data_pipeline/ddl_generation/training_data_agent.py
-from core.logging import get_data_pipeline_logger
-import time
-
-class SchemaTrainingDataAgent:
-    def __init__(self, db_config, output_dir):
-        self.logger = get_data_pipeline_logger("TrainingDataAgent")
-        self.db_config = db_config
-        self.output_dir = output_dir
-        
-    async def process_tables(self, table_list):
-        """处理表列表"""
-        start_time = time.time()
-        self.logger.info(f"开始处理{len(table_list)}个表的训练数据生成")
-        
-        success_count = 0
-        failed_tables = []
-        
-        for table in table_list:
-            try:
-                table_start = time.time()
-                self.logger.debug(f"开始处理表: {table}")
-                
-                await self._process_single_table(table)
-                
-                table_time = time.time() - table_start
-                self.logger.info(f"表 {table} 处理完成,耗时: {table_time:.2f}秒")
-                success_count += 1
-                
-            except Exception as e:
-                self.logger.error(f"表 {table} 处理失败: {str(e)}")
-                failed_tables.append(table)
-        
-        total_time = time.time() - start_time
-        self.logger.info(f"批量处理完成,成功: {success_count}个,失败: {len(failed_tables)}个,总耗时: {total_time:.2f}秒")
-        
-        if failed_tables:
-            self.logger.warning(f"处理失败的表: {failed_tables}")
-            
-        return {
-            "success_count": success_count,
-            "failed_count": len(failed_tables),
-            "failed_tables": failed_tables,
-            "total_time": total_time
-        }
-```
-
-### 4.3 在Agent中的使用(支持异步)
-
-```python
-# agent/citu_agent.py
-from core.logging import get_agent_logger, alog_info, alog_error
-
-class CituLangGraphAgent:
-    def __init__(self):
-        self.logger = get_agent_logger("CituAgent")
-    
-    async def process_question(self, question: str, session_id: str = None, user_id: str = None):
-        """异步处理问题"""
-        # 设置上下文(如果有的话)
-        if user_id or session_id:
-            set_log_context(user_id=user_id or 'anonymous', session_id=session_id or 'N/A')
-        
-        # 同步日志
-        self.logger.info(f"开始处理问题: {question[:50]}...")
-        
-        try:
-            # 异步日志
-            await alog_info(self.logger, f"开始分类问题")
-            
-            # 业务逻辑
-            result = await self._classify_question(question)
-            
-            await alog_info(self.logger, f"分类完成: {result.question_type}")
-            
-            return result
-            
-        except Exception as e:
-            await alog_error(self.logger, f"处理失败: {str(e)}")
-            raise
-```
-
-### 4.4 增强vanna日志
-
-```python
-# core/vanna_llm_factory.py
-from core.logging.vanna_adapter import enhance_vanna_logging
-from core.logging import get_vanna_logger
-
-def create_vanna_instance():
-    """创建增强了日志功能的vanna实例"""
-    logger = get_vanna_logger("VannaFactory")
-    logger.info("🧠 开始创建Vanna实例")
-    
-    try:
-        # 原有创建逻辑
-        vn = VannaDefault(
-            config={
-                'api_key': os.getenv('OPENAI_API_KEY'),
-                'model': 'gpt-4'
-            }
-        )
-        
-        # 增强日志功能
-        vn = enhance_vanna_logging(vn)
-        
-        logger.info("✅ Vanna实例创建成功")
-        return vn
-        
-    except Exception as e:
-        logger.error(f"❌ Vanna实例创建失败: {str(e)}")
-        raise
-```
-
-## 5. 配置调优建议
-
-### 5.1 开发环境配置
-
-```yaml
-# config/logging_config_dev.yaml
-version: 1
-
-global:
-  base_level: DEBUG
-
-default:
-  level: DEBUG
-  console:
-    enabled: true
-    level: DEBUG
-  file:
-    enabled: false  # 开发环境可以只用控制台
-
-modules:
-  data_pipeline:
-    level: DEBUG
-    console:
-      enabled: true
-      level: DEBUG
-      format: "🔄 %(asctime)s [%(levelname)s] Pipeline: %(message)s"
-    file:
-      enabled: true
-      level: DEBUG
-      filename: "data_pipeline.log"
-      
-  agent:
-    level: DEBUG
-    console:
-      enabled: true
-      level: DEBUG
-      format: "🤖 %(asctime)s [%(levelname)s] Agent: %(message)s"
-```
-
-### 5.2 生产环境配置
-
-```yaml
-# config/logging_config_prod.yaml
-version: 1
-
-global:
-  base_level: INFO
-
-default:
-  level: INFO
-  console:
-    enabled: false  # 生产环境不输出到控制台
-  file:
-    enabled: true
-    level: INFO
-    rotation:
-      enabled: true
-      max_size: "100MB"
-      backup_count: 20
-
-modules:
-  data_pipeline:
-    level: INFO
-    console:
-      enabled: false
-    file:
-      enabled: true
-      level: INFO
-      filename: "data_pipeline.log"
-      rotation:
-        enabled: true
-        max_size: "50MB"
-        backup_count: 15
-        
-  langchain:
-    level: ERROR  # 生产环境只记录错误
-    console:
-      enabled: false
-    file:
-      enabled: true
-      level: ERROR
-```
-
-## 6. 注意事项
-
-基于用户反馈,特别注意以下几点:
-
-1. **上下文安全性**:即使没有用户信息,日志系统也能正常工作(使用默认值)
-2. **降级策略**:当文件系统不可用时,自动降级到控制台输出
-3. **异步支持**:在async函数中使用异步日志方法,避免阻塞
-4. **主应用聚焦**:重点关注citu_app.py的集成,忽略flask_app.py和chainlit_app.py
-5. **性能考虑**:保持原有的跨函数时间统计方式,不强制使用装饰器
-
-## 7. 总结
-
-这个精简实用的日志改造方案提供了:
-
-1. **统一的日志管理**:类似Log4j的架构,单一配置文件管理所有日志
-2. **模块化日志文件**:每个模块独立的日志文件,便于问题定位
-3. **自动日志轮转**:使用RotatingFileHandler自动管理日志文件大小和数量
-4. **灵活的配置**:支持不同环境的配置,控制台和文件输出可独立配置
-5. **简单易用**:提供便捷的API,一行代码获取对应模块的logger
-6. **性能友好**:手动记录关键节点时间,不影响整体性能
-7. **技术栈兼容**:专门为vanna/langchain/langgraph设计适配器
-8. **异步支持**:适配项目大量使用async/await的特点
-9. **安全容错**:上下文信息可选,文件系统可降级
-
-该方案专注核心功能,去掉了不必要的复杂性,是一个可以直接落地实施的实用设计。