simple_db_manager.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. """
  2. Data Pipeline API 简化数据库管理器
  3. 复用现有的pgvector数据库连接机制,提供Data Pipeline任务的数据库操作功能
  4. """
  5. import json
  6. from datetime import datetime
  7. from typing import Dict, Any, List, Optional, Tuple
  8. import psycopg2
  9. from psycopg2.extras import RealDictCursor, Json
  10. from app_config import PGVECTOR_CONFIG
  11. import logging
  12. class SimpleTaskManager:
  13. """简化的任务管理器,复用现有pgvector连接"""
  14. def __init__(self):
  15. """初始化任务管理器"""
  16. # 使用简单的控制台日志,不使用文件日志
  17. self.logger = logging.getLogger("SimpleTaskManager")
  18. self.logger.setLevel(logging.INFO)
  19. self._connection = None
  20. def _get_connection(self):
  21. """获取pgvector数据库连接"""
  22. if self._connection is None or self._connection.closed:
  23. try:
  24. self._connection = psycopg2.connect(
  25. host=PGVECTOR_CONFIG.get('host'),
  26. port=PGVECTOR_CONFIG.get('port'),
  27. database=PGVECTOR_CONFIG.get('dbname'),
  28. user=PGVECTOR_CONFIG.get('user'),
  29. password=PGVECTOR_CONFIG.get('password')
  30. )
  31. self._connection.autocommit = True
  32. except Exception as e:
  33. self.logger.error(f"pgvector数据库连接失败: {e}")
  34. raise
  35. return self._connection
  36. def close_connection(self):
  37. """关闭数据库连接"""
  38. if self._connection and not self._connection.closed:
  39. self._connection.close()
  40. self._connection = None
  41. def generate_task_id(self) -> str:
  42. """生成任务ID,格式: task_YYYYMMDD_HHMMSS"""
  43. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  44. return f"task_{timestamp}"
  45. def create_task(self,
  46. table_list_file: str = None,
  47. business_context: str = None,
  48. db_name: str = None,
  49. **kwargs) -> str:
  50. """创建新任务"""
  51. task_id = self.generate_task_id()
  52. # 从 app_config 获取业务数据库连接信息
  53. from app_config import APP_DB_CONFIG
  54. # 构建业务数据库连接字符串(用于参数记录)
  55. business_db_connection = self._build_db_connection_string(APP_DB_CONFIG)
  56. # 使用传入的db_name或从APP_DB_CONFIG提取
  57. if not db_name:
  58. db_name = APP_DB_CONFIG.get('dbname', 'business_db')
  59. # 处理table_list_file参数
  60. # 如果未提供,将在执行时检查任务目录中的table_list.txt文件
  61. task_table_list_file = table_list_file
  62. if not task_table_list_file:
  63. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  64. upload_config = SCHEMA_TOOLS_CONFIG.get("file_upload", {})
  65. target_filename = upload_config.get("target_filename", "table_list.txt")
  66. # 使用相对于任务目录的路径
  67. task_table_list_file = f"{{task_directory}}/{target_filename}"
  68. # 构建参数
  69. parameters = {
  70. "db_connection": business_db_connection, # 业务数据库连接(用于schema_workflow执行)
  71. "table_list_file": task_table_list_file,
  72. "business_context": business_context or "数据库管理系统",
  73. "file_upload_mode": table_list_file is None, # 标记是否使用文件上传模式
  74. **kwargs
  75. }
  76. try:
  77. conn = self._get_connection()
  78. with conn.cursor() as cursor:
  79. # 创建任务记录
  80. cursor.execute("""
  81. INSERT INTO data_pipeline_tasks (
  82. task_id, task_type, status, parameters, created_type,
  83. by_user, db_name, output_directory
  84. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
  85. """, (
  86. task_id,
  87. 'data_workflow',
  88. 'pending',
  89. Json(parameters),
  90. 'api',
  91. 'guest',
  92. db_name,
  93. f"data_pipeline/training_data/{task_id}"
  94. ))
  95. # 预创建所有步骤记录(策略A)
  96. step_names = ['ddl_generation', 'qa_generation', 'sql_validation', 'training_load']
  97. for step_name in step_names:
  98. cursor.execute("""
  99. INSERT INTO data_pipeline_task_steps (
  100. task_id, step_name, step_status
  101. ) VALUES (%s, %s, %s)
  102. """, (task_id, step_name, 'pending'))
  103. # 创建任务目录
  104. try:
  105. from data_pipeline.api.simple_file_manager import SimpleFileManager
  106. file_manager = SimpleFileManager()
  107. success = file_manager.create_task_directory(task_id)
  108. if success:
  109. self.logger.info(f"任务目录创建成功: {task_id}")
  110. else:
  111. self.logger.warning(f"任务目录创建失败,但任务记录已保存: {task_id}")
  112. except Exception as dir_error:
  113. self.logger.warning(f"创建任务目录时出错: {dir_error},但任务记录已保存: {task_id}")
  114. self.logger.info(f"任务创建成功: {task_id}")
  115. return task_id
  116. except Exception as e:
  117. self.logger.error(f"任务创建失败: {e}")
  118. raise
  119. def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
  120. """获取任务信息"""
  121. try:
  122. conn = self._get_connection()
  123. with conn.cursor(cursor_factory=RealDictCursor) as cursor:
  124. cursor.execute("SELECT * FROM data_pipeline_tasks WHERE task_id = %s", (task_id,))
  125. result = cursor.fetchone()
  126. return dict(result) if result else None
  127. except Exception as e:
  128. self.logger.error(f"获取任务信息失败: {e}")
  129. raise
  130. def update_task_status(self, task_id: str, status: str, error_message: Optional[str] = None):
  131. """更新任务状态"""
  132. try:
  133. conn = self._get_connection()
  134. with conn.cursor() as cursor:
  135. update_fields = ["status = %s"]
  136. values = [status]
  137. if status == 'in_progress' and not self._get_task_started_at(task_id):
  138. update_fields.append("started_at = CURRENT_TIMESTAMP")
  139. if status in ['completed', 'failed']:
  140. update_fields.append("completed_at = CURRENT_TIMESTAMP")
  141. if error_message:
  142. update_fields.append("error_message = %s")
  143. values.append(error_message)
  144. values.append(task_id)
  145. cursor.execute(f"""
  146. UPDATE data_pipeline_tasks
  147. SET {', '.join(update_fields)}
  148. WHERE task_id = %s
  149. """, values)
  150. self.logger.info(f"任务状态更新: {task_id} -> {status}")
  151. except Exception as e:
  152. self.logger.error(f"任务状态更新失败: {e}")
  153. raise
  154. def update_step_status(self, task_id: str, step_name: str, step_status: str, error_message: Optional[str] = None):
  155. """更新步骤状态"""
  156. try:
  157. conn = self._get_connection()
  158. with conn.cursor() as cursor:
  159. update_fields = ["step_status = %s"]
  160. values = [step_status]
  161. # 如果状态是running,记录开始时间
  162. if step_status == 'running':
  163. update_fields.append("started_at = CURRENT_TIMESTAMP")
  164. # 如果状态是completed或failed,记录完成时间
  165. if step_status in ['completed', 'failed']:
  166. update_fields.append("completed_at = CURRENT_TIMESTAMP")
  167. # 如果有错误信息,记录错误信息
  168. if error_message:
  169. update_fields.append("error_message = %s")
  170. values.append(error_message)
  171. values.extend([task_id, step_name])
  172. cursor.execute(f"""
  173. UPDATE data_pipeline_task_steps
  174. SET {', '.join(update_fields)}
  175. WHERE task_id = %s AND step_name = %s
  176. """, values)
  177. self.logger.debug(f"步骤状态更新: {task_id}.{step_name} -> {step_status}")
  178. except Exception as e:
  179. self.logger.error(f"步骤状态更新失败: {e}")
  180. raise
  181. def update_step_execution_id(self, task_id: str, step_name: str, execution_id: str):
  182. """更新步骤的execution_id"""
  183. try:
  184. conn = self._get_connection()
  185. with conn.cursor() as cursor:
  186. cursor.execute("""
  187. UPDATE data_pipeline_task_steps
  188. SET execution_id = %s
  189. WHERE task_id = %s AND step_name = %s
  190. """, (execution_id, task_id, step_name))
  191. self.logger.debug(f"步骤execution_id更新: {task_id}.{step_name} -> {execution_id}")
  192. except Exception as e:
  193. self.logger.error(f"步骤execution_id更新失败: {e}")
  194. raise
  195. def start_step(self, task_id: str, step_name: str) -> str:
  196. """开始执行步骤"""
  197. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  198. execution_id = f"{task_id}_step_{step_name}_exec_{timestamp}"
  199. try:
  200. # 更新步骤状态为running并设置execution_id
  201. self.update_step_status(task_id, step_name, 'running')
  202. self.update_step_execution_id(task_id, step_name, execution_id)
  203. self.logger.info(f"步骤开始执行: {task_id}.{step_name} -> {execution_id}")
  204. return execution_id
  205. except Exception as e:
  206. self.logger.error(f"步骤开始执行失败: {e}")
  207. raise
  208. def complete_step(self, task_id: str, step_name: str, status: str, error_message: Optional[str] = None):
  209. """完成步骤执行"""
  210. try:
  211. self.update_step_status(task_id, step_name, status, error_message)
  212. self.logger.info(f"步骤执行完成: {task_id}.{step_name} -> {status}")
  213. except Exception as e:
  214. self.logger.error(f"步骤执行完成失败: {e}")
  215. raise
  216. def get_task_steps(self, task_id: str) -> List[Dict[str, Any]]:
  217. """获取任务的所有步骤状态"""
  218. try:
  219. conn = self._get_connection()
  220. with conn.cursor(cursor_factory=RealDictCursor) as cursor:
  221. cursor.execute("""
  222. SELECT * FROM data_pipeline_task_steps
  223. WHERE task_id = %s
  224. ORDER BY
  225. CASE step_name
  226. WHEN 'ddl_generation' THEN 1
  227. WHEN 'qa_generation' THEN 2
  228. WHEN 'sql_validation' THEN 3
  229. WHEN 'training_load' THEN 4
  230. ELSE 5
  231. END
  232. """, (task_id,))
  233. return [dict(row) for row in cursor.fetchall()]
  234. except Exception as e:
  235. self.logger.error(f"获取任务步骤状态失败: {e}")
  236. raise
  237. def get_step_status(self, task_id: str, step_name: str) -> Optional[Dict[str, Any]]:
  238. """获取特定步骤的状态"""
  239. try:
  240. conn = self._get_connection()
  241. with conn.cursor(cursor_factory=RealDictCursor) as cursor:
  242. cursor.execute("""
  243. SELECT * FROM data_pipeline_task_steps
  244. WHERE task_id = %s AND step_name = %s
  245. """, (task_id, step_name))
  246. result = cursor.fetchone()
  247. return dict(result) if result else None
  248. except Exception as e:
  249. self.logger.error(f"获取步骤状态失败: {e}")
  250. raise
  251. def get_tasks_list(self, limit: int = 50, offset: int = 0, status_filter: Optional[str] = None) -> List[Dict[str, Any]]:
  252. """获取任务列表"""
  253. try:
  254. conn = self._get_connection()
  255. with conn.cursor(cursor_factory=RealDictCursor) as cursor:
  256. where_clause = ""
  257. params = []
  258. if status_filter:
  259. where_clause = "WHERE status = %s"
  260. params.append(status_filter)
  261. params.extend([limit, offset])
  262. cursor.execute(f"""
  263. SELECT * FROM data_pipeline_tasks
  264. {where_clause}
  265. ORDER BY created_at DESC
  266. LIMIT %s OFFSET %s
  267. """, params)
  268. return [dict(row) for row in cursor.fetchall()]
  269. except Exception as e:
  270. self.logger.error(f"获取任务列表失败: {e}")
  271. raise
  272. def _get_task_started_at(self, task_id: str) -> Optional[datetime]:
  273. """获取任务开始时间"""
  274. try:
  275. conn = self._get_connection()
  276. with conn.cursor() as cursor:
  277. cursor.execute("SELECT started_at FROM data_pipeline_tasks WHERE task_id = %s", (task_id,))
  278. result = cursor.fetchone()
  279. return result[0] if result and result[0] else None
  280. except Exception:
  281. return None
  282. def _build_db_connection_string(self, db_config: dict) -> str:
  283. """构建数据库连接字符串"""
  284. try:
  285. host = db_config.get('host', 'localhost')
  286. port = db_config.get('port', 5432)
  287. dbname = db_config.get('dbname', 'database')
  288. user = db_config.get('user', 'postgres')
  289. password = db_config.get('password', '')
  290. return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
  291. except Exception:
  292. return "postgresql://localhost:5432/database"
  293. def _extract_db_name(self, connection_string: str) -> str:
  294. """从连接字符串提取数据库名称"""
  295. try:
  296. if '/' in connection_string:
  297. db_name = connection_string.split('/')[-1]
  298. if '?' in db_name:
  299. db_name = db_name.split('?')[0]
  300. return db_name if db_name else "database"
  301. else:
  302. return "database"
  303. except Exception:
  304. return "database"