test_metric_check.py 7.7 KB


  1. """
  2. 测试指标公式检查功能
  3. """
  4. import unittest
  5. from unittest.mock import Mock, patch, MagicMock
  6. from app.core.data_metric.metric_interface import metric_check
  7. class TestMetricCheck(unittest.TestCase):
  8. """测试 metric_check 函数"""
  9. @patch('app.core.data_metric.metric_interface.connect_graph')
  10. def test_simple_formula(self, mock_connect_graph):
  11. """测试简单的加法公式"""
  12. # Mock Neo4j session
  13. mock_driver = MagicMock()
  14. mock_session = MagicMock()
  15. mock_connect_graph.return_value = mock_driver
  16. mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
  17. mock_driver.session.return_value.__exit__ = Mock(return_value=None)
  18. # Mock query results
  19. mock_record1 = {
  20. 'n': {
  21. 'name': '单价',
  22. 'en_name': 'unit_price',
  23. 'createTime': '2024-01-15 10:00:00'
  24. },
  25. 'node_id': 101
  26. }
  27. mock_record2 = {
  28. 'n': {
  29. 'name': '数量',
  30. 'en_name': 'quantity',
  31. 'createTime': '2024-01-15 10:01:00'
  32. },
  33. 'node_id': 102
  34. }
  35. mock_session.run.side_effect = [
  36. Mock(single=Mock(return_value=mock_record1)),
  37. Mock(single=Mock(return_value=mock_record2))
  38. ]
  39. # 测试
  40. formula = "销售额 = 单价 + 数量"
  41. result = metric_check(formula)
  42. # 验证
  43. self.assertEqual(len(result), 2)
  44. self.assertTrue(any(item['variable'] == '单价' for item in result))
  45. self.assertTrue(any(item['variable'] == '数量' for item in result))
  46. self.assertTrue(all(item['findit'] == 1 for item in result))
  47. @patch('app.core.data_metric.metric_interface.connect_graph')
  48. def test_complex_formula(self, mock_connect_graph):
  49. """测试复杂的公式(包含多种运算符)"""
  50. # Mock Neo4j session
  51. mock_driver = MagicMock()
  52. mock_session = MagicMock()
  53. mock_connect_graph.return_value = mock_driver
  54. mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
  55. mock_driver.session.return_value.__exit__ = Mock(return_value=None)
  56. # Mock query results - 一个找到,一个未找到
  57. mock_record1 = {
  58. 'n': {
  59. 'name': '销售收入',
  60. 'en_name': 'sales_revenue',
  61. 'createTime': '2024-01-15 10:00:00'
  62. },
  63. 'node_id': 201
  64. }
  65. mock_session.run.side_effect = [
  66. Mock(single=Mock(return_value=mock_record1)),
  67. Mock(single=Mock(return_value=None)) # 成本未找到
  68. ]
  69. # 测试
  70. formula = "利润率 = (销售收入 - 成本) / 销售收入 * 100"
  71. result = metric_check(formula)
  72. # 验证
  73. self.assertEqual(len(result), 2)
  74. found_items = [item for item in result if item['findit'] == 1]
  75. not_found_items = [item for item in result if item['findit'] == 0]
  76. self.assertEqual(len(found_items), 1)
  77. self.assertEqual(len(not_found_items), 1)
  78. def test_formula_without_equals(self):
  79. """测试没有等号的公式"""
  80. formula = "销售额"
  81. result = metric_check(formula)
  82. # 验证返回空列表
  83. self.assertEqual(result, [])
  84. @patch('app.core.data_metric.metric_interface.connect_graph')
  85. def test_formula_with_numbers_only(self, mock_connect_graph):
  86. """测试只包含数字的公式"""
  87. # Mock Neo4j session
  88. mock_driver = MagicMock()
  89. mock_session = MagicMock()
  90. mock_connect_graph.return_value = mock_driver
  91. mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
  92. mock_driver.session.return_value.__exit__ = Mock(return_value=None)
  93. # 测试
  94. formula = "结果 = 100 + 200 * 3"
  95. result = metric_check(formula)
  96. # 验证返回空列表(因为没有中文变量)
  97. self.assertEqual(result, [])
  98. @patch('app.core.data_metric.metric_interface.connect_graph')
  99. def test_formula_with_chinese_brackets(self, mock_connect_graph):
  100. """测试包含中文括号的公式"""
  101. # Mock Neo4j session
  102. mock_driver = MagicMock()
  103. mock_session = MagicMock()
  104. mock_connect_graph.return_value = mock_driver
  105. mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
  106. mock_driver.session.return_value.__exit__ = Mock(return_value=None)
  107. # Mock query result
  108. mock_record = {
  109. 'n': {
  110. 'name': '收入',
  111. 'en_name': 'revenue',
  112. 'createTime': '2024-01-15 10:00:00'
  113. },
  114. 'node_id': 301
  115. }
  116. mock_session.run.side_effect = [
  117. Mock(single=Mock(return_value=mock_record)),
  118. Mock(single=Mock(return_value=mock_record))
  119. ]
  120. # 测试
  121. formula = "总额 = (收入 + 收入)"
  122. result = metric_check(formula)
  123. # 验证(收入会被去重)
  124. self.assertEqual(len(result), 1)
  125. self.assertEqual(result[0]['variable'], '收入')
  126. self.assertEqual(result[0]['findit'], 1)
  127. @patch('app.core.data_metric.metric_interface.connect_graph')
  128. def test_database_connection_failure(self, mock_connect_graph):
  129. """测试数据库连接失败"""
  130. # Mock connection failure
  131. mock_connect_graph.return_value = None
  132. # 测试
  133. formula = "销售额 = 单价 * 数量"
  134. result = metric_check(formula)
  135. # 验证返回空列表
  136. self.assertEqual(result, [])
  137. class TestMetricCheckAPI(unittest.TestCase):
  138. """测试 metric_check API 接口"""
  139. def setUp(self):
  140. """设置测试环境"""
  141. from app import create_app
  142. self.app = create_app('testing')
  143. self.client = self.app.test_client()
  144. self.ctx = self.app.app_context()
  145. self.ctx.push()
  146. def tearDown(self):
  147. """清理测试环境"""
  148. self.ctx.pop()
  149. @patch('app.api.data_metric.routes.metric_check')
  150. def test_api_success(self, mock_metric_check):
  151. """测试API成功调用"""
  152. # Mock返回值
  153. mock_metric_check.return_value = [
  154. {
  155. "variable": "单价",
  156. "name_zh": "单价",
  157. "name_en": "unit_price",
  158. "id": 101,
  159. "create_time": "2024-01-15 10:00:00",
  160. "findit": 1
  161. }
  162. ]
  163. # 发送请求
  164. response = self.client.post(
  165. '/api/data/metric/check',
  166. json={'formula': '销售额 = 单价 * 数量'}
  167. )
  168. # 验证响应
  169. self.assertEqual(response.status_code, 200)
  170. data = response.get_json()
  171. self.assertEqual(data['code'], 200)
  172. self.assertEqual(data['message'], 'success')
  173. self.assertTrue(isinstance(data['data'], list))
  174. def test_api_empty_formula(self):
  175. """测试空公式"""
  176. response = self.client.post(
  177. '/api/data/metric/check',
  178. json={'formula': ''}
  179. )
  180. # 验证响应
  181. self.assertEqual(response.status_code, 200)
  182. data = response.get_json()
  183. self.assertIn('error', data)
  184. def test_api_missing_formula(self):
  185. """测试缺少formula参数"""
  186. response = self.client.post(
  187. '/api/data/metric/check',
  188. json={}
  189. )
  190. # 验证响应
  191. self.assertEqual(response.status_code, 200)
  192. data = response.get_json()
  193. self.assertIn('error', data)
  194. if __name__ == '__main__':
  195. unittest.main()