simple_file_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. """
  2. Data Pipeline API 简化文件管理器
  3. 提供简单的文件列表、下载和上传功能,无压缩等复杂功能
  4. """
  5. import os
  6. from pathlib import Path
  7. from typing import Dict, Any, List, BinaryIO, Union
  8. from datetime import datetime
  9. import tempfile
  10. import shutil
  11. import logging
  12. class SimpleFileManager:
  13. """简化的文件管理器"""
  14. def __init__(self, base_output_dir: str = None):
  15. if base_output_dir is None:
  16. # 获取项目根目录的绝对路径
  17. from pathlib import Path
  18. project_root = Path(__file__).parent.parent.parent
  19. base_output_dir = str(project_root / "data_pipeline" / "training_data")
  20. """
  21. 初始化文件管理器
  22. Args:
  23. base_output_dir: 基础输出目录
  24. """
  25. self.base_output_dir = Path(base_output_dir)
  26. # 使用简单的控制台日志,不使用文件日志
  27. self.logger = logging.getLogger("SimpleFileManager")
  28. self.logger.setLevel(logging.INFO)
  29. # 确保基础目录存在
  30. self.base_output_dir.mkdir(parents=True, exist_ok=True)
  31. def get_task_directory(self, task_id: str) -> Path:
  32. """获取任务目录路径"""
  33. return self.base_output_dir / task_id
  34. def create_task_directory(self, task_id: str) -> bool:
  35. """创建任务目录"""
  36. try:
  37. task_dir = self.get_task_directory(task_id)
  38. task_dir.mkdir(parents=True, exist_ok=True)
  39. self.logger.info(f"任务目录已创建: {task_dir}")
  40. return True
  41. except Exception as e:
  42. self.logger.error(f"创建任务目录失败: {e}")
  43. return False
  44. def get_task_files(self, task_id: str) -> List[Dict[str, Any]]:
  45. """获取任务目录下的所有文件信息"""
  46. try:
  47. task_dir = self.get_task_directory(task_id)
  48. if not task_dir.exists():
  49. return []
  50. files_info = []
  51. for file_path in task_dir.iterdir():
  52. if file_path.is_file():
  53. file_info = self._get_file_info(file_path)
  54. files_info.append(file_info)
  55. # 按修改时间排序(最新的在前)
  56. files_info.sort(key=lambda x: x['modified_at'], reverse=True)
  57. return files_info
  58. except Exception as e:
  59. self.logger.error(f"获取任务文件失败: {e}")
  60. return []
  61. def _get_file_info(self, file_path: Path) -> Dict[str, Any]:
  62. """获取单个文件的基本信息"""
  63. try:
  64. stat = file_path.stat()
  65. return {
  66. "file_name": file_path.name,
  67. "file_path": str(file_path),
  68. "file_type": self._determine_file_type(file_path),
  69. "file_size": stat.st_size,
  70. "file_size_formatted": self._format_file_size(stat.st_size),
  71. "created_at": datetime.fromtimestamp(stat.st_ctime),
  72. "modified_at": datetime.fromtimestamp(stat.st_mtime),
  73. "is_readable": os.access(file_path, os.R_OK)
  74. }
  75. except Exception as e:
  76. self.logger.error(f"获取文件信息失败: {e}")
  77. return {
  78. "file_name": file_path.name,
  79. "file_path": str(file_path),
  80. "file_type": "unknown",
  81. "file_size": 0,
  82. "file_size_formatted": "0 B",
  83. "created_at": datetime.now(),
  84. "modified_at": datetime.now(),
  85. "is_readable": False
  86. }
  87. def _determine_file_type(self, file_path: Path) -> str:
  88. """根据文件扩展名确定文件类型"""
  89. suffix = file_path.suffix.lower()
  90. type_mapping = {
  91. '.ddl': 'ddl',
  92. '.sql': 'sql',
  93. '.md': 'markdown',
  94. '.markdown': 'markdown',
  95. '.json': 'json',
  96. '.txt': 'text',
  97. '.log': 'log'
  98. }
  99. return type_mapping.get(suffix, 'other')
  100. def _format_file_size(self, size_bytes: int) -> str:
  101. """格式化文件大小显示"""
  102. if size_bytes == 0:
  103. return "0 B"
  104. size_names = ["B", "KB", "MB", "GB"]
  105. i = 0
  106. size = float(size_bytes)
  107. while size >= 1024.0 and i < len(size_names) - 1:
  108. size /= 1024.0
  109. i += 1
  110. return f"{size:.1f} {size_names[i]}"
  111. def get_file_path(self, task_id: str, file_name: str) -> Path:
  112. """获取文件的完整路径"""
  113. task_dir = self.get_task_directory(task_id)
  114. return task_dir / file_name
  115. def file_exists(self, task_id: str, file_name: str) -> bool:
  116. """检查文件是否存在"""
  117. file_path = self.get_file_path(task_id, file_name)
  118. return file_path.exists() and file_path.is_file()
  119. def is_file_safe(self, task_id: str, file_name: str) -> bool:
  120. """检查文件路径是否安全(防止路径遍历攻击)"""
  121. try:
  122. task_dir = self.get_task_directory(task_id)
  123. file_path = task_dir / file_name
  124. # 确保文件在任务目录内
  125. file_path.resolve().relative_to(task_dir.resolve())
  126. return True
  127. except ValueError:
  128. return False
  129. def get_directory_info(self, task_id: str) -> Dict[str, Any]:
  130. """获取任务目录信息"""
  131. try:
  132. task_dir = self.get_task_directory(task_id)
  133. if not task_dir.exists():
  134. return {
  135. "exists": False,
  136. "directory_path": str(task_dir),
  137. "total_files": 0,
  138. "total_size": 0,
  139. "total_size_formatted": "0 B"
  140. }
  141. files = self.get_task_files(task_id)
  142. total_size = sum(file_info['file_size'] for file_info in files)
  143. return {
  144. "exists": True,
  145. "directory_path": str(task_dir),
  146. "total_files": len(files),
  147. "total_size": total_size,
  148. "total_size_formatted": self._format_file_size(total_size)
  149. }
  150. except Exception as e:
  151. self.logger.error(f"获取目录信息失败: {e}")
  152. return {
  153. "exists": False,
  154. "directory_path": str(self.get_task_directory(task_id)),
  155. "total_files": 0,
  156. "total_size": 0,
  157. "total_size_formatted": "0 B"
  158. }
  159. def upload_table_list_file(self, task_id: str, file_obj: Union[BinaryIO, bytes], filename: str = None) -> Dict[str, Any]:
  160. """
  161. 上传表清单文件到指定任务目录
  162. Args:
  163. task_id: 任务ID
  164. file_obj: 文件对象(Flask的FileStorage)或文件内容(字节流)
  165. filename: 原始文件名(可选,仅用于日志记录)
  166. Returns:
  167. Dict: 上传结果,包含filename、file_size、file_size_formatted、upload_time等
  168. Raises:
  169. ValueError: 文件验证失败(文件太大、空文件、格式错误等)
  170. FileNotFoundError: 任务目录不存在且无法创建
  171. IOError: 文件操作失败
  172. """
  173. try:
  174. # 获取配置
  175. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  176. upload_config = SCHEMA_TOOLS_CONFIG.get("file_upload", {})
  177. max_file_size_mb = upload_config.get("max_file_size_mb", 2)
  178. max_size = max_file_size_mb * 1024 * 1024 # 转换为字节
  179. target_filename = upload_config.get("target_filename", "table_list.txt")
  180. allowed_extensions = upload_config.get("allowed_extensions", ["txt"])
  181. # 处理文件对象或字节流
  182. if isinstance(file_obj, bytes):
  183. file_content = file_obj
  184. original_filename = filename or "uploaded_file.txt"
  185. else:
  186. # Flask FileStorage对象
  187. if hasattr(file_obj, 'filename') and file_obj.filename:
  188. original_filename = file_obj.filename
  189. else:
  190. original_filename = filename or "uploaded_file.txt"
  191. # 验证文件扩展名 - 修复:统一格式进行比较
  192. file_ext = Path(original_filename).suffix.lower().lstrip('.')
  193. if file_ext not in allowed_extensions:
  194. raise ValueError(f"不支持的文件类型,仅支持: {', '.join(['.' + ext for ext in allowed_extensions])}")
  195. # 读取文件内容并验证大小
  196. file_content = b''
  197. chunk_size = 8192
  198. total_size = 0
  199. while True:
  200. chunk = file_obj.read(chunk_size)
  201. if not chunk:
  202. break
  203. total_size += len(chunk)
  204. if total_size > max_size:
  205. raise ValueError(f"文件大小超过限制: {max_file_size_mb}MB")
  206. file_content += chunk
  207. # 验证文件内容为空
  208. if len(file_content) == 0:
  209. raise ValueError("文件为空,请选择有效的表清单文件")
  210. # 验证文件内容(简单检查是否为文本文件)
  211. self._validate_table_list_content_simple(file_content)
  212. # 确保任务目录存在
  213. task_dir = self.get_task_directory(task_id)
  214. if not task_dir.exists():
  215. task_dir.mkdir(parents=True, exist_ok=True)
  216. self.logger.info(f"创建任务目录: {task_dir}")
  217. # 确定目标文件路径
  218. target_file_path = task_dir / target_filename
  219. # 保存文件
  220. with open(target_file_path, 'wb') as f:
  221. f.write(file_content)
  222. # 验证文件是否成功写入
  223. if not target_file_path.exists():
  224. raise IOError("文件保存失败")
  225. # 获取文件信息
  226. file_stat = target_file_path.stat()
  227. upload_time = datetime.fromtimestamp(file_stat.st_mtime)
  228. self.logger.info(f"成功上传表清单文件到任务 {task_id}: {target_file_path}")
  229. return {
  230. "filename": target_filename,
  231. "original_filename": original_filename,
  232. "file_size": file_stat.st_size,
  233. "file_size_formatted": self._format_file_size(file_stat.st_size),
  234. "upload_time": upload_time,
  235. "target_path": str(target_file_path)
  236. }
  237. except Exception as e:
  238. self.logger.error(f"上传表清单文件失败: {e}")
  239. raise
  240. def _validate_table_list_content_simple(self, file_content: bytes) -> None:
  241. """
  242. 简单验证表清单文件内容
  243. Args:
  244. file_content: 文件内容(字节流)
  245. Raises:
  246. ValueError: 文件内容验证失败
  247. """
  248. try:
  249. # 尝试解码文件内容
  250. try:
  251. content = file_content.decode('utf-8')
  252. except UnicodeDecodeError:
  253. try:
  254. content = file_content.decode('gbk')
  255. except UnicodeDecodeError:
  256. raise ValueError("文件编码错误,请确保文件为UTF-8或GBK格式")
  257. # 检查文件是否为空
  258. if not content.strip():
  259. raise ValueError("表清单文件为空")
  260. # 简单验证:检查是否包含至少一个非空行
  261. lines = [line.strip() for line in content.split('\n') if line.strip()]
  262. if not lines:
  263. raise ValueError("表清单文件不包含有效的表名")
  264. # 可选:验证表名格式(避免SQL注入等安全问题)
  265. import re
  266. table_name_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?$')
  267. invalid_tables = []
  268. for line in lines[:10]: # 只检查前10行以避免过度验证
  269. # 忽略注释行
  270. if line.startswith('#') or line.startswith('--'):
  271. continue
  272. # 检查表名格式
  273. if not table_name_pattern.match(line):
  274. invalid_tables.append(line)
  275. if invalid_tables:
  276. raise ValueError(f"表清单文件包含无效的表名格式: {', '.join(invalid_tables[:3])}")
  277. except ValueError:
  278. raise
  279. except Exception as e:
  280. raise ValueError(f"文件内容验证失败: {str(e)}")
  281. def _validate_table_list_content(self, file_content: bytes, config: Dict[str, Any]) -> Dict[str, Any]:
  282. """
  283. 验证表清单文件内容
  284. Args:
  285. file_content: 文件内容(字节流)
  286. config: 文件上传配置
  287. Returns:
  288. Dict: 验证结果
  289. """
  290. try:
  291. # 解码文件内容
  292. encoding = config.get("encoding", "utf-8")
  293. try:
  294. content = file_content.decode(encoding)
  295. except UnicodeDecodeError:
  296. # 尝试其他编码
  297. for fallback_encoding in ["gbk", "latin1"]:
  298. try:
  299. content = file_content.decode(fallback_encoding)
  300. self.logger.warning(f"文件编码检测为 {fallback_encoding},建议使用 UTF-8")
  301. break
  302. except UnicodeDecodeError:
  303. continue
  304. else:
  305. return {
  306. "valid": False,
  307. "error": f"无法解码文件内容,请确保文件编码为 {encoding}"
  308. }
  309. # 分析文件内容
  310. lines = content.splitlines()
  311. total_lines = len(lines)
  312. # 过滤空行和注释行
  313. valid_lines = []
  314. comment_lines = 0
  315. empty_lines = 0
  316. for line_num, line in enumerate(lines, 1):
  317. stripped = line.strip()
  318. if not stripped:
  319. empty_lines += 1
  320. elif stripped.startswith('#'):
  321. comment_lines += 1
  322. else:
  323. # 简单验证表名格式
  324. if self._is_valid_table_name(stripped):
  325. valid_lines.append(stripped)
  326. else:
  327. return {
  328. "valid": False,
  329. "error": f"第 {line_num} 行包含无效的表名: {stripped}",
  330. "details": {
  331. "line_number": line_num,
  332. "invalid_content": stripped
  333. }
  334. }
  335. # 检查有效行数
  336. min_lines = config.get("min_lines", 1)
  337. max_lines = config.get("max_lines", 1000)
  338. if len(valid_lines) < min_lines:
  339. return {
  340. "valid": False,
  341. "error": f"文件至少需要包含 {min_lines} 个有效表名,当前只有 {len(valid_lines)} 个",
  342. "details": {
  343. "valid_tables": len(valid_lines),
  344. "min_required": min_lines
  345. }
  346. }
  347. if len(valid_lines) > max_lines:
  348. return {
  349. "valid": False,
  350. "error": f"文件包含的表名数量超过限制,最多允许 {max_lines} 个,当前有 {len(valid_lines)} 个",
  351. "details": {
  352. "valid_tables": len(valid_lines),
  353. "max_allowed": max_lines
  354. }
  355. }
  356. return {
  357. "valid": True,
  358. "details": {
  359. "total_lines": total_lines,
  360. "empty_lines": empty_lines,
  361. "comment_lines": comment_lines,
  362. "valid_tables": len(valid_lines),
  363. "table_names": valid_lines[:10] # 只返回前10个作为预览
  364. }
  365. }
  366. except Exception as e:
  367. return {
  368. "valid": False,
  369. "error": f"文件内容验证失败: {str(e)}"
  370. }
  371. def _is_valid_table_name(self, table_name: str) -> bool:
  372. """
  373. 验证表名格式是否有效
  374. Args:
  375. table_name: 表名
  376. Returns:
  377. bool: 是否有效
  378. """
  379. import re
  380. # 基本的表名格式检查
  381. # 支持: table_name, schema.table_name
  382. pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?$'
  383. return bool(re.match(pattern, table_name))
  384. def get_table_list_file_info(self, task_id: str) -> Dict[str, Any]:
  385. """
  386. 获取任务的表清单文件信息
  387. Args:
  388. task_id: 任务ID
  389. Returns:
  390. Dict: 文件信息或None
  391. """
  392. try:
  393. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  394. upload_config = SCHEMA_TOOLS_CONFIG.get("file_upload", {})
  395. target_filename = upload_config.get("target_filename", "table_list.txt")
  396. file_path = self.get_file_path(task_id, target_filename)
  397. if not file_path.exists():
  398. return {
  399. "exists": False,
  400. "file_name": target_filename,
  401. "expected_path": str(file_path)
  402. }
  403. file_stat = file_path.stat()
  404. # 尝试读取文件内容进行分析
  405. try:
  406. with open(file_path, 'r', encoding='utf-8') as f:
  407. content = f.read()
  408. lines = content.splitlines()
  409. valid_tables = [line.strip() for line in lines
  410. if line.strip() and not line.strip().startswith('#')]
  411. except Exception:
  412. valid_tables = []
  413. return {
  414. "exists": True,
  415. "file_name": target_filename,
  416. "file_path": str(file_path),
  417. "file_size": file_stat.st_size,
  418. "file_size_formatted": self._format_file_size(file_stat.st_size),
  419. "uploaded_at": datetime.fromtimestamp(file_stat.st_mtime).isoformat(),
  420. "created_at": datetime.fromtimestamp(file_stat.st_ctime).isoformat(),
  421. "table_count": len(valid_tables),
  422. "is_readable": os.access(file_path, os.R_OK)
  423. }
  424. except Exception as e:
  425. self.logger.error(f"获取表清单文件信息失败: {e}")
  426. return {
  427. "exists": False,
  428. "error": str(e)
  429. }