123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- """
- Data Pipeline API 简化数据库管理器
- 复用现有的pgvector数据库连接机制,提供Data Pipeline任务的数据库操作功能
- """
- import json
- from datetime import datetime
- from typing import Dict, Any, List, Optional, Tuple
- import psycopg2
- from psycopg2.extras import RealDictCursor, Json
- from app_config import PGVECTOR_CONFIG
- from core.logging import get_data_pipeline_logger
- class SimpleTaskManager:
- """简化的任务管理器,复用现有pgvector连接"""
-
- def __init__(self):
- """初始化任务管理器"""
- self.logger = get_data_pipeline_logger("SimpleTaskManager")
- self._connection = None
-
- def _get_connection(self):
- """获取pgvector数据库连接"""
- if self._connection is None or self._connection.closed:
- try:
- self._connection = psycopg2.connect(
- host=PGVECTOR_CONFIG.get('host'),
- port=PGVECTOR_CONFIG.get('port'),
- database=PGVECTOR_CONFIG.get('dbname'),
- user=PGVECTOR_CONFIG.get('user'),
- password=PGVECTOR_CONFIG.get('password')
- )
- self._connection.autocommit = True
- except Exception as e:
- self.logger.error(f"pgvector数据库连接失败: {e}")
- raise
- return self._connection
-
- def close_connection(self):
- """关闭数据库连接"""
- if self._connection and not self._connection.closed:
- self._connection.close()
- self._connection = None
-
- def generate_task_id(self) -> str:
- """生成任务ID,格式: task_YYYYMMDD_HHMMSS"""
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- return f"task_{timestamp}"
-
- def create_task(self,
- table_list_file: str,
- business_context: str,
- db_name: str = None,
- **kwargs) -> str:
- """创建新任务"""
- task_id = self.generate_task_id()
-
- # 从 app_config 获取业务数据库连接信息
- from app_config import APP_DB_CONFIG
-
- # 构建业务数据库连接字符串(用于参数记录)
- business_db_connection = self._build_db_connection_string(APP_DB_CONFIG)
-
- # 使用传入的db_name或从APP_DB_CONFIG提取
- if not db_name:
- db_name = APP_DB_CONFIG.get('dbname', 'business_db')
-
- # 构建参数
- parameters = {
- "db_connection": business_db_connection, # 业务数据库连接(用于schema_workflow执行)
- "table_list_file": table_list_file,
- "business_context": business_context,
- **kwargs
- }
-
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- cursor.execute("""
- INSERT INTO data_pipeline_tasks (
- id, task_type, status, parameters, created_by,
- db_name, business_context, output_directory
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
- """, (
- task_id,
- 'data_workflow',
- 'pending',
- Json(parameters),
- 'api',
- db_name,
- business_context,
- f"./data_pipeline/training_data/{task_id}"
- ))
-
- self.logger.info(f"任务创建成功: {task_id}")
- return task_id
-
- except Exception as e:
- self.logger.error(f"任务创建失败: {e}")
- raise
-
- def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
- """获取任务信息"""
- try:
- conn = self._get_connection()
- with conn.cursor(cursor_factory=RealDictCursor) as cursor:
- cursor.execute("SELECT * FROM data_pipeline_tasks WHERE id = %s", (task_id,))
- result = cursor.fetchone()
- return dict(result) if result else None
- except Exception as e:
- self.logger.error(f"获取任务信息失败: {e}")
- raise
-
- def update_task_status(self, task_id: str, status: str, error_message: Optional[str] = None):
- """更新任务状态"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- update_fields = ["status = %s"]
- values = [status]
-
- if status == 'in_progress' and not self._get_task_started_at(task_id):
- update_fields.append("started_at = CURRENT_TIMESTAMP")
-
- if status in ['completed', 'failed']:
- update_fields.append("completed_at = CURRENT_TIMESTAMP")
-
- if error_message:
- update_fields.append("error_message = %s")
- values.append(error_message)
-
- values.append(task_id)
-
- cursor.execute(f"""
- UPDATE data_pipeline_tasks
- SET {', '.join(update_fields)}
- WHERE id = %s
- """, values)
-
- self.logger.info(f"任务状态更新: {task_id} -> {status}")
- except Exception as e:
- self.logger.error(f"任务状态更新失败: {e}")
- raise
-
- def update_step_status(self, task_id: str, step_name: str, step_status: str):
- """更新步骤状态"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- cursor.execute("""
- UPDATE data_pipeline_tasks
- SET step_status = jsonb_set(step_status, %s, %s)
- WHERE id = %s
- """, ([step_name], json.dumps(step_status), task_id))
-
- self.logger.debug(f"步骤状态更新: {task_id}.{step_name} -> {step_status}")
- except Exception as e:
- self.logger.error(f"步骤状态更新失败: {e}")
- raise
-
- def create_execution(self, task_id: str, execution_step: str) -> str:
- """创建执行记录"""
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- execution_id = f"{task_id}_step_{execution_step}_exec_{timestamp}"
-
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- cursor.execute("""
- INSERT INTO data_pipeline_task_executions (
- task_id, execution_step, status, execution_id
- ) VALUES (%s, %s, %s, %s)
- """, (task_id, execution_step, 'running', execution_id))
-
- self.logger.info(f"执行记录创建: {execution_id}")
- return execution_id
- except Exception as e:
- self.logger.error(f"执行记录创建失败: {e}")
- raise
-
- def complete_execution(self, execution_id: str, status: str, error_message: Optional[str] = None):
- """完成执行记录"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- # 计算执行时长
- cursor.execute("""
- SELECT started_at FROM data_pipeline_task_executions
- WHERE execution_id = %s
- """, (execution_id,))
- result = cursor.fetchone()
-
- duration_seconds = None
- if result and result[0]:
- duration_seconds = int((datetime.now() - result[0]).total_seconds())
-
- # 更新执行记录
- update_fields = ["status = %s", "completed_at = CURRENT_TIMESTAMP"]
- values = [status]
-
- if duration_seconds is not None:
- update_fields.append("duration_seconds = %s")
- values.append(duration_seconds)
-
- if error_message:
- update_fields.append("error_message = %s")
- values.append(error_message)
-
- values.append(execution_id)
-
- cursor.execute(f"""
- UPDATE data_pipeline_task_executions
- SET {', '.join(update_fields)}
- WHERE execution_id = %s
- """, values)
-
- self.logger.info(f"执行记录完成: {execution_id} -> {status}")
- except Exception as e:
- self.logger.error(f"执行记录完成失败: {e}")
- raise
-
- def record_log(self, task_id: str, log_level: str, message: str,
- execution_id: Optional[str] = None, step_name: Optional[str] = None):
- """记录日志到数据库"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- cursor.execute("""
- INSERT INTO data_pipeline_task_logs (
- task_id, execution_id, log_level, message, step_name
- ) VALUES (%s, %s, %s, %s, %s)
- """, (task_id, execution_id, log_level, message, step_name))
- except Exception as e:
- self.logger.error(f"日志记录失败: {e}")
-
- def get_task_logs(self, task_id: str, limit: int = 100) -> List[Dict[str, Any]]:
- """获取任务日志"""
- try:
- conn = self._get_connection()
- with conn.cursor(cursor_factory=RealDictCursor) as cursor:
- cursor.execute("""
- SELECT * FROM data_pipeline_task_logs
- WHERE task_id = %s
- ORDER BY timestamp DESC
- LIMIT %s
- """, (task_id, limit))
-
- return [dict(row) for row in cursor.fetchall()]
- except Exception as e:
- self.logger.error(f"获取任务日志失败: {e}")
- raise
-
- def get_task_executions(self, task_id: str) -> List[Dict[str, Any]]:
- """获取任务执行记录"""
- try:
- conn = self._get_connection()
- with conn.cursor(cursor_factory=RealDictCursor) as cursor:
- cursor.execute("""
- SELECT * FROM data_pipeline_task_executions
- WHERE task_id = %s
- ORDER BY started_at DESC
- """, (task_id,))
-
- return [dict(row) for row in cursor.fetchall()]
- except Exception as e:
- self.logger.error(f"获取执行记录失败: {e}")
- raise
-
- def get_tasks_list(self, limit: int = 50, offset: int = 0, status_filter: Optional[str] = None) -> List[Dict[str, Any]]:
- """获取任务列表"""
- try:
- conn = self._get_connection()
- with conn.cursor(cursor_factory=RealDictCursor) as cursor:
- where_clause = ""
- params = []
-
- if status_filter:
- where_clause = "WHERE status = %s"
- params.append(status_filter)
-
- params.extend([limit, offset])
-
- cursor.execute(f"""
- SELECT * FROM data_pipeline_tasks
- {where_clause}
- ORDER BY created_at DESC
- LIMIT %s OFFSET %s
- """, params)
-
- return [dict(row) for row in cursor.fetchall()]
- except Exception as e:
- self.logger.error(f"获取任务列表失败: {e}")
- raise
-
- def _get_task_started_at(self, task_id: str) -> Optional[datetime]:
- """获取任务开始时间"""
- try:
- conn = self._get_connection()
- with conn.cursor() as cursor:
- cursor.execute("SELECT started_at FROM data_pipeline_tasks WHERE id = %s", (task_id,))
- result = cursor.fetchone()
- return result[0] if result and result[0] else None
- except Exception:
- return None
-
- def _build_db_connection_string(self, db_config: dict) -> str:
- """构建数据库连接字符串"""
- try:
- host = db_config.get('host', 'localhost')
- port = db_config.get('port', 5432)
- dbname = db_config.get('dbname', 'database')
- user = db_config.get('user', 'postgres')
- password = db_config.get('password', '')
-
- return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
- except Exception:
- return "postgresql://localhost:5432/database"
-
- def _extract_db_name(self, connection_string: str) -> str:
- """从连接字符串提取数据库名称"""
- try:
- if '/' in connection_string:
- db_name = connection_string.split('/')[-1]
- if '?' in db_name:
- db_name = db_name.split('?')[0]
- return db_name if db_name else "database"
- else:
- return "database"
- except Exception:
- return "database"
|