ソースを参照

修复数据可视化不能展示下级BD节点数据的问题。

maxiaolong 2 ヶ月 前
コミット
7543f745a8
2 ファイル変更455 行追加18 行削除
  1. 114 18
      app/core/data_service/data_product_service.py
  2. 341 0
      test_lineage_visualization.py

+ 114 - 18
app/core/data_service/data_product_service.py

@@ -926,6 +926,10 @@ class DataProductService:
             从目标节点的字段中提取有"键值"标签的字段及其对应的值
             同时考虑 ALIAS 别名关系,获取主元数据和所有别名的名称
 
+            改进:除了精确匹配元数据名称外,还会:
+            1. 直接将 sample_data 中的所有键值对加入(供上游节点匹配使用)
+            2. 通过别名关系扩展键值映射
+
             Args:
                 fields: 目标节点的字段列表
 
@@ -934,6 +938,13 @@ class DataProductService:
                 包含主元数据和所有别名元数据的名称,都映射到同一个值
             """
             key_values: dict[str, Any] = {}
+
+            # 首先,将 sample_data 中的所有键值对加入(用于上游节点匹配)
+            for key, value in sample_data.items():
+                if value is not None:
+                    key_values[key] = value
+
+            # 然后,处理有"键值"标签的字段,扩展别名映射
             for field in fields:
                 tags = field.get("tags", [])
                 # 检查该字段是否有"键值"标签
@@ -946,12 +957,37 @@ class DataProductService:
                     meta_id = field.get("meta_id")
 
                     # 从 sample_data 中获取键值字段的值
+                    # 支持多种方式匹配:精确匹配、包含匹配
                     key_value = None
+
+                    # 方式1:精确匹配元数据名称
                     if name_zh and name_zh in sample_data:
                         key_value = sample_data[name_zh]
                     elif name_en and name_en in sample_data:
                         key_value = sample_data[name_en]
 
+                    # 方式2:如果元数据名称不匹配,尝试模糊匹配
+                    # 例如 "仓库名称_统计2" 匹配 sample_data 中的 "warehouse_name"
+                    if key_value is None:
+                        for sample_key, sample_val in sample_data.items():
+                            # 检查是否有相似的字段名(去除后缀如 _统计、_stat 等)
+                            base_name_zh = name_zh.split("_")[0] if name_zh else ""
+                            base_name_en = name_en.split("_")[0] if name_en else ""
+                            sample_key_base = sample_key.split("_")[0]
+
+                            if (
+                                (base_name_zh and base_name_zh in sample_key)
+                                or (base_name_en and base_name_en in sample_key)
+                                or (sample_key_base and sample_key_base in name_en)
+                            ):
+                                key_value = sample_val
+                                logger.debug(
+                                    f"键值字段模糊匹配: "
+                                    f"meta_field='{name_zh or name_en}' -> "
+                                    f"sample_key='{sample_key}'"
+                                )
+                                break
+
                     if key_value is not None:
                         # 添加当前字段的名称映射
                         if name_zh:
@@ -977,6 +1013,10 @@ class DataProductService:
                                 f"all_names={[a.get('name_zh') or a.get('name_en') for a in alias_names]}"
                             )
 
+            logger.info(
+                f"提取的键值字段: keys={list(key_values.keys())}, "
+                f"values={list(key_values.values())}"
+            )
             return key_values
 
         def query_matched_data_by_keys(
@@ -988,6 +1028,8 @@ class DataProductService:
             """
             根据键值从 BusinessDomain 对应的数据表中检索匹配数据
 
+            改进:支持更灵活的字段名匹配,优先使用有"键值"标签的字段
+
             Args:
                 bd_id: BusinessDomain 节点 ID
                 bd_name_en: BusinessDomain 英文名(对应表名)
@@ -998,6 +1040,11 @@ class DataProductService:
                 匹配的数据列表,格式为 [{field_name: value, ...}, ...]
             """
             if not key_values or not bd_name_en:
+                logger.debug(
+                    f"跳过数据检索: bd_id={bd_id}, "
+                    f"key_values_empty={not key_values}, "
+                    f"bd_name_en_empty={not bd_name_en}"
+                )
                 return []
 
             try:
@@ -1008,11 +1055,11 @@ class DataProductService:
                 RETURN ds.schema as schema
                 """
                 ds_result = session.run(ds_query, {"bd_id": bd_id}).single()
