Explorar o código

启用test目录的git跟踪:包含所有测试代码文件和Jupyter notebook

- 修改.gitignore文件,移除对test目录的完全忽略
- 添加13个测试文件到git跟踪中:
  - 12个Python测试文件
  - 1个Jupyter notebook文件
- 保留对临时文件和缓存的忽略规则(__pycache__、.pytest_cache等)
wangxq hai 1 mes
pai
achega
6243095d72

+ 6 - 2
.gitignore

@@ -32,5 +32,9 @@ node_modules/
 # 忽略所有一级UUID目录
 /[0-9a-fA-F]*-[0-9a-fA-F]*-[0-9a-fA-F]*-[0-9a-fA-F]*-[0-9a-fA-F]*/
 
-
-test/
+# 忽略 test 目录中的临时文件和缓存,但跟踪所有代码文件
+test/__pycache__/
+test/.pytest_cache/
+test/.ipynb_checkpoints/
+test/*.pyc
+test/*.pyo

+ 248 - 0
test/redis_conversation_demo.py

@@ -0,0 +1,248 @@
+"""
+Redis对话管理功能演示脚本
+
+这个脚本演示了如何使用Redis对话管理系统的各种功能:
+1. 创建对话
+2. 多轮对话(带上下文)
+3. 缓存命中
+4. 对话历史查询
+5. 统计信息查看
+"""
+
+import requests
+import json
+import time
+import sys
+import os
+
+class ConversationDemo:
+    def __init__(self, base_url="http://localhost:8084/api/v0"):
+        self.base_url = base_url
+        self.session_id = f"demo_session_{int(time.time())}"
+        self.conversation_id = None
+        self.user_id = None
+    
+    def print_section(self, title):
+        """打印分隔线"""
+        print("\n" + "="*60)
+        print(f" {title} ")
+        print("="*60)
+    
+    def demo_basic_conversation(self):
+        """演示基本对话功能"""
+        self.print_section("1. 基本对话功能")
+        
+        # 第一个问题
+        print("\n[DEMO] 发送第一个问题...")
+        response = requests.post(
+            f"{self.base_url}/ask_agent",
+            json={
+                "question": "高速公路服务区有多少个?",
+                "session_id": self.session_id
+            }
+        )
+        
+        if response.status_code == 200:
+            data = response.json()
+            self.conversation_id = data['data']['conversation_id']
+            self.user_id = data['data']['user_id']
+            
+            print(f"[结果] 对话ID: {self.conversation_id}")
+            print(f"[结果] 用户ID: {self.user_id}")
+            print(f"[结果] 是否为Guest用户: {data['data'].get('is_guest_user')}")
+            print(f"[结果] 回答: {data['data'].get('response', '')[:100]}...")
+        else:
+            print(f"[错误] 响应码: {response.status_code}")
+    
+    def demo_context_awareness(self):
+        """演示上下文感知功能"""
+        self.print_section("2. 上下文感知功能")
+        
+        if not self.conversation_id:
+            print("[警告] 需要先运行基本对话演示")
+            return
+        
+        # 第二个问题(依赖上下文)
+        print("\n[DEMO] 发送依赖上下文的问题...")
+        response = requests.post(
+            f"{self.base_url}/ask_agent",
+            json={
+                "question": "这些服务区的经理都是谁?",
+                "session_id": self.session_id,
+                "conversation_id": self.conversation_id
+            }
+        )
+        
+        if response.status_code == 200:
+            data = response.json()
+            print(f"[结果] 使用了上下文: {data['data'].get('context_used')}")
+            print(f"[结果] 对话状态: {data['data'].get('conversation_status')}")
+            print(f"[结果] 回答: {data['data'].get('response', '')[:100]}...")
+        else:
+            print(f"[错误] 响应码: {response.status_code}")
+    
+    def demo_cache_functionality(self):
+        """演示缓存功能"""
+        self.print_section("3. 缓存功能")
+        
+        # 问相同的问题
+        question = "高速公路服务区的总数是多少?"
+        
+        print(f"\n[DEMO] 第一次询问: {question}")
+        response1 = requests.post(
+            f"{self.base_url}/ask_agent",
+            json={
+                "question": question,
+                "session_id": self.session_id + "_cache",
+            }
+        )
+        
+        if response1.status_code == 200:
+            data1 = response1.json()
+            print(f"[结果] 来自缓存: {data1['data'].get('from_cache')}")
+            conv_id = data1['data']['conversation_id']
+            
+            # 立即再问一次
+            print(f"\n[DEMO] 第二次询问相同问题...")
+            response2 = requests.post(
+                f"{self.base_url}/ask_agent",
+                json={
+                    "question": question,
+                    "session_id": self.session_id + "_cache",
+                    "conversation_id": conv_id
+                }
+            )
+            
+            if response2.status_code == 200:
+                data2 = response2.json()
+                print(f"[结果] 来自缓存: {data2['data'].get('from_cache')}")
+    
+    def demo_conversation_history(self):
+        """演示对话历史查询"""
+        self.print_section("4. 对话历史查询")
+        
+        if not self.user_id:
+            print("[警告] 需要先运行基本对话演示")
+            return
+        
+        # 获取用户的对话列表
+        print(f"\n[DEMO] 获取用户 {self.user_id} 的对话列表...")
+        response = requests.get(
+            f"{self.base_url}/user/{self.user_id}/conversations"
+        )
+        
+        if response.status_code == 200:
+            data = response.json()
+            conversations = data['data']['conversations']
+            print(f"[结果] 找到 {len(conversations)} 个对话")
+            
+            for i, conv in enumerate(conversations):
+                print(f"\n  对话 {i+1}:")
+                print(f"    ID: {conv['conversation_id']}")
+                print(f"    创建时间: {conv['created_at']}")
+                print(f"    消息数: {conv['message_count']}")
+        
+        # 获取特定对话的消息
+        if self.conversation_id:
+            print(f"\n[DEMO] 获取对话 {self.conversation_id} 的消息...")
+            response = requests.get(
+                f"{self.base_url}/conversation/{self.conversation_id}/messages"
+            )
+            
+            if response.status_code == 200:
+                data = response.json()
+                messages = data['data']['messages']
+                print(f"[结果] 找到 {len(messages)} 条消息")
+                
+                for msg in messages:
+                    role = "用户" if msg['role'] == 'user' else "助手"
+                    content = msg['content'][:50] + "..." if len(msg['content']) > 50 else msg['content']
+                    print(f"\n  [{role}]: {content}")
+    
+    def demo_statistics(self):
+        """演示统计信息"""
+        self.print_section("5. 统计信息")
+        
+        print("\n[DEMO] 获取对话系统统计信息...")
+        response = requests.get(f"{self.base_url}/conversation_stats")
+        
+        if response.status_code == 200:
+            data = response.json()
+            stats = data['data']
+            
+            print(f"\n[统计信息]")
+            print(f"  Redis可用: {stats.get('available')}")
+            print(f"  总用户数: {stats.get('total_users')}")
+            print(f"  总对话数: {stats.get('total_conversations')}")
+            print(f"  缓存的问答数: {stats.get('cached_qa_count')}")
+            
+            if stats.get('redis_info'):
+                print(f"\n[Redis信息]")
+                print(f"  内存使用: {stats['redis_info'].get('used_memory')}")
+                print(f"  连接客户端数: {stats['redis_info'].get('connected_clients')}")
+    
+    def demo_invalid_conversation_id(self):
+        """演示无效对话ID处理"""
+        self.print_section("6. 无效对话ID处理")
+        
+        print("\n[DEMO] 使用无效的对话ID...")
+        response = requests.post(
+            f"{self.base_url}/ask_agent",
+            json={
+                "question": "测试无效ID",
+                "session_id": self.session_id,
+                "conversation_id": "invalid_conversation_xyz"
+            }
+        )
+        
+        if response.status_code == 200:
+            data = response.json()
+            print(f"[结果] 对话状态: {data['data'].get('conversation_status')}")
+            print(f"[结果] 状态消息: {data['data'].get('conversation_message')}")
+            print(f"[结果] 请求的ID: {data['data'].get('requested_conversation_id')}")
+            print(f"[结果] 新创建的ID: {data['data'].get('conversation_id')}")
+    
+    def run_all_demos(self):
+        """运行所有演示"""
+        try:
+            # 检查服务是否可用
+            print("[DEMO] 检查服务可用性...")
+            response = requests.get(f"{self.base_url}/agent_health", timeout=5)
+            if response.status_code != 200:
+                print("[错误] 服务不可用,请先启动Flask应用")
+                return
+            
+            # 运行各个演示
+            self.demo_basic_conversation()
+            time.sleep(1)
+            
+            self.demo_context_awareness()
+            time.sleep(1)
+            
+            self.demo_cache_functionality()
+            time.sleep(1)
+            
+            self.demo_conversation_history()
+            time.sleep(1)
+            
+            self.demo_statistics()
+            time.sleep(1)
+            
+            self.demo_invalid_conversation_id()
+            
+            print("\n" + "="*60)
+            print(" 演示完成 ")
+            print("="*60)
+            
+        except Exception as e:
+            print(f"\n[错误] 演示过程中出错: {str(e)}")
+            print("请确保Flask应用正在运行 (python citu_app.py)")
+
+
+if __name__ == "__main__":
+    print("Redis对话管理功能演示")
+    print("确保已经启动了Flask应用和Redis服务")
+    print("-" * 60)
+    
+    demo = ConversationDemo()
+    demo.run_all_demos() 

+ 293 - 0
test/test_ask_agent_redis_integration.py

@@ -0,0 +1,293 @@
+import unittest
+import requests
+import json
+import sys
+import os
+import time
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from common.redis_conversation_manager import RedisConversationManager
+
+class TestAskAgentRedisIntegration(unittest.TestCase):
+    """ask_agent API的Redis集成测试"""
+    
+    def setUp(self):
+        """测试前准备"""
+        self.base_url = "http://localhost:8084/api/v0"
+        self.test_session_id = "test_session_" + str(int(time.time()))
+        self.manager = RedisConversationManager()
+        
+    def tearDown(self):
+        """测试后清理"""
+        # 清理测试数据
+        pass
+    
+    def test_api_availability(self):
+        """测试API可用性"""
+        try:
+            response = requests.get(f"{self.base_url}/agent_health", timeout=5)
+            print(f"[TEST] Agent健康检查响应码: {response.status_code}")
+        except Exception as e:
+            self.skipTest(f"API服务不可用: {str(e)}")
+    
+    def test_basic_ask_agent(self):
+        """测试基本的ask_agent调用"""
+        try:
+            # 第一次调用 - 创建新对话
+            payload = {
+                "question": "测试问题:高速公路服务区有多少个?",
+                "session_id": self.test_session_id
+            }
+            
+            response = requests.post(
+                f"{self.base_url}/ask_agent",
+                json=payload,
+                timeout=30
+            )
+            
+            print(f"[TEST] 第一次调用响应码: {response.status_code}")
+            
+            if response.status_code == 200:
+                data = response.json()
+                print(f"[TEST] 响应数据: {json.dumps(data, indent=2, ensure_ascii=False)}")
+                
+                # 验证返回字段
+                self.assertIn('data', data)
+                self.assertIn('conversation_id', data['data'])
+                self.assertIn('user_id', data['data'])
+                self.assertIn('conversation_status', data['data'])
+                
+                conversation_id = data['data']['conversation_id']
+                user_id = data['data']['user_id']
+                
+                print(f"[TEST] 创建的对话ID: {conversation_id}")
+                print(f"[TEST] 用户ID: {user_id}")
+                
+                return conversation_id, user_id
+                
+        except Exception as e:
+            self.skipTest(f"API调用失败: {str(e)}")
+    
+    def test_conversation_context(self):
+        """测试对话上下文功能"""
+        try:
+            # 第一次调用
+            payload1 = {
+                "question": "高速公路服务区有多少个?",
+                "session_id": self.test_session_id
+            }
+            
+            response1 = requests.post(
+                f"{self.base_url}/ask_agent",
+                json=payload1,
+                timeout=30
+            )
+            
+            if response1.status_code != 200:
+                self.skipTest("第一次API调用失败")
+            
+            data1 = response1.json()
+            conversation_id = data1['data']['conversation_id']
+            
+            # 第二次调用 - 使用相同的对话ID
+            payload2 = {
+                "question": "这些服务区的经理都是谁?",  # 这个问题依赖于前面的上下文
+                "session_id": self.test_session_id,
+                "conversation_id": conversation_id
+            }
+            
+            response2 = requests.post(
+                f"{self.base_url}/ask_agent",
+                json=payload2,
+                timeout=30
+            )
+            
+            print(f"[TEST] 第二次调用响应码: {response2.status_code}")
+            
+            if response2.status_code == 200:
+                data2 = response2.json()
+                print(f"[TEST] 使用了上下文: {data2['data'].get('context_used', False)}")
+                self.assertTrue(data2['data'].get('context_used', False))
+                
+        except Exception as e:
+            self.skipTest(f"上下文测试失败: {str(e)}")
+    
+    def test_cache_hit(self):
+        """测试缓存命中"""
+        try:
+            # 同样的问题问两次
+            question = "高速公路服务区的数量是多少?"
+            
+            # 第一次调用
+            payload = {
+                "question": question,
+                "session_id": self.test_session_id + "_cache_test"
+            }
+            
+            response1 = requests.post(
+                f"{self.base_url}/ask_agent",
+                json=payload,
+                timeout=30
+            )
+            
+            if response1.status_code != 200:
+                self.skipTest("第一次API调用失败")
+            
+            data1 = response1.json()
+            from_cache1 = data1['data'].get('from_cache', False)
+            print(f"[TEST] 第一次调用from_cache: {from_cache1}")
+            self.assertFalse(from_cache1)
+            
+            # 立即第二次调用相同的问题
+            response2 = requests.post(
+                f"{self.base_url}/ask_agent",
+                json=payload,
+                timeout=30
+            )
+            
+            if response2.status_code == 200:
+                data2 = response2.json()
+                from_cache2 = data2['data'].get('from_cache', False)
+                print(f"[TEST] 第二次调用from_cache: {from_cache2}")
+                # 注意:由于是新对话,可能不会命中缓存
+                
+        except Exception as e:
+            self.skipTest(f"缓存测试失败: {str(e)}")
+    
+    def test_invalid_conversation_id(self):
+        """测试无效的conversation_id处理"""
+        try:
+            payload = {
+                "question": "测试无效对话ID",
+                "session_id": self.test_session_id,
+                "conversation_id": "invalid_conv_id_xyz"
+            }
+            
+            response = requests.post(
+                f"{self.base_url}/ask_agent",
+                json=payload,
+                timeout=30
+            )
+            
+            if response.status_code == 200:
+                data = response.json()
+                status = data['data'].get('conversation_status')
+                print(f"[TEST] 无效对话ID的状态: {status}")
+                self.assertEqual(status, 'invalid_id_new')
+                self.assertEqual(
+                    data['data'].get('requested_conversation_id'),
+                    'invalid_conv_id_xyz'
+                )
+                
+        except Exception as e:
+            self.skipTest(f"无效ID测试失败: {str(e)}")
+    
+    def test_conversation_api_endpoints(self):
+        """测试对话管理API端点"""
+        try:
+            # 先创建一个对话
+            result = self.test_basic_ask_agent()
+            if not result:
+                self.skipTest("无法创建测试对话")
+            
+            conversation_id, user_id = result
+            
+            # 测试获取用户对话列表
+            response = requests.get(
+                f"{self.base_url}/user/{user_id}/conversations",
+                timeout=10
+            )
+            
+            print(f"[TEST] 获取对话列表响应码: {response.status_code}")
+            if response.status_code == 200:
+                data = response.json()
+                self.assertIn('data', data)
+                self.assertIn('conversations', data['data'])
+                print(f"[TEST] 用户对话数: {len(data['data']['conversations'])}")
+            
+            # 测试获取对话消息
+            response = requests.get(
+                f"{self.base_url}/conversation/{conversation_id}/messages",
+                timeout=10
+            )
+            
+            print(f"[TEST] 获取对话消息响应码: {response.status_code}")
+            if response.status_code == 200:
+                data = response.json()
+                self.assertIn('data', data)
+                self.assertIn('messages', data['data'])
+                print(f"[TEST] 对话消息数: {len(data['data']['messages'])}")
+            
+            # 测试获取统计信息
+            response = requests.get(
+                f"{self.base_url}/conversation_stats",
+                timeout=10
+            )
+            
+            print(f"[TEST] 获取统计信息响应码: {response.status_code}")
+            if response.status_code == 200:
+                data = response.json()
+                self.assertIn('data', data)
+                stats = data['data']
+                print(f"[TEST] Redis可用: {stats.get('available')}")
+                print(f"[TEST] 总用户数: {stats.get('total_users')}")
+                print(f"[TEST] 总对话数: {stats.get('total_conversations')}")
+                
+        except Exception as e:
+            print(f"[ERROR] 管理API测试失败: {str(e)}")
+    
+    def test_guest_user_generation(self):
+        """测试guest用户生成"""
+        try:
+            # 不提供user_id,应该生成guest用户
+            payload = {
+                "question": "测试guest用户",
+                "session_id": self.test_session_id + "_guest"
+            }
+            
+            response = requests.post(
+                f"{self.base_url}/ask_agent",
+                json=payload,
+                timeout=30
+            )
+            
+            if response.status_code == 200:
+                data = response.json()
+                user_id = data['data']['user_id']
+                is_guest = data['data'].get('is_guest_user', False)
+                
+                print(f"[TEST] 生成的用户ID: {user_id}")
+                print(f"[TEST] 是否为guest用户: {is_guest}")
+                
+                self.assertTrue(user_id.startswith('guest_'))
+                self.assertTrue(is_guest)
+                
+        except Exception as e:
+            self.skipTest(f"Guest用户测试失败: {str(e)}")
+
+
+def run_selected_tests():
+    """运行选定的测试"""
+    suite = unittest.TestSuite()
+    
+    # 添加要运行的测试
+    suite.addTest(TestAskAgentRedisIntegration('test_api_availability'))
+    suite.addTest(TestAskAgentRedisIntegration('test_basic_ask_agent'))
+    suite.addTest(TestAskAgentRedisIntegration('test_conversation_context'))
+    suite.addTest(TestAskAgentRedisIntegration('test_invalid_conversation_id'))
+    suite.addTest(TestAskAgentRedisIntegration('test_conversation_api_endpoints'))
+    
+    runner = unittest.TextTestRunner(verbosity=2)
+    runner.run(suite)
+
+
+if __name__ == '__main__':
+    print("=" * 60)
+    print("ask_agent Redis集成测试")
+    print("注意: 需要先启动Flask应用 (python citu_app.py)")
+    print("=" * 60)
+    
+    # 可以选择运行所有测试或选定的测试
+    unittest.main()
+    # 或者运行选定的测试
+    # run_selected_tests() 

+ 106 - 0
test/test_config_refactor.py

@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+"""
+测试配置重构是否成功
+"""
+
+def test_config_refactor():
+    """测试配置重构"""
+    print("=== 配置重构测试 ===")
+    
+    try:
+        import app_config
+        print("✓ app_config 导入成功")
+    except ImportError as e:
+        print(f"✗ app_config 导入失败: {e}")
+        return False
+    
+    # 测试新配置是否存在
+    new_configs = [
+        'API_DEEPSEEK_CONFIG',
+        'API_QWEN_CONFIG', 
+        'OLLAMA_EMBEDDING_CONFIG',
+        'API_LLM_MODEL',
+        'VECTOR_DB_TYPE'
+    ]
+    
+    print("\n--- 新配置检查 ---")
+    for config_name in new_configs:
+        if hasattr(app_config, config_name):
+            print(f"✓ {config_name} 存在")
+        else:
+            print(f"✗ {config_name} 不存在")
+            return False
+    
+    # 测试旧配置是否已删除
+    old_configs = [
+        'DEEPSEEK_CONFIG',
+        'QWEN_CONFIG',
+        'EMBEDDING_OLLAMA_CONFIG',
+        'LLM_MODEL_NAME',
+        'VECTOR_DB_NAME'
+    ]
+    
+    print("\n--- 旧配置检查 ---")
+    for config_name in old_configs:
+        if hasattr(app_config, config_name):
+            print(f"✗ {config_name} 仍然存在(应该已删除)")
+            return False
+        else:
+            print(f"✓ {config_name} 已删除")
+    
+    # 测试utils.py中的函数
+    print("\n--- Utils函数测试 ---")
+    try:
+        from common.utils import get_current_llm_config, get_current_embedding_config
+        
+        # 测试LLM配置
+        llm_config = get_current_llm_config()
+        print(f"✓ get_current_llm_config() 成功,返回类型: {type(llm_config)}")
+        
+        # 测试Embedding配置
+        embedding_config = get_current_embedding_config()
+        print(f"✓ get_current_embedding_config() 成功,返回类型: {type(embedding_config)}")
+        
+    except Exception as e:
+        print(f"✗ Utils函数测试失败: {e}")
+        return False
+    
+    # 测试配置内容
+    print("\n--- 配置内容验证 ---")
+    try:
+        # 验证API_QWEN_CONFIG
+        qwen_config = app_config.API_QWEN_CONFIG
+        if 'model' in qwen_config and 'api_key' in qwen_config:
+            print("✓ API_QWEN_CONFIG 结构正确")
+        else:
+            print("✗ API_QWEN_CONFIG 结构不正确")
+            return False
+            
+        # 验证API_DEEPSEEK_CONFIG
+        deepseek_config = app_config.API_DEEPSEEK_CONFIG
+        if 'model' in deepseek_config and 'api_key' in deepseek_config:
+            print("✓ API_DEEPSEEK_CONFIG 结构正确")
+        else:
+            print("✗ API_DEEPSEEK_CONFIG 结构不正确")
+            return False
+            
+        # 验证OLLAMA_EMBEDDING_CONFIG
+        ollama_embedding_config = app_config.OLLAMA_EMBEDDING_CONFIG
+        if 'model_name' in ollama_embedding_config and 'base_url' in ollama_embedding_config:
+            print("✓ OLLAMA_EMBEDDING_CONFIG 结构正确")
+        else:
+            print("✗ OLLAMA_EMBEDDING_CONFIG 结构不正确")
+            return False
+            
+    except Exception as e:
+        print(f"✗ 配置内容验证失败: {e}")
+        return False
+    
+    print("\n=== 配置重构测试完成 ===")
+    print("✓ 所有测试通过!配置重构成功!")
+    return True
+
+if __name__ == "__main__":
+    success = test_config_refactor()
+    if not success:
+        exit(1) 

+ 128 - 0
test/test_config_utils.py

@@ -0,0 +1,128 @@
+#!/usr/bin/env python3
+"""
+测试配置工具函数的脚本
+用于验证common/utils.py中的函数是否正常工作
+"""
+
+def test_config_utils():
+    """测试配置工具函数"""
+    try:
+        from common.utils import (
+            get_current_embedding_config,
+            get_current_llm_config,
+            get_current_vector_db_config,
+            get_current_model_info,
+            is_using_ollama_llm,
+            is_using_ollama_embedding,
+            is_using_api_llm,
+            is_using_api_embedding,
+            print_current_config
+        )
+        
+        print("=== 测试配置工具函数 ===")
+        
+        # 测试模型类型检查函数
+        print(f"使用Ollama LLM: {is_using_ollama_llm()}")
+        print(f"使用Ollama Embedding: {is_using_ollama_embedding()}")
+        print(f"使用API LLM: {is_using_api_llm()}")
+        print(f"使用API Embedding: {is_using_api_embedding()}")
+        print()
+        
+        # 测试配置获取函数
+        print("=== LLM配置 ===")
+        llm_config = get_current_llm_config()
+        for key, value in llm_config.items():
+            if key == "api_key" and value:
+                print(f"{key}: {'*' * 8}...{value[-4:]}")  # 隐藏API密钥
+            else:
+                print(f"{key}: {value}")
+        print()
+        
+        print("=== Embedding配置 ===")
+        embedding_config = get_current_embedding_config()
+        for key, value in embedding_config.items():
+            if key == "api_key" and value:
+                print(f"{key}: {'*' * 8}...{value[-4:]}")  # 隐藏API密钥
+            else:
+                print(f"{key}: {value}")
+        print()
+        
+        print("=== 向量数据库配置 ===")
+        vector_db_config = get_current_vector_db_config()
+        for key, value in vector_db_config.items():
+            if key == "password" and value:
+                print(f"{key}: {'*' * 8}")  # 隐藏密码
+            else:
+                print(f"{key}: {value}")
+        print()
+        
+        # 测试模型信息摘要
+        print("=== 模型信息摘要 ===")
+        model_info = get_current_model_info()
+        for key, value in model_info.items():
+            print(f"{key}: {value}")
+        print()
+        
+        # 测试打印配置函数
+        print_current_config()
+        
+        print("✅ 所有配置工具函数测试通过!")
+        
+    except Exception as e:
+        print(f"❌ 测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+
+def test_different_configurations():
+    """测试不同配置组合"""
+    import app_config
+    
+    print("\n=== 测试不同配置组合 ===")
+    
+    # 保存原始配置
+    original_llm_type = app_config.LLM_MODEL_TYPE
+    original_embedding_type = app_config.EMBEDDING_MODEL_TYPE
+    original_llm_name = app_config.LLM_MODEL_NAME
+    
+    try:
+        from common.utils import get_current_model_info, print_current_config
+        
+        # 测试配置1:API LLM + API Embedding
+        print("\n--- 配置1:API LLM + API Embedding ---")
+        app_config.LLM_MODEL_TYPE = "api"
+        app_config.EMBEDDING_MODEL_TYPE = "api"
+        app_config.LLM_MODEL_NAME = "qwen"
+        print_current_config()
+        
+        # 测试配置2:API LLM + Ollama Embedding
+        print("\n--- 配置2:API LLM + Ollama Embedding ---")
+        app_config.LLM_MODEL_TYPE = "api"
+        app_config.EMBEDDING_MODEL_TYPE = "ollama"
+        app_config.LLM_MODEL_NAME = "deepseek"
+        print_current_config()
+        
+        # 测试配置3:Ollama LLM + API Embedding
+        print("\n--- 配置3:Ollama LLM + API Embedding ---")
+        app_config.LLM_MODEL_TYPE = "ollama"
+        app_config.EMBEDDING_MODEL_TYPE = "api"
+        print_current_config()
+        
+        # 测试配置4:Ollama LLM + Ollama Embedding
+        print("\n--- 配置4:Ollama LLM + Ollama Embedding ---")
+        app_config.LLM_MODEL_TYPE = "ollama"
+        app_config.EMBEDDING_MODEL_TYPE = "ollama"
+        print_current_config()
+        
+    except Exception as e:
+        print(f"❌ 配置测试失败: {e}")
+    finally:
+        # 恢复原始配置
+        app_config.LLM_MODEL_TYPE = original_llm_type
+        app_config.EMBEDDING_MODEL_TYPE = original_embedding_type
+        app_config.LLM_MODEL_NAME = original_llm_name
+        print("\n--- 恢复原始配置 ---")
+        print_current_config()
+
+if __name__ == "__main__":
+    test_config_utils()
+    test_different_configurations() 

+ 225 - 0
test/test_ollama_integration.py

@@ -0,0 +1,225 @@
+#!/usr/bin/env python3
+"""
+测试Ollama集成功能的脚本
+用于验证Ollama LLM和Embedding是否正常工作
+"""
+
+def test_ollama_llm():
+    """测试Ollama LLM功能"""
+    print("=== 测试Ollama LLM ===")
+    
+    try:
+        from customollama.ollama_chat import OllamaChat
+        
+        # 测试配置
+        config = {
+            "base_url": "http://localhost:11434",
+            "model": "qwen2.5:7b",
+            "temperature": 0.7,
+            "timeout": 60
+        }
+        
+        # 创建实例
+        ollama_chat = OllamaChat(config=config)
+        
+        # 测试连接
+        print("测试Ollama连接...")
+        test_result = ollama_chat.test_connection()
+        
+        if test_result["success"]:
+            print(f"✅ Ollama LLM连接成功: {test_result['message']}")
+        else:
+            print(f"❌ Ollama LLM连接失败: {test_result['message']}")
+            return False
+            
+        # 测试简单对话
+        print("\n测试简单对话...")
+        response = ollama_chat.chat_with_llm("你好,请简单介绍一下你自己")
+        print(f"LLM响应: {response}")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ Ollama LLM测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_ollama_embedding():
+    """测试Ollama Embedding功能"""
+    print("\n=== 测试Ollama Embedding ===")
+    
+    try:
+        from customollama.ollama_embedding import OllamaEmbeddingFunction
+        
+        # 创建实例
+        embedding_func = OllamaEmbeddingFunction(
+            model_name="nomic-embed-text",
+            base_url="http://localhost:11434",
+            embedding_dimension=768
+        )
+        
+        # 测试连接
+        print("测试Ollama Embedding连接...")
+        test_result = embedding_func.test_connection()
+        
+        if test_result["success"]:
+            print(f"✅ Ollama Embedding连接成功: {test_result['message']}")
+        else:
+            print(f"❌ Ollama Embedding连接失败: {test_result['message']}")
+            return False
+            
+        # 测试生成embedding
+        print("\n测试生成embedding...")
+        test_texts = ["这是一个测试文本", "另一个测试文本"]
+        embeddings = embedding_func(test_texts)
+        
+        print(f"生成了 {len(embeddings)} 个embedding向量")
+        for i, emb in enumerate(embeddings):
+            print(f"文本 {i+1} 的embedding维度: {len(emb)}")
+            
+        return True
+        
+    except Exception as e:
+        print(f"❌ Ollama Embedding测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_ollama_with_config():
+    """测试使用配置文件的Ollama功能"""
+    print("\n=== 测试配置文件中的Ollama设置 ===")
+    
+    try:
+        import app_config
+        from common.utils import print_current_config, is_using_ollama_llm, is_using_ollama_embedding
+        
+        # 保存原始配置
+        original_llm_type = app_config.LLM_MODEL_TYPE
+        original_embedding_type = app_config.EMBEDDING_MODEL_TYPE
+        
+        try:
+            # 设置为Ollama模式
+            app_config.LLM_MODEL_TYPE = "ollama"
+            app_config.EMBEDDING_MODEL_TYPE = "ollama"
+            
+            print("当前配置:")
+            print_current_config()
+            
+            print(f"\n使用Ollama LLM: {is_using_ollama_llm()}")
+            print(f"使用Ollama Embedding: {is_using_ollama_embedding()}")
+            
+            # 测试embedding函数
+            print("\n测试通过配置获取embedding函数...")
+            from embedding_function import get_embedding_function
+            
+            embedding_func = get_embedding_function()
+            print(f"成功创建embedding函数: {type(embedding_func).__name__}")
+            
+            # 测试工厂函数(如果Ollama服务可用的话)
+            print("\n测试工厂函数...")
+            try:
+                from vanna_llm_factory import create_vanna_instance
+                vn = create_vanna_instance()
+                print(f"✅ 成功创建Vanna实例: {type(vn).__name__}")
+                return True
+            except Exception as e:
+                print(f"⚠️  工厂函数测试失败(可能是Ollama服务未启动): {e}")
+                return True  # 这不算失败,只是服务未启动
+                
+        finally:
+            # 恢复原始配置
+            app_config.LLM_MODEL_TYPE = original_llm_type
+            app_config.EMBEDDING_MODEL_TYPE = original_embedding_type
+            
+    except Exception as e:
+        print(f"❌ 配置测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_mixed_configurations():
+    """测试混合配置(API + Ollama)"""
+    print("\n=== 测试混合配置 ===")
+    
+    try:
+        import app_config
+        from common.utils import print_current_config
+        
+        # 保存原始配置
+        original_llm_type = app_config.LLM_MODEL_TYPE
+        original_embedding_type = app_config.EMBEDDING_MODEL_TYPE
+        
+        try:
+            # 测试配置1:API LLM + Ollama Embedding
+            print("\n--- 测试: API LLM + Ollama Embedding ---")
+            app_config.LLM_MODEL_TYPE = "api"
+            app_config.EMBEDDING_MODEL_TYPE = "ollama"
+            print_current_config()
+            
+            from embedding_function import get_embedding_function
+            embedding_func = get_embedding_function()
+            print(f"Embedding函数类型: {type(embedding_func).__name__}")
+            
+            # 测试配置2:Ollama LLM + API Embedding
+            print("\n--- 测试: Ollama LLM + API Embedding ---")
+            app_config.LLM_MODEL_TYPE = "ollama"
+            app_config.EMBEDDING_MODEL_TYPE = "api"
+            print_current_config()
+            
+            embedding_func = get_embedding_function()
+            print(f"Embedding函数类型: {type(embedding_func).__name__}")
+            
+            print("✅ 混合配置测试通过")
+            return True
+            
+        finally:
+            # 恢复原始配置
+            app_config.LLM_MODEL_TYPE = original_llm_type
+            app_config.EMBEDDING_MODEL_TYPE = original_embedding_type
+            
+    except Exception as e:
+        print(f"❌ 混合配置测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def main():
+    """主测试函数"""
+    print("开始测试Ollama集成功能...")
+    print("注意: 这些测试需要Ollama服务运行在 http://localhost:11434")
+    print("=" * 60)
+    
+    results = []
+    
+    # 测试配置和工具函数(不需要Ollama服务)
+    results.append(("配置文件测试", test_ollama_with_config()))
+    results.append(("混合配置测试", test_mixed_configurations()))
+    
+    # 测试实际的Ollama功能(需要Ollama服务)
+    print(f"\n{'='*60}")
+    print("以下测试需要Ollama服务运行,如果失败可能是服务未启动")
+    print("=" * 60)
+    
+    results.append(("Ollama LLM", test_ollama_llm()))
+    results.append(("Ollama Embedding", test_ollama_embedding()))
+    
+    # 总结
+    print(f"\n{'='*60}")
+    print("测试结果总结:")
+    print("=" * 60)
+    
+    for test_name, success in results:
+        status = "✅ 通过" if success else "❌ 失败"
+        print(f"{test_name}: {status}")
+    
+    total_passed = sum(1 for _, success in results if success)
+    print(f"\n总计: {total_passed}/{len(results)} 个测试通过")
+    
+    if total_passed == len(results):
+        print("🎉 所有测试都通过了!Ollama集成功能正常。")
+    else:
+        print("⚠️  部分测试失败,请检查Ollama服务是否正常运行。")
+
+if __name__ == "__main__":
+    main() 

+ 283 - 0
test/test_redis_conversation_manager.py

@@ -0,0 +1,283 @@
+import unittest
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from common.redis_conversation_manager import RedisConversationManager
+from datetime import datetime
+import time
+
+class TestRedisConversationManager(unittest.TestCase):
+    """Redis对话管理器单元测试"""
+    
+    def setUp(self):
+        """测试前准备"""
+        self.manager = RedisConversationManager()
+        # 清理测试数据
+        self.test_user_id = "test_user_123"
+        self.test_guest_id = "guest_test_456"
+        
+    def tearDown(self):
+        """测试后清理"""
+        # 清理测试创建的数据
+        if self.manager.is_available():
+            # 清理测试用户的对话
+            try:
+                conversations = self.manager.get_conversations(self.test_user_id)
+                for conv in conversations:
+                    conv_id = conv.get('conversation_id')
+                    if conv_id:
+                        self.manager.redis_client.delete(f"conversation:{conv_id}:meta")
+                        self.manager.redis_client.delete(f"conversation:{conv_id}:messages")
+                self.manager.redis_client.delete(f"user:{self.test_user_id}:conversations")
+                
+                # 清理guest用户
+                conversations = self.manager.get_conversations(self.test_guest_id)
+                for conv in conversations:
+                    conv_id = conv.get('conversation_id')
+                    if conv_id:
+                        self.manager.redis_client.delete(f"conversation:{conv_id}:meta")
+                        self.manager.redis_client.delete(f"conversation:{conv_id}:messages")
+                self.manager.redis_client.delete(f"user:{self.test_guest_id}:conversations")
+            except:
+                pass
+    
+    def test_redis_connection(self):
+        """测试Redis连接"""
+        is_available = self.manager.is_available()
+        print(f"[TEST] Redis可用状态: {is_available}")
+        if not is_available:
+            self.skipTest("Redis不可用,跳过测试")
+    
+    def test_user_id_resolution(self):
+        """测试用户ID解析逻辑"""
+        # 测试登录用户ID优先
+        user_id = self.manager.resolve_user_id(
+            "request_user", "session_123", "127.0.0.1", "login_user"
+        )
+        self.assertEqual(user_id, "login_user")
+        
+        # 测试请求参数用户ID
+        user_id = self.manager.resolve_user_id(
+            "request_user", "session_123", "127.0.0.1", None
+        )
+        self.assertEqual(user_id, "request_user")
+        
+        # 测试guest用户生成
+        user_id = self.manager.resolve_user_id(
+            None, "session_123", "127.0.0.1", None
+        )
+        self.assertTrue(user_id.startswith("guest_"))
+        
+        # 测试基于IP的临时guest
+        user_id = self.manager.resolve_user_id(
+            None, None, "127.0.0.1", None
+        )
+        self.assertTrue(user_id.startswith("guest_temp_"))
+    
+    def test_conversation_creation(self):
+        """测试对话创建"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        conv_id = self.manager.create_conversation(self.test_user_id)
+        print(f"[TEST] 创建的对话ID: {conv_id}")
+        
+        # 验证对话ID格式
+        self.assertTrue(conv_id.startswith("conv_"))
+        self.assertIn("_", conv_id)
+        
+        # 验证对话元信息
+        meta = self.manager.get_conversation_meta(conv_id)
+        self.assertEqual(meta.get('user_id'), self.test_user_id)
+        self.assertEqual(meta.get('conversation_id'), conv_id)
+        self.assertIn('created_at', meta)
+    
+    def test_message_saving_and_retrieval(self):
+        """测试消息保存和获取"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        # 创建对话
+        conv_id = self.manager.create_conversation(self.test_user_id)
+        
+        # 保存消息
+        self.manager.save_message(conv_id, "user", "测试问题")
+        self.manager.save_message(conv_id, "assistant", "测试回答")
+        
+        # 获取消息列表
+        messages = self.manager.get_conversation_messages(conv_id)
+        self.assertEqual(len(messages), 2)
+        
+        # 验证消息顺序(时间正序)
+        self.assertEqual(messages[0]['role'], 'user')
+        self.assertEqual(messages[0]['content'], '测试问题')
+        self.assertEqual(messages[1]['role'], 'assistant')
+        self.assertEqual(messages[1]['content'], '测试回答')
+    
+    def test_context_generation(self):
+        """测试上下文生成"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        # 创建对话并添加多条消息
+        conv_id = self.manager.create_conversation(self.test_user_id)
+        
+        self.manager.save_message(conv_id, "user", "问题1")
+        self.manager.save_message(conv_id, "assistant", "回答1")
+        self.manager.save_message(conv_id, "user", "问题2")
+        self.manager.save_message(conv_id, "assistant", "回答2")
+        
+        # 获取上下文
+        context = self.manager.get_context(conv_id, count=2)
+        print(f"[TEST] 生成的上下文:\n{context}")
+        
+        # 验证上下文格式
+        self.assertIn("用户: 问题1", context)
+        self.assertIn("助手: 回答1", context)
+        self.assertIn("用户: 问题2", context)
+        self.assertIn("助手: 回答2", context)
+    
+    def test_conversation_list(self):
+        """测试用户对话列表"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        # 创建多个对话
+        conv_ids = []
+        for i in range(3):
+            conv_id = self.manager.create_conversation(self.test_user_id)
+            conv_ids.append(conv_id)
+            time.sleep(0.1)  # 确保时间戳不同
+        
+        # 获取对话列表
+        conversations = self.manager.get_conversations(self.test_user_id)
+        self.assertEqual(len(conversations), 3)
+        
+        # 验证顺序(最新的在前)
+        self.assertEqual(conversations[0]['conversation_id'], conv_ids[2])
+        self.assertEqual(conversations[1]['conversation_id'], conv_ids[1])
+        self.assertEqual(conversations[2]['conversation_id'], conv_ids[0])
+    
+    def test_cache_functionality(self):
+        """测试缓存功能"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        question = "测试缓存问题"
+        context = "用户: 之前的问题\n助手: 之前的回答"
+        
+        # 测试缓存未命中
+        cached = self.manager.get_cached_answer(question, context)
+        self.assertIsNone(cached)
+        
+        # 缓存答案
+        answer = {
+            "success": True,
+            "data": {
+                "response": "测试答案",
+                "type": "CHAT"
+            }
+        }
+        self.manager.cache_answer(question, answer, context)
+        
+        # 测试缓存命中
+        cached = self.manager.get_cached_answer(question, context)
+        self.assertIsNotNone(cached)
+        self.assertEqual(cached['data']['response'], '测试答案')
+        
+        # 测试不同上下文的缓存
+        different_context = "用户: 不同的问题\n助手: 不同的回答"
+        cached = self.manager.get_cached_answer(question, different_context)
+        self.assertIsNone(cached)  # 不同上下文应该缓存未命中
+    
+    def test_conversation_id_resolution(self):
+        """测试对话ID解析"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        # 测试创建新对话
+        conv_id, status = self.manager.resolve_conversation_id(
+            self.test_user_id, None, False
+        )
+        self.assertTrue(conv_id.startswith("conv_"))
+        self.assertEqual(status['status'], 'new')
+        
+        # 测试使用已存在的对话
+        conv_id2, status2 = self.manager.resolve_conversation_id(
+            self.test_user_id, conv_id, False
+        )
+        self.assertEqual(conv_id2, conv_id)
+        self.assertEqual(status2['status'], 'existing')
+        
+        # 测试无效的对话ID
+        conv_id3, status3 = self.manager.resolve_conversation_id(
+            self.test_user_id, "invalid_conv_id", False
+        )
+        self.assertNotEqual(conv_id3, "invalid_conv_id")
+        self.assertEqual(status3['status'], 'invalid_id_new')
+        self.assertEqual(status3['requested_id'], 'invalid_conv_id')
+    
+    def test_statistics(self):
+        """测试统计功能"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        # 创建测试数据
+        conv_id = self.manager.create_conversation(self.test_user_id)
+        self.manager.save_message(conv_id, "user", "统计测试")
+        
+        # 获取统计信息
+        stats = self.manager.get_stats()
+        print(f"[TEST] 统计信息: {stats}")
+        
+        self.assertTrue(stats['available'])
+        self.assertIn('total_users', stats)
+        self.assertIn('total_conversations', stats)
+        self.assertIn('cached_qa_count', stats)
+        
+    def test_guest_user_limit(self):
+        """测试guest用户对话数量限制"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        # 创建多个对话,超过guest用户限制
+        from app_config import MAX_GUEST_CONVERSATIONS
+        
+        conv_ids = []
+        for i in range(MAX_GUEST_CONVERSATIONS + 2):
+            conv_id = self.manager.create_conversation(self.test_guest_id)
+            conv_ids.append(conv_id)
+            time.sleep(0.05)
+        
+        # 验证只保留了限制数量的对话
+        conversations = self.manager.get_conversations(self.test_guest_id)
+        self.assertEqual(len(conversations), MAX_GUEST_CONVERSATIONS)
+        
+        # 验证保留的是最新的对话
+        retained_ids = [conv['conversation_id'] for conv in conversations]
+        for i in range(MAX_GUEST_CONVERSATIONS):
+            self.assertIn(conv_ids[-(i+1)], retained_ids)
+    
+    def test_cleanup_functionality(self):
+        """测试清理功能"""
+        if not self.manager.is_available():
+            self.skipTest("Redis不可用")
+        
+        # 创建对话
+        conv_id = self.manager.create_conversation(self.test_user_id)
+        
+        # 手动删除对话元信息,模拟过期
+        self.manager.redis_client.delete(f"conversation:{conv_id}:meta")
+        
+        # 执行清理
+        self.manager.cleanup_expired_conversations()
+        
+        # 验证对话已从用户列表中移除
+        conversations = self.manager.get_conversations(self.test_user_id)
+        conv_ids = [conv['conversation_id'] for conv in conversations]
+        self.assertNotIn(conv_id, conv_ids)
+
+
+if __name__ == '__main__':
+    unittest.main() 

+ 111 - 0
test/test_redis_fix_validation.py

@@ -0,0 +1,111 @@
+"""
+Redis集成修复验证测试
+
+这个脚本用于快速验证Redis集成的修复是否有效
+"""
+
+import requests
+import json
+import time
+
+def test_ask_agent_basic():
+    """测试基本的ask_agent功能"""
+    base_url = "http://localhost:8084/api/v0"
+    
+    print("=== Redis集成修复验证测试 ===\n")
+    
+    # 测试1:第一次请求(应该成功)
+    print("1. 测试第一次请求...")
+    print("   (注意:第一次请求可能需要较长时间,请耐心等待...)")
+    response1 = requests.post(
+        f"{base_url}/ask_agent",
+        json={"question": "服务区有多少个?"},
+        timeout=120  # 增加到120秒,适应较慢的响应
+    )
+    
+    print(f"   状态码: {response1.status_code}")
+    result1 = response1.json()
+    print(f"   成功: {result1.get('success')}")
+    print(f"   消息: {result1.get('message')}")
+    
+    if result1.get('success'):
+        data = result1.get('data', {})
+        print(f"   响应类型: {data.get('type')}")
+        print(f"   响应文本: {data.get('response_text', '')[:50]}...")
+        print(f"   是否缓存: {data.get('from_cache', False)}")
+        print(f"   对话ID: {data.get('conversation_id')}")
+    else:
+        print(f"   错误: {json.dumps(result1, indent=2, ensure_ascii=False)}")
+    
+    # 等待一下
+    time.sleep(1)
+    
+    # 测试2:第二次相同请求(应该使用缓存)
+    print("\n2. 测试第二次请求(相同问题,应该使用缓存)...")
+    response2 = requests.post(
+        f"{base_url}/ask_agent",
+        json={"question": "服务区有多少个?"},
+        timeout=60  # 也增加超时时间,虽然缓存应该更快
+    )
+    
+    print(f"   状态码: {response2.status_code}")
+    result2 = response2.json()
+    print(f"   成功: {result2.get('success')}")
+    
+    if result2.get('success'):
+        data = result2.get('data', {})
+        print(f"   是否缓存: {data.get('from_cache', False)}")
+        print(f"   响应文本: {data.get('response_text', '')[:50]}...")
+        
+        # 验证缓存功能
+        if data.get('from_cache'):
+            print("\n✅ 缓存功能正常工作!")
+        else:
+            print("\n⚠️ 缓存功能可能有问题,第二次请求没有使用缓存")
+    else:
+        print(f"   错误: {json.dumps(result2, indent=2, ensure_ascii=False)}")
+        print("\n❌ 第二次请求失败,可能是缓存格式问题")
+    
+    # 测试3:测试对话管理API
+    print("\n3. 测试对话管理API...")
+    try:
+        stats_response = requests.get(f"{base_url}/conversation_stats", timeout=5)
+        if stats_response.status_code == 200:
+            stats = stats_response.json()
+            if stats.get('success'):
+                print("   ✅ 对话统计API正常")
+                print(f"   总对话数: {stats.get('data', {}).get('total_conversations', 0)}")
+                print(f"   总用户数: {stats.get('data', {}).get('total_users', 0)}")
+            else:
+                print("   ⚠️ 对话统计API返回失败")
+        else:
+            print(f"   ❌ 对话统计API错误: {stats_response.status_code}")
+    except Exception as e:
+        print(f"   ❌ 对话统计API异常: {str(e)}")
+    
+    print("\n=== 测试完成 ===")
+    
+    # 返回测试结果
+    return {
+        "first_request_success": result1.get('success', False),
+        "second_request_success": result2.get('success', False),
+        "cache_working": result2.get('data', {}).get('from_cache', False) if result2.get('success') else False
+    }
+
+if __name__ == "__main__":
+    try:
+        results = test_ask_agent_basic()
+        
+        print("\n测试结果汇总:")
+        print(f"- 第一次请求: {'✅ 成功' if results['first_request_success'] else '❌ 失败'}")
+        print(f"- 第二次请求: {'✅ 成功' if results['second_request_success'] else '❌ 失败'}")
+        print(f"- 缓存功能: {'✅ 正常' if results['cache_working'] else '❌ 异常'}")
+        
+        if all(results.values()):
+            print("\n🎉 所有测试通过!Redis集成修复成功!")
+        else:
+            print("\n❗ 部分测试失败,请检查日志")
+            
+    except Exception as e:
+        print(f"\n❌ 测试异常: {str(e)}")
+        print("请确保Flask服务正在运行 (python citu_app.py)") 

+ 94 - 0
test/test_routing_modes.py

@@ -0,0 +1,94 @@
+# test_routing_modes.py - 测试不同路由模式的功能
+
+import sys
+import os
+# 添加项目根目录到sys.path,以便导入app_config.py
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+def test_routing_modes():
+    """测试不同路由模式的配置和分类器行为"""
+    
+    print("=== 路由模式测试 ===")
+    
+    # 1. 测试配置参数
+    try:
+        import app_config
+        print(f"✓ 配置导入成功")
+        print(f"当前路由模式: {getattr(app_config, 'QUESTION_ROUTING_MODE', '未找到')}")
+    except ImportError as e:
+        print(f"✗ 配置导入失败: {e}")
+        return False
+    
+    # 2. 测试分类器
+    try:
+        from agent.classifier import QuestionClassifier, ClassificationResult
+        classifier = QuestionClassifier()
+        print(f"✓ 分类器创建成功")
+        
+        # 测试问题
+        test_questions = [
+            "查询本月服务区营业额",
+            "你好,请介绍一下平台功能",
+            "请问负责每个服务区的经理的名字是什么?"
+        ]
+        
+        # 临时修改路由模式进行测试
+        original_mode = getattr(app_config, 'QUESTION_ROUTING_MODE', 'hybrid')
+        
+        for mode in ["hybrid", "llm_only", "database_direct", "chat_direct"]:
+            print(f"\n--- 测试路由模式: {mode} ---")
+            app_config.QUESTION_ROUTING_MODE = mode
+            
+            for question in test_questions:
+                try:
+                    result = classifier.classify(question)
+                    print(f"问题: {question}")
+                    print(f"  分类: {result.question_type}")
+                    print(f"  置信度: {result.confidence}")
+                    print(f"  方法: {result.method}")
+                    print(f"  理由: {result.reason[:50]}...")
+                except Exception as e:
+                    print(f"  分类异常: {e}")
+        
+        # 恢复原始配置
+        app_config.QUESTION_ROUTING_MODE = original_mode
+        print(f"\n✓ 分类器测试完成")
+        
+    except ImportError as e:
+        print(f"✗ 分类器导入失败: {e}")
+        return False
+    except Exception as e:
+        print(f"✗ 分类器测试异常: {e}")
+        return False
+    
+    # 3. 测试Agent状态
+    try:
+        from agent.state import AgentState
+        print(f"✓ Agent状态定义正确")
+    except ImportError as e:
+        print(f"✗ Agent状态导入失败: {e}")
+        return False
+    
+    # 4. 测试Agent工作流创建(基础测试,不实际运行)
+    try:
+        from agent.citu_agent import CituLangGraphAgent
+        print(f"✓ Agent类导入成功")
+        
+        # 注意:这里只测试导入,不实际创建Agent实例
+        # 因为可能涉及LLM连接等复杂依赖
+        
+    except ImportError as e:
+        print(f"✗ Agent类导入失败: {e}")
+        return False
+    except Exception as e:
+        print(f"警告: Agent相关模块可能有依赖问题: {e}")
+    
+    print(f"\n=== 路由模式测试完成 ===")
+    return True
+
+if __name__ == "__main__":
+    success = test_routing_modes()
+    if success:
+        print("✓ 所有测试通过!路由模式功能实现成功!")
+    else:
+        print("✗ 测试失败,请检查实现。")

+ 146 - 0
test/test_thinking_control.py

@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+"""
+测试thinking内容控制功能
+验证DISPLAY_RESULT_THINKING参数是否正确控制thinking内容的显示/隐藏
+"""
+
+import sys
+import os
+
+# 添加项目根目录到Python路径
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+def test_thinking_removal():
+    """测试thinking内容移除功能"""
+    from customllm.base_llm_chat import BaseLLMChat
+    
+    # 创建一个测试类来测试_remove_thinking_content方法
+    class TestLLM(BaseLLMChat):
+        def submit_prompt(self, prompt, **kwargs):
+            return "测试响应"
+    
+    # 创建测试实例
+    test_llm = TestLLM(config={})
+    
+    # 测试用例
+    test_cases = [
+        # 基本thinking标签
+        {
+            "input": "<think>这是思考内容</think>这是最终答案",
+            "expected": "这是最终答案"
+        },
+        # 多行thinking标签
+        {
+            "input": "<think>\n这是多行\n思考内容\n</think>\n\n这是最终答案",
+            "expected": "这是最终答案"
+        },
+        # 大小写不敏感
+        {
+            "input": "<THINK>大写思考</THINK>最终答案",
+            "expected": "最终答案"
+        },
+        # 多个thinking标签
+        {
+            "input": "<think>第一段思考</think>中间内容<think>第二段思考</think>最终答案",
+            "expected": "中间内容最终答案"
+        },
+        # 没有thinking标签
+        {
+            "input": "这是没有thinking标签的普通文本",
+            "expected": "这是没有thinking标签的普通文本"
+        },
+        # 空文本
+        {
+            "input": "",
+            "expected": ""
+        },
+        # None输入
+        {
+            "input": None,
+            "expected": None
+        }
+    ]
+    
+    print("=== 测试thinking内容移除功能 ===")
+    
+    for i, test_case in enumerate(test_cases, 1):
+        input_text = test_case["input"]
+        expected = test_case["expected"]
+        
+        result = test_llm._remove_thinking_content(input_text)
+        
+        if result == expected:
+            print(f"✅ 测试用例 {i}: 通过")
+        else:
+            print(f"❌ 测试用例 {i}: 失败")
+            print(f"   输入: {repr(input_text)}")
+            print(f"   期望: {repr(expected)}")
+            print(f"   实际: {repr(result)}")
+    
+    print()
+
+def test_config_integration():
+    """测试配置集成"""
+    print("=== 测试配置集成 ===")
+    
+    try:
+        from app_config import DISPLAY_RESULT_THINKING
+        print(f"✅ 成功导入配置: DISPLAY_RESULT_THINKING = {DISPLAY_RESULT_THINKING}")
+        
+        from customllm.base_llm_chat import BaseLLMChat
+        print("✅ 成功导入BaseLLMChat类")
+        
+        # 检查类中是否正确导入了配置
+        import customllm.base_llm_chat as base_module
+        if hasattr(base_module, 'DISPLAY_RESULT_THINKING'):
+            print(f"✅ BaseLLMChat模块中的配置: DISPLAY_RESULT_THINKING = {base_module.DISPLAY_RESULT_THINKING}")
+        else:
+            print("❌ BaseLLMChat模块中未找到DISPLAY_RESULT_THINKING配置")
+            
+    except ImportError as e:
+        print(f"❌ 导入失败: {e}")
+    
+    print()
+
+def test_vanna_instance():
+    """测试Vanna实例的thinking处理"""
+    print("=== 测试Vanna实例thinking处理 ===")
+    
+    try:
+        from common.vanna_instance import get_vanna_instance
+        vn = get_vanna_instance()
+        
+        print(f"✅ 成功获取Vanna实例: {type(vn).__name__}")
+        
+        # 检查实例是否有_remove_thinking_content方法
+        if hasattr(vn, '_remove_thinking_content'):
+            print("✅ Vanna实例具有_remove_thinking_content方法")
+            
+            # 测试方法
+            test_text = "<think>测试思考</think>测试结果"
+            cleaned = vn._remove_thinking_content(test_text)
+            if cleaned == "测试结果":
+                print("✅ thinking内容移除功能正常工作")
+            else:
+                print(f"❌ thinking内容移除异常: {repr(cleaned)}")
+        else:
+            print("❌ Vanna实例缺少_remove_thinking_content方法")
+            
+    except Exception as e:
+        print(f"❌ 测试Vanna实例失败: {e}")
+    
+    print()
+
+def main():
+    """主测试函数"""
+    print("开始测试thinking内容控制功能...\n")
+    
+    # 运行所有测试
+    test_thinking_removal()
+    test_config_integration()
+    test_vanna_instance()
+    
+    print("测试完成!")
+
+if __name__ == "__main__":
+    main() 

+ 294 - 0
test/test_training_integration.py

@@ -0,0 +1,294 @@
+#!/usr/bin/env python3
+"""
+测试training目录的代码集成
+验证训练相关的模块是否能正常工作
+"""
+
+def test_training_imports():
+    """测试训练模块的导入"""
+    print("=== 测试训练模块导入 ===")
+    
+    try:
+        # 测试从training包导入
+        from training import (
+            train_ddl,
+            train_documentation,
+            train_sql_example,
+            train_question_sql_pair,
+            flush_training,
+            shutdown_trainer
+        )
+        print("✅ 成功从training包导入所有函数")
+        
+        # 测试直接导入
+        from training.vanna_trainer import BatchProcessor
+        print("✅ 成功导入BatchProcessor类")
+        
+        return True
+        
+    except ImportError as e:
+        print(f"❌ 导入失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_config_access():
+    """测试配置访问"""
+    print("\n=== 测试配置访问 ===")
+    
+    try:
+        import app_config
+        
+        # 测试训练批处理配置
+        batch_enabled = getattr(app_config, 'TRAINING_BATCH_PROCESSING_ENABLED', None)
+        batch_size = getattr(app_config, 'TRAINING_BATCH_SIZE', None)
+        max_workers = getattr(app_config, 'TRAINING_MAX_WORKERS', None)
+        
+        print(f"批处理启用: {batch_enabled}")
+        print(f"批处理大小: {batch_size}")
+        print(f"最大工作线程: {max_workers}")
+        
+        if batch_enabled is not None and batch_size is not None and max_workers is not None:
+            print("✅ 训练批处理配置正常")
+        else:
+            print("⚠️  部分训练批处理配置缺失")
+        
+        # 测试向量数据库配置
+        vector_db_name = getattr(app_config, 'VECTOR_DB_NAME', None)
+        print(f"向量数据库类型: {vector_db_name}")
+        
+        if vector_db_name == "pgvector":
+            pgvector_config = getattr(app_config, 'PGVECTOR_CONFIG', None)
+            if pgvector_config:
+                print("✅ PgVector配置存在")
+            else:
+                print("❌ PgVector配置缺失")
+        
+        # 测试新的配置工具函数
+        try:
+            from common.utils import get_current_embedding_config, get_current_model_info
+            
+            embedding_config = get_current_embedding_config()
+            model_info = get_current_model_info()
+            
+            print(f"当前embedding类型: {model_info['embedding_type']}")
+            print(f"当前embedding模型: {model_info['embedding_model']}")
+            print("✅ 新配置工具函数正常工作")
+            
+        except Exception as e:
+            print(f"⚠️  新配置工具函数测试失败: {e}")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ 配置访问测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_vanna_instance_creation():
+    """测试Vanna实例创建"""
+    print("\n=== 测试Vanna实例创建 ===")
+    
+    try:
+        from vanna_llm_factory import create_vanna_instance
+        
+        print("尝试创建Vanna实例...")
+        vn = create_vanna_instance()
+        
+        print(f"✅ 成功创建Vanna实例: {type(vn).__name__}")
+        
+        # 测试基本方法是否存在
+        required_methods = ['train', 'generate_question', 'get_training_data']
+        for method in required_methods:
+            if hasattr(vn, method):
+                print(f"✅ 方法 {method} 存在")
+            else:
+                print(f"⚠️  方法 {method} 不存在")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ Vanna实例创建失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_batch_processor():
+    """测试批处理器"""
+    print("\n=== 测试批处理器 ===")
+    
+    try:
+        from training.vanna_trainer import BatchProcessor
+        import app_config
+        
+        # 创建测试批处理器
+        batch_size = getattr(app_config, 'TRAINING_BATCH_SIZE', 5)
+        max_workers = getattr(app_config, 'TRAINING_MAX_WORKERS', 2)
+        
+        processor = BatchProcessor(batch_size=batch_size, max_workers=max_workers)
+        print(f"✅ 成功创建BatchProcessor实例")
+        print(f"   批处理大小: {processor.batch_size}")
+        print(f"   最大工作线程: {processor.max_workers}")
+        print(f"   批处理启用: {processor.batch_enabled}")
+        
+        # 测试关闭
+        processor.shutdown()
+        print("✅ 批处理器关闭成功")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ 批处理器测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_training_functions():
+    """测试训练函数(不实际训练)"""
+    print("\n=== 测试训练函数 ===")
+    
+    try:
+        from training import (
+            train_ddl,
+            train_documentation,
+            train_sql_example,
+            train_question_sql_pair,
+            flush_training,
+            shutdown_trainer
+        )
+        
+        print("✅ 所有训练函数导入成功")
+        
+        # 测试函数是否可调用
+        functions_to_test = [
+            ('train_ddl', train_ddl),
+            ('train_documentation', train_documentation),
+            ('train_sql_example', train_sql_example),
+            ('train_question_sql_pair', train_question_sql_pair),
+            ('flush_training', flush_training),
+            ('shutdown_trainer', shutdown_trainer)
+        ]
+        
+        for func_name, func in functions_to_test:
+            if callable(func):
+                print(f"✅ {func_name} 是可调用的")
+            else:
+                print(f"❌ {func_name} 不可调用")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ 训练函数测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_embedding_connection():
+    """测试embedding连接"""
+    print("\n=== 测试Embedding连接 ===")
+    
+    try:
+        from embedding_function import test_embedding_connection
+        
+        print("测试embedding模型连接...")
+        result = test_embedding_connection()
+        
+        if result["success"]:
+            print(f"✅ Embedding连接成功: {result['message']}")
+        else:
+            print(f"⚠️  Embedding连接失败: {result['message']}")
+            print("   这可能是因为API服务未启动或配置不正确")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ Embedding连接测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_run_training_script():
+    """测试run_training.py脚本的基本功能"""
+    print("\n=== 测试run_training.py脚本 ===")
+    
+    try:
+        # 导入run_training模块
+        import sys
+        import os
+        
+        # 添加training目录到路径
+        training_dir = os.path.join(os.path.dirname(__file__), 'training')
+        if training_dir not in sys.path:
+            sys.path.insert(0, training_dir)
+        
+        # 导入run_training模块的函数
+        from training.run_training import (
+            read_file_by_delimiter,
+            read_markdown_file_by_sections,
+            check_pgvector_connection
+        )
+        
+        print("✅ 成功导入run_training模块的函数")
+        
+        # 测试文件读取函数
+        test_content = "section1---section2---section3"
+        with open("test_temp.txt", "w", encoding="utf-8") as f:
+            f.write(test_content)
+        
+        try:
+            sections = read_file_by_delimiter("test_temp.txt", "---")
+            if len(sections) == 3:
+                print("✅ read_file_by_delimiter 函数正常工作")
+            else:
+                print(f"⚠️  read_file_by_delimiter 返回了 {len(sections)} 个部分,期望 3 个")
+        finally:
+            if os.path.exists("test_temp.txt"):
+                os.remove("test_temp.txt")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ run_training.py脚本测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def main():
+    """主测试函数"""
+    print("开始测试training目录的代码集成...")
+    print("=" * 60)
+    
+    results = []
+    
+    # 运行所有测试
+    results.append(("训练模块导入", test_training_imports()))
+    results.append(("配置访问", test_config_access()))
+    results.append(("Vanna实例创建", test_vanna_instance_creation()))
+    results.append(("批处理器", test_batch_processor()))
+    results.append(("训练函数", test_training_functions()))
+    results.append(("Embedding连接", test_embedding_connection()))
+    results.append(("run_training脚本", test_run_training_script()))
+    
+    # 总结
+    print(f"\n{'='*60}")
+    print("测试结果总结:")
+    print("=" * 60)
+    
+    for test_name, success in results:
+        status = "✅ 通过" if success else "❌ 失败"
+        print(f"{test_name}: {status}")
+    
+    total_passed = sum(1 for _, success in results if success)
+    print(f"\n总计: {total_passed}/{len(results)} 个测试通过")
+    
+    if total_passed == len(results):
+        print("🎉 所有测试都通过了!training目录的代码可以正常工作。")
+    elif total_passed >= len(results) - 1:
+        print("✅ 大部分测试通过,training目录的代码基本可以正常工作。")
+        print("   部分失败可能是由于服务未启动或配置问题。")
+    else:
+        print("⚠️  多个测试失败,请检查相关依赖和配置。")
+
+if __name__ == "__main__":
+    main() 

+ 235 - 0
test/test_vanna_combinations.py

@@ -0,0 +1,235 @@
+#!/usr/bin/env python3
+"""
+测试统一的Vanna组合类文件
+验证common/vanna_combinations.py中的功能
+"""
+
+def test_import_combinations():
+    """测试导入组合类"""
+    print("=== 测试导入组合类 ===")
+    
+    try:
+        from common.vanna_combinations import (
+            Vanna_Qwen_ChromaDB,
+            Vanna_DeepSeek_ChromaDB,
+            Vanna_Qwen_PGVector,
+            Vanna_DeepSeek_PGVector,
+            Vanna_Ollama_ChromaDB,
+            Vanna_Ollama_PGVector,
+            get_vanna_class,
+            list_available_combinations,
+            print_available_combinations
+        )
+        print("✅ 成功导入所有组合类和工具函数")
+        return True
+    except ImportError as e:
+        print(f"❌ 导入失败: {e}")
+        return False
+
+def test_get_vanna_class():
+    """测试get_vanna_class函数"""
+    print("\n=== 测试get_vanna_class函数 ===")
+    
+    try:
+        from common.vanna_combinations import get_vanna_class
+        
+        # 测试有效组合
+        test_cases = [
+            ("qwen", "chromadb"),
+            ("deepseek", "chromadb"),
+            ("qwen", "pgvector"),
+            ("deepseek", "pgvector"),
+            ("ollama", "chromadb"),
+            ("ollama", "pgvector"),
+        ]
+        
+        for llm_type, vector_db_type in test_cases:
+            try:
+                cls = get_vanna_class(llm_type, vector_db_type)
+                print(f"✅ {llm_type} + {vector_db_type} -> {cls.__name__}")
+            except Exception as e:
+                print(f"⚠️  {llm_type} + {vector_db_type} -> 错误: {e}")
+        
+        # 测试无效组合
+        print("\n测试无效组合:")
+        try:
+            get_vanna_class("invalid_llm", "chromadb")
+            print("❌ 应该抛出异常但没有")
+            return False
+        except ValueError:
+            print("✅ 正确处理无效LLM类型")
+        
+        try:
+            get_vanna_class("qwen", "invalid_db")
+            print("❌ 应该抛出异常但没有")
+            return False
+        except ValueError:
+            print("✅ 正确处理无效向量数据库类型")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ 测试失败: {e}")
+        return False
+
+def test_list_available_combinations():
+    """测试列出可用组合"""
+    print("\n=== 测试列出可用组合 ===")
+    
+    try:
+        from common.vanna_combinations import list_available_combinations, print_available_combinations
+        
+        # 获取可用组合
+        combinations = list_available_combinations()
+        print(f"可用组合数据结构: {combinations}")
+        
+        # 打印可用组合
+        print("\n打印可用组合:")
+        print_available_combinations()
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ 测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_class_instantiation():
+    """测试类实例化(不需要实际服务)"""
+    print("\n=== 测试类实例化 ===")
+    
+    try:
+        from common.vanna_combinations import get_vanna_class
+        
+        # 测试ChromaDB组合(通常可用)
+        test_cases = [
+            ("qwen", "chromadb"),
+            ("deepseek", "chromadb"),
+        ]
+        
+        for llm_type, vector_db_type in test_cases:
+            try:
+                cls = get_vanna_class(llm_type, vector_db_type)
+                
+                # 尝试创建实例(使用空配置)
+                instance = cls(config={})
+                print(f"✅ 成功创建 {cls.__name__} 实例")
+                
+                # 检查实例类型
+                print(f"   实例类型: {type(instance)}")
+                print(f"   MRO: {[c.__name__ for c in type(instance).__mro__[:3]]}")
+                
+            except Exception as e:
+                print(f"⚠️  创建 {llm_type}+{vector_db_type} 实例失败: {e}")
+        
+        return True
+        
+    except Exception as e:
+        print(f"❌ 测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def test_factory_integration():
+    """测试与工厂函数的集成"""
+    print("\n=== 测试与工厂函数的集成 ===")
+    
+    try:
+        import app_config
+        from common.utils import print_current_config
+        
+        # 保存原始配置
+        original_llm_type = app_config.LLM_MODEL_TYPE
+        original_embedding_type = app_config.EMBEDDING_MODEL_TYPE
+        original_vector_db = app_config.VECTOR_DB_NAME
+        
+        try:
+            # 测试不同配置
+            test_configs = [
+                ("api", "api", "qwen", "chromadb"),
+                ("api", "api", "deepseek", "chromadb"),
+                ("ollama", "ollama", None, "chromadb"),
+            ]
+            
+            for llm_type, emb_type, llm_name, vector_db in test_configs:
+                print(f"\n--- 测试配置: LLM={llm_type}, EMB={emb_type}, MODEL={llm_name}, DB={vector_db} ---")
+                
+                # 设置配置
+                app_config.LLM_MODEL_TYPE = llm_type
+                app_config.EMBEDDING_MODEL_TYPE = emb_type
+                if llm_name:
+                    app_config.LLM_MODEL_NAME = llm_name
+                app_config.VECTOR_DB_NAME = vector_db
+                
+                # 打印当前配置
+                print_current_config()
+                
+                # 测试工厂函数(不实际创建实例,只测试类选择)
+                try:
+                    from vanna_llm_factory import create_vanna_instance
+                    from common.utils import get_current_model_info, is_using_ollama_llm
+                    from common.vanna_combinations import get_vanna_class
+                    
+                    model_info = get_current_model_info()
+                    
+                    if is_using_ollama_llm():
+                        selected_llm_type = "ollama"
+                    else:
+                        selected_llm_type = model_info["llm_model"].lower()
+                    
+                    selected_vector_db = model_info["vector_db"].lower()
+                    
+                    cls = get_vanna_class(selected_llm_type, selected_vector_db)
+                    print(f"✅ 工厂函数会选择: {cls.__name__}")
+                    
+                except Exception as e:
+                    print(f"⚠️  工厂函数测试失败: {e}")
+            
+            return True
+            
+        finally:
+            # 恢复原始配置
+            app_config.LLM_MODEL_TYPE = original_llm_type
+            app_config.EMBEDDING_MODEL_TYPE = original_embedding_type
+            app_config.VECTOR_DB_NAME = original_vector_db
+            
+    except Exception as e:
+        print(f"❌ 测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+def main():
+    """主测试函数"""
+    print("开始测试统一的Vanna组合类...")
+    print("=" * 60)
+    
+    results = []
+    
+    # 运行所有测试
+    results.append(("导入组合类", test_import_combinations()))
+    results.append(("get_vanna_class函数", test_get_vanna_class()))
+    results.append(("列出可用组合", test_list_available_combinations()))
+    results.append(("类实例化", test_class_instantiation()))
+    results.append(("工厂函数集成", test_factory_integration()))
+    
+    # 总结
+    print(f"\n{'='*60}")
+    print("测试结果总结:")
+    print("=" * 60)
+    
+    for test_name, success in results:
+        status = "✅ 通过" if success else "❌ 失败"
+        print(f"{test_name}: {status}")
+    
+    total_passed = sum(1 for _, success in results if success)
+    print(f"\n总计: {total_passed}/{len(results)} 个测试通过")
+    
+    if total_passed == len(results):
+        print("🎉 所有测试都通过了!统一组合类文件工作正常。")
+    else:
+        print("⚠️  部分测试失败,请检查相关依赖和配置。")
+
+if __name__ == "__main__":
+    main() 

+ 103 - 0
test/test_vanna_singleton.py

@@ -0,0 +1,103 @@
+"""
+测试 Vanna 单例模式是否正常工作
+"""
+import sys
+import os
+
+# 添加项目根目录到路径
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+def test_vanna_singleton():
+    """测试 Vanna 单例模式"""
+    from common.vanna_instance import get_vanna_instance, get_instance_status
+    
+    print("=" * 50)
+    print("测试 Vanna 单例模式")
+    print("=" * 50)
+    
+    # 检查初始状态
+    status = get_instance_status()
+    print(f"初始状态: {status}")
+    
+    # 第一次获取实例
+    print("\n第一次获取实例...")
+    instance1 = get_vanna_instance()
+    print(f"实例1 ID: {id(instance1)}")
+    print(f"实例1 类型: {type(instance1)}")
+    
+    # 第二次获取实例(应该是同一个)
+    print("\n第二次获取实例...")
+    instance2 = get_vanna_instance()
+    print(f"实例2 ID: {id(instance2)}")
+    print(f"实例2 类型: {type(instance2)}")
+    
+    # 验证是否为同一个实例
+    is_same = instance1 is instance2
+    print(f"\n实例是否相同: {is_same}")
+    
+    # 检查最终状态
+    final_status = get_instance_status()
+    print(f"最终状态: {final_status}")
+    
+    if is_same:
+        print("\n✅ 单例模式测试通过!")
+    else:
+        print("\n❌ 单例模式测试失败!")
+    
+    return is_same
+
+def test_import_from_tools():
+    """测试从工具文件导入是否正常"""
+    print("\n" + "=" * 50)
+    print("测试从工具文件导入")
+    print("=" * 50)
+    
+    try:
+        # 导入工具模块
+        from agent.tools.sql_generation import get_vanna_instance as gen_instance
+        from agent.tools.sql_execution import get_vanna_instance as exec_instance
+        from agent.tools.summary_generation import get_vanna_instance as sum_instance
+        
+        # 获取实例
+        instance_gen = gen_instance()
+        instance_exec = exec_instance()
+        instance_sum = sum_instance()
+        
+        print(f"SQL生成工具实例 ID: {id(instance_gen)}")
+        print(f"SQL执行工具实例 ID: {id(instance_exec)}")
+        print(f"摘要生成工具实例 ID: {id(instance_sum)}")
+        
+        # 验证是否都是同一个实例
+        all_same = (instance_gen is instance_exec) and (instance_exec is instance_sum)
+        
+        if all_same:
+            print("\n✅ 工具导入测试通过!所有工具使用同一个实例")
+        else:
+            print("\n❌ 工具导入测试失败!工具使用不同的实例")
+        
+        return all_same
+        
+    except Exception as e:
+        print(f"\n❌ 导入测试异常: {str(e)}")
+        return False
+
+if __name__ == "__main__":
+    try:
+        singleton_test = test_vanna_singleton()
+        import_test = test_import_from_tools()
+        
+        print("\n" + "=" * 50)
+        print("测试总结")
+        print("=" * 50)
+        print(f"单例模式测试: {'通过' if singleton_test else '失败'}")
+        print(f"工具导入测试: {'通过' if import_test else '失败'}")
+        
+        if singleton_test and import_test:
+            print("\n🎉 所有测试通过!Vanna 单例模式工作正常")
+        else:
+            print("\n⚠️  存在测试失败,请检查实现")
+            
+    except Exception as e:
+        print(f"测试执行异常: {str(e)}")
+        import traceback
+        traceback.print_exc() 

+ 136 - 0
test/vanna_test.ipynb

@@ -0,0 +1,136 @@
+{
+  "cells": [
+    {
+      "cell_type": "raw",
+      "metadata": {
+        "vscode": {
+          "languageId": "raw"
+        }
+      },
+      "source": [
+        "# Vanna Chainlit ChromaDB 测试 Notebook\n",
+        "\n",
+        "这个 Notebook 用于测试项目的各种功能和 API。\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# 导入必要的模块\n",
+        "import sys\n",
+        "import os\n",
+        "\n",
+        "# 添加项目根目录到 Python 路径\n",
+        "sys.path.append(os.path.join(os.path.dirname(os.getcwd())))\n",
+        "\n",
+        "print(\"项目路径已添加到 Python 路径\")\n"
+      ]
+    },
+    {
+      "cell_type": "raw",
+      "metadata": {
+        "vscode": {
+          "languageId": "raw"
+        }
+      },
+      "source": [
+        "## 1. 测试配置加载\n",
+        "\n",
+        "测试项目的各种配置是否能正常加载。\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# 测试配置加载\n",
+        "try:\n",
+        "    import app_config\n",
+        "    print(\"配置加载成功!\")\n",
+        "    print(f\"LLM模型类型: {app_config.LLM_MODEL_TYPE}\")\n",
+        "    print(f\"API LLM模型: {app_config.API_LLM_MODEL}\")\n",
+        "    print(f\"向量数据库类型: {app_config.VECTOR_DB_TYPE}\")\n",
+        "except Exception as e:\n",
+        "    print(f\"配置加载失败: {e}\")\n"
+      ]
+    },
+    {
+      "cell_type": "raw",
+      "metadata": {
+        "vscode": {
+          "languageId": "raw"
+        }
+      },
+      "source": [
+        "## 2. 测试数据管道工具\n",
+        "\n",
+        "测试数据管道模块的配置和功能。\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# 测试数据管道模块\n",
+        "try:\n",
+        "    from data_pipeline.config import SCHEMA_TOOLS_CONFIG\n",
+        "    print(\"数据管道配置加载成功!\")\n",
+        "    print(f\"输出目录: {SCHEMA_TOOLS_CONFIG['output_directory']}\")\n",
+        "    print(f\"最大表数量: {SCHEMA_TOOLS_CONFIG['qs_generation']['max_tables']}\")\n",
+        "except Exception as e:\n",
+        "    print(f\"数据管道配置加载失败: {e}\")\n"
+      ]
+    },
+    {
+      "cell_type": "raw",
+      "metadata": {
+        "vscode": {
+          "languageId": "raw"
+        }
+      },
+      "source": [
+        "## 总结\n",
+        "\n",
+        "这个 Notebook 用于测试项目的各个组件,包括:\n",
+        "- 配置加载\n",
+        "- 数据管道工具\n",
+        "- Vanna 实例创建\n",
+        "- 工具函数\n",
+        "- 日志系统\n",
+        "\n",
+        "可以根据需要添加更多的测试用例。\n",
+        "\n",
+        "### 使用说明\n",
+        "\n",
+        "1. 确保已激活项目的虚拟环境:\n",
+        "   ```bash\n",
+        "   .\\.venv\\Scripts\\Activate.ps1\n",
+        "   ```\n",
+        "\n",
+        "2. 安装 Jupyter(如果尚未安装):\n",
+        "   ```bash\n",
+        "   pip install jupyter\n",
+        "   ```\n",
+        "\n",
+        "3. 启动 Jupyter:\n",
+        "   ```bash\n",
+        "   jupyter notebook\n",
+        "   ```\n"
+      ]
+    }
+  ],
+  "metadata": {
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 2
+}