""" 数据血缘可视化功能测试 测试 DataFlow 的 INPUT/OUTPUT 关系创建以及血缘追溯功能 """ from __future__ import annotations import json from typing import Any from unittest.mock import MagicMock, patch import pytest class TestHandleScriptRelationships: """测试 DataFlowService._handle_script_relationships 方法""" @patch("app.core.data_flow.dataflows.connect_graph") @patch("app.core.data_flow.dataflows.get_formatted_time") def test_creates_input_output_relationships( self, mock_get_time: MagicMock, mock_connect_graph: MagicMock, ) -> None: """测试正确创建 INPUT 和 OUTPUT 关系""" from app.core.data_flow.dataflows import DataFlowService # Mock 时间 mock_get_time.return_value = "2024-01-01 00:00:00" # Mock Neo4j session mock_session = MagicMock() mock_connect_graph.return_value.session.return_value.__enter__ = MagicMock( return_value=mock_session ) mock_connect_graph.return_value.session.return_value.__exit__ = MagicMock( return_value=False ) # Mock DataFlow 查询结果 mock_session.run.return_value.single.side_effect = [ {"dataflow_id": 100}, # DataFlow ID {"source_id": 200}, # Source BD ID {"target_id": 300}, # Target BD ID {"r": {}}, # INPUT relationship {"r": {}}, # OUTPUT relationship ] data = { "source_table": "BusinessDomain:user_info", "target_table": "BusinessDomain:user_profile", "script_type": "sql", "status": "active", "update_mode": "append", } # 调用方法 DataFlowService._handle_script_relationships( data=data, dataflow_name="用户数据加工", name_en="user_data_process", ) # 验证调用次数 (至少调用了 5 次 run) assert mock_session.run.call_count >= 3 @patch("app.core.data_flow.dataflows.connect_graph") @patch("app.core.data_flow.dataflows.get_formatted_time") def test_skips_when_source_or_target_empty( self, mock_get_time: MagicMock, mock_connect_graph: MagicMock, ) -> None: """测试当 source 或 target 为空时跳过关系创建""" from app.core.data_flow.dataflows import DataFlowService mock_session = MagicMock() mock_connect_graph.return_value.session.return_value.__enter__ = MagicMock( return_value=mock_session ) data = { "source_table": "", "target_table": "BusinessDomain:user_profile", } # 应该不抛出异常,但也不创建关系 DataFlowService._handle_script_relationships( data=data, dataflow_name="测试", name_en="test", ) # 验证没有调用 session.run mock_session.run.assert_not_called() class TestGetDataLineageVisualization: """测试 DataProductService.get_data_lineage_visualization 方法""" @patch("app.core.data_service.data_product_service.DataProduct") @patch("app.core.data_service.data_product_service.neo4j_driver") def test_returns_lineage_graph( self, mock_neo4j_driver: MagicMock, mock_data_product: MagicMock, ) -> None: """测试正确返回血缘图谱数据""" from app.core.data_service.data_product_service import DataProductService # Mock 数据产品 mock_product = MagicMock() mock_product.source_dataflow_id = 100 mock_product.target_table = "user_profile" mock_data_product.query.get.return_value = mock_product # Mock Neo4j session mock_session = MagicMock() mock_neo4j_driver.get_session.return_value.__enter__ = MagicMock( return_value=mock_session ) mock_neo4j_driver.get_session.return_value.__exit__ = MagicMock( return_value=False ) # Mock 查询结果 - 找到起始 BD mock_session.run.return_value.single.return_value = { "bd_id": 300, "name_zh": "用户画像", } # Mock _trace_lineage_upstream with patch.object( DataProductService, "_trace_lineage_upstream", return_value=( [ { "id": 300, "name_zh": "用户画像", "node_type": "BusinessDomain", "is_target": True, "matched_fields": [], } ], [], 0, ), ): result = DataProductService.get_data_lineage_visualization( product_id=1, sample_data={"用户ID": 123, "姓名": "张三"}, ) assert "nodes" in result assert "lines" in result assert "lineage_depth" in result assert len(result["nodes"]) == 1 @patch("app.core.data_service.data_product_service.DataProduct") def test_raises_error_when_product_not_found( self, mock_data_product: MagicMock, ) -> None: """测试数据产品不存在时抛出异常""" from app.core.data_service.data_product_service import DataProductService mock_data_product.query.get.return_value = None with pytest.raises(ValueError, match="数据产品不存在"): DataProductService.get_data_lineage_visualization( product_id=999, sample_data={"test": "value"}, ) class TestMatchFieldsWithSample: """测试 DataProductService._match_fields_with_sample 方法""" def test_matches_fields_by_name_zh(self) -> None: """测试通过中文名匹配字段""" from app.core.data_service.data_product_service import DataProductService mock_session = MagicMock() # Mock DataMeta 查询结果 mock_session.run.return_value.data.return_value = [ { "name_zh": "用户ID", "name_en": "user_id", "data_type": "integer", "meta_id": 1001, }, { "name_zh": "姓名", "name_en": "name", "data_type": "string", "meta_id": 1002, }, { "name_zh": "年龄", "name_en": "age", "data_type": "integer", "meta_id": 1003, }, ] sample_data = {"用户ID": 123, "姓名": "张三"} result = DataProductService._match_fields_with_sample( session=mock_session, bd_id=100, sample_data=sample_data, ) # 应该匹配到 2 个字段 assert len(result) == 2 # 验证匹配结果 matched_names = {field["field_name"] for field in result} assert "用户ID" in matched_names assert "姓名" in matched_names # 验证值 for field in result: if field["field_name"] == "用户ID": assert field["value"] == 123 elif field["field_name"] == "姓名": assert field["value"] == "张三" def test_returns_empty_when_no_match(self) -> None: """测试无匹配时返回空列表""" from app.core.data_service.data_product_service import DataProductService mock_session = MagicMock() mock_session.run.return_value.data.return_value = [ { "name_zh": "订单号", "name_en": "order_id", "data_type": "string", "meta_id": 2001, }, ] sample_data = {"用户ID": 123} # 不匹配 result = DataProductService._match_fields_with_sample( session=mock_session, bd_id=100, sample_data=sample_data, ) assert len(result) == 0 class TestLineageVisualizationAPI: """测试血缘可视化 API 端点""" @pytest.fixture def app(self) -> Any: """创建测试应用""" from app import create_app app = create_app() app.config["TESTING"] = True return app @pytest.fixture def client(self, app: Any) -> Any: """创建测试客户端""" return app.test_client() def test_returns_400_when_no_data(self, client: Any) -> None: """测试无请求数据时返回 400""" response = client.post("/api/data-service/products/1/lineage-visualization") # 检查状态码或响应体 data = json.loads(response.data) assert data.get("code") in [400, 500] # 可能是 400 或 500 def test_returns_400_when_sample_data_invalid(self, client: Any) -> None: """测试 sample_data 格式无效时返回 400""" response = client.post( "/api/data-service/products/1/lineage-visualization", data=json.dumps({"sample_data": "not_a_dict"}), content_type="application/json", ) data = json.loads(response.data) assert data.get("code") in [400, 500]