ddl_parser.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import json
  2. import logging
  3. import re
  4. import time
  5. import requests
  6. from flask import current_app
  7. logger = logging.getLogger(__name__)
  8. class DDLParser:
  9. def __init__(self, api_key=None, timeout=60, max_retries=3):
  10. """
  11. 初始化DDL解析器
  12. 参数:
  13. api_key: LLM API密钥,如果未提供,将从应用配置或环境变量中获取
  14. timeout: API请求超时时间(秒),默认60秒
  15. max_retries: 最大重试次数,默认3次
  16. """
  17. # 如果在Flask应用上下文中,则从应用配置获取参数
  18. self.api_key = api_key or current_app.config.get("LLM_API_KEY")
  19. self.base_url = current_app.config.get("LLM_BASE_URL")
  20. self.model_name = current_app.config.get("LLM_MODEL_NAME")
  21. self.timeout = timeout
  22. self.max_retries = max_retries
  23. self.headers = {
  24. "Authorization": f"Bearer {self.api_key}",
  25. "Content-Type": "application/json",
  26. }
  27. def _make_llm_request(self, payload, operation_name="LLM请求"):
  28. """
  29. 发送LLM请求,支持自动重试
  30. 参数:
  31. payload: 请求payload
  32. operation_name: 操作名称,用于日志
  33. 返回:
  34. API响应结果
  35. """
  36. last_error = None
  37. for attempt in range(self.max_retries):
  38. try:
  39. if attempt > 0:
  40. wait_time = 2**attempt # 指数退避: 2, 4, 8秒
  41. logger.info(
  42. f"{operation_name} 第{attempt + 1}次重试,等待{wait_time}秒..."
  43. )
  44. time.sleep(wait_time)
  45. logger.info(
  46. f"{operation_name} 尝试 {attempt + 1}/{self.max_retries},超时时间: {self.timeout}秒"
  47. )
  48. response = requests.post(
  49. f"{self.base_url}/chat/completions",
  50. headers=self.headers,
  51. json=payload,
  52. timeout=self.timeout,
  53. )
  54. response.raise_for_status()
  55. result = response.json()
  56. logger.info(f"{operation_name} 成功")
  57. return result
  58. except requests.Timeout as e:
  59. last_error = f"请求超时(超过{self.timeout}秒): {str(e)}"
  60. logger.warning(f"{operation_name} 超时: {str(e)}")
  61. except requests.RequestException as e:
  62. last_error = f"API请求失败: {str(e)}"
  63. logger.warning(f"{operation_name} 失败: {str(e)}")
  64. except Exception as e:
  65. last_error = f"未知错误: {str(e)}"
  66. logger.error(f"{operation_name} 异常: {str(e)}")
  67. break # 对于非网络错误,不重试
  68. # 所有重试都失败
  69. logger.error(f"{operation_name} 在{self.max_retries}次尝试后失败: {last_error}")
  70. return None
  71. def parse_ddl(self, sql_content):
  72. """
  73. 解析DDL语句,返回标准化的结构
  74. 参数:
  75. sql_content: 要解析的DDL语句
  76. 返回:
  77. 解析结果的JSON对象
  78. """
  79. prompt = self._optimize_ddl_prompt()
  80. payload = {
  81. "model": self.model_name,
  82. "messages": [
  83. {
  84. "role": "system",
  85. "content": "你是一个专业的SQL DDL语句解析专家,擅长从DDL语句中提取表结构信息并转换为结构化的JSON格式。",
  86. },
  87. {"role": "user", "content": f"{prompt}\n\n{sql_content}"},
  88. ],
  89. }
  90. try:
  91. result = self._make_llm_request(payload, "DDL解析")
  92. if not result:
  93. return {
  94. "code": 500,
  95. "message": f"API请求失败: 在{self.max_retries}次尝试后仍然失败",
  96. }
  97. if "choices" in result and len(result["choices"]) > 0:
  98. content = result["choices"][0]["message"]["content"]
  99. try:
  100. json_match = re.search(r"```json\s*([\s\S]*?)\s*```", content)
  101. if json_match:
  102. json_content = json_match.group(1)
  103. else:
  104. json_content = content
  105. parsed_result = json.loads(json_content)
  106. return parsed_result
  107. except json.JSONDecodeError as e:
  108. return {
  109. "code": 500,
  110. "message": f"无法解析返回的JSON: {str(e)}",
  111. "original_response": content,
  112. }
  113. return {
  114. "code": 500,
  115. "message": "无法获取有效响应",
  116. "original_response": result,
  117. }
  118. except Exception as e:
  119. logger.error(f"DDL解析异常: {str(e)}")
  120. return {"code": 500, "message": f"解析失败: {str(e)}"}
  121. def parse_db_conn_str(self, conn_str):
  122. """
  123. 解析数据库连接字符串
  124. 参数:
  125. conn_str: 要解析的数据库连接字符串
  126. 返回:
  127. 解析结果的JSON对象
  128. """
  129. prompt = self._optimize_connstr_parse_prompt()
  130. payload = {
  131. "model": self.model_name,
  132. "messages": [
  133. {
  134. "role": "system",
  135. "content": "你是一个专业的数据库连接字符串解析专家,擅长解析各种数据库的连接字符串并提取关键信息。",
  136. },
  137. {"role": "user", "content": f"{prompt}\n\n{conn_str}"},
  138. ],
  139. }
  140. try:
  141. result = self._make_llm_request(payload, "连接字符串解析")
  142. if not result:
  143. return {
  144. "code": 500,
  145. "message": f"API请求失败: 在{self.max_retries}次尝试后仍然失败",
  146. }
  147. if "choices" in result and len(result["choices"]) > 0:
  148. content = result["choices"][0]["message"]["content"]
  149. try:
  150. json_match = re.search(r"```json\s*([\s\S]*?)\s*```", content)
  151. if json_match:
  152. json_content = json_match.group(1)
  153. else:
  154. json_content = content
  155. parsed_result = json.loads(json_content)
  156. return parsed_result
  157. except json.JSONDecodeError as e:
  158. return {
  159. "code": 500,
  160. "message": f"无法解析返回的JSON: {str(e)}",
  161. "original_response": content,
  162. }
  163. return {
  164. "code": 500,
  165. "message": "无法获取有效响应",
  166. "original_response": result,
  167. }
  168. except Exception as e:
  169. logger.error(f"连接字符串解析异常: {str(e)}")
  170. return {"code": 500, "message": f"解析失败: {str(e)}"}
  171. def _optimize_ddl_prompt(self):
  172. """返回优化后的提示词模板"""
  173. return """
  174. 请解析以下DDL建表语句,并按照指定的JSON格式返回结果:
  175. 规则说明:
  176. 1. 从DDL语句中识别所有表,可能会有多个表。将所有表放在一个数组中返回。
  177. 2. 表的英文名称(name_en)使用原始大小写,不要转换为小写。
  178. 3. 表的中文名称(name_zh)提取规则:
  179. - 优先从COMMENT ON TABLE语句中提取
  180. - 如果没有注释,则name_zh为空字符串
  181. - 中文名称中不要出现标点符号、"主键"、"外键"、"索引"等字样
  182. 4. 对于每个表,提取所有字段信息到columns数组中,每个字段包含:
  183. - name_zh: 字段中文名称(从COMMENT ON COLUMN提取,如果没有注释则翻译英文名,如果是无意义缩写则为空)
  184. - name_en: 字段英文名称(保持原始大小写)
  185. - data_type: 数据类型(包含长度信息,如VARCHAR(22))
  186. - is_primary: 是否主键("是"或"否",从PRIMARY KEY约束判断)
  187. - comment: 注释内容(从COMMENT ON COLUMN提取完整注释,如果没有则为空字符串)
  188. - nullable: 是否可为空("是"或"否",从NOT NULL约束判断,默认为"是")
  189. 5. 中文字段名不要出现逗号、"主键"、"外键"、"索引"等字样。
  190. 6. 返回格式(使用数组支持多表):
  191. [
  192. {
  193. "table_info": {
  194. "name_zh": "科室对照表",
  195. "name_en": "TB_JC_KSDZB"
  196. },
  197. "columns": [
  198. {
  199. "name_zh": "医疗机构代码",
  200. "name_en": "YLJGDM",
  201. "data_type": "VARCHAR(22)",
  202. "is_primary": "是",
  203. "comment": "医疗机构代码,复合主键",
  204. "nullable": "否"
  205. },
  206. {
  207. "name_zh": "HIS科室代码",
  208. "name_en": "HISKSDM",
  209. "data_type": "CHAR(20)",
  210. "is_primary": "是",
  211. "comment": "HIS科室代码,主键、唯一",
  212. "nullable": "否"
  213. },
  214. {
  215. "name_zh": "HIS科室名称",
  216. "name_en": "HISKSMC",
  217. "data_type": "CHAR(20)",
  218. "is_primary": "否",
  219. "comment": "HIS科室名称",
  220. "nullable": "否"
  221. }
  222. ]
  223. }
  224. ]
  225. 注意:
  226. - 如果只有一个表,也要返回数组格式:[{table_info: {...}, columns: [...]}]
  227. - 如果有多个表,数组中包含多个元素:[{表1}, {表2}, {表3}]
  228. 请仅返回JSON格式结果,不要包含任何其他解释文字。
  229. """
  230. def _optimize_ddl_source_prompt(self):
  231. """返回优化后的提示词模板"""
  232. return """
  233. 请解析以下DDL建表语句,并按照指定的JSON格式返回结果:
  234. 规则说明:
  235. 1. 从DDL语句中识别所有表名,并在data对象中为每个表创建条目,表名请使用小写,可能会有多个表。
  236. 2. 对于每个表,提取所有字段信息,包括名称、数据类型和注释。
  237. - 中文表名中不要出现标点符号
  238. 3. 字段中文名称(name_zh)的确定规则:
  239. - 如有COMMENT注释,直接使用注释内容
  240. - 如无注释但字段名有明确含义,将英文名翻译为中文
  241. - 如字段名是无意义的拼音缩写,则name_zh为空字符串
  242. - 字段名中不要出现逗号,以及"主键"、"外键"、"索引"等字样
  243. 4. 所有的表的定义信息,请放在tables对象中, tables对象的key为表名,value为表的定义信息。这里可能会有多个表,请一一识别。
  244. 5. data_source对象,请放在data_source标签中,它与tables对象同级。
  245. 6. 数据库连接串处理:
  246. - 将连接串识别后并拆解为:主机名/IP地址、端口、数据库名称、用户名、密码。
  247. - 根据连接串格式识别数据库类型,数据库类型请使用小写,参考例子,如 mysql/postgresql/sqlserver/oracle/db2/sybase
  248. - data_source.name_en格式为: "{数据库名称}_{hostname或ip地址}_{端口}_{数据库用户名}",如某个元素无法识别,则跳过不添加.
  249. - data_source.name_zh留空.
  250. - 无法确定数据库类型时,type设为"unknown"
  251. - 如果从ddl中没有识别到数据库连接串,则json不返回"data_source"标签
  252. - 除了database,password,username,name_en,host,port,type,name_zh 之外,连接串的其它字段放在param属性中。
  253. 7. 参考格式如下:
  254. {
  255. "tables": {
  256. "users": { //表名
  257. "name_zh": "用户表", //表的中文名,来自于COMMENT注释或LLM翻译,如果无法确定,则name_zh为空字符串
  258. "schema": "public",
  259. "meta": [{
  260. "name_en": "id",
  261. "data_type": "integer",
  262. "name_zh": "用户ID"
  263. },
  264. {
  265. "name_en": "username",
  266. "data_type": "varchar",
  267. "name_zh": "用户名"
  268. }
  269. ]
  270. }
  271. },
  272. "data_source": [{
  273. "name_en": "mydatabase_10.52.31.104_5432_myuser", //{数据库名称}_{hostname或ip地址}_{端口}_{数据库用户名}
  274. "name_zh": "", //如果没有注释,这里留空
  275. "type": "postgresql",
  276. "host": "10.52.31.104",
  277. "port": 5432,
  278. "database": "mydatabase",
  279. "username": "myuser",
  280. "password": "mypassword",
  281. "param": "useUnicode=true&characterEncoding=utf8&serverTimezone=UTC"
  282. }]
  283. }
  284. 请仅返回JSON格式结果,不要包含任何其他解释文字。
  285. """
  286. def _optimize_connstr_parse_prompt(self):
  287. """返回优化后的连接字符串解析提示词模板"""
  288. return """
  289. 请解析以下数据库连接字符串,并按照指定的JSON格式返回结果:
  290. 规则说明:
  291. 1. 将连接串识别后并拆解为:主机名/IP地址、端口、数据库名称、用户名、密码。
  292. 2. 根据连接串格式识别数据库类型,数据库类型请使用小写,如 mysql/postgresql/sqlserver/oracle/db2/sybase
  293. 3. data_source.name_en格式为: "{数据库名称}_{hostname或ip地址}_{端口}_{数据库用户名}",如某个元素无法识别,则跳过不添加
  294. 4. data_source.name_zh留空
  295. 5. 无法确定数据库类型时,type设为"unknown"
  296. 6. 除了database,password,username,name_en,host,port,type,name_zh 之外,连接串的其它字段放在param属性中
  297. 返回格式示例:
  298. {
  299. "data_source": {
  300. "name_en": "mydatabase_10.52.31.104_5432_myuser",
  301. "name_zh": "",
  302. "type": "postgresql",
  303. "host": "10.52.31.104",
  304. "port": 5432,
  305. "database": "mydatabase",
  306. "username": "myuser",
  307. "password": "mypassword",
  308. "param": "useUnicode=true&characterEncoding=utf8&serverTimezone=UTC"
  309. }
  310. }
  311. 请仅返回JSON格式结果,不要包含任何其他解释文字。
  312. """
  313. def _optimize_connstr_valid_prompt(self):
  314. """返回优化后的连接字符串验证提示词模板"""
  315. return """
  316. 请验证以下数据库连接信息是否符合规则:
  317. 规则说明:
  318. 1. 必填字段检查:
  319. - database: 数据库名称,不能为空,符合数据库名称的命名规范。
  320. - name_en: 格式必须为 "{数据库名称}_{hostname或ip地址}_{端口}_{数据库用户名}"
  321. - host: 主机名或IP地址,不能为空
  322. - port: 端口号,必须为数字
  323. - type: 数据库类型,必须为以下之一:mysql/postgresql/sqlserver/oracle/db2/sybase
  324. - username: 用户名,不能为空,名称中间不能有空格。
  325. 2. 字段格式检查:
  326. - en_name中的各个部分必须与对应的字段值匹配
  327. - port必须是有效的端口号(1-65535)
  328. - type必须是小写的数据库类型名称
  329. - param中的参数格式必须正确(key=value格式)
  330. 3. 可选字段:
  331. - password: 密码(可选)
  332. - name: 中文名称(可选)
  333. - desc: 描述(可选)
  334. 请检查提供的连接信息是否符合以上规则,如果符合则返回"success",否则返回"failure"。
  335. 请仅返回"success"或"failure",不要包含任何其他解释文字。
  336. """
  337. def valid_db_conn_str(self, conn_str):
  338. """
  339. 验证数据库连接字符串是否符合规则
  340. 参数:
  341. conn_str: 要验证的数据库连接信息(JSON格式)
  342. 返回:
  343. "success" 或 "failure"
  344. """
  345. prompt = self._optimize_connstr_valid_prompt()
  346. payload = {
  347. "model": self.model_name,
  348. "messages": [
  349. {
  350. "role": "system",
  351. "content": "你是一个专业的数据库连接信息验证专家,擅长验证数据库连接信息的完整性和正确性。",
  352. },
  353. {
  354. "role": "user",
  355. "content": f"{prompt}\n\n{json.dumps(conn_str, ensure_ascii=False)}",
  356. },
  357. ],
  358. }
  359. try:
  360. result = self._make_llm_request(payload, "连接字符串验证")
  361. if not result:
  362. logger.error(
  363. f"连接字符串验证失败: 在{self.max_retries}次尝试后仍然失败"
  364. )
  365. return "failure"
  366. if "choices" in result and len(result["choices"]) > 0:
  367. content = result["choices"][0]["message"]["content"].strip().lower()
  368. return "success" if content == "success" else "failure"
  369. return "failure"
  370. except Exception as e:
  371. logger.error(f"LLM 验证数据库连接字符串失败: {str(e)}")
  372. return "failure"