evaluator.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # orm/evaluator.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. import operator
  8. from .. import inspect
  9. from .. import util
  10. from ..sql import and_
  11. from ..sql import operators
  12. from ..sql.sqltypes import Integer
  13. from ..sql.sqltypes import Numeric
  14. class UnevaluatableError(Exception):
  15. pass
  16. class _NoObject(operators.ColumnOperators):
  17. def operate(self, *arg, **kw):
  18. return None
  19. def reverse_operate(self, *arg, **kw):
  20. return None
  21. _NO_OBJECT = _NoObject()
  22. _straight_ops = set(
  23. getattr(operators, op)
  24. for op in (
  25. "lt",
  26. "le",
  27. "ne",
  28. "gt",
  29. "ge",
  30. "eq",
  31. )
  32. )
  33. _math_only_straight_ops = set(
  34. getattr(operators, op)
  35. for op in (
  36. "add",
  37. "mul",
  38. "sub",
  39. "div",
  40. "mod",
  41. "truediv",
  42. )
  43. )
  44. _extended_ops = {
  45. operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
  46. operators.not_in_op: (
  47. lambda a, b: a not in b if a is not _NO_OBJECT else None
  48. ),
  49. }
  50. _notimplemented_ops = set(
  51. getattr(operators, op)
  52. for op in (
  53. "like_op",
  54. "not_like_op",
  55. "ilike_op",
  56. "not_ilike_op",
  57. "startswith_op",
  58. "between_op",
  59. "endswith_op",
  60. )
  61. )
  62. class EvaluatorCompiler(object):
  63. def __init__(self, target_cls=None):
  64. self.target_cls = target_cls
  65. def process(self, *clauses):
  66. if len(clauses) > 1:
  67. clause = and_(*clauses)
  68. elif clauses:
  69. clause = clauses[0]
  70. meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
  71. if not meth:
  72. raise UnevaluatableError(
  73. "Cannot evaluate %s" % type(clause).__name__
  74. )
  75. return meth(clause)
  76. def visit_grouping(self, clause):
  77. return self.process(clause.element)
  78. def visit_null(self, clause):
  79. return lambda obj: None
  80. def visit_false(self, clause):
  81. return lambda obj: False
  82. def visit_true(self, clause):
  83. return lambda obj: True
  84. def visit_column(self, clause):
  85. if "parentmapper" in clause._annotations:
  86. parentmapper = clause._annotations["parentmapper"]
  87. if self.target_cls and not issubclass(
  88. self.target_cls, parentmapper.class_
  89. ):
  90. raise UnevaluatableError(
  91. "Can't evaluate criteria against alternate class %s"
  92. % parentmapper.class_
  93. )
  94. key = parentmapper._columntoproperty[clause].key
  95. else:
  96. key = clause.key
  97. if (
  98. self.target_cls
  99. and key in inspect(self.target_cls).column_attrs
  100. ):
  101. util.warn(
  102. "Evaluating non-mapped column expression '%s' onto "
  103. "ORM instances; this is a deprecated use case. Please "
  104. "make use of the actual mapped columns in ORM-evaluated "
  105. "UPDATE / DELETE expressions." % clause
  106. )
  107. else:
  108. raise UnevaluatableError("Cannot evaluate column: %s" % clause)
  109. get_corresponding_attr = operator.attrgetter(key)
  110. return (
  111. lambda obj: get_corresponding_attr(obj)
  112. if obj is not None
  113. else _NO_OBJECT
  114. )
  115. def visit_tuple(self, clause):
  116. return self.visit_clauselist(clause)
  117. def visit_clauselist(self, clause):
  118. evaluators = list(map(self.process, clause.clauses))
  119. if clause.operator is operators.or_:
  120. def evaluate(obj):
  121. has_null = False
  122. for sub_evaluate in evaluators:
  123. value = sub_evaluate(obj)
  124. if value:
  125. return True
  126. has_null = has_null or value is None
  127. if has_null:
  128. return None
  129. return False
  130. elif clause.operator is operators.and_:
  131. def evaluate(obj):
  132. for sub_evaluate in evaluators:
  133. value = sub_evaluate(obj)
  134. if not value:
  135. if value is None or value is _NO_OBJECT:
  136. return None
  137. return False
  138. return True
  139. elif clause.operator is operators.comma_op:
  140. def evaluate(obj):
  141. values = []
  142. for sub_evaluate in evaluators:
  143. value = sub_evaluate(obj)
  144. if value is None or value is _NO_OBJECT:
  145. return None
  146. values.append(value)
  147. return tuple(values)
  148. else:
  149. raise UnevaluatableError(
  150. "Cannot evaluate clauselist with operator %s" % clause.operator
  151. )
  152. return evaluate
  153. def visit_binary(self, clause):
  154. eval_left, eval_right = list(
  155. map(self.process, [clause.left, clause.right])
  156. )
  157. operator = clause.operator
  158. if operator is operators.is_:
  159. def evaluate(obj):
  160. return eval_left(obj) == eval_right(obj)
  161. elif operator is operators.is_not:
  162. def evaluate(obj):
  163. return eval_left(obj) != eval_right(obj)
  164. elif operator is operators.concat_op:
  165. def evaluate(obj):
  166. return eval_left(obj) + eval_right(obj)
  167. elif operator in _extended_ops:
  168. def evaluate(obj):
  169. left_val = eval_left(obj)
  170. right_val = eval_right(obj)
  171. if left_val is None or right_val is None:
  172. return None
  173. return _extended_ops[operator](left_val, right_val)
  174. elif operator in _math_only_straight_ops:
  175. if (
  176. clause.left.type._type_affinity
  177. not in (
  178. Numeric,
  179. Integer,
  180. )
  181. or clause.right.type._type_affinity not in (Numeric, Integer)
  182. ):
  183. raise UnevaluatableError(
  184. 'Cannot evaluate math operator "%s" for '
  185. "datatypes %s, %s"
  186. % (operator.__name__, clause.left.type, clause.right.type)
  187. )
  188. def evaluate(obj):
  189. left_val = eval_left(obj)
  190. right_val = eval_right(obj)
  191. if left_val is None or right_val is None:
  192. return None
  193. return operator(eval_left(obj), eval_right(obj))
  194. elif operator in _straight_ops:
  195. def evaluate(obj):
  196. left_val = eval_left(obj)
  197. right_val = eval_right(obj)
  198. if left_val is None or right_val is None:
  199. return None
  200. return operator(eval_left(obj), eval_right(obj))
  201. else:
  202. raise UnevaluatableError(
  203. "Cannot evaluate %s with operator %s"
  204. % (type(clause).__name__, clause.operator)
  205. )
  206. return evaluate
  207. def visit_unary(self, clause):
  208. eval_inner = self.process(clause.element)
  209. if clause.operator is operators.inv:
  210. def evaluate(obj):
  211. value = eval_inner(obj)
  212. if value is None:
  213. return None
  214. return not value
  215. return evaluate
  216. raise UnevaluatableError(
  217. "Cannot evaluate %s with operator %s"
  218. % (type(clause).__name__, clause.operator)
  219. )
  220. def visit_bindparam(self, clause):
  221. if clause.callable:
  222. val = clause.callable()
  223. else:
  224. val = clause.value
  225. return lambda obj: val