qa_feedback_manager.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. """
  2. QA反馈数据管理器 - 复用Vanna连接版本
  3. 用于处理用户对问答结果的点赞/点踩反馈,并将反馈转化为训练数据
  4. """
  5. import app_config
  6. from sqlalchemy import create_engine, text, MetaData, Table, Column, Integer, String, Boolean, DateTime, func
  7. from sqlalchemy.exc import OperationalError, ProgrammingError
  8. from datetime import datetime
  9. from typing import List, Dict, Any, Optional, Tuple
  10. import logging
  11. class QAFeedbackManager:
  12. """QA反馈数据管理器 - 复用Vanna连接版本"""
  13. def __init__(self, vanna_instance=None):
  14. """初始化数据库连接
  15. Args:
  16. vanna_instance: 可选的vanna实例,用于复用其数据库连接
  17. """
  18. self.engine = None
  19. self.vanna_instance = vanna_instance
  20. self._init_database_connection()
  21. self._ensure_table_exists()
  22. def _init_database_connection(self):
  23. """初始化数据库连接"""
  24. try:
  25. # 方案1: 优先尝试复用vanna连接
  26. if self.vanna_instance and hasattr(self.vanna_instance, 'engine'):
  27. self.engine = self.vanna_instance.engine
  28. print(f"[QAFeedbackManager] 复用Vanna数据库连接")
  29. return
  30. # 方案2: 创建新的连接(原有方式)
  31. db_config = app_config.APP_DB_CONFIG
  32. connection_string = (
  33. f"postgresql://{db_config['user']}:{db_config['password']}"
  34. f"@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"
  35. )
  36. # 使用连接池配置
  37. self.engine = create_engine(
  38. connection_string,
  39. echo=False,
  40. pool_size=5, # 连接池大小
  41. max_overflow=10, # 最大溢出连接数
  42. pool_timeout=30, # 获取连接超时
  43. pool_recycle=3600 # 连接回收时间(1小时)
  44. )
  45. # 测试连接
  46. with self.engine.connect() as conn:
  47. conn.execute(text("SELECT 1"))
  48. print(f"[QAFeedbackManager] 数据库连接成功: {db_config['host']}:{db_config['port']}/{db_config['dbname']}")
  49. except Exception as e:
  50. print(f"[ERROR] QAFeedbackManager数据库连接失败: {e}")
  51. raise
  52. def _ensure_table_exists(self):
  53. """检查并创建qa_feedback表"""
  54. create_table_sql = """
  55. CREATE TABLE IF NOT EXISTS qa_feedback (
  56. id SERIAL PRIMARY KEY,
  57. question TEXT NOT NULL,
  58. sql TEXT NOT NULL,
  59. is_thumb_up BOOLEAN NOT NULL,
  60. user_id VARCHAR(64) NOT NULL,
  61. create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  62. is_in_training_data BOOLEAN DEFAULT FALSE,
  63. update_time TIMESTAMP
  64. );
  65. """
  66. # 创建索引SQL
  67. create_indexes_sql = [
  68. "CREATE INDEX IF NOT EXISTS idx_qa_feedback_user_id ON qa_feedback(user_id);",
  69. "CREATE INDEX IF NOT EXISTS idx_qa_feedback_create_time ON qa_feedback(create_time);",
  70. "CREATE INDEX IF NOT EXISTS idx_qa_feedback_is_thumb_up ON qa_feedback(is_thumb_up);",
  71. "CREATE INDEX IF NOT EXISTS idx_qa_feedback_is_in_training ON qa_feedback(is_in_training_data);"
  72. ]
  73. try:
  74. with self.engine.connect() as conn:
  75. with conn.begin():
  76. # 创建表
  77. conn.execute(text(create_table_sql))
  78. # 创建索引
  79. for index_sql in create_indexes_sql:
  80. conn.execute(text(index_sql))
  81. print("[QAFeedbackManager] qa_feedback表检查/创建成功")
  82. except Exception as e:
  83. print(f"[ERROR] qa_feedback表创建失败: {e}")
  84. raise
  85. def add_feedback(self, question: str, sql: str, is_thumb_up: bool, user_id: str = "guest") -> int:
  86. """添加反馈记录
  87. Args:
  88. question: 用户问题
  89. sql: 生成的SQL
  90. is_thumb_up: 是否点赞
  91. user_id: 用户ID
  92. Returns:
  93. 新创建记录的ID
  94. """
  95. insert_sql = """
  96. INSERT INTO qa_feedback (question, sql, is_thumb_up, user_id, create_time)
  97. VALUES (:question, :sql, :is_thumb_up, :user_id, :create_time)
  98. RETURNING id
  99. """
  100. try:
  101. with self.engine.connect() as conn:
  102. with conn.begin():
  103. result = conn.execute(text(insert_sql), {
  104. 'question': question,
  105. 'sql': sql,
  106. 'is_thumb_up': is_thumb_up,
  107. 'user_id': user_id,
  108. 'create_time': datetime.now()
  109. })
  110. feedback_id = result.fetchone()[0]
  111. print(f"[QAFeedbackManager] 反馈记录创建成功, ID: {feedback_id}")
  112. return feedback_id
  113. except Exception as e:
  114. print(f"[ERROR] 添加反馈记录失败: {e}")
  115. raise
  116. def query_feedback(self, page: int = 1, page_size: int = 20,
  117. is_thumb_up: Optional[bool] = None,
  118. create_time_start: Optional[str] = None,
  119. create_time_end: Optional[str] = None,
  120. is_in_training_data: Optional[bool] = None,
  121. sort_by: str = "create_time",
  122. sort_order: str = "desc") -> Tuple[List[Dict], int]:
  123. """查询反馈记录
  124. Args:
  125. page: 页码 (从1开始)
  126. page_size: 每页大小
  127. is_thumb_up: 是否点赞筛选
  128. create_time_start: 创建时间开始
  129. create_time_end: 创建时间结束
  130. is_in_training_data: 是否已加入训练数据
  131. sort_by: 排序字段
  132. sort_order: 排序方向 (asc/desc)
  133. Returns:
  134. (记录列表, 总数)
  135. """
  136. # 构建WHERE条件
  137. where_conditions = []
  138. params = {}
  139. if is_thumb_up is not None:
  140. where_conditions.append("is_thumb_up = :is_thumb_up")
  141. params['is_thumb_up'] = is_thumb_up
  142. if create_time_start:
  143. where_conditions.append("create_time >= :create_time_start")
  144. params['create_time_start'] = create_time_start
  145. if create_time_end:
  146. where_conditions.append("create_time <= :create_time_end")
  147. params['create_time_end'] = create_time_end
  148. if is_in_training_data is not None:
  149. where_conditions.append("is_in_training_data = :is_in_training_data")
  150. params['is_in_training_data'] = is_in_training_data
  151. where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else ""
  152. # 验证排序参数
  153. valid_sort_fields = ['id', 'create_time', 'update_time', 'user_id']
  154. if sort_by not in valid_sort_fields:
  155. sort_by = 'create_time'
  156. if sort_order.lower() not in ['asc', 'desc']:
  157. sort_order = 'desc'
  158. # 计算OFFSET
  159. offset = (page - 1) * page_size
  160. # 查询数据
  161. query_sql = f"""
  162. SELECT id, question, sql, is_thumb_up, user_id, create_time,
  163. is_in_training_data, update_time
  164. FROM qa_feedback
  165. {where_clause}
  166. ORDER BY {sort_by} {sort_order.upper()}
  167. LIMIT :limit OFFSET :offset
  168. """
  169. # 查询总数
  170. count_sql = f"""
  171. SELECT COUNT(*) as total
  172. FROM qa_feedback
  173. {where_clause}
  174. """
  175. try:
  176. with self.engine.connect() as conn:
  177. # 查询数据
  178. params.update({'limit': page_size, 'offset': offset})
  179. result = conn.execute(text(query_sql), params)
  180. records = []
  181. for row in result:
  182. records.append({
  183. 'id': row.id,
  184. 'question': row.question,
  185. 'sql': row.sql,
  186. 'is_thumb_up': row.is_thumb_up,
  187. 'user_id': row.user_id,
  188. 'create_time': row.create_time.isoformat() if row.create_time else None,
  189. 'is_in_training_data': row.is_in_training_data,
  190. 'update_time': row.update_time.isoformat() if row.update_time else None
  191. })
  192. # 查询总数
  193. count_result = conn.execute(text(count_sql), {k: v for k, v in params.items() if k not in ['limit', 'offset']})
  194. total = count_result.fetchone().total
  195. return records, total
  196. except Exception as e:
  197. print(f"[ERROR] 查询反馈记录失败: {e}")
  198. raise
  199. def delete_feedback(self, feedback_id: int) -> bool:
  200. """删除反馈记录
  201. Args:
  202. feedback_id: 反馈记录ID
  203. Returns:
  204. 删除是否成功
  205. """
  206. delete_sql = "DELETE FROM qa_feedback WHERE id = :id"
  207. try:
  208. with self.engine.connect() as conn:
  209. with conn.begin():
  210. result = conn.execute(text(delete_sql), {'id': feedback_id})
  211. if result.rowcount > 0:
  212. print(f"[QAFeedbackManager] 反馈记录删除成功, ID: {feedback_id}")
  213. return True
  214. else:
  215. print(f"[WARNING] 反馈记录不存在, ID: {feedback_id}")
  216. return False
  217. except Exception as e:
  218. print(f"[ERROR] 删除反馈记录失败: {e}")
  219. raise
  220. def update_feedback(self, feedback_id: int, **kwargs) -> bool:
  221. """更新反馈记录
  222. Args:
  223. feedback_id: 反馈记录ID
  224. **kwargs: 要更新的字段
  225. Returns:
  226. 更新是否成功
  227. """
  228. # 允许更新的字段
  229. allowed_fields = ['question', 'sql', 'is_thumb_up', 'user_id', 'is_in_training_data']
  230. update_fields = []
  231. params = {'id': feedback_id, 'update_time': datetime.now()}
  232. for field, value in kwargs.items():
  233. if field in allowed_fields:
  234. update_fields.append(f"{field} = :{field}")
  235. params[field] = value
  236. if not update_fields:
  237. print("[WARNING] 没有有效的更新字段")
  238. return False
  239. update_fields.append("update_time = :update_time")
  240. update_sql = f"""
  241. UPDATE qa_feedback
  242. SET {', '.join(update_fields)}
  243. WHERE id = :id
  244. """
  245. try:
  246. with self.engine.connect() as conn:
  247. with conn.begin():
  248. result = conn.execute(text(update_sql), params)
  249. if result.rowcount > 0:
  250. print(f"[QAFeedbackManager] 反馈记录更新成功, ID: {feedback_id}")
  251. return True
  252. else:
  253. print(f"[WARNING] 反馈记录不存在或无变化, ID: {feedback_id}")
  254. return False
  255. except Exception as e:
  256. print(f"[ERROR] 更新反馈记录失败: {e}")
  257. raise
  258. def get_feedback_by_ids(self, feedback_ids: List[int]) -> List[Dict]:
  259. """根据ID列表获取反馈记录
  260. Args:
  261. feedback_ids: 反馈记录ID列表
  262. Returns:
  263. 反馈记录列表
  264. """
  265. if not feedback_ids:
  266. return []
  267. # 构建IN查询
  268. placeholders = ','.join([f':id_{i}' for i in range(len(feedback_ids))])
  269. params = {f'id_{i}': feedback_id for i, feedback_id in enumerate(feedback_ids)}
  270. query_sql = f"""
  271. SELECT id, question, sql, is_thumb_up, user_id, create_time,
  272. is_in_training_data, update_time
  273. FROM qa_feedback
  274. WHERE id IN ({placeholders})
  275. """
  276. try:
  277. with self.engine.connect() as conn:
  278. result = conn.execute(text(query_sql), params)
  279. records = []
  280. for row in result:
  281. records.append({
  282. 'id': row.id,
  283. 'question': row.question,
  284. 'sql': row.sql,
  285. 'is_thumb_up': row.is_thumb_up,
  286. 'user_id': row.user_id,
  287. 'create_time': row.create_time,
  288. 'is_in_training_data': row.is_in_training_data,
  289. 'update_time': row.update_time
  290. })
  291. return records
  292. except Exception as e:
  293. print(f"[ERROR] 根据ID查询反馈记录失败: {e}")
  294. raise
  295. def mark_training_status(self, feedback_ids: List[int], status: bool = True) -> int:
  296. """批量标记训练状态
  297. Args:
  298. feedback_ids: 反馈记录ID列表
  299. status: 训练状态
  300. Returns:
  301. 更新的记录数
  302. """
  303. if not feedback_ids:
  304. return 0
  305. placeholders = ','.join([f':id_{i}' for i in range(len(feedback_ids))])
  306. params = {f'id_{i}': feedback_id for i, feedback_id in enumerate(feedback_ids)}
  307. params['status'] = status
  308. params['update_time'] = datetime.now()
  309. update_sql = f"""
  310. UPDATE qa_feedback
  311. SET is_in_training_data = :status, update_time = :update_time
  312. WHERE id IN ({placeholders})
  313. """
  314. try:
  315. with self.engine.connect() as conn:
  316. with conn.begin():
  317. result = conn.execute(text(update_sql), params)
  318. print(f"[QAFeedbackManager] 批量更新训练状态成功, 影响行数: {result.rowcount}")
  319. return result.rowcount
  320. except Exception as e:
  321. print(f"[ERROR] 批量更新训练状态失败: {e}")
  322. raise