-                schema = ds_result["schema"] if ds_result else "public"
+                schema = ds_result["schema"] if ds_result else "dags"
 
                 table_name = bd_name_en
 
-                # 检查表是否存在
+                # 检查表是否存在(先检查原 schema,再检查 dags schema)
                 check_sql = text(
                     """
                     SELECT EXISTS (
@@ -1026,6 +1073,14 @@ class DataProductService:
                     check_sql, {"schema": schema, "table": table_name}
                 ).scalar()
 
+                # 如果原 schema 不存在,尝试 dags schema
+                if not exists and schema != "dags":
+                    exists = db.session.execute(
+                        check_sql, {"schema": "dags", "table": table_name}
+                    ).scalar()
+                    if exists:
+                        schema = "dags"
+
                 if not exists:
                     logger.debug(f"表 {schema}.{table_name} 不存在,跳过数据检索")
                     return []
@@ -1043,38 +1098,81 @@ class DataProductService:
                 )
                 actual_columns = {row[0] for row in columns_result}
 
+                logger.debug(
+                    f"表 {schema}.{table_name} 的列: {actual_columns}, "
+                    f"可用键值: {list(key_values.keys())}"
+                )
+
                 # 构建 WHERE 条件:使用键值字段进行匹配
-                # 只使用表中实际存在的列
+                # 优先使用有"键值"标签的字段,其次尝试模糊匹配
                 where_conditions = []
                 params: dict[str, Any] = {}
 
+                # 首先,处理有"键值"标签的字段
                 for field in fields:
+                    tags = field.get("tags", [])
+                    is_key_field = any(
+                        tag.get("name_zh") == "键值" for tag in tags if tag.get("id")
+                    )
+                    if not is_key_field:
+                        continue
+
                     name_en = field.get("name_en", "")
                     name_zh = field.get("name_zh", "")
 
-                    # 检查该字段是否是键值字段(在 key_values 中有值)
-                    key_value = None
+                    # 确定表中的实际列名
                     field_name_in_table = None
-
                     if name_en and name_en in actual_columns:
                         field_name_in_table = name_en
-                        if name_en in key_values:
-                            key_value = key_values[name_en]
-                        elif name_zh in key_values:
-                            key_value = key_values[name_zh]
                     elif name_zh and name_zh in actual_columns:
                         field_name_in_table = name_zh
-                        if name_zh in key_values:
-                            key_value = key_values[name_zh]
-                        elif name_en in key_values:
-                            key_value = key_values[name_en]
 
-                    if field_name_in_table and key_value is not None:
+                    if not field_name_in_table:
+                        continue
+
+                    # 尝试从 key_values 中获取匹配的值
+                    key_value = None
+
+                    # 方式1:精确匹配
+                    if name_en in key_values:
+                        key_value = key_values[name_en]
+                    elif name_zh in key_values:
+                        key_value = key_values[name_zh]
+
+                    # 方式2:模糊匹配(例如 warehouse 匹配 warehouse_name)
+                    if key_value is None:
+                        for kv_key, kv_val in key_values.items():
+                            # 检查键值名称是否包含字段名,或字段名包含键值名称
+                            if (
+                                (name_en and name_en in kv_key)
+                                or (name_en and kv_key in name_en)
+                                or (name_zh and name_zh in kv_key)
+                                or (name_zh and kv_key in name_zh)
+                            ):
+                                key_value = kv_val
+                                logger.debug(
+                                    f"键值模糊匹配成功: "
+                                    f"field='{name_en or name_zh}' -> "
+                                    f"key='{kv_key}', value='{kv_val}'"
+                                )
+                                break
+
+                    if key_value is not None:
                         param_name = f"key_{len(where_conditions)}"
                         where_conditions.append(
                             f'"{field_name_in_table}" = :{param_name}'
                         )
                         params[param_name] = key_value
+                        logger.debug(f"添加键值条件: {field_name_in_table}={key_value}")
+
+                # 如果没有通过键值字段匹配到,尝试直接用 key_values 中的键匹配表列
+                if not where_conditions:
+                    for kv_key, kv_val in key_values.items():
+                        if kv_key in actual_columns and kv_val is not None:
+                            param_name = f"key_{len(where_conditions)}"
+                            where_conditions.append(f'"{kv_key}" = :{param_name}')
+                            params[param_name] = kv_val
+                            logger.debug(f"直接列名匹配: {kv_key}={kv_val}")
 
                 if not where_conditions:
                     logger.debug(
@@ -1968,9 +2066,7 @@ class DataOrderService:
                         }
 
                     # 合并 common_fields
-                    pair_dict[pair_key]["common_fields"].extend(
-                        record["common_fields"]
-                    )
+                    pair_dict[pair_key]["common_fields"].extend(record["common_fields"])
 
                 connection_pairs = list(pair_dict.values())
 

+ 341 - 0
test_lineage_visualization.py

@@ -0,0 +1,341 @@
+"""
+测试数据加工可视化 API
+
+用于测试 /api/dataservice/products/23/lineage-visualization 接口
+通过图谱和数据库表查询 warehouse_inventory_summary 和 test_product_inventory 表中的数据
+"""
+
+import json
+import os
+import sys
+from datetime import datetime
+
+# 设置为生产环境以连接远程数据库
+os.environ["FLASK_ENV"] = "production"
+
+# 添加项目路径
+sys.path.insert(0, ".")
+
+from sqlalchemy import text
+
+from app import create_app, db
+from app.models.data_product import DataProduct
+
+
+def get_product_info(app, product_id: int):
+    """获取数据产品基本信息"""
+    with app.app_context():
+        product = DataProduct.query.get(product_id)
+        if product:
+            print(f"\n{'=' * 60}")
+            print(f"数据产品信息 (ID: {product_id})")
+            print(f"{'=' * 60}")
+            print(f"  产品名称: {product.product_name}")
+            print(f"  英文名称: {product.product_name_en}")
+            print(f"  目标表名: {product.target_table}")
+            print(f"  目标Schema: {product.target_schema}")
+            print(f"  关联DataFlow ID: {product.source_dataflow_id}")
+            print(f"  记录数: {product.record_count}")
+            print(f"  列数: {product.column_count}")
+            return product.to_dict()
+        else:
+            print(f"数据产品 ID={product_id} 不存在")
+            return None
+
+
+def get_sample_data_from_table(app, schema: str, table_name: str, limit: int = 1):
+    """从指定表获取样例数据"""
+    with app.app_context():
+        try:
+            # 检查表是否存在
+            check_sql = text("""
+                SELECT EXISTS (
+                    SELECT FROM information_schema.tables
+                    WHERE table_schema = :schema
+                    AND table_name = :table
+                )
+            """)
+            exists = db.session.execute(
+                check_sql, {"schema": schema, "table": table_name}
+            ).scalar()
+
+            if not exists:
+                print(f"表 {schema}.{table_name} 不存在")
+                return None
+
+            # 获取表的列信息
+            columns_sql = text("""
+                SELECT column_name, data_type 
+                FROM information_schema.columns
+                WHERE table_schema = :schema AND table_name = :table
+                ORDER BY ordinal_position
+            """)
+            columns = db.session.execute(
+                columns_sql, {"schema": schema, "table": table_name}
+            ).fetchall()
+
+            print(f"\n{'=' * 60}")
+            print(f"表结构: {schema}.{table_name}")
+            print(f"{'=' * 60}")
+            for col_name, col_type in columns:
+                print(f"  {col_name}: {col_type}")
+
+            # 获取样例数据
+            query_sql = text(f'SELECT * FROM "{schema}"."{table_name}" LIMIT :limit')
+            result = db.session.execute(query_sql, {"limit": limit})
+            rows = result.fetchall()
+            column_names = list(result.keys())
+
+            if rows:
+                sample_data = dict(zip(column_names, rows[0]))
+                print("\n样例数据:")
+                print(f"{'-' * 60}")
+                for key, value in sample_data.items():
+                    print(f"  {key}: {value}")
+                return sample_data
+            else:
+                print(f"表 {schema}.{table_name} 中没有数据")
+                return None
+
+        except Exception as e:
+            print(f"查询表 {schema}.{table_name} 失败: {str(e)}")
+            return None
+
+
+def query_table_data(app, schema: str, table_name: str):
+    """查询指定表的所有数据"""
+    with app.app_context():
+        try:
+            check_sql = text("""
+                SELECT EXISTS (
+                    SELECT FROM information_schema.tables
+                    WHERE table_schema = :schema
+                    AND table_name = :table
+                )
+            """)
+            exists = db.session.execute(
+                check_sql, {"schema": schema, "table": table_name}
+            ).scalar()
+
+            if not exists:
+                print(f"表 {schema}.{table_name} 不存在")
+                return None
+
+            # 获取数据总数
+            count_sql = text(f'SELECT COUNT(*) FROM "{schema}"."{table_name}"')
+            count = db.session.execute(count_sql).scalar()
+
+            # 获取样例数据
+            query_sql = text(f'SELECT * FROM "{schema}"."{table_name}" LIMIT 5')
+            result = db.session.execute(query_sql)
+            rows = result.fetchall()
+            column_names = list(result.keys())
+
+            print(f"\n{'=' * 60}")
+            print(f"表数据: {schema}.{table_name} (总计 {count} 条记录)")
+            print(f"{'=' * 60}")
+
+            if rows:
+                # 打印列名
+                print(f"列名: {column_names}")
+                print(f"{'-' * 60}")
+                for i, row in enumerate(rows):
+                    print(f"[{i + 1}] {dict(zip(column_names, row))}")
+                return [dict(zip(column_names, row)) for row in rows]
+            else:
+                print("表中没有数据")
+                return []
+
+        except Exception as e:
+            print(f"查询表 {schema}.{table_name} 失败: {str(e)}")
+            return None
+
+
+def test_lineage_visualization_api(app, product_id: int, sample_data: dict):
+    """测试血缘可视化 API(直接调用服务层)"""
+    from app.core.data_service.data_product_service import DataProductService
+
+    with app.app_context():
+        print(f"\n{'=' * 60}")
+        print(f"测试血缘可视化: product_id={product_id}")
+        print(f"{'=' * 60}")
+
+        print("请求体 sample_data:")
+        print(json.dumps(sample_data, ensure_ascii=False, indent=2, default=str))
+
+        try:
+            result = DataProductService.get_data_lineage_visualization(
+                product_id=product_id,
+                sample_data=sample_data,
+            )
+
+            print("\n响应内容:")
+            print(json.dumps(result, ensure_ascii=False, indent=2, default=str))
+
+            # 打印节点详情
+            if result.get("nodes"):
+                print(f"\n{'=' * 60}")
+                print(f"节点详情 (共 {len(result['nodes'])} 个)")
+                print(f"{'=' * 60}")
+                for node in result["nodes"]:
+                    print(f"\n  节点 ID: {node.get('id')}")
+                    print(f"    中文名: {node.get('name_zh')}")
+                    print(f"    英文名: {node.get('name_en')}")
+                    print(f"    类型: {node.get('node_type')}")
+                    print(f"    标签: {node.get('labels')}")
+                    print(f"    是目标节点: {node.get('is_target')}")
+                    print(f"    是源节点: {node.get('is_source')}")
+
+                    if node.get("matched_fields"):
+                        print(f"    匹配字段数: {len(node['matched_fields'])}")
+                        for field in node["matched_fields"]:
+                            print(
+                                f"      - {field.get('name_zh')} ({field.get('name_en')}): {field.get('value')}"
+                            )
+
+                    if node.get("matched_data"):
+                        print("    匹配数据:")
+                        for data in node["matched_data"][:2]:  # 只显示前2条
+                            print(f"      {data}")
+
+            # 打印关系详情
+            if result.get("lines"):
+                print(f"\n{'=' * 60}")
+                print(f"关系详情 (共 {len(result['lines'])} 条)")
+                print(f"{'=' * 60}")
+                for line in result["lines"]:
+                    rel_type = line.get("text") or line.get("type", "UNKNOWN")
+                    print(f"  {line['from']} --[{rel_type}]--> {line['to']}")
+
+            return result
+
+        except Exception as e:
+            import traceback
+
+            print(f"调用失败: {str(e)}")
+            traceback.print_exc()
+            return None
+
+
+def query_neo4j_business_domain(app, table_name: str):
+    """查询 Neo4j 中与表名对应的 BusinessDomain 节点"""
+    with app.app_context():
+        from app.services.neo4j_driver import neo4j_driver
+
+        print(f"\n{'=' * 60}")
+        print(f"查询 Neo4j BusinessDomain: {table_name}")
+        print(f"{'=' * 60}")
+
+        with neo4j_driver.get_session() as session:
+            # 通过表名查找 BusinessDomain
+            query = """
+            MATCH (bd:BusinessDomain)
+            WHERE bd.name_en = $table_name OR bd.name = $table_name
+            RETURN id(bd) as bd_id, bd.name_zh as name_zh, bd.name_en as name_en,
+                   labels(bd) as labels, bd.description as description
+            """
+            result = session.run(query, {"table_name": table_name}).data()
+
+            if result:
+                for bd in result:
+                    print(f"  ID: {bd['bd_id']}")
+                    print(f"  中文名: {bd['name_zh']}")
+                    print(f"  英文名: {bd['name_en']}")
+                    print(f"  标签: {bd['labels']}")
+                    print(f"  描述: {bd.get('description', 'N/A')}")
+                    print("-" * 40)
+                return result
+            else:
+                print(f"未找到与 {table_name} 对应的 BusinessDomain 节点")
+
+                # 尝试模糊匹配
+                fuzzy_query = """
+                MATCH (bd:BusinessDomain)
+                WHERE bd.name_en CONTAINS $keyword OR bd.name CONTAINS $keyword
+                    OR bd.name_zh CONTAINS $keyword
+                RETURN id(bd) as bd_id, bd.name_zh as name_zh, bd.name_en as name_en,
+                       labels(bd) as labels
+                LIMIT 10
+                """
+                fuzzy_result = session.run(
+                    fuzzy_query, {"keyword": table_name.split("_")[0]}
+                ).data()
+
+                if fuzzy_result:
+                    print("\n可能相关的 BusinessDomain 节点:")
+                    for bd in fuzzy_result:
+                        print(
+                            f"  - {bd['name_zh']} ({bd['name_en']}) [ID: {bd['bd_id']}]"
+                        )
+
+                return None
+
+
+def main():
+    print(f"\n{'#' * 60}")
+    print("# 数据加工可视化 API 测试脚本")
+    print(f"# 时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
+    print(f"{'#' * 60}")
+
+    # 创建应用
+    app = create_app()
+
+    product_id = 23
+
+    # 1. 获取数据产品信息
+    product_info = get_product_info(app, product_id)
+
+    if not product_info:
+        print("无法获取数据产品信息,退出测试")
+        return
+
+    target_table = product_info.get("target_table")
+    target_schema = product_info.get("target_schema", "public")
+
+    # 2. 查询目标表数据
+    print(f"\n\n{'#' * 60}")
+    print("# 查询目标表数据")
+    print(f"{'#' * 60}")
+    sample_data = get_sample_data_from_table(app, target_schema, target_table)
+
+    # 3. 查询用户提到的两个表(在 dags schema 中)
+    print(f"\n\n{'#' * 60}")
+    print("# 查询 warehouse_inventory_summary 表 (dags schema)")
+    print(f"{'#' * 60}")
+    query_table_data(app, "dags", "warehouse_inventory_summary")
+
+    print(f"\n\n{'#' * 60}")
+    print("# 查询 test_product_inventory 表 (dags schema)")
+    print(f"{'#' * 60}")
+    query_table_data(app, "dags", "test_product_inventory")
+
+    # 4. 查询 Neo4j 中的相关节点
+    print(f"\n\n{'#' * 60}")
+    print("# 查询 Neo4j 图谱节点")
+    print(f"{'#' * 60}")
+
+    query_neo4j_business_domain(app, target_table)
+    query_neo4j_business_domain(app, "warehouse_inventory_summary")
+    query_neo4j_business_domain(app, "test_product_inventory")
+
+    # 5. 测试 API
+    if sample_data:
+        print(f"\n\n{'#' * 60}")
+        print("# 调用血缘可视化 API")
+        print(f"{'#' * 60}")
+        test_lineage_visualization_api(app, product_id, sample_data)
+    else:
+        print("\n没有样例数据,跳过 API 测试")
+
+        # 尝试使用手动构造的测试数据
+        print("\n尝试使用手动构造的测试数据...")
+        manual_sample_data = {
+            "产品名称": "测试产品",
+            "仓库名称": "测试仓库",
+            "库存数量": 100,
+        }
+        test_lineage_visualization_api(app, product_id, manual_sample_data)
+
+
+if __name__ == "__main__":
+    main()