interface.py 38 KB


  1. # -*- coding: utf-8 -*-
  2. from contextlib import suppress
  3. import logging
  4. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  5. from flask_appbuilder._compat import as_unicode
  6. from flask_appbuilder.const import (
  7. LOGMSG_ERR_DBI_DEL_GENERIC,
  8. LOGMSG_WAR_DBI_ADD_INTEGRITY,
  9. LOGMSG_WAR_DBI_DEL_INTEGRITY,
  10. LOGMSG_WAR_DBI_EDIT_INTEGRITY,
  11. )
  12. from flask_appbuilder.exceptions import InterfaceQueryWithoutSession
  13. from flask_appbuilder.filemanager import FileManager, ImageManager
  14. from flask_appbuilder.models.base import BaseInterface
  15. from flask_appbuilder.models.filters import Filters
  16. from flask_appbuilder.models.group import GroupByCol, GroupByDateMonth, GroupByDateYear
  17. from flask_appbuilder.models.mixins import FileColumn, ImageColumn
  18. from flask_appbuilder.models.sqla import filters, Model
  19. from flask_appbuilder.utils.base import (
  20. get_column_leaf,
  21. get_column_root_relation,
  22. is_column_dotted,
  23. )
  24. from sqlalchemy import asc, desc
  25. from sqlalchemy import types as sa_types
  26. from sqlalchemy.exc import IntegrityError
  27. from sqlalchemy.orm import aliased, class_mapper, ColumnProperty, contains_eager, Load
  28. from sqlalchemy.orm.descriptor_props import SynonymProperty
  29. from sqlalchemy.orm.properties import RelationshipProperty
  30. from sqlalchemy.orm.query import Query
  31. from sqlalchemy.orm.session import Session as SessionBase
  32. from sqlalchemy.orm.util import AliasedClass
  33. from sqlalchemy.sql import visitors
  34. from sqlalchemy.sql.elements import BinaryExpression
  35. from sqlalchemy.sql.sqltypes import TypeEngine
  36. from sqlalchemy_utils.types.uuid import UUIDType
  37. log = logging.getLogger(__name__)
  38. def _is_sqla_type(model: Model, sa_type: Type[TypeEngine]) -> bool:
  39. return (
  40. isinstance(model, sa_type)
  41. or isinstance(model, sa_types.TypeDecorator)
  42. and isinstance(model.impl, sa_type)
  43. )
  44. class SQLAInterface(BaseInterface):
  45. """
  46. SQLAModel
  47. Implements SQLA support methods for views
  48. """
  49. filter_converter_class = filters.SQLAFilterConverter
  50. def __init__(self, obj: Type[Model], session: Optional[SessionBase] = None) -> None:
  51. _include_filters(self)
  52. self.list_columns = dict()
  53. self.list_properties = dict()
  54. self.session = session
  55. # Collect all SQLA columns and properties
  56. for prop in class_mapper(obj).iterate_properties:
  57. if type(prop) != SynonymProperty:
  58. self.list_properties[prop.key] = prop
  59. for col_name in obj.__mapper__.columns.keys():
  60. if col_name in self.list_properties:
  61. self.list_columns[col_name] = obj.__mapper__.columns[col_name]
  62. super(SQLAInterface, self).__init__(obj)
  63. @property
  64. def model_name(self):
  65. """
  66. Returns the models class name
  67. useful for auto title on views
  68. """
  69. return self.obj.__name__
  70. @staticmethod
  71. def is_model_already_joined(query: Query, model: Type[Model]) -> bool:
  72. if hasattr(query, "_join_entities"): # For SQLAlchemy < 1.3
  73. return model in [mapper.class_ for mapper in query._join_entities]
  74. # Solution for SQLAlchemy >= 1.4
  75. model_table_name = model.__table__.fullname
  76. for visitor in visitors.iterate(query.statement):
  77. # Checking for `.join(Parent.child)` clauses
  78. if visitor.__visit_name__ == "alias":
  79. _visitor = visitor.element
  80. else:
  81. _visitor = visitor
  82. if _visitor.__visit_name__ == "select":
  83. continue
  84. if _visitor.__visit_name__ == "binary":
  85. for vis in visitors.iterate(_visitor):
  86. # Visitor might not have table attribute
  87. with suppress(AttributeError):
  88. # Verify if already present based on table name
  89. if model_table_name == vis.table.fullname:
  90. return True
  91. # Checking for `.join(Child)` clauses
  92. if _visitor.__visit_name__ == "table":
  93. # Visitor might be of ColumnCollection or so,
  94. # which cannot be compared to model
  95. if model_table_name == _visitor.fullname:
  96. return True
  97. # Checking for `Model.column` clauses
  98. if _visitor.__visit_name__ == "column":
  99. with suppress(AttributeError):
  100. if model_table_name == _visitor.table.fullname:
  101. return True
  102. return False
  103. def _get_base_query(
  104. self, query=None, filters=None, order_column="", order_direction=""
  105. ):
  106. if filters:
  107. query = filters.apply_all(query)
  108. return self.apply_order_by(query, order_column, order_direction)
  109. def _query_join_relation(
  110. self,
  111. query: Query,
  112. root_relation: str,
  113. aliases_mapping: Dict[str, AliasedClass] = None,
  114. ) -> Query:
  115. """
  116. Helper function that applies necessary joins for dotted columns on a
  117. SQLAlchemy query object
  118. :param query: SQLAlchemy query object
  119. :param root_relation: The root part of a dotted column, so the root relation
  120. :return: Transformed SQLAlchemy Query
  121. """
  122. if aliases_mapping is None:
  123. aliases_mapping = {}
  124. relations = self.get_related_model_and_join(root_relation)
  125. for relation in relations:
  126. model_relation, relation_join = relation
  127. # Use alias if it's not a custom relation
  128. if not hasattr(relation_join, "clauses"):
  129. model_relation = aliased(model_relation, name=root_relation)
  130. aliases_mapping[root_relation] = model_relation
  131. relation_pk = self.get_pk(model_relation)
  132. if relation_join.left.foreign_keys:
  133. relation_join = BinaryExpression(
  134. relation_join.left, relation_pk, relation_join.operator
  135. )
  136. else:
  137. relation_join = BinaryExpression(
  138. relation_join.right, relation_pk, relation_join.operator
  139. )
  140. query = query.join(model_relation, relation_join, isouter=True)
  141. return query
  142. def apply_engine_specific_hack(
  143. self,
  144. query: Query,
  145. page: Optional[int],
  146. page_size: Optional[int],
  147. order_column: Optional[str],
  148. ) -> Query:
  149. # MSSQL exception page/limit must have an order by
  150. if (
  151. page
  152. and page_size
  153. and not order_column
  154. and self.session.bind.dialect.name == "mssql"
  155. ):
  156. pk_name = self.get_pk_name()
  157. return query.order_by(pk_name)
  158. return query
  159. def apply_order_by(
  160. self,
  161. query: Query,
  162. order_column: str,
  163. order_direction: str,
  164. aliases_mapping: Dict[str, AliasedClass] = None,
  165. ) -> Query:
  166. if order_column != "":
  167. # if Model has custom decorator **renders('<COL_NAME>')**
  168. # this decorator will add a property to the method named *_col_name*
  169. if hasattr(self.obj, order_column):
  170. if hasattr(getattr(self.obj, order_column), "_col_name"):
  171. order_column = getattr(self._get_attr(order_column), "_col_name")
  172. _order_column = self._get_attr(order_column) or order_column
  173. if is_column_dotted(order_column):
  174. root_relation = get_column_root_relation(order_column)
  175. # On MVC we still allow for joins to happen here
  176. if not self.is_model_already_joined(
  177. query, self.get_related_model(root_relation)
  178. ):
  179. query = self._query_join_relation(
  180. query, root_relation, aliases_mapping=aliases_mapping
  181. )
  182. column_leaf = get_column_leaf(order_column)
  183. _alias = self.get_alias_mapping(root_relation, aliases_mapping)
  184. _order_column = getattr(_alias, column_leaf)
  185. if order_direction == "asc":
  186. query = query.order_by(asc(_order_column))
  187. else:
  188. query = query.order_by(desc(_order_column))
  189. return query
  190. def apply_pagination(
  191. self, query: Query, page: Optional[int], page_size: Optional[int]
  192. ) -> Query:
  193. if page and page_size:
  194. query = query.offset(page * page_size)
  195. if page_size:
  196. query = query.limit(page_size)
  197. return query
  198. def apply_filters(self, query: Query, filters: Optional[Filters]) -> Query:
  199. if filters:
  200. return filters.apply_all(query)
  201. return query
  202. def _apply_normal_col_select_option(self, query: Query, column: str) -> Query:
  203. if not self.is_relation(column) and not self.is_property_or_function(column):
  204. return query.options(Load(self.obj).load_only(column))
  205. return query
  206. def _apply_relation_fks_select_options(self, query: Query, relation_name) -> Query:
  207. relation = getattr(self.obj, relation_name)
  208. if hasattr(relation, "property"):
  209. local_cols = getattr(self.obj, relation_name).property.local_columns
  210. for local_fk in local_cols:
  211. query = query.options(Load(self.obj).load_only(local_fk.name))
  212. return query
  213. return query
  214. def apply_inner_select_joins(
  215. self,
  216. query: Query,
  217. select_columns: List[str] = None,
  218. aliases_mapping: Dict[str, AliasedClass] = None,
  219. ) -> Query:
  220. """
  221. Add select load options to query. The goal
  222. is to only SQL select what is requested and join all the necessary
  223. models when dotted notation is used. Inner implies non dotted columns
  224. and many to one and one to one
  225. :param query:
  226. :param select_columns:
  227. :return:
  228. """
  229. if not select_columns:
  230. return query
  231. joined_models = []
  232. for column in select_columns:
  233. if not is_column_dotted(column):
  234. query = self._apply_normal_col_select_option(query, column)
  235. continue
  236. # Dotted column
  237. root_relation = get_column_root_relation(column)
  238. leaf_column = get_column_leaf(column)
  239. related_model = self.get_alias_mapping(root_relation, aliases_mapping)
  240. relation = getattr(self.obj, root_relation)
  241. if self.is_relation_many_to_one(
  242. root_relation
  243. ) or self.is_relation_many_to_many_special(root_relation):
  244. if root_relation not in joined_models:
  245. query = self._query_join_relation(
  246. query, root_relation, aliases_mapping=aliases_mapping
  247. )
  248. query = query.add_entity(
  249. self.get_alias_mapping(root_relation, aliases_mapping)
  250. )
  251. # Add relation FK to avoid N+1 performance issue
  252. query = self._apply_relation_fks_select_options(
  253. query, root_relation
  254. )
  255. joined_models.append(root_relation)
  256. related_model = self.get_alias_mapping(root_relation, aliases_mapping)
  257. relation = getattr(self.obj, root_relation)
  258. # The Zen of eager loading :(
  259. # https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html
  260. query = query.options(
  261. contains_eager(relation.of_type(related_model)).load_only(
  262. leaf_column
  263. )
  264. )
  265. query = query.options(Load(related_model).load_only(leaf_column))
  266. return query
  267. def apply_outer_select_joins(
  268. self,
  269. query: Query,
  270. select_columns: List[str] = None,
  271. outer_default_load: bool = False,
  272. ) -> Query:
  273. if not select_columns:
  274. return query
  275. for column in select_columns:
  276. if not is_column_dotted(column):
  277. query = self._apply_normal_col_select_option(query, column)
  278. continue
  279. root_relation = get_column_root_relation(column)
  280. leaf_column = get_column_leaf(column)
  281. if self.is_relation_many_to_many(
  282. root_relation
  283. ) or self.is_relation_one_to_many(root_relation):
  284. if outer_default_load:
  285. query = query.options(
  286. Load(self.obj).defaultload(root_relation).load_only(leaf_column)
  287. )
  288. else:
  289. query = query.options(
  290. Load(self.obj).joinedload(root_relation).load_only(leaf_column)
  291. )
  292. else:
  293. related_model = self.get_related_model(root_relation)
  294. query = query.options(Load(related_model).load_only(leaf_column))
  295. return query
  296. def get_inner_filters(self, filters: Optional[Filters]) -> Filters:
  297. """
  298. Inner filters are non dotted columns and
  299. one to many or one to one relations
  300. :param filters: All filters
  301. :return: New filtered filters to apply to an inner query
  302. """
  303. inner_filters = Filters(self.filter_converter_class, self)
  304. _filters = []
  305. if filters:
  306. for flt, value in zip(filters.filters, filters.values):
  307. if not is_column_dotted(flt.column_name):
  308. _filters.append((flt.column_name, flt.__class__, value))
  309. elif self.is_relation_many_to_one(
  310. get_column_root_relation(flt.column_name)
  311. ) or self.is_relation_one_to_one(
  312. get_column_root_relation(flt.column_name)
  313. ):
  314. _filters.append((flt.column_name, flt.__class__, value))
  315. inner_filters.add_filter_list(_filters)
  316. return inner_filters
  317. def exists_col_to_many(self, select_columns: List[str]) -> bool:
  318. for column in select_columns:
  319. if is_column_dotted(column):
  320. root_relation = get_column_root_relation(column)
  321. if self.is_relation_many_to_many(
  322. root_relation
  323. ) or self.is_relation_one_to_many(root_relation):
  324. return True
  325. return False
  326. def get_alias_mapping(
  327. self, model_name: str, aliases_mapping: Dict[str, AliasedClass]
  328. ) -> Union[AliasedClass, Type[Model]]:
  329. if aliases_mapping is None:
  330. return self.get_related_model(model_name)
  331. return aliases_mapping.get(model_name, self.get_related_model(model_name))
  332. def _apply_inner_all(
  333. self,
  334. query: Query,
  335. filters: Optional[Filters] = None,
  336. order_column: str = "",
  337. order_direction: str = "",
  338. page: Optional[int] = None,
  339. page_size: Optional[int] = None,
  340. select_columns: Optional[List[str]] = None,
  341. aliases_mapping: Dict[str, AliasedClass] = None,
  342. ) -> Query:
  343. inner_filters = self.get_inner_filters(filters)
  344. query = self.apply_inner_select_joins(query, select_columns, aliases_mapping)
  345. query = self.apply_filters(query, inner_filters)
  346. query = self.apply_engine_specific_hack(query, page, page_size, order_column)
  347. query = self.apply_order_by(
  348. query, order_column, order_direction, aliases_mapping=aliases_mapping
  349. )
  350. query = self.apply_pagination(query, page, page_size)
  351. return query
  352. def query_count(
  353. self,
  354. query: Query,
  355. filters: Optional[Filters] = None,
  356. select_columns: Optional[List[str]] = None,
  357. ) -> int:
  358. return self._apply_inner_all(
  359. query, filters, select_columns=select_columns, aliases_mapping={}
  360. ).count()
  361. def apply_all(
  362. self,
  363. query: Query,
  364. filters: Optional[Filters] = None,
  365. order_column: str = "",
  366. order_direction: str = "",
  367. page: Optional[int] = None,
  368. page_size: Optional[int] = None,
  369. select_columns: Optional[List[str]] = None,
  370. outer_default_load: bool = False,
  371. ) -> Query:
  372. """
  373. Accepts a SQLAlchemy Query and applies all filtering logic, order by and
  374. pagination.
  375. :param query: The query to apply all
  376. :param filters:
  377. dict with filters {<col_name>:<value,...}
  378. :param order_column:
  379. name of the column to order
  380. :param order_direction:
  381. the direction to order <'asc'|'desc'>
  382. :param page:
  383. the current page
  384. :param page_size:
  385. the current page size
  386. :param select_columns:
  387. A List of columns to be specifically selected on the query
  388. :param outer_default_load: If True, the default load for outer joins will be
  389. applied. This is useful for when you want to control
  390. the load of the many-to-many relationships at the model level.
  391. we will apply:
  392. https://docs.sqlalchemy.org/en/14/orm/loading_relationships.html#sqlalchemy.orm.Load.defaultload
  393. :return: A SQLAlchemy Query with all the applied logic
  394. """
  395. aliases_mapping = {}
  396. inner_query = self._apply_inner_all(
  397. query,
  398. filters,
  399. order_column,
  400. order_direction,
  401. page,
  402. page_size,
  403. select_columns,
  404. aliases_mapping=aliases_mapping,
  405. )
  406. # Only use a from_self if we need to select a join one to many or many to many
  407. if select_columns and self.exists_col_to_many(select_columns):
  408. if select_columns and order_column:
  409. select_columns = select_columns + [order_column]
  410. outer_query = inner_query.from_self()
  411. outer_query = self.apply_outer_select_joins(
  412. outer_query, select_columns, outer_default_load=outer_default_load
  413. )
  414. return self.apply_order_by(outer_query, order_column, order_direction)
  415. else:
  416. return inner_query
  417. def query(
  418. self,
  419. filters: Optional[Filters] = None,
  420. order_column: str = "",
  421. order_direction: str = "",
  422. page: Optional[int] = None,
  423. page_size: Optional[int] = None,
  424. select_columns: Optional[List[str]] = None,
  425. outer_default_load: bool = False,
  426. ) -> Tuple[int, List[Model]]:
  427. """
  428. Returns the results for a model query, applies filters, sorting and pagination
  429. :param filters: A Filter class that contains all filters to apply
  430. :param order_column: name of the column to order
  431. :param order_direction: the direction to order <'asc'|'desc'>
  432. :param page: the current page
  433. :param page_size: the current page size
  434. :param select_columns: A List of columns to be specifically selected
  435. on the query. Supports dotted notation.
  436. :param outer_default_load: If True, the default load for outer joins will be
  437. applied. This is useful for when you want to control
  438. the load of the many-to-many relationships at the model level.
  439. we will apply:
  440. https://docs.sqlalchemy.org/en/14/orm/loading_relationships.html#sqlalchemy.orm.Load.defaultload
  441. :return: A tuple with the query count (non paginated) and the results
  442. """
  443. if not self.session:
  444. raise InterfaceQueryWithoutSession()
  445. query = self.session.query(self.obj)
  446. count = self.query_count(query, filters, select_columns)
  447. query = self.apply_all(
  448. query,
  449. filters,
  450. order_column,
  451. order_direction,
  452. page,
  453. page_size,
  454. select_columns,
  455. )
  456. query_results = query.all()
  457. result = []
  458. for item in query_results:
  459. if hasattr(item, self.obj.__name__):
  460. result.append(getattr(item, self.obj.__name__))
  461. else:
  462. return count, query_results
  463. return count, result
  464. def query_simple_group(
  465. self, group_by="", aggregate_func=None, aggregate_col=None, filters=None
  466. ):
  467. query = self.session.query(self.obj)
  468. query = self._get_base_query(query=query, filters=filters)
  469. query_result = query.all()
  470. group = GroupByCol(group_by, "Group by")
  471. return group.apply(query_result)
  472. def query_month_group(self, group_by="", filters=None):
  473. query = self.session.query(self.obj)
  474. query = self._get_base_query(query=query, filters=filters)
  475. query_result = query.all()
  476. group = GroupByDateMonth(group_by, "Group by Month")
  477. return group.apply(query_result)
  478. def query_year_group(self, group_by="", filters=None):
  479. query = self.session.query(self.obj)
  480. query = self._get_base_query(query=query, filters=filters)
  481. query_result = query.all()
  482. group_year = GroupByDateYear(group_by, "Group by Year")
  483. return group_year.apply(query_result)
  484. """
  485. -----------------------------------------
  486. FUNCTIONS for Testing TYPES
  487. -----------------------------------------
  488. """
  489. def is_image(self, col_name: str) -> bool:
  490. try:
  491. return isinstance(self.list_columns[col_name].type, ImageColumn)
  492. except KeyError:
  493. return False
  494. def is_file(self, col_name: str) -> bool:
  495. try:
  496. return isinstance(self.list_columns[col_name].type, FileColumn)
  497. except KeyError:
  498. return False
  499. def is_string(self, col_name: str) -> bool:
  500. try:
  501. return (
  502. _is_sqla_type(self.list_columns[col_name].type, sa_types.String)
  503. or self.list_columns[col_name].type.__class__ == UUIDType
  504. )
  505. except KeyError:
  506. return False
  507. def is_text(self, col_name: str) -> bool:
  508. try:
  509. return _is_sqla_type(self.list_columns[col_name].type, sa_types.Text)
  510. except KeyError:
  511. return False
  512. def is_binary(self, col_name: str) -> bool:
  513. try:
  514. return _is_sqla_type(self.list_columns[col_name].type, sa_types.LargeBinary)
  515. except KeyError:
  516. return False
  517. def is_integer(self, col_name: str) -> bool:
  518. try:
  519. return _is_sqla_type(self.list_columns[col_name].type, sa_types.Integer)
  520. except KeyError:
  521. return False
  522. def is_numeric(self, col_name: str) -> bool:
  523. try:
  524. return _is_sqla_type(self.list_columns[col_name].type, sa_types.Numeric)
  525. except KeyError:
  526. return False
  527. def is_float(self, col_name: str) -> bool:
  528. try:
  529. return _is_sqla_type(self.list_columns[col_name].type, sa_types.Float)
  530. except KeyError:
  531. return False
  532. def is_boolean(self, col_name: str) -> bool:
  533. try:
  534. return _is_sqla_type(self.list_columns[col_name].type, sa_types.Boolean)
  535. except KeyError:
  536. return False
  537. def is_date(self, col_name: str) -> bool:
  538. try:
  539. return _is_sqla_type(self.list_columns[col_name].type, sa_types.Date)
  540. except KeyError:
  541. return False
  542. def is_datetime(self, col_name: str) -> bool:
  543. try:
  544. return _is_sqla_type(self.list_columns[col_name].type, sa_types.DateTime)
  545. except KeyError:
  546. return False
  547. def is_enum(self, col_name: str) -> bool:
  548. try:
  549. return _is_sqla_type(self.list_columns[col_name].type, sa_types.Enum)
  550. except KeyError:
  551. return False
  552. def is_relation(self, col_name: str) -> bool:
  553. try:
  554. return isinstance(self.list_properties[col_name], RelationshipProperty)
  555. except KeyError:
  556. return False
  557. def is_relation_many_to_one(self, col_name: str) -> bool:
  558. try:
  559. if self.is_relation(col_name):
  560. return self.list_properties[col_name].direction.name == "MANYTOONE"
  561. return False
  562. except KeyError:
  563. return False
  564. def is_relation_many_to_many(self, col_name: str) -> bool:
  565. try:
  566. if self.is_relation(col_name):
  567. relation = self.list_properties[col_name]
  568. return relation.direction.name == "MANYTOMANY"
  569. return False
  570. except KeyError:
  571. return False
  572. def is_relation_many_to_many_special(self, col_name: str) -> bool:
  573. try:
  574. if self.is_relation(col_name):
  575. relation = self.list_properties[col_name]
  576. return relation.direction.name == "ONETOONE" and relation.uselist
  577. return False
  578. except KeyError:
  579. return False
  580. def is_relation_one_to_one(self, col_name: str) -> bool:
  581. try:
  582. if self.is_relation(col_name):
  583. relation = self.list_properties[col_name]
  584. return self.list_properties[col_name].direction.name == "ONETOONE" or (
  585. relation.direction.name == "ONETOMANY" and relation.uselist is False
  586. )
  587. return False
  588. except KeyError:
  589. return False
  590. def is_relation_one_to_many(self, col_name: str) -> bool:
  591. try:
  592. if self.is_relation(col_name):
  593. relation = self.list_properties[col_name]
  594. return relation.direction.name == "ONETOMANY" and relation.uselist
  595. return False
  596. except KeyError:
  597. return False
  598. def is_nullable(self, col_name: str) -> bool:
  599. if self.is_relation_many_to_one(col_name):
  600. col = self.get_relation_fk(col_name)
  601. return col.nullable
  602. try:
  603. return self.list_columns[col_name].nullable
  604. except KeyError:
  605. return False
  606. def is_unique(self, col_name: str) -> bool:
  607. try:
  608. return self.list_columns[col_name].unique is True
  609. except KeyError:
  610. return False
  611. def is_pk(self, col_name: str) -> bool:
  612. try:
  613. return self.list_columns[col_name].primary_key
  614. except KeyError:
  615. return False
  616. def is_pk_composite(self) -> bool:
  617. return len(self.obj.__mapper__.primary_key) > 1
  618. def is_fk(self, col_name: str) -> bool:
  619. try:
  620. return self.list_columns[col_name].foreign_keys
  621. except KeyError:
  622. return False
  623. def is_property(self, col_name: str) -> bool:
  624. return hasattr(getattr(self.obj, col_name), "fget")
  625. def is_function(self, col_name: str) -> bool:
  626. return hasattr(getattr(self.obj, col_name), "__call__")
  627. def is_property_or_function(self, col_name: str) -> bool:
  628. return self.is_property(col_name) or self.is_function(col_name)
  629. def get_max_length(self, col_name: str) -> int:
  630. try:
  631. if self.is_enum(col_name):
  632. return -1
  633. col = self.list_columns[col_name]
  634. if col.type.length:
  635. return col.type.length
  636. else:
  637. return -1
  638. except Exception:
  639. return -1
  640. """
  641. -------------------------------
  642. FUNCTIONS FOR CRUD OPERATIONS
  643. -------------------------------
  644. """
  645. def add(self, item: Model, raise_exception: bool = False) -> bool:
  646. try:
  647. self.session.add(item)
  648. self.session.commit()
  649. self.message = (as_unicode(self.add_row_message), "success")
  650. return True
  651. except IntegrityError as e:
  652. self.message = (as_unicode(self.add_integrity_error_message), "warning")
  653. log.warning(LOGMSG_WAR_DBI_ADD_INTEGRITY, e)
  654. self.session.rollback()
  655. if raise_exception:
  656. raise e
  657. return False
  658. except Exception as e:
  659. self.message = (as_unicode(self.database_error_message), "danger")
  660. log.exception("Database error")
  661. self.session.rollback()
  662. if raise_exception:
  663. raise e
  664. return False
  665. def edit(self, item: Model, raise_exception: bool = False) -> bool:
  666. try:
  667. self.session.merge(item)
  668. self.session.commit()
  669. self.message = (as_unicode(self.edit_row_message), "success")
  670. return True
  671. except IntegrityError as e:
  672. self.message = (as_unicode(self.edit_integrity_error_message), "warning")
  673. log.warning(LOGMSG_WAR_DBI_EDIT_INTEGRITY, e)
  674. self.session.rollback()
  675. if raise_exception:
  676. raise e
  677. return False
  678. except Exception as e:
  679. self.message = (as_unicode(self.database_error_message), "danger")
  680. log.exception("Database error")
  681. self.session.rollback()
  682. if raise_exception:
  683. raise e
  684. return False
  685. def delete(self, item: Model, raise_exception: bool = False) -> bool:
  686. try:
  687. self._delete_files(item)
  688. self.session.delete(item)
  689. self.session.commit()
  690. self.message = (as_unicode(self.delete_row_message), "success")
  691. return True
  692. except IntegrityError as e:
  693. self.message = (as_unicode(self.delete_integrity_error_message), "warning")
  694. log.warning(LOGMSG_WAR_DBI_DEL_INTEGRITY, e)
  695. self.session.rollback()
  696. if raise_exception:
  697. raise e
  698. return False
  699. except Exception as e:
  700. self.message = (as_unicode(self.database_error_message), "danger")
  701. log.exception("Database error")
  702. self.session.rollback()
  703. if raise_exception:
  704. raise e
  705. return False
  706. def delete_all(self, items: List[Model]) -> bool:
  707. try:
  708. for item in items:
  709. self._delete_files(item)
  710. self.session.delete(item)
  711. self.session.commit()
  712. self.message = (as_unicode(self.delete_row_message), "success")
  713. return True
  714. except IntegrityError as e:
  715. self.message = (as_unicode(self.delete_integrity_error_message), "warning")
  716. log.warning(LOGMSG_WAR_DBI_DEL_INTEGRITY, e)
  717. self.session.rollback()
  718. return False
  719. except Exception as e:
  720. self.message = (as_unicode(self.database_error_message), "danger")
  721. log.exception(LOGMSG_ERR_DBI_DEL_GENERIC, e)
  722. self.session.rollback()
  723. return False
  724. """
  725. -----------------------
  726. FILE HANDLING METHODS
  727. -----------------------
  728. """
  729. def _add_files(self, this_request, item: Model):
  730. fm = FileManager()
  731. im = ImageManager()
  732. for file_col in this_request.files:
  733. if self.is_file(file_col):
  734. fm.save_file(this_request.files[file_col], getattr(item, file_col))
  735. for file_col in this_request.files:
  736. if self.is_image(file_col):
  737. im.save_file(this_request.files[file_col], getattr(item, file_col))
  738. def _delete_files(self, item: Model):
  739. for file_col in self.get_file_column_list():
  740. if self.is_file(file_col) and getattr(item, file_col):
  741. fm = FileManager()
  742. fm.delete_file(getattr(item, file_col))
  743. for file_col in self.get_image_column_list():
  744. if self.is_image(file_col) and getattr(item, file_col):
  745. im = ImageManager()
  746. im.delete_file(getattr(item, file_col))
  747. """
  748. ------------------------------
  749. FUNCTIONS FOR RELATED MODELS
  750. ------------------------------
  751. """
  752. def get_col_default(self, col_name: str) -> Any:
  753. default = getattr(self.list_columns[col_name], "default", None)
  754. if default is None:
  755. return None
  756. value = getattr(default, "arg", None)
  757. if value is None:
  758. return None
  759. if getattr(default, "is_callable", False):
  760. return lambda: default.arg(None)
  761. if not getattr(default, "is_scalar", True):
  762. return None
  763. return value
  764. def get_related_model(self, col_name: str) -> Type[Model]:
  765. return self.list_properties[col_name].mapper.class_
  766. def get_related_model_and_join(
  767. self, col_name: str
  768. ) -> List[Tuple[Type[Model], object]]:
  769. relation = self.list_properties[col_name]
  770. if relation.direction.name == "MANYTOMANY":
  771. return [
  772. (relation.secondary, relation.primaryjoin),
  773. (relation.mapper.class_, relation.secondaryjoin),
  774. ]
  775. return [(relation.mapper.class_, relation.primaryjoin)]
  776. def get_related_interface(self, col_name: str):
  777. return self.__class__(self.get_related_model(col_name), self.session)
  778. def get_related_obj(self, col_name: str, value: Any) -> Optional[Type[Model]]:
  779. rel_model = self.get_related_model(col_name)
  780. if self.session:
  781. return self.session.query(rel_model).get(value)
  782. return None
  783. def get_related_fks(self, related_views) -> List[str]:
  784. return [view.datamodel.get_related_fk(self.obj) for view in related_views]
  785. def get_related_fk(self, model: Type[Model]) -> Optional[str]:
  786. for col_name in self.list_properties.keys():
  787. if self.is_relation(col_name):
  788. if model == self.get_related_model(col_name):
  789. return col_name
  790. return None
  791. def get_info(self, col_name: str):
  792. if col_name in self.list_properties:
  793. return self.list_properties[col_name].info
  794. return {}
  795. """
  796. -------------
  797. GET METHODS
  798. -------------
  799. """
  800. def get_columns_list(self) -> List[str]:
  801. """
  802. Returns all model's columns on SQLA properties
  803. """
  804. return list(self.list_properties.keys())
  805. def get_user_columns_list(self) -> List[str]:
  806. """
  807. Returns all model's columns except pk or fk
  808. """
  809. return [
  810. col_name
  811. for col_name in self.get_columns_list()
  812. if (not self.is_pk(col_name)) and (not self.is_fk(col_name))
  813. ]
  814. # TODO get different solution, more integrated with filters
  815. def get_search_columns_list(self) -> List[str]:
  816. ret_lst = []
  817. for col_name in self.get_columns_list():
  818. if not self.is_relation(col_name):
  819. tmp_prop = self.get_property_first_col(col_name).name
  820. if (
  821. (not self.is_pk(tmp_prop))
  822. and (not self.is_fk(tmp_prop))
  823. and (not self.is_image(col_name))
  824. and (not self.is_file(col_name))
  825. ):
  826. ret_lst.append(col_name)
  827. else:
  828. ret_lst.append(col_name)
  829. return ret_lst
  830. def get_order_columns_list(self, list_columns: List[str] = None) -> List[str]:
  831. """
  832. Returns the columns that can be ordered.
  833. :param list_columns: optional list of columns name, if provided will
  834. use this list only.
  835. """
  836. ret_lst = []
  837. list_columns = list_columns or self.get_columns_list()
  838. for col_name in list_columns:
  839. if self.is_relation(col_name):
  840. continue
  841. if hasattr(self.obj, col_name):
  842. attribute = getattr(self.obj, col_name)
  843. if not callable(attribute) or hasattr(attribute, "_col_name"):
  844. ret_lst.append(col_name)
  845. else:
  846. ret_lst.append(col_name)
  847. return ret_lst
  848. def get_file_column_list(self) -> List[str]:
  849. return [
  850. i.name
  851. for i in self.obj.__mapper__.columns
  852. if isinstance(i.type, FileColumn)
  853. ]
  854. def get_image_column_list(self) -> List[str]:
  855. return [
  856. i.name
  857. for i in self.obj.__mapper__.columns
  858. if isinstance(i.type, ImageColumn)
  859. ]
  860. def get_property_first_col(self, col_name: str) -> ColumnProperty:
  861. # support for only one col for pk and fk
  862. return self.list_properties[col_name].columns[0]
  863. def get_relation_fk(self, col_name: str) -> str:
  864. # support for only one col for pk and fk
  865. return list(self.list_properties[col_name].local_columns)[0]
  866. def get(
  867. self,
  868. id,
  869. filters: Optional[Filters] = None,
  870. select_columns: Optional[List[str]] = None,
  871. outer_default_load: bool = False,
  872. ) -> Optional[Model]:
  873. """
  874. Returns the result for a model get, applies filters and supports dotted
  875. notation for joins and granular selecting query columns.
  876. :param id: The model id (pk).
  877. :param filters: A Filter class that contains all filters to apply.
  878. :param select_columns: A List of columns to be specifically selected.
  879. on the query. Supports dotted notation.
  880. :return:
  881. """
  882. pk = self.get_pk_name()
  883. if filters:
  884. _filters = filters.copy()
  885. else:
  886. _filters = Filters(self.filter_converter_class, self)
  887. if self.is_pk_composite():
  888. for _pk, _id in zip(pk, id):
  889. _filters.add_filter(_pk, self.FilterEqual, _id)
  890. else:
  891. _filters.add_filter(pk, self.FilterEqual, id)
  892. query = self.session.query(self.obj)
  893. item = self.apply_all(
  894. query,
  895. _filters,
  896. select_columns=select_columns,
  897. outer_default_load=outer_default_load,
  898. ).one_or_none()
  899. if item:
  900. if hasattr(item, self.obj.__name__):
  901. return getattr(item, self.obj.__name__)
  902. return item
  903. def get_pk_name(self) -> Optional[Union[List[str], str]]:
  904. """
  905. Get the model primary key column name.
  906. """
  907. return self._get_pk_name(self.obj)
  908. def get_pk(self, model: Optional[Type[Model]] = None):
  909. """
  910. Get the model primary key SQLAlchemy column.
  911. Will not support composite keys
  912. """
  913. model_ = model or self.obj
  914. pk_name = self._get_pk_name(model_)
  915. if pk_name and isinstance(pk_name, str):
  916. return getattr(model_, pk_name)
  917. return None
  918. def _get_pk_name(self, model: Type[Model]) -> Optional[Union[List[str], str]]:
  919. pk = [pk.name for pk in model.__mapper__.primary_key]
  920. if pk:
  921. return pk if self.is_pk_composite() else pk[0]
  922. return None
  923. def _include_filters(interface: SQLAInterface) -> None:
  924. """
  925. Injects all filters on the interface class itself
  926. :param interface:
  927. """
  928. for key in filters.__all__:
  929. if not hasattr(interface, key):
  930. setattr(interface, key, getattr(filters, key))
  931. """
  932. For Retro-Compatibility
  933. """
  934. SQLModel = SQLAInterface