vector_table_manager.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. import asyncio
  2. import time
  3. import os
  4. from datetime import datetime
  5. from pathlib import Path
  6. from typing import Dict, Any, List
  7. import psycopg2
  8. import logging
  9. class VectorTableManager:
  10. """Vector表管理器,负责备份和清空操作"""
  11. def __init__(self, task_output_dir: str, task_id: str = None):
  12. """
  13. Args:
  14. task_output_dir: 任务输出目录(用于存放备份文件)
  15. task_id: 任务ID(用于日志记录)
  16. Note:
  17. 数据库连接将从data_pipeline.config.SCHEMA_TOOLS_CONFIG自动获取
  18. """
  19. self.task_output_dir = task_output_dir
  20. self.task_id = task_id
  21. # 从data_pipeline.config获取配置
  22. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  23. self.config = SCHEMA_TOOLS_CONFIG.get("vector_table_management", {})
  24. # 初始化日志
  25. if task_id:
  26. from data_pipeline.dp_logging import get_logger
  27. self.logger = get_logger("VectorTableManager", task_id)
  28. else:
  29. import logging
  30. self.logger = logging.getLogger("VectorTableManager")
  31. def execute_vector_management(self, backup: bool, truncate: bool) -> Dict[str, Any]:
  32. """执行vector表管理操作的主流程"""
  33. start_time = time.time()
  34. # 1. 参数验证和自动启用逻辑
  35. if truncate and not backup:
  36. backup = True
  37. self.logger.info("🔄 启用truncate时自动启用backup")
  38. if not backup and not truncate:
  39. self.logger.info("⏭️ 未启用vector表管理,跳过操作")
  40. return {"backup_performed": False, "truncate_performed": False}
  41. # 2. 初始化结果统计
  42. result = {
  43. "backup_performed": backup,
  44. "truncate_performed": truncate,
  45. "tables_backed_up": {},
  46. "truncate_results": {},
  47. "errors": [],
  48. "backup_directory": None,
  49. "duration": 0
  50. }
  51. try:
  52. # 3. 创建备份目录
  53. backup_dir = Path(self.task_output_dir) / self.config.get("backup_directory", "vector_bak")
  54. if backup:
  55. backup_dir.mkdir(parents=True, exist_ok=True)
  56. result["backup_directory"] = str(backup_dir)
  57. self.logger.info(f"📁 备份目录: {backup_dir}")
  58. # 4. 执行备份操作
  59. if backup:
  60. self.logger.info("🗂️ 开始备份vector表...")
  61. backup_results = self.backup_vector_tables()
  62. result["tables_backed_up"] = backup_results
  63. # 检查备份是否全部成功
  64. backup_failed = any(not r.get("success", False) for r in backup_results.values())
  65. if backup_failed:
  66. result["errors"].append("部分表备份失败")
  67. if truncate:
  68. self.logger.error("❌ 备份失败,取消清空操作")
  69. result["truncate_performed"] = False
  70. truncate = False
  71. # 5. 执行清空操作(仅在备份成功时)
  72. if truncate:
  73. self.logger.info("🗑️ 开始清空vector表...")
  74. truncate_results = self.truncate_vector_tables()
  75. result["truncate_results"] = truncate_results
  76. # 检查清空是否成功
  77. truncate_failed = any(not r.get("success", False) for r in truncate_results.values())
  78. if truncate_failed:
  79. result["errors"].append("部分表清空失败")
  80. # 6. 生成备份日志文件
  81. if backup and backup_dir.exists():
  82. self._write_backup_log(backup_dir, result)
  83. # 7. 计算总耗时
  84. result["duration"] = time.time() - start_time
  85. # 8. 记录最终状态
  86. if result["errors"]:
  87. self.logger.warning(f"⚠️ Vector表管理完成,但有错误: {'; '.join(result['errors'])}")
  88. else:
  89. self.logger.info(f"✅ Vector表管理完成,耗时: {result['duration']:.2f}秒")
  90. return result
  91. except Exception as e:
  92. result["duration"] = time.time() - start_time
  93. result["errors"].append(f"执行失败: {str(e)}")
  94. self.logger.error(f"❌ Vector表管理失败: {e}")
  95. raise
  96. def backup_vector_tables(self) -> Dict[str, Any]:
  97. """备份vector表数据"""
  98. # 1. 创建备份目录
  99. backup_dir = Path(self.task_output_dir) / self.config.get("backup_directory", "vector_bak")
  100. backup_dir.mkdir(parents=True, exist_ok=True)
  101. # 2. 生成时间戳
  102. timestamp = datetime.now().strftime(self.config.get("timestamp_format", "%Y%m%d_%H%M%S"))
  103. # 3. 执行备份(每个表分别处理)
  104. results = {}
  105. supported_tables = self.config.get("supported_tables", ["langchain_pg_collection", "langchain_pg_embedding"])
  106. for table_name in supported_tables:
  107. try:
  108. # 3.1 定义文件路径(.tmp临时文件)
  109. temp_file = backup_dir / f"{table_name}_{timestamp}.csv.tmp"
  110. final_file = backup_dir / f"{table_name}_{timestamp}.csv"
  111. # 确保使用绝对路径(PostgreSQL COPY命令要求)
  112. temp_file_abs = temp_file.resolve()
  113. # 3.2 通过psycopg2使用流式客户端导出(支持大数据量)
  114. start_time = time.time()
  115. row_count = 0
  116. batch_size = 10000 # 每批处理1万条记录
  117. with self.get_connection() as conn:
  118. # 临时关闭autocommit以支持流式处理
  119. old_autocommit = conn.autocommit
  120. conn.autocommit = False
  121. try:
  122. with conn.cursor() as cursor:
  123. # 设置游标为流式模式
  124. cursor.itersize = batch_size
  125. # 执行编码设置
  126. cursor.execute("SET client_encoding TO 'UTF8'")
  127. # 执行查询
  128. cursor.execute(f"SELECT * FROM {table_name}")
  129. # 获取列名
  130. colnames = [desc[0] for desc in cursor.description]
  131. # 使用流式方式写入CSV文件
  132. import csv
  133. with open(temp_file_abs, 'w', newline='', encoding='utf-8') as csvfile:
  134. writer = csv.writer(csvfile)
  135. # 写入表头
  136. writer.writerow(colnames)
  137. # 流式读取和写入数据
  138. while True:
  139. rows = cursor.fetchmany(batch_size)
  140. if not rows:
  141. break
  142. # 批量写入当前批次的数据
  143. for row in rows:
  144. writer.writerow(row)
  145. row_count += 1
  146. # 记录进度(大数据量时有用)
  147. if row_count % (batch_size * 5) == 0: # 每5万条记录记录一次
  148. self.logger.info(f"📊 {table_name} 已导出 {row_count} 行数据...")
  149. # 提交事务
  150. conn.commit()
  151. finally:
  152. # 恢复原来的autocommit设置
  153. conn.autocommit = old_autocommit
  154. self.logger.info(f"📊 {table_name} 流式导出完成,总计 {row_count} 行")
  155. # 3.3 导出完成后,重命名文件 (.tmp -> .csv)
  156. if temp_file.exists():
  157. temp_file.rename(final_file)
  158. # 3.4 获取文件信息
  159. file_stat = final_file.stat()
  160. duration = time.time() - start_time
  161. results[table_name] = {
  162. "success": True,
  163. "row_count": row_count,
  164. "file_size": self._format_file_size(file_stat.st_size),
  165. "backup_file": final_file.name,
  166. "duration": duration
  167. }
  168. self.logger.info(f"✅ {table_name} 备份成功: {row_count}行 -> {final_file.name}")
  169. else:
  170. raise Exception(f"临时文件 {temp_file} 未生成")
  171. except Exception as e:
  172. results[table_name] = {
  173. "success": False,
  174. "error": str(e)
  175. }
  176. self.logger.error(f"❌ {table_name} 备份失败: {e}")
  177. # 清理可能的临时文件
  178. if temp_file.exists():
  179. temp_file.unlink()
  180. return results
  181. def truncate_vector_tables(self) -> Dict[str, Any]:
  182. """清空vector表数据(只清空langchain_pg_embedding)"""
  183. results = {}
  184. # 只清空配置中指定的表(通常只有langchain_pg_embedding)
  185. truncate_tables = self.config.get("truncate_tables", ["langchain_pg_embedding"])
  186. for table_name in truncate_tables:
  187. try:
  188. # 记录清空前的行数(用于统计)
  189. count_sql = f"SELECT COUNT(*) FROM {table_name}"
  190. start_time = time.time()
  191. with self.get_connection() as conn:
  192. with conn.cursor() as cursor:
  193. # 1. 获取清空前的行数
  194. cursor.execute(count_sql)
  195. rows_before = cursor.fetchone()[0]
  196. # 2. 执行TRUNCATE
  197. cursor.execute(f"TRUNCATE TABLE {table_name}")
  198. # 3. 验证清空结果
  199. cursor.execute(count_sql)
  200. rows_after = cursor.fetchone()[0]
  201. duration = time.time() - start_time
  202. if rows_after == 0:
  203. results[table_name] = {
  204. "success": True,
  205. "rows_before": rows_before,
  206. "rows_after": rows_after,
  207. "duration": duration
  208. }
  209. self.logger.info(f"✅ {table_name} 清空成功: {rows_before}行 -> 0行")
  210. else:
  211. raise Exception(f"清空失败,表中仍有 {rows_after} 行数据")
  212. except Exception as e:
  213. results[table_name] = {
  214. "success": False,
  215. "error": str(e)
  216. }
  217. self.logger.error(f"❌ {table_name} 清空失败: {e}")
  218. return results
  219. def get_connection(self):
  220. """获取pgvector数据库连接(从data_pipeline.config获取配置)"""
  221. import psycopg2
  222. try:
  223. # 方法1:如果SCHEMA_TOOLS_CONFIG中有连接字符串,直接使用
  224. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  225. connection_string = SCHEMA_TOOLS_CONFIG.get("default_db_connection")
  226. if connection_string:
  227. conn = psycopg2.connect(connection_string)
  228. else:
  229. # 方法2:从app_config获取pgvector数据库配置
  230. import app_config
  231. pgvector_config = app_config.PGVECTOR_CONFIG
  232. conn = psycopg2.connect(
  233. host=pgvector_config.get('host'),
  234. port=pgvector_config.get('port'),
  235. database=pgvector_config.get('dbname'),
  236. user=pgvector_config.get('user'),
  237. password=pgvector_config.get('password')
  238. )
  239. # 设置自动提交,避免事务问题
  240. conn.autocommit = True
  241. return conn
  242. except Exception as e:
  243. self.logger.error(f"pgvector数据库连接失败: {e}")
  244. raise
  245. def _write_backup_log(self, backup_dir: Path, result: Dict[str, Any]):
  246. """写入详细的备份日志"""
  247. log_file = backup_dir / "vector_backup_log.txt"
  248. try:
  249. with open(log_file, 'w', encoding='utf-8') as f:
  250. f.write("=== Vector Table Backup Log ===\n")
  251. f.write(f"Backup Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
  252. f.write(f"Task ID: {self.task_id or 'Unknown'}\n")
  253. f.write(f"Duration: {result.get('duration', 0):.2f}s\n\n")
  254. # 备份状态
  255. f.write("Tables Backup Status:\n")
  256. for table_name, info in result.get("tables_backed_up", {}).items():
  257. if info.get("success", False):
  258. f.write(f"✓ {table_name}: {info['row_count']} rows -> {info['backup_file']} ({info['file_size']})\n")
  259. else:
  260. f.write(f"✗ {table_name}: FAILED - {info.get('error', 'Unknown error')}\n")
  261. # 清空状态
  262. if result.get("truncate_performed", False):
  263. f.write("\nTruncate Status:\n")
  264. for table_name, info in result.get("truncate_results", {}).items():
  265. if info.get("success", False):
  266. f.write(f"✓ {table_name}: TRUNCATED ({info['rows_before']} rows removed)\n")
  267. else:
  268. f.write(f"✗ {table_name}: FAILED - {info.get('error', 'Unknown error')}\n")
  269. else:
  270. f.write("\nTruncate Status:\n- Not performed\n")
  271. # 错误汇总
  272. if result.get("errors"):
  273. f.write(f"\nErrors: {'; '.join(result['errors'])}\n")
  274. except Exception as e:
  275. self.logger.warning(f"写入备份日志失败: {e}")
  276. def _format_file_size(self, size_bytes: int) -> str:
  277. """格式化文件大小显示"""
  278. if size_bytes == 0:
  279. return "0 B"
  280. size_names = ["B", "KB", "MB", "GB"]
  281. i = 0
  282. size = float(size_bytes)
  283. while size >= 1024.0 and i < len(size_names) - 1:
  284. size /= 1024.0
  285. i += 1
  286. return f"{size:.1f} {size_names[i]}"