qa_feedback_manager.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. """
  2. QA反馈数据管理器 - 复用Vanna连接版本
  3. 用于处理用户对问答结果的点赞/点踩反馈,并将反馈转化为训练数据
  4. """
  5. import app_config
  6. from sqlalchemy import create_engine, text, MetaData, Table, Column, Integer, String, Boolean, DateTime, func
  7. from sqlalchemy.exc import OperationalError, ProgrammingError
  8. from datetime import datetime
  9. from typing import List, Dict, Any, Optional, Tuple
  10. import logging
  11. from core.logging import get_app_logger
  12. class QAFeedbackManager:
  13. """QA反馈数据管理器 - 复用Vanna连接版本"""
  14. def __init__(self, vanna_instance=None):
  15. """初始化数据库连接
  16. Args:
  17. vanna_instance: 可选的vanna实例,用于复用其数据库连接
  18. """
  19. # 初始化日志
  20. self.logger = get_app_logger("QAFeedbackManager")
  21. self.engine = None
  22. self.vanna_instance = vanna_instance
  23. self._init_database_connection()
  24. self._ensure_table_exists()
  25. def _init_database_connection(self):
  26. """初始化数据库连接"""
  27. try:
  28. # 方案1: 优先尝试复用vanna连接
  29. if self.vanna_instance and hasattr(self.vanna_instance, 'engine'):
  30. self.engine = self.vanna_instance.engine
  31. self.logger.info("复用Vanna数据库连接")
  32. return
  33. # 方案2: 创建新的连接(原有方式)
  34. db_config = app_config.APP_DB_CONFIG
  35. connection_string = (
  36. f"postgresql://{db_config['user']}:{db_config['password']}"
  37. f"@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"
  38. )
  39. # 使用连接池配置
  40. self.engine = create_engine(
  41. connection_string,
  42. echo=False,
  43. pool_size=5, # 连接池大小
  44. max_overflow=10, # 最大溢出连接数
  45. pool_timeout=30, # 获取连接超时
  46. pool_recycle=3600 # 连接回收时间(1小时)
  47. )
  48. # 测试连接
  49. with self.engine.connect() as conn:
  50. conn.execute(text("SELECT 1"))
  51. self.logger.info(f"数据库连接成功: {db_config['host']}:{db_config['port']}/{db_config['dbname']}")
  52. except Exception as e:
  53. self.logger.error(f"QAFeedbackManager数据库连接失败: {e}")
  54. raise
  55. def _ensure_table_exists(self):
  56. """检查并创建qa_feedback表"""
  57. create_table_sql = """
  58. CREATE TABLE IF NOT EXISTS qa_feedback (
  59. id SERIAL PRIMARY KEY,
  60. question TEXT NOT NULL,
  61. sql TEXT NOT NULL,
  62. is_thumb_up BOOLEAN NOT NULL,
  63. user_id VARCHAR(64) NOT NULL,
  64. create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  65. is_in_training_data BOOLEAN DEFAULT FALSE,
  66. update_time TIMESTAMP
  67. );
  68. """
  69. # 创建索引SQL
  70. create_indexes_sql = [
  71. "CREATE INDEX IF NOT EXISTS idx_qa_feedback_user_id ON qa_feedback(user_id);",
  72. "CREATE INDEX IF NOT EXISTS idx_qa_feedback_create_time ON qa_feedback(create_time);",
  73. "CREATE INDEX IF NOT EXISTS idx_qa_feedback_is_thumb_up ON qa_feedback(is_thumb_up);",
  74. "CREATE INDEX IF NOT EXISTS idx_qa_feedback_is_in_training ON qa_feedback(is_in_training_data);"
  75. ]
  76. try:
  77. with self.engine.connect() as conn:
  78. with conn.begin():
  79. # 创建表
  80. conn.execute(text(create_table_sql))
  81. # 创建索引
  82. for index_sql in create_indexes_sql:
  83. conn.execute(text(index_sql))
  84. self.logger.info("qa_feedback表检查/创建成功")
  85. except Exception as e:
  86. self.logger.error(f"qa_feedback表创建失败: {e}")
  87. raise
  88. def add_feedback(self, question: str, sql: str, is_thumb_up: bool, user_id: str = "guest") -> int:
  89. """添加反馈记录
  90. Args:
  91. question: 用户问题
  92. sql: 生成的SQL
  93. is_thumb_up: 是否点赞
  94. user_id: 用户ID
  95. Returns:
  96. 新创建记录的ID
  97. """
  98. insert_sql = """
  99. INSERT INTO qa_feedback (question, sql, is_thumb_up, user_id, create_time)
  100. VALUES (:question, :sql, :is_thumb_up, :user_id, :create_time)
  101. RETURNING id
  102. """
  103. try:
  104. with self.engine.connect() as conn:
  105. with conn.begin():
  106. result = conn.execute(text(insert_sql), {
  107. 'question': question,
  108. 'sql': sql,
  109. 'is_thumb_up': is_thumb_up,
  110. 'user_id': user_id,
  111. 'create_time': datetime.now()
  112. })
  113. feedback_id = result.fetchone()[0]
  114. self.logger.info(f"反馈记录创建成功, ID: {feedback_id}")
  115. return feedback_id
  116. except Exception as e:
  117. self.logger.error(f"添加反馈记录失败: {e}")
  118. raise
  119. def query_feedback(self, page: int = 1, page_size: int = 20,
  120. is_thumb_up: Optional[bool] = None,
  121. create_time_start: Optional[str] = None,
  122. create_time_end: Optional[str] = None,
  123. is_in_training_data: Optional[bool] = None,
  124. sort_by: str = "create_time",
  125. sort_order: str = "desc") -> Tuple[List[Dict], int]:
  126. """查询反馈记录
  127. Args:
  128. page: 页码 (从1开始)
  129. page_size: 每页大小
  130. is_thumb_up: 是否点赞筛选
  131. create_time_start: 创建时间开始
  132. create_time_end: 创建时间结束
  133. is_in_training_data: 是否已加入训练数据
  134. sort_by: 排序字段
  135. sort_order: 排序方向 (asc/desc)
  136. Returns:
  137. (记录列表, 总数)
  138. """
  139. # 构建WHERE条件
  140. where_conditions = []
  141. params = {}
  142. if is_thumb_up is not None:
  143. where_conditions.append("is_thumb_up = :is_thumb_up")
  144. params['is_thumb_up'] = is_thumb_up
  145. if create_time_start:
  146. where_conditions.append("create_time >= :create_time_start")
  147. params['create_time_start'] = create_time_start
  148. if create_time_end:
  149. where_conditions.append("create_time <= :create_time_end")
  150. params['create_time_end'] = create_time_end
  151. if is_in_training_data is not None:
  152. where_conditions.append("is_in_training_data = :is_in_training_data")
  153. params['is_in_training_data'] = is_in_training_data
  154. where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else ""
  155. # 验证排序参数
  156. valid_sort_fields = ['id', 'create_time', 'update_time', 'user_id']
  157. if sort_by not in valid_sort_fields:
  158. sort_by = 'create_time'
  159. if sort_order.lower() not in ['asc', 'desc']:
  160. sort_order = 'desc'
  161. # 计算OFFSET
  162. offset = (page - 1) * page_size
  163. # 查询数据
  164. query_sql = f"""
  165. SELECT id, question, sql, is_thumb_up, user_id, create_time,
  166. is_in_training_data, update_time
  167. FROM qa_feedback
  168. {where_clause}
  169. ORDER BY {sort_by} {sort_order.upper()}
  170. LIMIT :limit OFFSET :offset
  171. """
  172. # 查询总数
  173. count_sql = f"""
  174. SELECT COUNT(*) as total
  175. FROM qa_feedback
  176. {where_clause}
  177. """
  178. try:
  179. with self.engine.connect() as conn:
  180. # 查询数据
  181. params.update({'limit': page_size, 'offset': offset})
  182. result = conn.execute(text(query_sql), params)
  183. records = []
  184. for row in result:
  185. records.append({
  186. 'id': row.id,
  187. 'question': row.question,
  188. 'sql': row.sql,
  189. 'is_thumb_up': row.is_thumb_up,
  190. 'user_id': row.user_id,
  191. 'create_time': row.create_time.isoformat() if row.create_time else None,
  192. 'is_in_training_data': row.is_in_training_data,
  193. 'update_time': row.update_time.isoformat() if row.update_time else None
  194. })
  195. # 查询总数
  196. count_result = conn.execute(text(count_sql), {k: v for k, v in params.items() if k not in ['limit', 'offset']})
  197. total = count_result.fetchone().total
  198. return records, total
  199. except Exception as e:
  200. self.logger.error(f"查询反馈记录失败: {e}")
  201. raise
  202. def delete_feedback(self, feedback_id: int) -> bool:
  203. """删除反馈记录
  204. Args:
  205. feedback_id: 反馈记录ID
  206. Returns:
  207. 删除是否成功
  208. """
  209. delete_sql = "DELETE FROM qa_feedback WHERE id = :id"
  210. try:
  211. with self.engine.connect() as conn:
  212. with conn.begin():
  213. result = conn.execute(text(delete_sql), {'id': feedback_id})
  214. if result.rowcount > 0:
  215. self.logger.info(f"反馈记录删除成功, ID: {feedback_id}")
  216. return True
  217. else:
  218. self.logger.warning(f"反馈记录不存在, ID: {feedback_id}")
  219. return False
  220. except Exception as e:
  221. self.logger.error(f"删除反馈记录失败: {e}")
  222. raise
  223. def update_feedback(self, feedback_id: int, **kwargs) -> bool:
  224. """更新反馈记录
  225. Args:
  226. feedback_id: 反馈记录ID
  227. **kwargs: 要更新的字段
  228. Returns:
  229. 更新是否成功
  230. """
  231. # 允许更新的字段
  232. allowed_fields = ['question', 'sql', 'is_thumb_up', 'user_id', 'is_in_training_data']
  233. update_fields = []
  234. params = {'id': feedback_id, 'update_time': datetime.now()}
  235. for field, value in kwargs.items():
  236. if field in allowed_fields:
  237. update_fields.append(f"{field} = :{field}")
  238. params[field] = value
  239. if not update_fields:
  240. self.logger.warning("没有有效的更新字段")
  241. return False
  242. update_fields.append("update_time = :update_time")
  243. update_sql = f"""
  244. UPDATE qa_feedback
  245. SET {', '.join(update_fields)}
  246. WHERE id = :id
  247. """
  248. try:
  249. with self.engine.connect() as conn:
  250. with conn.begin():
  251. result = conn.execute(text(update_sql), params)
  252. if result.rowcount > 0:
  253. self.logger.info(f"反馈记录更新成功, ID: {feedback_id}")
  254. return True
  255. else:
  256. self.logger.warning(f"反馈记录不存在或无变化, ID: {feedback_id}")
  257. return False
  258. except Exception as e:
  259. self.logger.error(f"更新反馈记录失败: {e}")
  260. raise
  261. def get_feedback_by_ids(self, feedback_ids: List[int]) -> List[Dict]:
  262. """根据ID列表获取反馈记录
  263. Args:
  264. feedback_ids: 反馈记录ID列表
  265. Returns:
  266. 反馈记录列表
  267. """
  268. if not feedback_ids:
  269. return []
  270. # 构建IN查询
  271. placeholders = ','.join([f':id_{i}' for i in range(len(feedback_ids))])
  272. params = {f'id_{i}': feedback_id for i, feedback_id in enumerate(feedback_ids)}
  273. query_sql = f"""
  274. SELECT id, question, sql, is_thumb_up, user_id, create_time,
  275. is_in_training_data, update_time
  276. FROM qa_feedback
  277. WHERE id IN ({placeholders})
  278. """
  279. try:
  280. with self.engine.connect() as conn:
  281. result = conn.execute(text(query_sql), params)
  282. records = []
  283. for row in result:
  284. records.append({
  285. 'id': row.id,
  286. 'question': row.question,
  287. 'sql': row.sql,
  288. 'is_thumb_up': row.is_thumb_up,
  289. 'user_id': row.user_id,
  290. 'create_time': row.create_time,
  291. 'is_in_training_data': row.is_in_training_data,
  292. 'update_time': row.update_time
  293. })
  294. return records
  295. except Exception as e:
  296. self.logger.error(f"根据ID查询反馈记录失败: {e}")
  297. raise
  298. def mark_training_status(self, feedback_ids: List[int], status: bool = True) -> int:
  299. """批量标记训练状态
  300. Args:
  301. feedback_ids: 反馈记录ID列表
  302. status: 训练状态
  303. Returns:
  304. 更新的记录数
  305. """
  306. if not feedback_ids:
  307. return 0
  308. placeholders = ','.join([f':id_{i}' for i in range(len(feedback_ids))])
  309. params = {f'id_{i}': feedback_id for i, feedback_id in enumerate(feedback_ids)}
  310. params['status'] = status
  311. params['update_time'] = datetime.now()
  312. update_sql = f"""
  313. UPDATE qa_feedback
  314. SET is_in_training_data = :status, update_time = :update_time
  315. WHERE id IN ({placeholders})
  316. """
  317. try:
  318. with self.engine.connect() as conn:
  319. with conn.begin():
  320. result = conn.execute(text(update_sql), params)
  321. self.logger.info(f"批量更新训练状态成功, 影响行数: {result.rowcount}")
  322. return result.rowcount
  323. except Exception as e:
  324. self.logger.error(f"批量更新训练状态失败: {e}")
  325. raise