test_data_lineage.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """
  2. 数据血缘可视化功能测试
  3. 测试 DataFlow 的 INPUT/OUTPUT 关系创建以及血缘追溯功能
  4. """
  5. from __future__ import annotations
  6. import json
  7. from typing import Any
  8. from unittest.mock import MagicMock, patch
  9. import pytest
  10. class TestHandleScriptRelationships:
  11. """测试 DataFlowService._handle_script_relationships 方法"""
  12. @patch("app.core.data_flow.dataflows.connect_graph")
  13. @patch("app.core.data_flow.dataflows.get_formatted_time")
  14. def test_creates_input_output_relationships(
  15. self,
  16. mock_get_time: MagicMock,
  17. mock_connect_graph: MagicMock,
  18. ) -> None:
  19. """测试正确创建 INPUT 和 OUTPUT 关系"""
  20. from app.core.data_flow.dataflows import DataFlowService
  21. # Mock 时间
  22. mock_get_time.return_value = "2024-01-01 00:00:00"
  23. # Mock Neo4j session
  24. mock_session = MagicMock()
  25. mock_connect_graph.return_value.session.return_value.__enter__ = MagicMock(
  26. return_value=mock_session
  27. )
  28. mock_connect_graph.return_value.session.return_value.__exit__ = MagicMock(
  29. return_value=False
  30. )
  31. # Mock DataFlow 查询结果
  32. mock_session.run.return_value.single.side_effect = [
  33. {"dataflow_id": 100}, # DataFlow ID
  34. {"source_id": 200}, # Source BD ID
  35. {"target_id": 300}, # Target BD ID
  36. {"r": {}}, # INPUT relationship
  37. {"r": {}}, # OUTPUT relationship
  38. ]
  39. data = {
  40. "source_table": "BusinessDomain:user_info",
  41. "target_table": "BusinessDomain:user_profile",
  42. "script_type": "sql",
  43. "status": "active",
  44. "update_mode": "append",
  45. }
  46. # 调用方法
  47. DataFlowService._handle_script_relationships(
  48. data=data,
  49. dataflow_name="用户数据加工",
  50. name_en="user_data_process",
  51. )
  52. # 验证调用次数 (至少调用了 5 次 run)
  53. assert mock_session.run.call_count >= 3
  54. @patch("app.core.data_flow.dataflows.connect_graph")
  55. @patch("app.core.data_flow.dataflows.get_formatted_time")
  56. def test_skips_when_source_or_target_empty(
  57. self,
  58. mock_get_time: MagicMock,
  59. mock_connect_graph: MagicMock,
  60. ) -> None:
  61. """测试当 source 或 target 为空时跳过关系创建"""
  62. from app.core.data_flow.dataflows import DataFlowService
  63. mock_session = MagicMock()
  64. mock_connect_graph.return_value.session.return_value.__enter__ = MagicMock(
  65. return_value=mock_session
  66. )
  67. data = {
  68. "source_table": "",
  69. "target_table": "BusinessDomain:user_profile",
  70. }
  71. # 应该不抛出异常,但也不创建关系
  72. DataFlowService._handle_script_relationships(
  73. data=data,
  74. dataflow_name="测试",
  75. name_en="test",
  76. )
  77. # 验证没有调用 session.run
  78. mock_session.run.assert_not_called()
  79. class TestGetDataLineageVisualization:
  80. """测试 DataProductService.get_data_lineage_visualization 方法"""
  81. @patch("app.core.data_service.data_product_service.DataProduct")
  82. @patch("app.core.data_service.data_product_service.neo4j_driver")
  83. def test_returns_lineage_graph(
  84. self,
  85. mock_neo4j_driver: MagicMock,
  86. mock_data_product: MagicMock,
  87. ) -> None:
  88. """测试正确返回血缘图谱数据"""
  89. from app.core.data_service.data_product_service import DataProductService
  90. # Mock 数据产品
  91. mock_product = MagicMock()
  92. mock_product.source_dataflow_id = 100
  93. mock_product.target_table = "user_profile"
  94. mock_data_product.query.get.return_value = mock_product
  95. # Mock Neo4j session
  96. mock_session = MagicMock()
  97. mock_neo4j_driver.get_session.return_value.__enter__ = MagicMock(
  98. return_value=mock_session
  99. )
  100. mock_neo4j_driver.get_session.return_value.__exit__ = MagicMock(
  101. return_value=False
  102. )
  103. # Mock 查询结果 - 找到起始 BD
  104. mock_session.run.return_value.single.return_value = {
  105. "bd_id": 300,
  106. "name_zh": "用户画像",
  107. }
  108. # Mock _trace_lineage_upstream
  109. with patch.object(
  110. DataProductService,
  111. "_trace_lineage_upstream",
  112. return_value=(
  113. [
  114. {
  115. "id": 300,
  116. "name_zh": "用户画像",
  117. "node_type": "BusinessDomain",
  118. "is_target": True,
  119. "matched_fields": [],
  120. }
  121. ],
  122. [],
  123. 0,
  124. ),
  125. ):
  126. result = DataProductService.get_data_lineage_visualization(
  127. product_id=1,
  128. sample_data={"用户ID": 123, "姓名": "张三"},
  129. )
  130. assert "nodes" in result
  131. assert "lines" in result
  132. assert "lineage_depth" in result
  133. assert len(result["nodes"]) == 1
  134. @patch("app.core.data_service.data_product_service.DataProduct")
  135. def test_raises_error_when_product_not_found(
  136. self,
  137. mock_data_product: MagicMock,
  138. ) -> None:
  139. """测试数据产品不存在时抛出异常"""
  140. from app.core.data_service.data_product_service import DataProductService
  141. mock_data_product.query.get.return_value = None
  142. with pytest.raises(ValueError, match="数据产品不存在"):
  143. DataProductService.get_data_lineage_visualization(
  144. product_id=999,
  145. sample_data={"test": "value"},
  146. )
  147. class TestMatchFieldsWithSample:
  148. """测试 DataProductService._match_fields_with_sample 方法"""
  149. def test_matches_fields_by_name_zh(self) -> None:
  150. """测试通过中文名匹配字段"""
  151. from app.core.data_service.data_product_service import DataProductService
  152. mock_session = MagicMock()
  153. # Mock DataMeta 查询结果
  154. mock_session.run.return_value.data.return_value = [
  155. {
  156. "name_zh": "用户ID",
  157. "name_en": "user_id",
  158. "data_type": "integer",
  159. "meta_id": 1001,
  160. },
  161. {
  162. "name_zh": "姓名",
  163. "name_en": "name",
  164. "data_type": "string",
  165. "meta_id": 1002,
  166. },
  167. {
  168. "name_zh": "年龄",
  169. "name_en": "age",
  170. "data_type": "integer",
  171. "meta_id": 1003,
  172. },
  173. ]
  174. sample_data = {"用户ID": 123, "姓名": "张三"}
  175. result = DataProductService._match_fields_with_sample(
  176. session=mock_session,
  177. bd_id=100,
  178. sample_data=sample_data,
  179. )
  180. # 应该匹配到 2 个字段
  181. assert len(result) == 2
  182. # 验证匹配结果
  183. matched_names = {field["field_name"] for field in result}
  184. assert "用户ID" in matched_names
  185. assert "姓名" in matched_names
  186. # 验证值
  187. for field in result:
  188. if field["field_name"] == "用户ID":
  189. assert field["value"] == 123
  190. elif field["field_name"] == "姓名":
  191. assert field["value"] == "张三"
  192. def test_returns_empty_when_no_match(self) -> None:
  193. """测试无匹配时返回空列表"""
  194. from app.core.data_service.data_product_service import DataProductService
  195. mock_session = MagicMock()
  196. mock_session.run.return_value.data.return_value = [
  197. {
  198. "name_zh": "订单号",
  199. "name_en": "order_id",
  200. "data_type": "string",
  201. "meta_id": 2001,
  202. },
  203. ]
  204. sample_data = {"用户ID": 123} # 不匹配
  205. result = DataProductService._match_fields_with_sample(
  206. session=mock_session,
  207. bd_id=100,
  208. sample_data=sample_data,
  209. )
  210. assert len(result) == 0
  211. class TestLineageVisualizationAPI:
  212. """测试血缘可视化 API 端点"""
  213. @pytest.fixture
  214. def app(self) -> Any:
  215. """创建测试应用"""
  216. from app import create_app
  217. app = create_app()
  218. app.config["TESTING"] = True
  219. return app
  220. @pytest.fixture
  221. def client(self, app: Any) -> Any:
  222. """创建测试客户端"""
  223. return app.test_client()
  224. def test_returns_400_when_no_data(self, client: Any) -> None:
  225. """测试无请求数据时返回 400"""
  226. response = client.post("/api/data-service/products/1/lineage-visualization")
  227. # 检查状态码或响应体
  228. data = json.loads(response.data)
  229. assert data.get("code") in [400, 500] # 可能是 400 或 500
  230. def test_returns_400_when_sample_data_invalid(self, client: Any) -> None:
  231. """测试 sample_data 格式无效时返回 400"""
  232. response = client.post(
  233. "/api/data-service/products/1/lineage-visualization",
  234. data=json.dumps({"sample_data": "not_a_dict"}),
  235. content_type="application/json",
  236. )
  237. data = json.loads(response.data)
  238. assert data.get("code") in [400, 500]