table_inspector_api.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. """
  2. 表检查API模块
  3. 复用data_pipeline中的数据库连接和查询功能,提供独立的表信息查询API
  4. """
  5. import asyncio
  6. import asyncpg
  7. import logging
  8. from typing import List, Optional, Dict, Any
  9. from data_pipeline.tools.database_inspector import DatabaseInspectorTool
  10. class TableInspectorAPI:
  11. """表检查API类,复用现有的数据库功能"""
  12. def __init__(self):
  13. self.logger = logging.getLogger("TableInspectorAPI")
  14. self.db_inspector = None
  15. async def get_tables_list(self, db_connection: str, schema: Optional[str] = None) -> List[str]:
  16. """
  17. 获取数据库表列表
  18. Args:
  19. db_connection: 完整的PostgreSQL连接字符串
  20. schema: 可选的schema参数,支持多个schema用逗号分隔
  21. 如果为None或空字符串,则只返回public schema的表
  22. Returns:
  23. 表名列表,格式为 schema.tablename
  24. """
  25. try:
  26. # 创建数据库检查器实例
  27. self.db_inspector = DatabaseInspectorTool(db_connection=db_connection)
  28. # 创建连接池
  29. await self.db_inspector._create_connection_pool()
  30. # 解析schema参数
  31. target_schemas = self._parse_schemas(schema)
  32. # 查询表列表
  33. tables = await self._query_tables(target_schemas)
  34. return tables
  35. except Exception as e:
  36. self.logger.error(f"获取表列表失败: {e}")
  37. raise
  38. finally:
  39. # 清理连接池
  40. if self.db_inspector and self.db_inspector.connection_pool:
  41. await self.db_inspector.connection_pool.close()
  42. def _parse_schemas(self, schema: Optional[str]) -> List[str]:
  43. """
  44. 解析schema参数
  45. Args:
  46. schema: schema参数,可以是单个schema或逗号分隔的多个schema
  47. Returns:
  48. schema列表
  49. """
  50. if not schema or schema.strip() == "":
  51. # 如果没有指定schema,默认只查询public schema
  52. return ["public"]
  53. # 解析逗号分隔的schema
  54. schemas = [s.strip() for s in schema.split(",") if s.strip()]
  55. # 如果解析后为空,回退到public
  56. if not schemas:
  57. return ["public"]
  58. return schemas
  59. async def _query_tables(self, schemas: List[str]) -> List[str]:
  60. """
  61. 查询指定schema中的表
  62. Args:
  63. schemas: schema列表
  64. Returns:
  65. 表名列表,格式为 schema.tablename
  66. """
  67. tables = []
  68. async with self.db_inspector.connection_pool.acquire() as conn:
  69. for schema in schemas:
  70. # 查询指定schema中的表
  71. query = """
  72. SELECT schemaname, tablename
  73. FROM pg_tables
  74. WHERE schemaname = $1
  75. ORDER BY tablename
  76. """
  77. rows = await conn.fetch(query, schema)
  78. # 格式化表名为 schema.tablename
  79. for row in rows:
  80. schema_name = row['schemaname']
  81. table_name = row['tablename']
  82. full_table_name = f"{schema_name}.{table_name}"
  83. tables.append(full_table_name)
  84. # 按名称排序
  85. tables.sort()
  86. self.logger.info(f"查询到 {len(tables)} 个表,schemas: {schemas}")
  87. return tables
  88. async def get_table_ddl(self, db_connection: str, table: str, business_context: str = None, output_type: str = "ddl") -> Dict[str, Any]:
  89. """
  90. 获取表的DDL语句或MD文档
  91. Args:
  92. db_connection: 数据库连接字符串
  93. table: 表名,格式为 schema.tablename
  94. business_context: 业务上下文描述
  95. output_type: 输出类型,支持 "ddl", "md", "both"
  96. Returns:
  97. 包含DDL/MD内容的字典
  98. """
  99. try:
  100. # 解析表名
  101. schema_name, table_name = self._parse_table_name(table)
  102. # 导入必要的模块
  103. from data_pipeline.tools.database_inspector import DatabaseInspectorTool
  104. from data_pipeline.tools.comment_generator import CommentGeneratorTool
  105. from data_pipeline.tools.ddl_generator import DDLGeneratorTool
  106. from data_pipeline.tools.doc_generator import DocGeneratorTool
  107. from data_pipeline.tools.data_sampler import DataSamplerTool
  108. from data_pipeline.utils.data_structures import TableMetadata, TableProcessingContext
  109. from core.vanna_llm_factory import create_vanna_instance
  110. # 创建数据库检查器实例
  111. db_inspector = DatabaseInspectorTool(db_connection=db_connection)
  112. await db_inspector._create_connection_pool()
  113. # 创建表元数据对象
  114. table_metadata = TableMetadata(
  115. table_name=table_name,
  116. schema_name=schema_name,
  117. full_name=f"{schema_name}.{table_name}",
  118. fields=[],
  119. comment=None,
  120. sample_data=[]
  121. )
  122. # 获取全局Vanna实例(仅用于LLM调用,不修改其数据库连接)
  123. from common.vanna_instance import get_vanna_instance
  124. vn = get_vanna_instance()
  125. self.logger.info("使用全局Vanna单例实例进行LLM调用(不修改其数据库连接)")
  126. # 创建处理上下文
  127. context = TableProcessingContext(
  128. table_metadata=table_metadata,
  129. business_context=business_context or "数据库管理系统",
  130. output_dir="/tmp", # 临时目录,API不会真正写文件
  131. pipeline="api_direct", # API直接调用标识
  132. vn=vn,
  133. file_manager=None, # 不需要文件管理器
  134. step_results={}
  135. )
  136. # 第1步:获取表结构信息
  137. self.logger.info(f"开始获取表结构: {table}")
  138. inspect_result = await db_inspector.execute(context)
  139. if not inspect_result.success:
  140. raise Exception(f"获取表结构失败: {inspect_result.error_message}")
  141. # 第2步:获取样例数据(用于生成更好的注释)
  142. self.logger.info("开始获取样例数据")
  143. try:
  144. data_sampler = DataSamplerTool(vn=vn, db_connection=db_connection)
  145. sample_result = await data_sampler.execute(context)
  146. if sample_result.success:
  147. self.logger.info("样例数据获取成功")
  148. else:
  149. self.logger.warning(f"样例数据获取失败: {sample_result.error_message}")
  150. except Exception as e:
  151. self.logger.warning(f"样例数据获取异常: {e}")
  152. # 第3步:生成注释(调用LLM)
  153. if business_context:
  154. self.logger.info("开始生成LLM注释")
  155. try:
  156. comment_generator = CommentGeneratorTool(
  157. vn=vn,
  158. business_context=business_context,
  159. db_connection=db_connection
  160. )
  161. comment_result = await comment_generator.execute(context)
  162. if comment_result.success:
  163. self.logger.info("LLM注释生成成功")
  164. else:
  165. self.logger.warning(f"LLM注释生成失败: {comment_result.error_message}")
  166. except Exception as e:
  167. self.logger.warning(f"LLM注释生成异常: {e}")
  168. # 第4步:根据类型生成输出
  169. result = {}
  170. if output_type in ["ddl", "both"]:
  171. self.logger.info("开始生成DDL")
  172. ddl_generator = DDLGeneratorTool()
  173. ddl_result = await ddl_generator.execute(context)
  174. if ddl_result.success:
  175. result["ddl"] = ddl_result.data.get("ddl_content", "")
  176. # 保存DDL结果供MD生成器使用
  177. context.step_results["ddl_generator"] = ddl_result
  178. else:
  179. raise Exception(f"DDL生成失败: {ddl_result.error_message}")
  180. if output_type in ["md", "both"]:
  181. self.logger.info("开始生成MD文档")
  182. doc_generator = DocGeneratorTool()
  183. # 直接调用MD生成方法,不依赖文件系统
  184. md_content = doc_generator._generate_md_content(
  185. table_metadata,
  186. result.get("ddl", "")
  187. )
  188. result["md"] = md_content
  189. # 添加表信息摘要
  190. result["table_info"] = {
  191. "table_name": table_metadata.table_name,
  192. "schema_name": table_metadata.schema_name,
  193. "full_name": table_metadata.full_name,
  194. "comment": table_metadata.comment,
  195. "field_count": len(table_metadata.fields),
  196. "row_count": table_metadata.row_count,
  197. "table_size": table_metadata.table_size
  198. }
  199. # 添加字段信息
  200. result["fields"] = [
  201. {
  202. "name": field.name,
  203. "type": field.type,
  204. "nullable": field.nullable,
  205. "comment": field.comment,
  206. "is_primary_key": field.is_primary_key,
  207. "is_foreign_key": field.is_foreign_key,
  208. "default_value": field.default_value,
  209. "is_enum": getattr(field, 'is_enum', False),
  210. "enum_values": getattr(field, 'enum_values', [])
  211. }
  212. for field in table_metadata.fields
  213. ]
  214. self.logger.info(f"表DDL生成完成: {table}, 输出类型: {output_type}")
  215. return result
  216. except Exception as e:
  217. self.logger.error(f"获取表DDL失败: {e}")
  218. raise
  219. finally:
  220. # 清理连接池
  221. if 'db_inspector' in locals() and db_inspector.connection_pool:
  222. await db_inspector.connection_pool.close()
  223. def _parse_table_name(self, table: str) -> tuple[str, str]:
  224. """
  225. 解析表名
  226. Args:
  227. table: 表名,格式为 schema.tablename 或 tablename
  228. Returns:
  229. (schema_name, table_name) 元组
  230. """
  231. if "." in table:
  232. parts = table.split(".", 1)
  233. return parts[0], parts[1]
  234. else:
  235. # 如果没有指定schema,默认为public
  236. return "public", table
  237. def _parse_db_connection(self, db_connection: str) -> Dict[str, Any]:
  238. """
  239. 解析PostgreSQL连接字符串
  240. Args:
  241. db_connection: PostgreSQL连接字符串,格式为 postgresql://user:password@host:port/dbname
  242. Returns:
  243. 包含数据库连接参数的字典
  244. """
  245. import re
  246. # 解析连接字符串的正则表达式
  247. pattern = r'postgresql://([^:]+):([^@]+)@([^:]+):(\d+)/(.+)'
  248. match = re.match(pattern, db_connection)
  249. if not match:
  250. raise ValueError(f"无效的PostgreSQL连接字符串格式: {db_connection}")
  251. user, password, host, port, dbname = match.groups()
  252. return {
  253. 'user': user,
  254. 'password': password,
  255. 'host': host,
  256. 'port': int(port),
  257. 'dbname': dbname
  258. }