123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- """
- QA反馈数据管理器 - 复用Vanna连接版本
- 用于处理用户对问答结果的点赞/点踩反馈,并将反馈转化为训练数据
- """
- import app_config
- from sqlalchemy import create_engine, text, MetaData, Table, Column, Integer, String, Boolean, DateTime, func
- from sqlalchemy.exc import OperationalError, ProgrammingError
- from datetime import datetime
- from typing import List, Dict, Any, Optional, Tuple
- import logging
- class QAFeedbackManager:
- """QA反馈数据管理器 - 复用Vanna连接版本"""
-
- def __init__(self, vanna_instance=None):
- """初始化数据库连接
-
- Args:
- vanna_instance: 可选的vanna实例,用于复用其数据库连接
- """
- self.engine = None
- self.vanna_instance = vanna_instance
- self._init_database_connection()
- self._ensure_table_exists()
-
- def _init_database_connection(self):
- """初始化数据库连接"""
- try:
- # 方案1: 优先尝试复用vanna连接
- if self.vanna_instance and hasattr(self.vanna_instance, 'engine'):
- self.engine = self.vanna_instance.engine
- print(f"[QAFeedbackManager] 复用Vanna数据库连接")
- return
-
- # 方案2: 创建新的连接(原有方式)
- db_config = app_config.APP_DB_CONFIG
- connection_string = (
- f"postgresql://{db_config['user']}:{db_config['password']}"
- f"@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"
- )
- # 使用连接池配置
- self.engine = create_engine(
- connection_string,
- echo=False,
- pool_size=5, # 连接池大小
- max_overflow=10, # 最大溢出连接数
- pool_timeout=30, # 获取连接超时
- pool_recycle=3600 # 连接回收时间(1小时)
- )
-
- # 测试连接
- with self.engine.connect() as conn:
- conn.execute(text("SELECT 1"))
-
- print(f"[QAFeedbackManager] 数据库连接成功: {db_config['host']}:{db_config['port']}/{db_config['dbname']}")
-
- except Exception as e:
- print(f"[ERROR] QAFeedbackManager数据库连接失败: {e}")
- raise
-
- def _ensure_table_exists(self):
- """检查并创建qa_feedback表"""
- create_table_sql = """
- CREATE TABLE IF NOT EXISTS qa_feedback (
- id SERIAL PRIMARY KEY,
- question TEXT NOT NULL,
- sql TEXT NOT NULL,
- is_thumb_up BOOLEAN NOT NULL,
- user_id VARCHAR(64) NOT NULL,
- create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- is_in_training_data BOOLEAN DEFAULT FALSE,
- update_time TIMESTAMP
- );
- """
-
- # 创建索引SQL
- create_indexes_sql = [
- "CREATE INDEX IF NOT EXISTS idx_qa_feedback_user_id ON qa_feedback(user_id);",
- "CREATE INDEX IF NOT EXISTS idx_qa_feedback_create_time ON qa_feedback(create_time);",
- "CREATE INDEX IF NOT EXISTS idx_qa_feedback_is_thumb_up ON qa_feedback(is_thumb_up);",
- "CREATE INDEX IF NOT EXISTS idx_qa_feedback_is_in_training ON qa_feedback(is_in_training_data);"
- ]
-
- try:
- with self.engine.connect() as conn:
- with conn.begin():
- # 创建表
- conn.execute(text(create_table_sql))
-
- # 创建索引
- for index_sql in create_indexes_sql:
- conn.execute(text(index_sql))
-
- print("[QAFeedbackManager] qa_feedback表检查/创建成功")
-
- except Exception as e:
- print(f"[ERROR] qa_feedback表创建失败: {e}")
- raise
-
- def add_feedback(self, question: str, sql: str, is_thumb_up: bool, user_id: str = "guest") -> int:
- """添加反馈记录
-
- Args:
- question: 用户问题
- sql: 生成的SQL
- is_thumb_up: 是否点赞
- user_id: 用户ID
-
- Returns:
- 新创建记录的ID
- """
- insert_sql = """
- INSERT INTO qa_feedback (question, sql, is_thumb_up, user_id, create_time)
- VALUES (:question, :sql, :is_thumb_up, :user_id, :create_time)
- RETURNING id
- """
-
- try:
- with self.engine.connect() as conn:
- with conn.begin():
- result = conn.execute(text(insert_sql), {
- 'question': question,
- 'sql': sql,
- 'is_thumb_up': is_thumb_up,
- 'user_id': user_id,
- 'create_time': datetime.now()
- })
- feedback_id = result.fetchone()[0]
-
- print(f"[QAFeedbackManager] 反馈记录创建成功, ID: {feedback_id}")
- return feedback_id
-
- except Exception as e:
- print(f"[ERROR] 添加反馈记录失败: {e}")
- raise
-
- def query_feedback(self, page: int = 1, page_size: int = 20,
- is_thumb_up: Optional[bool] = None,
- create_time_start: Optional[str] = None,
- create_time_end: Optional[str] = None,
- is_in_training_data: Optional[bool] = None,
- sort_by: str = "create_time",
- sort_order: str = "desc") -> Tuple[List[Dict], int]:
- """查询反馈记录
-
- Args:
- page: 页码 (从1开始)
- page_size: 每页大小
- is_thumb_up: 是否点赞筛选
- create_time_start: 创建时间开始
- create_time_end: 创建时间结束
- is_in_training_data: 是否已加入训练数据
- sort_by: 排序字段
- sort_order: 排序方向 (asc/desc)
-
- Returns:
- (记录列表, 总数)
- """
- # 构建WHERE条件
- where_conditions = []
- params = {}
-
- if is_thumb_up is not None:
- where_conditions.append("is_thumb_up = :is_thumb_up")
- params['is_thumb_up'] = is_thumb_up
-
- if create_time_start:
- where_conditions.append("create_time >= :create_time_start")
- params['create_time_start'] = create_time_start
-
- if create_time_end:
- where_conditions.append("create_time <= :create_time_end")
- params['create_time_end'] = create_time_end
-
- if is_in_training_data is not None:
- where_conditions.append("is_in_training_data = :is_in_training_data")
- params['is_in_training_data'] = is_in_training_data
-
- where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else ""
-
- # 验证排序参数
- valid_sort_fields = ['id', 'create_time', 'update_time', 'user_id']
- if sort_by not in valid_sort_fields:
- sort_by = 'create_time'
-
- if sort_order.lower() not in ['asc', 'desc']:
- sort_order = 'desc'
-
- # 计算OFFSET
- offset = (page - 1) * page_size
-
- # 查询数据
- query_sql = f"""
- SELECT id, question, sql, is_thumb_up, user_id, create_time,
- is_in_training_data, update_time
- FROM qa_feedback
- {where_clause}
- ORDER BY {sort_by} {sort_order.upper()}
- LIMIT :limit OFFSET :offset
- """
-
- # 查询总数
- count_sql = f"""
- SELECT COUNT(*) as total
- FROM qa_feedback
- {where_clause}
- """
-
- try:
- with self.engine.connect() as conn:
- # 查询数据
- params.update({'limit': page_size, 'offset': offset})
- result = conn.execute(text(query_sql), params)
- records = []
-
- for row in result:
- records.append({
- 'id': row.id,
- 'question': row.question,
- 'sql': row.sql,
- 'is_thumb_up': row.is_thumb_up,
- 'user_id': row.user_id,
- 'create_time': row.create_time.isoformat() if row.create_time else None,
- 'is_in_training_data': row.is_in_training_data,
- 'update_time': row.update_time.isoformat() if row.update_time else None
- })
-
- # 查询总数
- count_result = conn.execute(text(count_sql), {k: v for k, v in params.items() if k not in ['limit', 'offset']})
- total = count_result.fetchone().total
-
- return records, total
-
- except Exception as e:
- print(f"[ERROR] 查询反馈记录失败: {e}")
- raise
-
- def delete_feedback(self, feedback_id: int) -> bool:
- """删除反馈记录
-
- Args:
- feedback_id: 反馈记录ID
-
- Returns:
- 删除是否成功
- """
- delete_sql = "DELETE FROM qa_feedback WHERE id = :id"
-
- try:
- with self.engine.connect() as conn:
- with conn.begin():
- result = conn.execute(text(delete_sql), {'id': feedback_id})
-
- if result.rowcount > 0:
- print(f"[QAFeedbackManager] 反馈记录删除成功, ID: {feedback_id}")
- return True
- else:
- print(f"[WARNING] 反馈记录不存在, ID: {feedback_id}")
- return False
-
- except Exception as e:
- print(f"[ERROR] 删除反馈记录失败: {e}")
- raise
-
- def update_feedback(self, feedback_id: int, **kwargs) -> bool:
- """更新反馈记录
-
- Args:
- feedback_id: 反馈记录ID
- **kwargs: 要更新的字段
-
- Returns:
- 更新是否成功
- """
- # 允许更新的字段
- allowed_fields = ['question', 'sql', 'is_thumb_up', 'user_id', 'is_in_training_data']
-
- update_fields = []
- params = {'id': feedback_id, 'update_time': datetime.now()}
-
- for field, value in kwargs.items():
- if field in allowed_fields:
- update_fields.append(f"{field} = :{field}")
- params[field] = value
-
- if not update_fields:
- print("[WARNING] 没有有效的更新字段")
- return False
-
- update_fields.append("update_time = :update_time")
-
- update_sql = f"""
- UPDATE qa_feedback
- SET {', '.join(update_fields)}
- WHERE id = :id
- """
-
- try:
- with self.engine.connect() as conn:
- with conn.begin():
- result = conn.execute(text(update_sql), params)
-
- if result.rowcount > 0:
- print(f"[QAFeedbackManager] 反馈记录更新成功, ID: {feedback_id}")
- return True
- else:
- print(f"[WARNING] 反馈记录不存在或无变化, ID: {feedback_id}")
- return False
-
- except Exception as e:
- print(f"[ERROR] 更新反馈记录失败: {e}")
- raise
-
- def get_feedback_by_ids(self, feedback_ids: List[int]) -> List[Dict]:
- """根据ID列表获取反馈记录
-
- Args:
- feedback_ids: 反馈记录ID列表
-
- Returns:
- 反馈记录列表
- """
- if not feedback_ids:
- return []
-
- # 构建IN查询
- placeholders = ','.join([f':id_{i}' for i in range(len(feedback_ids))])
- params = {f'id_{i}': feedback_id for i, feedback_id in enumerate(feedback_ids)}
-
- query_sql = f"""
- SELECT id, question, sql, is_thumb_up, user_id, create_time,
- is_in_training_data, update_time
- FROM qa_feedback
- WHERE id IN ({placeholders})
- """
-
- try:
- with self.engine.connect() as conn:
- result = conn.execute(text(query_sql), params)
- records = []
-
- for row in result:
- records.append({
- 'id': row.id,
- 'question': row.question,
- 'sql': row.sql,
- 'is_thumb_up': row.is_thumb_up,
- 'user_id': row.user_id,
- 'create_time': row.create_time,
- 'is_in_training_data': row.is_in_training_data,
- 'update_time': row.update_time
- })
-
- return records
-
- except Exception as e:
- print(f"[ERROR] 根据ID查询反馈记录失败: {e}")
- raise
-
- def mark_training_status(self, feedback_ids: List[int], status: bool = True) -> int:
- """批量标记训练状态
-
- Args:
- feedback_ids: 反馈记录ID列表
- status: 训练状态
-
- Returns:
- 更新的记录数
- """
- if not feedback_ids:
- return 0
-
- placeholders = ','.join([f':id_{i}' for i in range(len(feedback_ids))])
- params = {f'id_{i}': feedback_id for i, feedback_id in enumerate(feedback_ids)}
- params['status'] = status
- params['update_time'] = datetime.now()
-
- update_sql = f"""
- UPDATE qa_feedback
- SET is_in_training_data = :status, update_time = :update_time
- WHERE id IN ({placeholders})
- """
-
- try:
- with self.engine.connect() as conn:
- with conn.begin():
- result = conn.execute(text(update_sql), params)
-
- print(f"[QAFeedbackManager] 批量更新训练状态成功, 影响行数: {result.rowcount}")
- return result.rowcount
-
- except Exception as e:
- print(f"[ERROR] 批量更新训练状态失败: {e}")
- raise
|