| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- """
- 测试指标公式检查功能
- """
- import unittest
- from unittest.mock import Mock, patch, MagicMock
- from app.core.data_metric.metric_interface import metric_check
- class TestMetricCheck(unittest.TestCase):
- """测试 metric_check 函数"""
- @patch('app.core.data_metric.metric_interface.connect_graph')
- def test_simple_formula(self, mock_connect_graph):
- """测试简单的加法公式"""
- # Mock Neo4j session
- mock_driver = MagicMock()
- mock_session = MagicMock()
- mock_connect_graph.return_value = mock_driver
- mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
- mock_driver.session.return_value.__exit__ = Mock(return_value=None)
-
- # Mock query results
- mock_record1 = {
- 'n': {
- 'name': '单价',
- 'en_name': 'unit_price',
- 'createTime': '2024-01-15 10:00:00'
- },
- 'node_id': 101
- }
- mock_record2 = {
- 'n': {
- 'name': '数量',
- 'en_name': 'quantity',
- 'createTime': '2024-01-15 10:01:00'
- },
- 'node_id': 102
- }
-
- mock_session.run.side_effect = [
- Mock(single=Mock(return_value=mock_record1)),
- Mock(single=Mock(return_value=mock_record2))
- ]
-
- # 测试
- formula = "销售额 = 单价 + 数量"
- result = metric_check(formula)
-
- # 验证
- self.assertEqual(len(result), 2)
- self.assertTrue(any(item['variable'] == '单价' for item in result))
- self.assertTrue(any(item['variable'] == '数量' for item in result))
- self.assertTrue(all(item['findit'] == 1 for item in result))
- @patch('app.core.data_metric.metric_interface.connect_graph')
- def test_complex_formula(self, mock_connect_graph):
- """测试复杂的公式(包含多种运算符)"""
- # Mock Neo4j session
- mock_driver = MagicMock()
- mock_session = MagicMock()
- mock_connect_graph.return_value = mock_driver
- mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
- mock_driver.session.return_value.__exit__ = Mock(return_value=None)
-
- # Mock query results - 一个找到,一个未找到
- mock_record1 = {
- 'n': {
- 'name': '销售收入',
- 'en_name': 'sales_revenue',
- 'createTime': '2024-01-15 10:00:00'
- },
- 'node_id': 201
- }
-
- mock_session.run.side_effect = [
- Mock(single=Mock(return_value=mock_record1)),
- Mock(single=Mock(return_value=None)) # 成本未找到
- ]
-
- # 测试
- formula = "利润率 = (销售收入 - 成本) / 销售收入 * 100"
- result = metric_check(formula)
-
- # 验证
- self.assertEqual(len(result), 2)
- found_items = [item for item in result if item['findit'] == 1]
- not_found_items = [item for item in result if item['findit'] == 0]
- self.assertEqual(len(found_items), 1)
- self.assertEqual(len(not_found_items), 1)
- def test_formula_without_equals(self):
- """测试没有等号的公式"""
- formula = "销售额"
- result = metric_check(formula)
-
- # 验证返回空列表
- self.assertEqual(result, [])
- @patch('app.core.data_metric.metric_interface.connect_graph')
- def test_formula_with_numbers_only(self, mock_connect_graph):
- """测试只包含数字的公式"""
- # Mock Neo4j session
- mock_driver = MagicMock()
- mock_session = MagicMock()
- mock_connect_graph.return_value = mock_driver
- mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
- mock_driver.session.return_value.__exit__ = Mock(return_value=None)
-
- # 测试
- formula = "结果 = 100 + 200 * 3"
- result = metric_check(formula)
-
- # 验证返回空列表(因为没有中文变量)
- self.assertEqual(result, [])
- @patch('app.core.data_metric.metric_interface.connect_graph')
- def test_formula_with_chinese_brackets(self, mock_connect_graph):
- """测试包含中文括号的公式"""
- # Mock Neo4j session
- mock_driver = MagicMock()
- mock_session = MagicMock()
- mock_connect_graph.return_value = mock_driver
- mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
- mock_driver.session.return_value.__exit__ = Mock(return_value=None)
-
- # Mock query result
- mock_record = {
- 'n': {
- 'name': '收入',
- 'en_name': 'revenue',
- 'createTime': '2024-01-15 10:00:00'
- },
- 'node_id': 301
- }
-
- mock_session.run.side_effect = [
- Mock(single=Mock(return_value=mock_record)),
- Mock(single=Mock(return_value=mock_record))
- ]
-
- # 测试
- formula = "总额 = (收入 + 收入)"
- result = metric_check(formula)
-
- # 验证(收入会被去重)
- self.assertEqual(len(result), 1)
- self.assertEqual(result[0]['variable'], '收入')
- self.assertEqual(result[0]['findit'], 1)
- @patch('app.core.data_metric.metric_interface.connect_graph')
- def test_database_connection_failure(self, mock_connect_graph):
- """测试数据库连接失败"""
- # Mock connection failure
- mock_connect_graph.return_value = None
-
- # 测试
- formula = "销售额 = 单价 * 数量"
- result = metric_check(formula)
-
- # 验证返回空列表
- self.assertEqual(result, [])
- class TestMetricCheckAPI(unittest.TestCase):
- """测试 metric_check API 接口"""
- def setUp(self):
- """设置测试环境"""
- from app import create_app
- self.app = create_app('testing')
- self.client = self.app.test_client()
- self.ctx = self.app.app_context()
- self.ctx.push()
- def tearDown(self):
- """清理测试环境"""
- self.ctx.pop()
- @patch('app.api.data_metric.routes.metric_check')
- def test_api_success(self, mock_metric_check):
- """测试API成功调用"""
- # Mock返回值
- mock_metric_check.return_value = [
- {
- "variable": "单价",
- "name_zh": "单价",
- "name_en": "unit_price",
- "id": 101,
- "create_time": "2024-01-15 10:00:00",
- "findit": 1
- }
- ]
-
- # 发送请求
- response = self.client.post(
- '/api/data/metric/check',
- json={'formula': '销售额 = 单价 * 数量'}
- )
-
- # 验证响应
- self.assertEqual(response.status_code, 200)
- data = response.get_json()
- self.assertEqual(data['code'], 200)
- self.assertEqual(data['message'], 'success')
- self.assertTrue(isinstance(data['data'], list))
- def test_api_empty_formula(self):
- """测试空公式"""
- response = self.client.post(
- '/api/data/metric/check',
- json={'formula': ''}
- )
-
- # 验证响应
- self.assertEqual(response.status_code, 200)
- data = response.get_json()
- self.assertIn('error', data)
- def test_api_missing_formula(self):
- """测试缺少formula参数"""
- response = self.client.post(
- '/api/data/metric/check',
- json={}
- )
-
- # 验证响应
- self.assertEqual(response.status_code, 200)
- data = response.get_json()
- self.assertIn('error', data)
- if __name__ == '__main__':
- unittest.main()
|