123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- """
- Data Pipeline API 简化数据库管理器
- 复用现有的pgvector数据库连接机制,提供Data Pipeline任务的数据库操作功能
- """
- import json
- from datetime import datetime
- from typing import Dict, Any, List, Optional, Tuple
- import psycopg2
- from psycopg2.extras import RealDictCursor, Json
- from app_config import PGVECTOR_CONFIG
- import logging
- class SimpleTaskManager:
- """简化的任务管理器,复用现有pgvector连接"""
-
- def __init__(self):
- """初始化任务管理器"""
- # 使用简单的控制台日志,不使用文件日志
- self.logger = logging.getLogger("SimpleTaskManager")
- self.logger.setLevel(logging.INFO)
- self._connection = None
-
- def _get_connection(self):
- """获取pgvector数据库连接"""
- if self._connection is None or self._connection.closed:
- try:
- self._connection = psycopg2.connect(
- host=PGVECTOR_CONFIG.get('host'),
- port=PGVECTOR_CONFIG.get('port'),
- database=PGVECTOR_CONFIG.get('dbname'),
- user=PGVECTOR_CONFIG.get('user'),
- password=PGVECTOR_CONFIG.get('password')
- )
- self._connection.autocommit = True
- except Exception as e:
- self.logger.error(f"pgvector数据库连接失败: {e}")
- raise
- return self._connection
-
- def close_connection(self):
- """关闭数据库连接"""
- if self._connection and not self._connection.closed:
- self._connection.close()
- self._connection = None
-
- def generate_task_id(self) -> str:
- """生成任务ID,格式: task_YYYYMMDD_HHMMSS"""
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- return f"task_{timestamp}"
-
- def create_task(self,
- table_list_file: str = None,
- business_context: str = None,
- db_name: str = None,
- **kwargs) -> str:
- """创建新任务"""
- task_id = self.generate_task_id()
-
- # 从 app_config 获取业务数据库连接信息
- from app_config import APP_DB_CONFIG
-
- # 构建业务数据库连接字符串(用于参数记录)
- business_db_connection = self._build_db_connection_string(APP_DB_CONFIG)
-
- # 使用传入的db_name或从APP_DB_CONFIG提取
- if not db_name:
- db_name = APP_DB_CONFIG.get('dbname', 'business_db')
-
- # 处理table_list_file参数
- # 如果未提供,将在执行时检查任务目录中的table_list.txt文件
- task_table_list_file = table_list_file
- if not task_table_list_file:
- from data_pipeline.config import SCHEMA_TOOLS_CONFIG
- upload_config = SCHEMA_TOOLS_CONFIG.get("file_upload", {})
- target_filename = upload_config.get("target_filename", "table_list.txt")
- # 使用相对于任务目录的路径
- task_table_list_file = f"{{task_directory}}/{target_filename}"
-
- # 构建参数
- parameters = {
- "db_connection": business_db_connection, # 业务数据库连接(用于schema_workflow执行)
- "table_list_file": task_table_list_file,
- "business_context": business_context or "数据库管理系统",
- "file_upload_mode": table_list_file is None, # 标记是否使用文件上传模式
- **kwargs
- }
-
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- # 创建任务记录
- cursor.execute("""
- INSERT INTO data_pipeline_tasks (
- task_id, task_type, status, parameters, created_type,
- by_user, db_name, output_directory
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
- """, (
- task_id,
- 'data_workflow',
- 'pending',
- Json(parameters),
- 'api',
- 'guest',
- db_name,
- f"data_pipeline/training_data/{task_id}"
- ))
-
- # 预创建所有步骤记录(策略A)
- step_names = ['ddl_generation', 'qa_generation', 'sql_validation', 'training_load']
- for step_name in step_names:
- cursor.execute("""
- INSERT INTO data_pipeline_task_steps (
- task_id, step_name, step_status
- ) VALUES (%s, %s, %s)
- """, (task_id, step_name, 'pending'))
-
- # 创建任务目录
- try:
- from data_pipeline.api.simple_file_manager import SimpleFileManager
- file_manager = SimpleFileManager()
- success = file_manager.create_task_directory(task_id)
- if success:
- self.logger.info(f"任务目录创建成功: {task_id}")
- else:
- self.logger.warning(f"任务目录创建失败,但任务记录已保存: {task_id}")
- except Exception as dir_error:
- self.logger.warning(f"创建任务目录时出错: {dir_error},但任务记录已保存: {task_id}")
-
- self.logger.info(f"任务创建成功: {task_id}")
- return task_id
-
- except Exception as e:
- self.logger.error(f"任务创建失败: {e}")
- raise
-
- def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
- """获取任务信息"""
- try:
- conn = self._get_connection()
- with conn.cursor(cursor_factory=RealDictCursor) as cursor:
- cursor.execute("SELECT * FROM data_pipeline_tasks WHERE task_id = %s", (task_id,))
- result = cursor.fetchone()
- return dict(result) if result else None
- except Exception as e:
- self.logger.error(f"获取任务信息失败: {e}")
- raise
-
- def update_task_status(self, task_id: str, status: str, error_message: Optional[str] = None):
- """更新任务状态"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- update_fields = ["status = %s"]
- values = [status]
-
- if status == 'in_progress' and not self._get_task_started_at(task_id):
- update_fields.append("started_at = CURRENT_TIMESTAMP")
-
- if status in ['completed', 'failed']:
- update_fields.append("completed_at = CURRENT_TIMESTAMP")
-
- if error_message:
- update_fields.append("error_message = %s")
- values.append(error_message)
-
- values.append(task_id)
-
- cursor.execute(f"""
- UPDATE data_pipeline_tasks
- SET {', '.join(update_fields)}
- WHERE task_id = %s
- """, values)
-
- self.logger.info(f"任务状态更新: {task_id} -> {status}")
- except Exception as e:
- self.logger.error(f"任务状态更新失败: {e}")
- raise
-
- def update_step_status(self, task_id: str, step_name: str, step_status: str, error_message: Optional[str] = None):
- """更新步骤状态"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- update_fields = ["step_status = %s"]
- values = [step_status]
-
- # 如果状态是running,记录开始时间
- if step_status == 'running':
- update_fields.append("started_at = CURRENT_TIMESTAMP")
-
- # 如果状态是completed或failed,记录完成时间
- if step_status in ['completed', 'failed']:
- update_fields.append("completed_at = CURRENT_TIMESTAMP")
-
- # 如果有错误信息,记录错误信息
- if error_message:
- update_fields.append("error_message = %s")
- values.append(error_message)
-
- values.extend([task_id, step_name])
-
- cursor.execute(f"""
- UPDATE data_pipeline_task_steps
- SET {', '.join(update_fields)}
- WHERE task_id = %s AND step_name = %s
- """, values)
-
- self.logger.debug(f"步骤状态更新: {task_id}.{step_name} -> {step_status}")
- except Exception as e:
- self.logger.error(f"步骤状态更新失败: {e}")
- raise
-
- def update_step_execution_id(self, task_id: str, step_name: str, execution_id: str):
- """更新步骤的execution_id"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- cursor.execute("""
- UPDATE data_pipeline_task_steps
- SET execution_id = %s
- WHERE task_id = %s AND step_name = %s
- """, (execution_id, task_id, step_name))
-
- self.logger.debug(f"步骤execution_id更新: {task_id}.{step_name} -> {execution_id}")
- except Exception as e:
- self.logger.error(f"步骤execution_id更新失败: {e}")
- raise
-
- def start_step(self, task_id: str, step_name: str) -> str:
- """开始执行步骤"""
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- execution_id = f"{task_id}_step_{step_name}_exec_{timestamp}"
-
- try:
- # 更新步骤状态为running并设置execution_id
- self.update_step_status(task_id, step_name, 'running')
- self.update_step_execution_id(task_id, step_name, execution_id)
-
- self.logger.info(f"步骤开始执行: {task_id}.{step_name} -> {execution_id}")
- return execution_id
- except Exception as e:
- self.logger.error(f"步骤开始执行失败: {e}")
- raise
-
- def complete_step(self, task_id: str, step_name: str, status: str, error_message: Optional[str] = None):
- """完成步骤执行"""
- try:
- self.update_step_status(task_id, step_name, status, error_message)
- self.logger.info(f"步骤执行完成: {task_id}.{step_name} -> {status}")
- except Exception as e:
- self.logger.error(f"步骤执行完成失败: {e}")
- raise
-
- def get_task_steps(self, task_id: str) -> List[Dict[str, Any]]:
- """获取任务的所有步骤状态"""
- try:
- conn = self._get_connection()
- with conn.cursor(cursor_factory=RealDictCursor) as cursor:
- cursor.execute("""
- SELECT * FROM data_pipeline_task_steps
- WHERE task_id = %s
- ORDER BY
- CASE step_name
- WHEN 'ddl_generation' THEN 1
- WHEN 'qa_generation' THEN 2
- WHEN 'sql_validation' THEN 3
- WHEN 'training_load' THEN 4
- ELSE 5
- END
- """, (task_id,))
-
- return [dict(row) for row in cursor.fetchall()]
- except Exception as e:
- self.logger.error(f"获取任务步骤状态失败: {e}")
- raise
-
- def get_step_status(self, task_id: str, step_name: str) -> Optional[Dict[str, Any]]:
- """获取特定步骤的状态"""
- try:
- conn = self._get_connection()
- with conn.cursor(cursor_factory=RealDictCursor) as cursor:
- cursor.execute("""
- SELECT * FROM data_pipeline_task_steps
- WHERE task_id = %s AND step_name = %s
- """, (task_id, step_name))
-
- result = cursor.fetchone()
- return dict(result) if result else None
- except Exception as e:
- self.logger.error(f"获取步骤状态失败: {e}")
- raise
-
-
- def get_tasks_list(self, limit: int = 50, offset: int = 0, status_filter: Optional[str] = None) -> List[Dict[str, Any]]:
- """获取任务列表"""
- try:
- conn = self._get_connection()
- with conn.cursor(cursor_factory=RealDictCursor) as cursor:
- where_clause = ""
- params = []
-
- if status_filter:
- where_clause = "WHERE status = %s"
- params.append(status_filter)
-
- params.extend([limit, offset])
-
- cursor.execute(f"""
- SELECT * FROM data_pipeline_tasks
- {where_clause}
- ORDER BY created_at DESC
- LIMIT %s OFFSET %s
- """, params)
-
- return [dict(row) for row in cursor.fetchall()]
- except Exception as e:
- self.logger.error(f"获取任务列表失败: {e}")
- raise
-
- def _get_task_started_at(self, task_id: str) -> Optional[datetime]:
- """获取任务开始时间"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- cursor.execute("SELECT started_at FROM data_pipeline_tasks WHERE task_id = %s", (task_id,))
- result = cursor.fetchone()
- return result[0] if result and result[0] else None
- except Exception:
- return None
-
- def _build_db_connection_string(self, db_config: dict) -> str:
- """构建数据库连接字符串"""
- try:
- host = db_config.get('host', 'localhost')
- port = db_config.get('port', 5432)
- dbname = db_config.get('dbname', 'database')
- user = db_config.get('user', 'postgres')
- password = db_config.get('password', '')
-
- return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
- except Exception:
- return "postgresql://localhost:5432/database"
-
- def _extract_db_name(self, connection_string: str) -> str:
- """从连接字符串提取数据库名称"""
- try:
- if '/' in connection_string:
- db_name = connection_string.split('/')[-1]
- if '?' in db_name:
- db_name = db_name.split('?')[0]
- return db_name if db_name else "database"
- else:
- return "database"
- except Exception:
- return "database"
|