table_inspector_api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. """
  2. 表检查API模块
  3. 复用data_pipeline中的数据库连接和查询功能,提供独立的表信息查询API
  4. """
  5. import asyncio
  6. import asyncpg
  7. import logging
  8. from typing import List, Optional, Dict, Any
  9. from data_pipeline.tools.database_inspector import DatabaseInspectorTool
  10. class TableInspectorAPI:
  11. """表检查API类,复用现有的数据库功能"""
  12. def __init__(self):
  13. self.logger = logging.getLogger("TableInspectorAPI")
  14. self.db_inspector = None
  15. async def get_tables_list(self, db_connection: str, schema: Optional[str] = None, table_name_pattern: Optional[str] = None) -> List[str]:
  16. """
  17. 获取数据库表列表
  18. Args:
  19. db_connection: 完整的PostgreSQL连接字符串
  20. schema: 可选的schema参数,支持多个schema用逗号分隔
  21. 如果为None或空字符串,则只返回public schema的表
  22. table_name_pattern: 可选的表名模式匹配,支持通配符
  23. - ods_* : 以"ods_"开头的表
  24. - *_dim : 以"_dim"结尾的表
  25. - *fact* : 包含"fact"的表
  26. - ods_% : 直接使用SQL LIKE语法
  27. Returns:
  28. 表名列表,格式为 schema.tablename
  29. """
  30. try:
  31. # 创建数据库检查器实例
  32. self.db_inspector = DatabaseInspectorTool(db_connection=db_connection)
  33. # 创建连接池
  34. await self.db_inspector._create_connection_pool()
  35. # 解析schema参数
  36. target_schemas = self._parse_schemas(schema)
  37. # 查询表列表
  38. tables = await self._query_tables(target_schemas, table_name_pattern)
  39. return tables
  40. except Exception as e:
  41. self.logger.error(f"获取表列表失败: {e}")
  42. raise
  43. finally:
  44. # 清理连接池
  45. if self.db_inspector and self.db_inspector.connection_pool:
  46. await self.db_inspector.connection_pool.close()
  47. def _parse_schemas(self, schema: Optional[str]) -> List[str]:
  48. """
  49. 解析schema参数
  50. Args:
  51. schema: schema参数,可以是单个schema或逗号分隔的多个schema
  52. Returns:
  53. schema列表
  54. """
  55. if not schema or schema.strip() == "":
  56. # 如果没有指定schema,默认只查询public schema
  57. return ["public"]
  58. # 解析逗号分隔的schema
  59. schemas = [s.strip() for s in schema.split(",") if s.strip()]
  60. # 如果解析后为空,回退到public
  61. if not schemas:
  62. return ["public"]
  63. return schemas
  64. async def _query_tables(self, schemas: List[str], table_name_pattern: Optional[str] = None) -> List[str]:
  65. """
  66. 查询指定schema中的表
  67. Args:
  68. schemas: schema列表
  69. table_name_pattern: 可选的表名模式匹配,支持通配符
  70. - ods_* : 以"ods_"开头的表
  71. - *_dim : 以"_dim"结尾的表
  72. - *fact* : 包含"fact"的表
  73. - ods_% : 直接使用SQL LIKE语法
  74. Returns:
  75. 表名列表,格式为 schema.tablename
  76. """
  77. tables = []
  78. async with self.db_inspector.connection_pool.acquire() as conn:
  79. for schema in schemas:
  80. # 构建查询语句
  81. if table_name_pattern:
  82. # 转换通配符模式为SQL LIKE语法
  83. sql_pattern = self._convert_wildcard_to_sql_like(table_name_pattern)
  84. query = """
  85. SELECT schemaname, tablename
  86. FROM pg_tables
  87. WHERE schemaname = $1 AND tablename LIKE $2
  88. ORDER BY tablename
  89. """
  90. rows = await conn.fetch(query, schema, sql_pattern)
  91. else:
  92. # 没有表名模式,查询所有表
  93. query = """
  94. SELECT schemaname, tablename
  95. FROM pg_tables
  96. WHERE schemaname = $1
  97. ORDER BY tablename
  98. """
  99. rows = await conn.fetch(query, schema)
  100. # 格式化表名为 schema.tablename
  101. for row in rows:
  102. schema_name = row['schemaname']
  103. table_name = row['tablename']
  104. full_table_name = f"{schema_name}.{table_name}"
  105. tables.append(full_table_name)
  106. # 按名称排序
  107. tables.sort()
  108. pattern_info = f",表名模式: {table_name_pattern}" if table_name_pattern else ""
  109. self.logger.info(f"查询到 {len(tables)} 个表,schemas: {schemas}{pattern_info}")
  110. return tables
  111. async def get_table_ddl(self, db_connection: str, table: str, business_context: str = None, output_type: str = "ddl") -> Dict[str, Any]:
  112. """
  113. 获取表的DDL语句或MD文档
  114. Args:
  115. db_connection: 数据库连接字符串
  116. table: 表名,格式为 schema.tablename
  117. business_context: 业务上下文描述
  118. output_type: 输出类型,支持 "ddl", "md", "both"
  119. Returns:
  120. 包含DDL/MD内容的字典
  121. """
  122. try:
  123. # 解析表名
  124. schema_name, table_name = self._parse_table_name(table)
  125. # 导入必要的模块
  126. from data_pipeline.tools.database_inspector import DatabaseInspectorTool
  127. from data_pipeline.tools.comment_generator import CommentGeneratorTool
  128. from data_pipeline.tools.ddl_generator import DDLGeneratorTool
  129. from data_pipeline.tools.doc_generator import DocGeneratorTool
  130. from data_pipeline.tools.data_sampler import DataSamplerTool
  131. from data_pipeline.utils.data_structures import TableMetadata, TableProcessingContext
  132. from core.vanna_llm_factory import create_vanna_instance
  133. # 创建数据库检查器实例
  134. db_inspector = DatabaseInspectorTool(db_connection=db_connection)
  135. await db_inspector._create_connection_pool()
  136. # 创建表元数据对象
  137. table_metadata = TableMetadata(
  138. table_name=table_name,
  139. schema_name=schema_name,
  140. full_name=f"{schema_name}.{table_name}",
  141. fields=[],
  142. comment=None,
  143. sample_data=[]
  144. )
  145. # 获取全局Vanna实例(仅用于LLM调用,不修改其数据库连接)
  146. from common.vanna_instance import get_vanna_instance
  147. vn = get_vanna_instance()
  148. self.logger.info("使用全局Vanna单例实例进行LLM调用(不修改其数据库连接)")
  149. # 创建处理上下文
  150. context = TableProcessingContext(
  151. table_metadata=table_metadata,
  152. business_context=business_context or "数据库管理系统",
  153. output_dir="/tmp", # 临时目录,API不会真正写文件
  154. pipeline="api_direct", # API直接调用标识
  155. vn=vn,
  156. file_manager=None, # 不需要文件管理器
  157. step_results={}
  158. )
  159. # 第1步:获取表结构信息
  160. self.logger.info(f"开始获取表结构: {table}")
  161. inspect_result = await db_inspector.execute(context)
  162. if not inspect_result.success:
  163. raise Exception(f"获取表结构失败: {inspect_result.error_message}")
  164. # 第2步:获取样例数据(用于生成更好的注释)
  165. self.logger.info("开始获取样例数据")
  166. try:
  167. data_sampler = DataSamplerTool(vn=vn, db_connection=db_connection)
  168. sample_result = await data_sampler.execute(context)
  169. if sample_result.success:
  170. self.logger.info("样例数据获取成功")
  171. else:
  172. self.logger.warning(f"样例数据获取失败: {sample_result.error_message}")
  173. except Exception as e:
  174. self.logger.warning(f"样例数据获取异常: {e}")
  175. # 第3步:生成注释(调用LLM)
  176. if business_context:
  177. self.logger.info("开始生成LLM注释")
  178. try:
  179. comment_generator = CommentGeneratorTool(
  180. vn=vn,
  181. business_context=business_context,
  182. db_connection=db_connection
  183. )
  184. comment_result = await comment_generator.execute(context)
  185. if comment_result.success:
  186. self.logger.info("LLM注释生成成功")
  187. else:
  188. self.logger.warning(f"LLM注释生成失败: {comment_result.error_message}")
  189. except Exception as e:
  190. self.logger.warning(f"LLM注释生成异常: {e}")
  191. # 第4步:根据类型生成输出
  192. result = {}
  193. if output_type in ["ddl", "both"]:
  194. self.logger.info("开始生成DDL")
  195. ddl_generator = DDLGeneratorTool()
  196. ddl_result = await ddl_generator.execute(context)
  197. if ddl_result.success:
  198. result["ddl"] = ddl_result.data.get("ddl_content", "")
  199. # 保存DDL结果供MD生成器使用
  200. context.step_results["ddl_generator"] = ddl_result
  201. else:
  202. raise Exception(f"DDL生成失败: {ddl_result.error_message}")
  203. if output_type in ["md", "both"]:
  204. self.logger.info("开始生成MD文档")
  205. doc_generator = DocGeneratorTool()
  206. # 直接调用MD生成方法,不依赖文件系统
  207. md_content = doc_generator._generate_md_content(
  208. table_metadata,
  209. result.get("ddl", "")
  210. )
  211. result["md"] = md_content
  212. # 添加表信息摘要
  213. result["table_info"] = {
  214. "table_name": table_metadata.table_name,
  215. "schema_name": table_metadata.schema_name,
  216. "full_name": table_metadata.full_name,
  217. "comment": table_metadata.comment,
  218. "field_count": len(table_metadata.fields),
  219. "row_count": table_metadata.row_count,
  220. "table_size": table_metadata.table_size
  221. }
  222. # 添加字段信息
  223. result["fields"] = [
  224. {
  225. "name": field.name,
  226. "type": field.type,
  227. "nullable": field.nullable,
  228. "comment": field.comment,
  229. "is_primary_key": field.is_primary_key,
  230. "is_foreign_key": field.is_foreign_key,
  231. "default_value": field.default_value,
  232. "is_enum": getattr(field, 'is_enum', False),
  233. "enum_values": getattr(field, 'enum_values', [])
  234. }
  235. for field in table_metadata.fields
  236. ]
  237. self.logger.info(f"表DDL生成完成: {table}, 输出类型: {output_type}")
  238. return result
  239. except Exception as e:
  240. self.logger.error(f"获取表DDL失败: {e}")
  241. raise
  242. finally:
  243. # 清理连接池
  244. if 'db_inspector' in locals() and db_inspector.connection_pool:
  245. await db_inspector.connection_pool.close()
  246. def _parse_table_name(self, table: str) -> tuple[str, str]:
  247. """
  248. 解析表名
  249. Args:
  250. table: 表名,格式为 schema.tablename 或 tablename
  251. Returns:
  252. (schema_name, table_name) 元组
  253. """
  254. if "." in table:
  255. parts = table.split(".", 1)
  256. return parts[0], parts[1]
  257. else:
  258. # 如果没有指定schema,默认为public
  259. return "public", table
  260. def _parse_db_connection(self, db_connection: str) -> Dict[str, Any]:
  261. """
  262. 解析PostgreSQL连接字符串
  263. Args:
  264. db_connection: PostgreSQL连接字符串,格式为 postgresql://user:password@host:port/dbname
  265. Returns:
  266. 包含数据库连接参数的字典
  267. """
  268. import re
  269. # 解析连接字符串的正则表达式
  270. pattern = r'postgresql://([^:]+):([^@]+)@([^:]+):(\d+)/(.+)'
  271. match = re.match(pattern, db_connection)
  272. if not match:
  273. raise ValueError(f"无效的PostgreSQL连接字符串格式: {db_connection}")
  274. user, password, host, port, dbname = match.groups()
  275. return {
  276. 'user': user,
  277. 'password': password,
  278. 'host': host,
  279. 'port': int(port),
  280. 'dbname': dbname
  281. }
  282. def _convert_wildcard_to_sql_like(self, pattern: str) -> str:
  283. """
  284. 将通配符模式转换为SQL LIKE语法
  285. Args:
  286. pattern: 通配符模式
  287. - ods_* : 以"ods_"开头的表
  288. - *_dim : 以"_dim"结尾的表
  289. - *fact* : 包含"fact"的表
  290. - ods_% : 直接使用SQL LIKE语法(不转换)
  291. Returns:
  292. SQL LIKE语法的模式字符串
  293. """
  294. if not pattern:
  295. return "%"
  296. # 如果已经是SQL LIKE语法(包含%),直接返回
  297. if "%" in pattern:
  298. return pattern
  299. # 转换通配符*为%
  300. sql_pattern = pattern.replace("*", "%")
  301. # 记录转换日志
  302. if pattern != sql_pattern:
  303. self.logger.debug(f"通配符模式转换: {pattern} -> {sql_pattern}")
  304. return sql_pattern