simple_db_manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. """
  2. Data Pipeline API 简化数据库管理器
  3. 复用现有的pgvector数据库连接机制,提供Data Pipeline任务的数据库操作功能
  4. """
  5. import json
  6. from datetime import datetime
  7. from typing import Dict, Any, List, Optional, Tuple
  8. import psycopg2
  9. from psycopg2.extras import RealDictCursor, Json
  10. from app_config import PGVECTOR_CONFIG
  11. from core.logging import get_data_pipeline_logger
  12. class SimpleTaskManager:
  13. """简化的任务管理器,复用现有pgvector连接"""
  14. def __init__(self):
  15. """初始化任务管理器"""
  16. self.logger = get_data_pipeline_logger("SimpleTaskManager")
  17. self._connection = None
  18. def _get_connection(self):
  19. """获取pgvector数据库连接"""
  20. if self._connection is None or self._connection.closed:
  21. try:
  22. self._connection = psycopg2.connect(
  23. host=PGVECTOR_CONFIG.get('host'),
  24. port=PGVECTOR_CONFIG.get('port'),
  25. database=PGVECTOR_CONFIG.get('dbname'),
  26. user=PGVECTOR_CONFIG.get('user'),
  27. password=PGVECTOR_CONFIG.get('password')
  28. )
  29. self._connection.autocommit = True
  30. except Exception as e:
  31. self.logger.error(f"pgvector数据库连接失败: {e}")
  32. raise
  33. return self._connection
  34. def close_connection(self):
  35. """关闭数据库连接"""
  36. if self._connection and not self._connection.closed:
  37. self._connection.close()
  38. self._connection = None
  39. def generate_task_id(self) -> str:
  40. """生成任务ID,格式: task_YYYYMMDD_HHMMSS"""
  41. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  42. return f"task_{timestamp}"
  43. def create_task(self,
  44. table_list_file: str,
  45. business_context: str,
  46. db_name: str = None,
  47. **kwargs) -> str:
  48. """创建新任务"""
  49. task_id = self.generate_task_id()
  50. # 从 app_config 获取业务数据库连接信息
  51. from app_config import APP_DB_CONFIG
  52. # 构建业务数据库连接字符串(用于参数记录)
  53. business_db_connection = self._build_db_connection_string(APP_DB_CONFIG)
  54. # 使用传入的db_name或从APP_DB_CONFIG提取
  55. if not db_name:
  56. db_name = APP_DB_CONFIG.get('dbname', 'business_db')
  57. # 构建参数
  58. parameters = {
  59. "db_connection": business_db_connection, # 业务数据库连接(用于schema_workflow执行)
  60. "table_list_file": table_list_file,
  61. "business_context": business_context,
  62. **kwargs
  63. }
  64. try:
  65. conn = self._get_connection()
  66. with conn.cursor() as cursor:
  67. cursor.execute("""
  68. INSERT INTO data_pipeline_tasks (
  69. id, task_type, status, parameters, created_by,
  70. db_name, business_context, output_directory
  71. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
  72. """, (
  73. task_id,
  74. 'data_workflow',
  75. 'pending',
  76. Json(parameters),
  77. 'api',
  78. db_name,
  79. business_context,
  80. f"./data_pipeline/training_data/{task_id}"
  81. ))
  82. self.logger.info(f"任务创建成功: {task_id}")
  83. return task_id
  84. except Exception as e:
  85. self.logger.error(f"任务创建失败: {e}")
  86. raise
  87. def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
  88. """获取任务信息"""
  89. try:
  90. conn = self._get_connection()
  91. with conn.cursor(cursor_factory=RealDictCursor) as cursor:
  92. cursor.execute("SELECT * FROM data_pipeline_tasks WHERE id = %s", (task_id,))
  93. result = cursor.fetchone()
  94. return dict(result) if result else None
  95. except Exception as e:
  96. self.logger.error(f"获取任务信息失败: {e}")
  97. raise
  98. def update_task_status(self, task_id: str, status: str, error_message: Optional[str] = None):
  99. """更新任务状态"""
  100. try:
  101. conn = self._get_connection()
  102. with conn.cursor() as cursor:
  103. update_fields = ["status = %s"]
  104. values = [status]
  105. if status == 'in_progress' and not self._get_task_started_at(task_id):
  106. update_fields.append("started_at = CURRENT_TIMESTAMP")
  107. if status in ['completed', 'failed']:
  108. update_fields.append("completed_at = CURRENT_TIMESTAMP")
  109. if error_message:
  110. update_fields.append("error_message = %s")
  111. values.append(error_message)
  112. values.append(task_id)
  113. cursor.execute(f"""
  114. UPDATE data_pipeline_tasks
  115. SET {', '.join(update_fields)}
  116. WHERE id = %s
  117. """, values)
  118. self.logger.info(f"任务状态更新: {task_id} -> {status}")
  119. except Exception as e:
  120. self.logger.error(f"任务状态更新失败: {e}")
  121. raise
  122. def update_step_status(self, task_id: str, step_name: str, step_status: str):
  123. """更新步骤状态"""
  124. try:
  125. conn = self._get_connection()
  126. with conn.cursor() as cursor:
  127. cursor.execute("""
  128. UPDATE data_pipeline_tasks
  129. SET step_status = jsonb_set(step_status, %s, %s)
  130. WHERE id = %s
  131. """, ([step_name], json.dumps(step_status), task_id))
  132. self.logger.debug(f"步骤状态更新: {task_id}.{step_name} -> {step_status}")
  133. except Exception as e:
  134. self.logger.error(f"步骤状态更新失败: {e}")
  135. raise
  136. def create_execution(self, task_id: str, execution_step: str) -> str:
  137. """创建执行记录"""
  138. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  139. execution_id = f"{task_id}_step_{execution_step}_exec_{timestamp}"
  140. try:
  141. conn = self._get_connection()
  142. with conn.cursor() as cursor:
  143. cursor.execute("""
  144. INSERT INTO data_pipeline_task_executions (
  145. task_id, execution_step, status, execution_id
  146. ) VALUES (%s, %s, %s, %s)
  147. """, (task_id, execution_step, 'running', execution_id))
  148. self.logger.info(f"执行记录创建: {execution_id}")
  149. return execution_id
  150. except Exception as e:
  151. self.logger.error(f"执行记录创建失败: {e}")
  152. raise
  153. def complete_execution(self, execution_id: str, status: str, error_message: Optional[str] = None):
  154. """完成执行记录"""
  155. try:
  156. conn = self._get_connection()
  157. with conn.cursor() as cursor:
  158. # 计算执行时长
  159. cursor.execute("""
  160. SELECT started_at FROM data_pipeline_task_executions
  161. WHERE execution_id = %s
  162. """, (execution_id,))
  163. result = cursor.fetchone()
  164. duration_seconds = None
  165. if result and result[0]:
  166. duration_seconds = int((datetime.now() - result[0]).total_seconds())
  167. # 更新执行记录
  168. update_fields = ["status = %s", "completed_at = CURRENT_TIMESTAMP"]
  169. values = [status]
  170. if duration_seconds is not None:
  171. update_fields.append("duration_seconds = %s")
  172. values.append(duration_seconds)
  173. if error_message:
  174. update_fields.append("error_message = %s")
  175. values.append(error_message)
  176. values.append(execution_id)
  177. cursor.execute(f"""
  178. UPDATE data_pipeline_task_executions
  179. SET {', '.join(update_fields)}
  180. WHERE execution_id = %s
  181. """, values)
  182. self.logger.info(f"执行记录完成: {execution_id} -> {status}")
  183. except Exception as e:
  184. self.logger.error(f"执行记录完成失败: {e}")
  185. raise
  186. def record_log(self, task_id: str, log_level: str, message: str,
  187. execution_id: Optional[str] = None, step_name: Optional[str] = None):
  188. """记录日志到数据库"""
  189. try:
  190. conn = self._get_connection()
  191. with conn.cursor() as cursor:
  192. cursor.execute("""
  193. INSERT INTO data_pipeline_task_logs (
  194. task_id, execution_id, log_level, message, step_name
  195. ) VALUES (%s, %s, %s, %s, %s)
  196. """, (task_id, execution_id, log_level, message, step_name))
  197. except Exception as e:
  198. self.logger.error(f"日志记录失败: {e}")
  199. def get_task_logs(self, task_id: str, limit: int = 100) -> List[Dict[str, Any]]:
  200. """获取任务日志"""
  201. try:
  202. conn = self._get_connection()
  203. with conn.cursor(cursor_factory=RealDictCursor) as cursor:
  204. cursor.execute("""
  205. SELECT * FROM data_pipeline_task_logs
  206. WHERE task_id = %s
  207. ORDER BY timestamp DESC
  208. LIMIT %s
  209. """, (task_id, limit))
  210. return [dict(row) for row in cursor.fetchall()]
  211. except Exception as e:
  212. self.logger.error(f"获取任务日志失败: {e}")
  213. raise
  214. def get_task_executions(self, task_id: str) -> List[Dict[str, Any]]:
  215. """获取任务执行记录"""
  216. try:
  217. conn = self._get_connection()
  218. with conn.cursor(cursor_factory=RealDictCursor) as cursor:
  219. cursor.execute("""
  220. SELECT * FROM data_pipeline_task_executions
  221. WHERE task_id = %s
  222. ORDER BY started_at DESC
  223. """, (task_id,))
  224. return [dict(row) for row in cursor.fetchall()]
  225. except Exception as e:
  226. self.logger.error(f"获取执行记录失败: {e}")
  227. raise
  228. def get_tasks_list(self, limit: int = 50, offset: int = 0, status_filter: Optional[str] = None) -> List[Dict[str, Any]]:
  229. """获取任务列表"""
  230. try:
  231. conn = self._get_connection()
  232. with conn.cursor(cursor_factory=RealDictCursor) as cursor:
  233. where_clause = ""
  234. params = []
  235. if status_filter:
  236. where_clause = "WHERE status = %s"
  237. params.append(status_filter)
  238. params.extend([limit, offset])
  239. cursor.execute(f"""
  240. SELECT * FROM data_pipeline_tasks
  241. {where_clause}
  242. ORDER BY created_at DESC
  243. LIMIT %s OFFSET %s
  244. """, params)
  245. return [dict(row) for row in cursor.fetchall()]
  246. except Exception as e:
  247. self.logger.error(f"获取任务列表失败: {e}")
  248. raise
  249. def _get_task_started_at(self, task_id: str) -> Optional[datetime]:
  250. """获取任务开始时间"""
  251. try:
  252. conn = self._get_connection()
  253. with conn.cursor() as cursor:
  254. cursor.execute("SELECT started_at FROM data_pipeline_tasks WHERE id = %s", (task_id,))
  255. result = cursor.fetchone()
  256. return result[0] if result and result[0] else None
  257. except Exception:
  258. return None
  259. def _build_db_connection_string(self, db_config: dict) -> str:
  260. """构建数据库连接字符串"""
  261. try:
  262. host = db_config.get('host', 'localhost')
  263. port = db_config.get('port', 5432)
  264. dbname = db_config.get('dbname', 'database')
  265. user = db_config.get('user', 'postgres')
  266. password = db_config.get('password', '')
  267. return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
  268. except Exception:
  269. return "postgresql://localhost:5432/database"
  270. def _extract_db_name(self, connection_string: str) -> str:
  271. """从连接字符串提取数据库名称"""
  272. try:
  273. if '/' in connection_string:
  274. db_name = connection_string.split('/')[-1]
  275. if '?' in db_name:
  276. db_name = db_name.split('?')[0]
  277. return db_name if db_name else "database"
  278. else:
  279. return "database"
  280. except Exception:
  281. return "database"