database_inspector.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import asyncio
  2. import asyncpg
  3. from typing import List, Dict, Any, Optional
  4. from schema_tools.tools.base import BaseTool, ToolRegistry
  5. from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo, TableMetadata
  6. @ToolRegistry.register("database_inspector")
  7. class DatabaseInspectorTool(BaseTool):
  8. """数据库元数据检查工具"""
  9. needs_llm = False
  10. tool_name = "数据库检查器"
  11. def __init__(self, **kwargs):
  12. super().__init__(**kwargs)
  13. self.db_connection = kwargs.get('db_connection')
  14. self.connection_pool = None
  15. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  16. """执行数据库元数据检查"""
  17. try:
  18. # 建立数据库连接
  19. if not self.connection_pool:
  20. await self._create_connection_pool()
  21. table_name = context.table_metadata.table_name
  22. schema_name = context.table_metadata.schema_name
  23. # 获取表的基本信息
  24. table_info = await self._get_table_info(schema_name, table_name)
  25. if not table_info:
  26. return ProcessingResult(
  27. success=False,
  28. error_message=f"表 {schema_name}.{table_name} 不存在或无权限访问"
  29. )
  30. # 获取字段信息
  31. fields = await self._get_table_fields(schema_name, table_name)
  32. # 获取表注释
  33. table_comment = await self._get_table_comment(schema_name, table_name)
  34. # 获取表统计信息
  35. stats = await self._get_table_statistics(schema_name, table_name)
  36. # 更新表元数据
  37. context.table_metadata.original_comment = table_comment
  38. context.table_metadata.comment = table_comment
  39. context.table_metadata.fields = fields
  40. context.table_metadata.row_count = stats.get('row_count')
  41. context.table_metadata.table_size = stats.get('table_size')
  42. return ProcessingResult(
  43. success=True,
  44. data={
  45. 'fields_count': len(fields),
  46. 'table_comment': table_comment,
  47. 'row_count': stats.get('row_count'),
  48. 'table_size': stats.get('table_size')
  49. },
  50. metadata={'tool': self.tool_name}
  51. )
  52. except Exception as e:
  53. self.logger.exception(f"数据库检查失败")
  54. return ProcessingResult(
  55. success=False,
  56. error_message=f"数据库检查失败: {str(e)}"
  57. )
  58. async def _create_connection_pool(self):
  59. """创建数据库连接池"""
  60. try:
  61. self.connection_pool = await asyncpg.create_pool(
  62. self.db_connection,
  63. min_size=1,
  64. max_size=5,
  65. command_timeout=30
  66. )
  67. self.logger.info("数据库连接池创建成功")
  68. except Exception as e:
  69. self.logger.error(f"创建数据库连接池失败: {e}")
  70. raise
  71. async def _get_table_info(self, schema_name: str, table_name: str) -> Optional[Dict]:
  72. """获取表基本信息"""
  73. query = """
  74. SELECT schemaname, tablename, tableowner, tablespace, hasindexes, hasrules, hastriggers
  75. FROM pg_tables
  76. WHERE schemaname = $1 AND tablename = $2
  77. """
  78. async with self.connection_pool.acquire() as conn:
  79. result = await conn.fetchrow(query, schema_name, table_name)
  80. return dict(result) if result else None
  81. async def _get_table_fields(self, schema_name: str, table_name: str) -> List[FieldInfo]:
  82. """获取表字段信息"""
  83. query = """
  84. SELECT
  85. c.column_name,
  86. c.data_type,
  87. c.is_nullable,
  88. c.column_default,
  89. c.character_maximum_length,
  90. c.numeric_precision,
  91. c.numeric_scale,
  92. pd.description as column_comment,
  93. CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END as is_primary_key,
  94. CASE WHEN fk.column_name IS NOT NULL THEN true ELSE false END as is_foreign_key
  95. FROM information_schema.columns c
  96. LEFT JOIN pg_description pd ON pd.objsubid = c.ordinal_position
  97. AND pd.objoid = (
  98. SELECT oid FROM pg_class
  99. WHERE relname = c.table_name
  100. AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = c.table_schema)
  101. )
  102. LEFT JOIN (
  103. SELECT ku.column_name
  104. FROM information_schema.table_constraints tc
  105. JOIN information_schema.key_column_usage ku ON tc.constraint_name = ku.constraint_name
  106. WHERE tc.table_schema = $1 AND tc.table_name = $2 AND tc.constraint_type = 'PRIMARY KEY'
  107. ) pk ON pk.column_name = c.column_name
  108. LEFT JOIN (
  109. SELECT ku.column_name
  110. FROM information_schema.table_constraints tc
  111. JOIN information_schema.key_column_usage ku ON tc.constraint_name = ku.constraint_name
  112. WHERE tc.table_schema = $1 AND tc.table_name = $2 AND tc.constraint_type = 'FOREIGN KEY'
  113. ) fk ON fk.column_name = c.column_name
  114. WHERE c.table_schema = $1 AND c.table_name = $2
  115. ORDER BY c.ordinal_position
  116. """
  117. fields = []
  118. async with self.connection_pool.acquire() as conn:
  119. rows = await conn.fetch(query, schema_name, table_name)
  120. for row in rows:
  121. field = FieldInfo(
  122. name=row['column_name'],
  123. type=row['data_type'],
  124. nullable=row['is_nullable'] == 'YES',
  125. default_value=row['column_default'],
  126. original_comment=row['column_comment'],
  127. comment=row['column_comment'],
  128. is_primary_key=row['is_primary_key'],
  129. is_foreign_key=row['is_foreign_key'],
  130. max_length=row['character_maximum_length'],
  131. precision=row['numeric_precision'],
  132. scale=row['numeric_scale']
  133. )
  134. fields.append(field)
  135. return fields
  136. async def _get_table_comment(self, schema_name: str, table_name: str) -> Optional[str]:
  137. """获取表注释"""
  138. query = """
  139. SELECT obj_description(oid) as table_comment
  140. FROM pg_class
  141. WHERE relname = $2
  142. AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $1)
  143. """
  144. async with self.connection_pool.acquire() as conn:
  145. result = await conn.fetchval(query, schema_name, table_name)
  146. return result
  147. async def _get_table_statistics(self, schema_name: str, table_name: str) -> Dict[str, Any]:
  148. """获取表统计信息"""
  149. stats_query = """
  150. SELECT
  151. schemaname,
  152. tablename,
  153. attname,
  154. n_distinct,
  155. most_common_vals,
  156. most_common_freqs,
  157. histogram_bounds
  158. FROM pg_stats
  159. WHERE schemaname = $1 AND tablename = $2
  160. """
  161. size_query = """
  162. SELECT pg_size_pretty(pg_total_relation_size($1::oid)) as table_size,
  163. pg_relation_size($1::oid) as table_size_bytes
  164. """
  165. count_query = f"SELECT COUNT(*) as row_count FROM {schema_name}.{table_name}"
  166. stats = {}
  167. async with self.connection_pool.acquire() as conn:
  168. try:
  169. # 获取行数
  170. row_count = await conn.fetchval(count_query)
  171. stats['row_count'] = row_count
  172. # 获取表大小
  173. table_oid = await conn.fetchval(
  174. "SELECT oid FROM pg_class WHERE relname = $1 AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $2)",
  175. table_name, schema_name
  176. )
  177. if table_oid:
  178. # 确保 table_oid 作为整数传递
  179. size_result = await conn.fetchrow(size_query, int(table_oid))
  180. stats['table_size'] = size_result['table_size']
  181. stats['table_size_bytes'] = size_result['table_size_bytes']
  182. except Exception as e:
  183. self.logger.warning(f"获取表统计信息失败: {e}")
  184. return stats