vector_restore_manager.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. """
  2. Vector表恢复管理器
  3. 提供pgvector表备份文件扫描和数据恢复功能,与VectorTableManager形成完整的备份恢复解决方案
  4. """
  5. import os
  6. import re
  7. import time
  8. import glob
  9. from datetime import datetime
  10. from pathlib import Path
  11. from typing import Dict, Any, List, Optional
  12. import psycopg2
  13. import logging
  14. class VectorRestoreManager:
  15. """Vector表恢复管理器 - 仿照VectorTableManager设计"""
  16. def __init__(self, base_output_dir: str = None):
  17. """
  18. 初始化恢复管理器,复用现有配置机制
  19. Args:
  20. base_output_dir: 基础输出目录,默认从data_pipeline.config获取
  21. """
  22. if base_output_dir is None:
  23. # 从配置文件获取默认目录
  24. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  25. base_output_dir = SCHEMA_TOOLS_CONFIG.get("output_directory", "./data_pipeline/training_data/")
  26. self.base_output_dir = Path(base_output_dir)
  27. # 从data_pipeline.config获取配置
  28. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  29. self.config = SCHEMA_TOOLS_CONFIG.get("vector_table_management", {})
  30. # 初始化日志
  31. self.logger = logging.getLogger("VectorRestoreManager")
  32. # 支持的表名
  33. self.supported_tables = self.config.get("supported_tables", [
  34. "langchain_pg_collection",
  35. "langchain_pg_embedding"
  36. ])
  37. def scan_backup_files(self, global_only: bool = False, task_id: str = None) -> Dict[str, Any]:
  38. """
  39. 扫描可用的备份文件
  40. Args:
  41. global_only: 仅查询全局备份目录(training_data/vector_bak/)
  42. task_id: 指定task_id,仅查询该任务下的备份文件
  43. Returns:
  44. 包含备份文件信息的字典
  45. """
  46. scan_start_time = datetime.now()
  47. backup_locations = []
  48. try:
  49. # 确定扫描范围
  50. if task_id:
  51. # 仅扫描指定任务
  52. directories_to_scan = [self.base_output_dir / task_id / "vector_bak"]
  53. elif global_only:
  54. # 仅扫描全局目录
  55. directories_to_scan = [self.base_output_dir / "vector_bak"]
  56. else:
  57. # 扫描所有目录
  58. directories_to_scan = self._get_all_vector_bak_directories()
  59. # 扫描每个目录
  60. for backup_dir in directories_to_scan:
  61. if not backup_dir.exists():
  62. continue
  63. # 查找有效的备份集
  64. backup_sets = self._find_backup_sets(backup_dir)
  65. if not backup_sets:
  66. continue
  67. # 构建备份位置信息
  68. location_info = self._build_location_info(backup_dir, backup_sets)
  69. if location_info:
  70. backup_locations.append(location_info)
  71. # 构建汇总信息
  72. summary = self._build_summary(backup_locations, scan_start_time)
  73. return {
  74. "backup_locations": backup_locations,
  75. "summary": summary
  76. }
  77. except Exception as e:
  78. self.logger.error(f"扫描备份文件失败: {e}")
  79. raise
  80. def restore_from_backup(self, backup_path: str, timestamp: str,
  81. tables: List[str] = None, db_connection: str = None,
  82. truncate_before_restore: bool = False) -> Dict[str, Any]:
  83. """
  84. 从备份文件恢复数据
  85. Args:
  86. backup_path: 备份文件所在的目录路径(相对路径)
  87. timestamp: 备份文件的时间戳
  88. tables: 要恢复的表名列表,None表示恢复所有表
  89. db_connection: PostgreSQL连接字符串,None则从config获取
  90. truncate_before_restore: 恢复前是否清空目标表
  91. Returns:
  92. 恢复操作的详细结果
  93. """
  94. start_time = time.time()
  95. # 设置默认表列表
  96. if tables is None:
  97. tables = self.supported_tables.copy()
  98. # 验证表名
  99. invalid_tables = [t for t in tables if t not in self.supported_tables]
  100. if invalid_tables:
  101. raise ValueError(f"不支持的表名: {invalid_tables}")
  102. # 解析备份路径
  103. backup_dir = Path(backup_path)
  104. if not backup_dir.is_absolute():
  105. # 相对路径,相对于项目根目录
  106. project_root = Path(__file__).parent.parent.parent
  107. backup_dir = project_root / backup_path
  108. if not backup_dir.exists():
  109. raise FileNotFoundError(f"备份目录不存在: {backup_path}")
  110. # 验证备份文件存在
  111. missing_files = []
  112. backup_files = {}
  113. for table_name in tables:
  114. csv_file = backup_dir / f"{table_name}_{timestamp}.csv"
  115. if not csv_file.exists():
  116. missing_files.append(csv_file.name)
  117. else:
  118. backup_files[table_name] = csv_file
  119. if missing_files:
  120. raise FileNotFoundError(f"备份文件不存在: {', '.join(missing_files)}")
  121. # 初始化结果
  122. result = {
  123. "restore_performed": True,
  124. "truncate_performed": truncate_before_restore,
  125. "backup_info": {
  126. "backup_path": backup_path,
  127. "timestamp": timestamp,
  128. "backup_date": self._parse_timestamp_to_date(timestamp)
  129. },
  130. "truncate_results": {},
  131. "restore_results": {},
  132. "errors": [],
  133. "duration": 0
  134. }
  135. # 临时修改数据库连接配置
  136. original_config = None
  137. if db_connection:
  138. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  139. original_config = SCHEMA_TOOLS_CONFIG.get("default_db_connection")
  140. SCHEMA_TOOLS_CONFIG["default_db_connection"] = db_connection
  141. try:
  142. # 执行清空操作(如果需要)
  143. if truncate_before_restore:
  144. self.logger.info("🗑️ 开始清空目标表...")
  145. for table_name in tables:
  146. truncate_result = self._truncate_table(table_name)
  147. result["truncate_results"][table_name] = truncate_result
  148. if not truncate_result.get("success", False):
  149. result["errors"].append(f"{table_name}表清空失败")
  150. # 执行恢复操作
  151. self.logger.info("📥 开始恢复表数据...")
  152. for table_name in tables:
  153. csv_file = backup_files[table_name]
  154. restore_result = self._restore_table_from_csv(table_name, csv_file)
  155. result["restore_results"][table_name] = restore_result
  156. if not restore_result.get("success", False):
  157. result["errors"].append(f"{table_name}表恢复失败")
  158. # 计算总耗时
  159. result["duration"] = time.time() - start_time
  160. # 记录最终状态
  161. if result["errors"]:
  162. self.logger.warning(f"⚠️ Vector表恢复完成,但有错误: {'; '.join(result['errors'])}")
  163. else:
  164. self.logger.info(f"✅ Vector表恢复完成,耗时: {result['duration']:.2f}秒")
  165. return result
  166. finally:
  167. # 恢复原始配置
  168. if original_config is not None:
  169. SCHEMA_TOOLS_CONFIG["default_db_connection"] = original_config
  170. def get_connection(self):
  171. """获取数据库连接 - 完全复用VectorTableManager的连接逻辑"""
  172. try:
  173. # 方法1:如果SCHEMA_TOOLS_CONFIG中有连接字符串,直接使用
  174. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  175. connection_string = SCHEMA_TOOLS_CONFIG.get("default_db_connection")
  176. if connection_string:
  177. conn = psycopg2.connect(connection_string)
  178. else:
  179. # 方法2:从app_config获取pgvector数据库配置
  180. import app_config
  181. pgvector_config = app_config.PGVECTOR_CONFIG
  182. conn = psycopg2.connect(
  183. host=pgvector_config.get('host'),
  184. port=pgvector_config.get('port'),
  185. database=pgvector_config.get('dbname'),
  186. user=pgvector_config.get('user'),
  187. password=pgvector_config.get('password')
  188. )
  189. # 设置自动提交
  190. conn.autocommit = True
  191. return conn
  192. except Exception as e:
  193. self.logger.error(f"pgvector数据库连接失败: {e}")
  194. raise
  195. def _get_all_vector_bak_directories(self) -> List[Path]:
  196. """获取所有vector_bak目录"""
  197. directories = []
  198. # 全局备份目录
  199. global_backup_dir = self.base_output_dir / "vector_bak"
  200. if global_backup_dir.exists():
  201. directories.append(global_backup_dir)
  202. # 任务备份目录 (task_* 和 manual_*)
  203. for pattern in ["task_*", "manual_*"]:
  204. for task_dir in self.base_output_dir.glob(pattern):
  205. if task_dir.is_dir():
  206. vector_bak_dir = task_dir / "vector_bak"
  207. if vector_bak_dir.exists():
  208. directories.append(vector_bak_dir)
  209. return directories
  210. def _find_backup_sets(self, backup_dir: Path) -> List[str]:
  211. """查找备份目录中的有效备份集"""
  212. # 查找所有CSV文件
  213. collection_files = list(backup_dir.glob("langchain_pg_collection_*.csv"))
  214. embedding_files = list(backup_dir.glob("langchain_pg_embedding_*.csv"))
  215. # 提取时间戳
  216. collection_timestamps = set()
  217. embedding_timestamps = set()
  218. for file in collection_files:
  219. timestamp = self._extract_timestamp_from_filename(file.name)
  220. if timestamp:
  221. collection_timestamps.add(timestamp)
  222. for file in embedding_files:
  223. timestamp = self._extract_timestamp_from_filename(file.name)
  224. if timestamp:
  225. embedding_timestamps.add(timestamp)
  226. # 找到同时存在两个文件的时间戳
  227. valid_timestamps = collection_timestamps & embedding_timestamps
  228. # 按时间戳降序排列(最新的在前)
  229. return sorted(valid_timestamps, reverse=True)
  230. def _extract_timestamp_from_filename(self, filename: str) -> Optional[str]:
  231. """从文件名中提取时间戳"""
  232. # 匹配格式:langchain_pg_collection_20250722_010318.csv
  233. pattern = r'langchain_pg_(?:collection|embedding)_(\d{8}_\d{6})\.csv'
  234. match = re.search(pattern, filename)
  235. return match.group(1) if match else None
  236. def _build_location_info(self, backup_dir: Path, backup_sets: List[str]) -> Optional[Dict[str, Any]]:
  237. """构建备份位置信息"""
  238. if not backup_sets:
  239. return None
  240. # 确定位置类型和相关信息
  241. relative_path = self._get_relative_path(backup_dir)
  242. location_type, task_id = self._determine_location_type(backup_dir)
  243. # 构建备份信息列表
  244. backups = []
  245. for timestamp in backup_sets:
  246. backup_info = self._build_backup_info(backup_dir, timestamp)
  247. if backup_info:
  248. backups.append(backup_info)
  249. location_info = {
  250. "type": location_type,
  251. "relative_path": relative_path,
  252. "backups": backups
  253. }
  254. if task_id:
  255. location_info["task_id"] = task_id
  256. return location_info
  257. def _get_relative_path(self, backup_dir: Path) -> str:
  258. """获取相对路径(Unix风格)"""
  259. try:
  260. # 计算相对于项目根目录的路径
  261. project_root = Path(__file__).parent.parent.parent
  262. relative_path = backup_dir.relative_to(project_root)
  263. # 转换为Unix风格路径
  264. return "./" + str(relative_path).replace("\\", "/")
  265. except ValueError:
  266. # 如果无法计算相对路径,直接转换
  267. return str(backup_dir).replace("\\", "/")
  268. def _determine_location_type(self, backup_dir: Path) -> tuple:
  269. """确定位置类型和task_id"""
  270. backup_dir_str = str(backup_dir)
  271. if "/vector_bak" in backup_dir_str.replace("\\", "/"):
  272. parent = backup_dir.parent.name
  273. if parent.startswith(("task_", "manual_")):
  274. return "task", parent
  275. else:
  276. return "global", None
  277. return "unknown", None
  278. def _build_backup_info(self, backup_dir: Path, timestamp: str) -> Optional[Dict[str, Any]]:
  279. """构建单个备份信息"""
  280. try:
  281. collection_file = backup_dir / f"langchain_pg_collection_{timestamp}.csv"
  282. embedding_file = backup_dir / f"langchain_pg_embedding_{timestamp}.csv"
  283. log_file = backup_dir / "vector_backup_log.txt"
  284. # 检查文件存在性
  285. if not (collection_file.exists() and embedding_file.exists()):
  286. return None
  287. # 获取文件大小
  288. collection_size = self._format_file_size(collection_file.stat().st_size)
  289. embedding_size = self._format_file_size(embedding_file.stat().st_size)
  290. # 解析备份日期
  291. backup_date = self._parse_timestamp_to_date(timestamp)
  292. return {
  293. "timestamp": timestamp,
  294. "collection_file": collection_file.name,
  295. "embedding_file": embedding_file.name,
  296. "collection_size": collection_size,
  297. "embedding_size": embedding_size,
  298. "backup_date": backup_date,
  299. "has_log": log_file.exists(),
  300. "log_file": log_file.name if log_file.exists() else None
  301. }
  302. except Exception as e:
  303. self.logger.warning(f"构建备份信息失败: {e}")
  304. return None
  305. def _parse_timestamp_to_date(self, timestamp: str) -> str:
  306. """将时间戳转换为可读日期格式"""
  307. try:
  308. # 解析格式:20250722_010318
  309. dt = datetime.strptime(timestamp, "%Y%m%d_%H%M%S")
  310. return dt.strftime("%Y-%m-%d %H:%M:%S")
  311. except Exception:
  312. return timestamp
  313. def _build_summary(self, backup_locations: List[Dict], scan_start_time: datetime) -> Dict[str, Any]:
  314. """构建汇总信息"""
  315. total_backup_sets = sum(len(loc["backups"]) for loc in backup_locations)
  316. global_backups = sum(len(loc["backups"]) for loc in backup_locations if loc["type"] == "global")
  317. task_backups = total_backup_sets - global_backups
  318. return {
  319. "total_locations": len(backup_locations),
  320. "total_backup_sets": total_backup_sets,
  321. "global_backups": global_backups,
  322. "task_backups": task_backups,
  323. "scan_time": scan_start_time.isoformat()
  324. }
  325. def _restore_table_from_csv(self, table_name: str, csv_file: Path) -> Dict[str, Any]:
  326. """从CSV文件恢复单个表 - 使用COPY FROM STDIN"""
  327. try:
  328. start_time = time.time()
  329. with self.get_connection() as conn:
  330. with conn.cursor() as cursor:
  331. # 检查是否是embedding表,需要特殊处理JSON格式
  332. if table_name == "langchain_pg_embedding":
  333. self._restore_embedding_table_with_json_fix(cursor, csv_file)
  334. else:
  335. # 其他表直接使用COPY FROM STDIN
  336. with open(csv_file, 'r', encoding='utf-8') as f:
  337. # 使用CSV HEADER选项自动跳过表头,无需手动next(f)
  338. cursor.copy_expert(
  339. f"COPY {table_name} FROM STDIN WITH (FORMAT CSV, HEADER)",
  340. f
  341. )
  342. # 验证导入结果
  343. cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
  344. rows_restored = cursor.fetchone()[0]
  345. duration = time.time() - start_time
  346. file_size = csv_file.stat().st_size
  347. return {
  348. "success": True,
  349. "source_file": csv_file.name,
  350. "rows_restored": rows_restored,
  351. "file_size": self._format_file_size(file_size),
  352. "duration": duration
  353. }
  354. except Exception as e:
  355. return {
  356. "success": False,
  357. "source_file": csv_file.name,
  358. "error": str(e)
  359. }
  360. def _truncate_table(self, table_name: str) -> Dict[str, Any]:
  361. """清空指定表"""
  362. try:
  363. start_time = time.time()
  364. with self.get_connection() as conn:
  365. with conn.cursor() as cursor:
  366. # 获取清空前的行数
  367. cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
  368. rows_before = cursor.fetchone()[0]
  369. # 执行TRUNCATE
  370. cursor.execute(f"TRUNCATE TABLE {table_name}")
  371. # 验证清空结果
  372. cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
  373. rows_after = cursor.fetchone()[0]
  374. duration = time.time() - start_time
  375. if rows_after == 0:
  376. return {
  377. "success": True,
  378. "rows_before": rows_before,
  379. "rows_after": rows_after,
  380. "duration": duration
  381. }
  382. else:
  383. raise Exception(f"清空失败,表中仍有 {rows_after} 行数据")
  384. except Exception as e:
  385. return {
  386. "success": False,
  387. "error": str(e)
  388. }
  389. def _format_file_size(self, size_bytes: int) -> str:
  390. """格式化文件大小显示"""
  391. if size_bytes == 0:
  392. return "0 B"
  393. size_names = ["B", "KB", "MB", "GB"]
  394. i = 0
  395. size = float(size_bytes)
  396. while size >= 1024.0 and i < len(size_names) - 1:
  397. size /= 1024.0
  398. i += 1
  399. return f"{size:.1f} {size_names[i]}"
  400. def _restore_embedding_table_with_json_fix(self, cursor, csv_file: Path):
  401. """恢复embedding表,修复cmetadata列的JSON格式问题"""
  402. import csv
  403. import json
  404. import ast
  405. import io
  406. # 读取CSV并修复JSON格式
  407. corrected_data = io.StringIO()
  408. with open(csv_file, 'r', encoding='utf-8') as f:
  409. reader = csv.reader(f)
  410. writer = csv.writer(corrected_data)
  411. # 处理表头
  412. header = next(reader)
  413. writer.writerow(header)
  414. # 找到cmetadata列的索引
  415. try:
  416. cmetadata_index = header.index('cmetadata')
  417. except ValueError:
  418. # 如果没有cmetadata列,直接使用原始CSV
  419. corrected_data.seek(0)
  420. corrected_data.truncate(0)
  421. f.seek(0)
  422. corrected_data.write(f.read())
  423. corrected_data.seek(0)
  424. cursor.copy_expert(
  425. "COPY langchain_pg_embedding FROM STDIN WITH (FORMAT CSV, HEADER)",
  426. corrected_data
  427. )
  428. return
  429. # 处理数据行
  430. for row in reader:
  431. if len(row) > cmetadata_index and row[cmetadata_index]:
  432. try:
  433. # 尝试将Python字典格式转换为JSON格式
  434. # 如果已经是JSON格式,json.loads会成功
  435. if row[cmetadata_index].startswith('{') and row[cmetadata_index].endswith('}'):
  436. try:
  437. # 先尝试作为JSON解析
  438. json.loads(row[cmetadata_index])
  439. # 已经是有效JSON,不需要转换
  440. except json.JSONDecodeError:
  441. # 不是有效JSON,尝试作为Python字典解析并转换
  442. try:
  443. python_dict = ast.literal_eval(row[cmetadata_index])
  444. row[cmetadata_index] = json.dumps(python_dict, ensure_ascii=False)
  445. except (ValueError, SyntaxError):
  446. # 如果都失败了,记录错误但继续处理
  447. self.logger.warning(f"无法解析cmetadata: {row[cmetadata_index]}")
  448. except Exception as e:
  449. self.logger.warning(f"处理cmetadata时出错: {e}")
  450. writer.writerow(row)
  451. # 使用修复后的数据进行导入
  452. corrected_data.seek(0)
  453. cursor.copy_expert(
  454. "COPY langchain_pg_embedding FROM STDIN WITH (FORMAT CSV, HEADER)",
  455. corrected_data
  456. )