123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358 |
- import asyncio
- import time
- import os
- from datetime import datetime
- from pathlib import Path
- from typing import Dict, Any, List
- import psycopg2
- import logging
- class VectorTableManager:
- """Vector表管理器,负责备份和清空操作"""
-
- def __init__(self, task_output_dir: str, task_id: str = None):
- """
- Args:
- task_output_dir: 任务输出目录(用于存放备份文件)
- task_id: 任务ID(用于日志记录)
- Note:
- 数据库连接将从data_pipeline.config.SCHEMA_TOOLS_CONFIG自动获取
- """
- self.task_output_dir = task_output_dir
- self.task_id = task_id
-
- # 从data_pipeline.config获取配置
- from data_pipeline.config import SCHEMA_TOOLS_CONFIG
- self.config = SCHEMA_TOOLS_CONFIG.get("vector_table_management", {})
-
- # 初始化日志
- if task_id:
- from data_pipeline.dp_logging import get_logger
- self.logger = get_logger("VectorTableManager", task_id)
- else:
- import logging
- self.logger = logging.getLogger("VectorTableManager")
-
- def execute_vector_management(self, backup: bool, truncate: bool) -> Dict[str, Any]:
- """执行vector表管理操作的主流程"""
-
- start_time = time.time()
-
- # 1. 参数验证和自动启用逻辑
- if truncate and not backup:
- backup = True
- self.logger.info("🔄 启用truncate时自动启用backup")
-
- if not backup and not truncate:
- self.logger.info("⏭️ 未启用vector表管理,跳过操作")
- return {"backup_performed": False, "truncate_performed": False}
-
- # 2. 初始化结果统计
- result = {
- "backup_performed": backup,
- "truncate_performed": truncate,
- "tables_backed_up": {},
- "truncate_results": {},
- "errors": [],
- "backup_directory": None,
- "duration": 0
- }
-
- try:
- # 3. 创建备份目录
- backup_dir = Path(self.task_output_dir) / self.config.get("backup_directory", "vector_bak")
- if backup:
- backup_dir.mkdir(parents=True, exist_ok=True)
- result["backup_directory"] = str(backup_dir)
- self.logger.info(f"📁 备份目录: {backup_dir}")
-
- # 4. 执行备份操作
- if backup:
- self.logger.info("🗂️ 开始备份vector表...")
- backup_results = self.backup_vector_tables()
- result["tables_backed_up"] = backup_results
-
- # 检查备份是否全部成功
- backup_failed = any(not r.get("success", False) for r in backup_results.values())
- if backup_failed:
- result["errors"].append("部分表备份失败")
- if truncate:
- self.logger.error("❌ 备份失败,取消清空操作")
- result["truncate_performed"] = False
- truncate = False
-
- # 5. 执行清空操作(仅在备份成功时)
- if truncate:
- self.logger.info("🗑️ 开始清空vector表...")
- truncate_results = self.truncate_vector_tables()
- result["truncate_results"] = truncate_results
-
- # 检查清空是否成功
- truncate_failed = any(not r.get("success", False) for r in truncate_results.values())
- if truncate_failed:
- result["errors"].append("部分表清空失败")
-
- # 6. 生成备份日志文件
- if backup and backup_dir.exists():
- self._write_backup_log(backup_dir, result)
-
- # 7. 计算总耗时
- result["duration"] = time.time() - start_time
-
- # 8. 记录最终状态
- if result["errors"]:
- self.logger.warning(f"⚠️ Vector表管理完成,但有错误: {'; '.join(result['errors'])}")
- else:
- self.logger.info(f"✅ Vector表管理完成,耗时: {result['duration']:.2f}秒")
-
- return result
-
- except Exception as e:
- result["duration"] = time.time() - start_time
- result["errors"].append(f"执行失败: {str(e)}")
- self.logger.error(f"❌ Vector表管理失败: {e}")
- raise
-
- def backup_vector_tables(self) -> Dict[str, Any]:
- """备份vector表数据"""
-
- # 1. 创建备份目录
- backup_dir = Path(self.task_output_dir) / self.config.get("backup_directory", "vector_bak")
- backup_dir.mkdir(parents=True, exist_ok=True)
-
- # 2. 生成时间戳
- timestamp = datetime.now().strftime(self.config.get("timestamp_format", "%Y%m%d_%H%M%S"))
-
- # 3. 执行备份(每个表分别处理)
- results = {}
- supported_tables = self.config.get("supported_tables", ["langchain_pg_collection", "langchain_pg_embedding"])
-
- for table_name in supported_tables:
- try:
- # 3.1 定义文件路径(.tmp临时文件)
- temp_file = backup_dir / f"{table_name}_{timestamp}.csv.tmp"
- final_file = backup_dir / f"{table_name}_{timestamp}.csv"
-
- # 确保使用绝对路径(PostgreSQL COPY命令要求)
- temp_file_abs = temp_file.resolve()
-
- # 3.2 通过psycopg2使用流式客户端导出(支持大数据量)
- start_time = time.time()
- row_count = 0
- batch_size = 10000 # 每批处理1万条记录
-
- with self.get_connection() as conn:
- # 临时关闭autocommit以支持流式处理
- old_autocommit = conn.autocommit
- conn.autocommit = False
-
- try:
- with conn.cursor() as cursor:
- # 设置游标为流式模式
- cursor.itersize = batch_size
-
- # 执行编码设置
- cursor.execute("SET client_encoding TO 'UTF8'")
-
- # 执行查询
- cursor.execute(f"SELECT * FROM {table_name}")
-
- # 获取列名
- colnames = [desc[0] for desc in cursor.description]
-
- # 使用流式方式写入CSV文件
- import csv
- with open(temp_file_abs, 'w', newline='', encoding='utf-8') as csvfile:
- writer = csv.writer(csvfile)
-
- # 写入表头
- writer.writerow(colnames)
-
- # 流式读取和写入数据
- while True:
- rows = cursor.fetchmany(batch_size)
- if not rows:
- break
-
- # 批量写入当前批次的数据
- for row in rows:
- writer.writerow(row)
- row_count += 1
-
- # 记录进度(大数据量时有用)
- if row_count % (batch_size * 5) == 0: # 每5万条记录记录一次
- self.logger.info(f"📊 {table_name} 已导出 {row_count} 行数据...")
-
- # 提交事务
- conn.commit()
-
- finally:
- # 恢复原来的autocommit设置
- conn.autocommit = old_autocommit
-
- self.logger.info(f"📊 {table_name} 流式导出完成,总计 {row_count} 行")
-
- # 3.3 导出完成后,重命名文件 (.tmp -> .csv)
- if temp_file.exists():
- temp_file.rename(final_file)
-
- # 3.4 获取文件信息
- file_stat = final_file.stat()
- duration = time.time() - start_time
-
- results[table_name] = {
- "success": True,
- "row_count": row_count,
- "file_size": self._format_file_size(file_stat.st_size),
- "backup_file": final_file.name,
- "duration": duration
- }
-
- self.logger.info(f"✅ {table_name} 备份成功: {row_count}行 -> {final_file.name}")
- else:
- raise Exception(f"临时文件 {temp_file} 未生成")
-
- except Exception as e:
- results[table_name] = {
- "success": False,
- "error": str(e)
- }
- self.logger.error(f"❌ {table_name} 备份失败: {e}")
-
- # 清理可能的临时文件
- if temp_file.exists():
- temp_file.unlink()
-
- return results
-
- def truncate_vector_tables(self) -> Dict[str, Any]:
- """清空vector表数据(只清空langchain_pg_embedding)"""
-
- results = {}
-
- # 只清空配置中指定的表(通常只有langchain_pg_embedding)
- truncate_tables = self.config.get("truncate_tables", ["langchain_pg_embedding"])
-
- for table_name in truncate_tables:
- try:
- # 记录清空前的行数(用于统计)
- count_sql = f"SELECT COUNT(*) FROM {table_name}"
-
- start_time = time.time()
- with self.get_connection() as conn:
- with conn.cursor() as cursor:
- # 1. 获取清空前的行数
- cursor.execute(count_sql)
- rows_before = cursor.fetchone()[0]
-
- # 2. 执行TRUNCATE
- cursor.execute(f"TRUNCATE TABLE {table_name}")
-
- # 3. 验证清空结果
- cursor.execute(count_sql)
- rows_after = cursor.fetchone()[0]
-
- duration = time.time() - start_time
-
- if rows_after == 0:
- results[table_name] = {
- "success": True,
- "rows_before": rows_before,
- "rows_after": rows_after,
- "duration": duration
- }
- self.logger.info(f"✅ {table_name} 清空成功: {rows_before}行 -> 0行")
- else:
- raise Exception(f"清空失败,表中仍有 {rows_after} 行数据")
-
- except Exception as e:
- results[table_name] = {
- "success": False,
- "error": str(e)
- }
- self.logger.error(f"❌ {table_name} 清空失败: {e}")
-
- return results
-
- def get_connection(self):
- """获取pgvector数据库连接(从data_pipeline.config获取配置)"""
- import psycopg2
-
- try:
- # 方法1:如果SCHEMA_TOOLS_CONFIG中有连接字符串,直接使用
- from data_pipeline.config import SCHEMA_TOOLS_CONFIG
- connection_string = SCHEMA_TOOLS_CONFIG.get("default_db_connection")
- if connection_string:
- conn = psycopg2.connect(connection_string)
- else:
- # 方法2:从app_config获取pgvector数据库配置
- import app_config
- pgvector_config = app_config.PGVECTOR_CONFIG
- conn = psycopg2.connect(
- host=pgvector_config.get('host'),
- port=pgvector_config.get('port'),
- database=pgvector_config.get('dbname'),
- user=pgvector_config.get('user'),
- password=pgvector_config.get('password')
- )
-
- # 设置自动提交,避免事务问题
- conn.autocommit = True
- return conn
-
- except Exception as e:
- self.logger.error(f"pgvector数据库连接失败: {e}")
- raise
- def _write_backup_log(self, backup_dir: Path, result: Dict[str, Any]):
- """写入详细的备份日志"""
- log_file = backup_dir / "vector_backup_log.txt"
-
- try:
- with open(log_file, 'w', encoding='utf-8') as f:
- f.write("=== Vector Table Backup Log ===\n")
- f.write(f"Backup Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
- f.write(f"Task ID: {self.task_id or 'Unknown'}\n")
- f.write(f"Duration: {result.get('duration', 0):.2f}s\n\n")
-
- # 备份状态
- f.write("Tables Backup Status:\n")
- for table_name, info in result.get("tables_backed_up", {}).items():
- if info.get("success", False):
- f.write(f"✓ {table_name}: {info['row_count']} rows -> {info['backup_file']} ({info['file_size']})\n")
- else:
- f.write(f"✗ {table_name}: FAILED - {info.get('error', 'Unknown error')}\n")
-
- # 清空状态
- if result.get("truncate_performed", False):
- f.write("\nTruncate Status:\n")
- for table_name, info in result.get("truncate_results", {}).items():
- if info.get("success", False):
- f.write(f"✓ {table_name}: TRUNCATED ({info['rows_before']} rows removed)\n")
- else:
- f.write(f"✗ {table_name}: FAILED - {info.get('error', 'Unknown error')}\n")
- else:
- f.write("\nTruncate Status:\n- Not performed\n")
-
- # 错误汇总
- if result.get("errors"):
- f.write(f"\nErrors: {'; '.join(result['errors'])}\n")
-
- except Exception as e:
- self.logger.warning(f"写入备份日志失败: {e}")
-
- def _format_file_size(self, size_bytes: int) -> str:
- """格式化文件大小显示"""
- if size_bytes == 0:
- return "0 B"
-
- size_names = ["B", "KB", "MB", "GB"]
- i = 0
- size = float(size_bytes)
-
- while size >= 1024.0 and i < len(size_names) - 1:
- size /= 1024.0
- i += 1
-
- return f"{size:.1f} {size_names[i]}"
|