aiohttp_api.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. """
  2. This module defines an AioHttp Connexion API which implements translations between AioHttp and
  3. Connexion requests / responses.
  4. """
  5. import asyncio
  6. import logging
  7. import re
  8. import traceback
  9. from contextlib import suppress
  10. from http import HTTPStatus
  11. from urllib.parse import parse_qs
  12. import aiohttp_jinja2
  13. import jinja2
  14. from aiohttp import web
  15. from aiohttp.web_exceptions import HTTPNotFound, HTTPPermanentRedirect
  16. from aiohttp.web_middlewares import normalize_path_middleware
  17. from werkzeug.exceptions import HTTPException as werkzeug_HTTPException
  18. from connexion.apis.abstract import AbstractAPI
  19. from connexion.exceptions import ProblemException
  20. from connexion.handlers import AuthErrorHandler
  21. from connexion.jsonifier import JSONEncoder, Jsonifier
  22. from connexion.lifecycle import ConnexionRequest, ConnexionResponse
  23. from connexion.problem import problem
  24. from connexion.security import AioHttpSecurityHandlerFactory
  25. from connexion.utils import yamldumper
  26. logger = logging.getLogger('connexion.apis.aiohttp_api')
  27. def _generic_problem(http_status: HTTPStatus, exc: Exception = None):
  28. extra = None
  29. if exc is not None:
  30. loop = asyncio.get_event_loop()
  31. if loop.get_debug():
  32. tb = None
  33. with suppress(Exception):
  34. tb = traceback.format_exc()
  35. if tb:
  36. extra = {"traceback": tb}
  37. return problem(
  38. status=http_status.value,
  39. title=http_status.phrase,
  40. detail=http_status.description,
  41. ext=extra,
  42. )
  43. @web.middleware
  44. async def problems_middleware(request, handler):
  45. try:
  46. response = await handler(request)
  47. except ProblemException as exc:
  48. response = problem(status=exc.status, detail=exc.detail, title=exc.title,
  49. type=exc.type, instance=exc.instance, headers=exc.headers, ext=exc.ext)
  50. except (werkzeug_HTTPException, _HttpNotFoundError) as exc:
  51. response = problem(status=exc.code, title=exc.name, detail=exc.description)
  52. except web.HTTPError as exc:
  53. if exc.text == f"{exc.status}: {exc.reason}":
  54. detail = HTTPStatus(exc.status).description
  55. else:
  56. detail = exc.text
  57. response = problem(status=exc.status, title=exc.reason, detail=detail)
  58. except (
  59. web.HTTPException, # eg raised HTTPRedirection or HTTPSuccessful
  60. asyncio.CancelledError, # skipped in default web_protocol
  61. ):
  62. # leave this to default handling in aiohttp.web_protocol.RequestHandler.start()
  63. raise
  64. except asyncio.TimeoutError as exc:
  65. # overrides 504 from aiohttp.web_protocol.RequestHandler.start()
  66. logger.debug('Request handler timed out.', exc_info=exc)
  67. response = _generic_problem(HTTPStatus.GATEWAY_TIMEOUT, exc)
  68. except Exception as exc:
  69. # overrides 500 from aiohttp.web_protocol.RequestHandler.start()
  70. logger.exception('Error handling request', exc_info=exc)
  71. response = _generic_problem(HTTPStatus.INTERNAL_SERVER_ERROR, exc)
  72. if isinstance(response, ConnexionResponse):
  73. response = await AioHttpApi.get_response(response)
  74. return response
  75. class AioHttpApi(AbstractAPI):
  76. def __init__(self, *args, **kwargs):
  77. # NOTE we use HTTPPermanentRedirect (308) because
  78. # clients sometimes turn POST requests into GET requests
  79. # on 301, 302, or 303
  80. # see https://tools.ietf.org/html/rfc7538
  81. trailing_slash_redirect = normalize_path_middleware(
  82. append_slash=True,
  83. redirect_class=HTTPPermanentRedirect
  84. )
  85. self.subapp = web.Application(
  86. middlewares=[
  87. problems_middleware,
  88. trailing_slash_redirect
  89. ]
  90. )
  91. AbstractAPI.__init__(self, *args, **kwargs)
  92. aiohttp_jinja2.setup(
  93. self.subapp,
  94. loader=jinja2.FileSystemLoader(
  95. str(self.options.openapi_console_ui_from_dir)
  96. )
  97. )
  98. middlewares = self.options.as_dict().get('middlewares', [])
  99. self.subapp.middlewares.extend(middlewares)
  100. @staticmethod
  101. def make_security_handler_factory(pass_context_arg_name):
  102. """ Create default SecurityHandlerFactory to create all security check handlers """
  103. return AioHttpSecurityHandlerFactory(pass_context_arg_name)
  104. def _set_base_path(self, base_path):
  105. AbstractAPI._set_base_path(self, base_path)
  106. self._api_name = AioHttpApi.normalize_string(self.base_path)
  107. @staticmethod
  108. def normalize_string(string):
  109. return re.sub(r'[^a-zA-Z0-9]', '_', string.strip('/'))
  110. def _base_path_for_prefix(self, request):
  111. """
  112. returns a modified basePath which includes the incoming request's
  113. path prefix.
  114. """
  115. base_path = self.base_path
  116. if not request.path.startswith(self.base_path):
  117. prefix = request.path.split(self.base_path)[0]
  118. base_path = prefix + base_path
  119. return base_path
  120. def _spec_for_prefix(self, request):
  121. """
  122. returns a spec with a modified basePath / servers block
  123. which corresponds to the incoming request path.
  124. This is needed when behind a path-altering reverse proxy.
  125. """
  126. base_path = self._base_path_for_prefix(request)
  127. return self.specification.with_base_path(base_path).raw
  128. def add_openapi_json(self):
  129. """
  130. Adds openapi json to {base_path}/openapi.json
  131. (or {base_path}/swagger.json for swagger2)
  132. """
  133. logger.debug('Adding spec json: %s/%s', self.base_path,
  134. self.options.openapi_spec_path)
  135. self.subapp.router.add_route(
  136. 'GET',
  137. self.options.openapi_spec_path,
  138. self._get_openapi_json
  139. )
  140. def add_openapi_yaml(self):
  141. """
  142. Adds openapi json to {base_path}/openapi.json
  143. (or {base_path}/swagger.json for swagger2)
  144. """
  145. if not self.options.openapi_spec_path.endswith("json"):
  146. return
  147. openapi_spec_path_yaml = \
  148. self.options.openapi_spec_path[:-len("json")] + "yaml"
  149. logger.debug('Adding spec yaml: %s/%s', self.base_path,
  150. openapi_spec_path_yaml)
  151. self.subapp.router.add_route(
  152. 'GET',
  153. openapi_spec_path_yaml,
  154. self._get_openapi_yaml
  155. )
  156. async def _get_openapi_json(self, request):
  157. return web.Response(
  158. status=200,
  159. content_type='application/json',
  160. body=self.jsonifier.dumps(self._spec_for_prefix(request))
  161. )
  162. async def _get_openapi_yaml(self, request):
  163. return web.Response(
  164. status=200,
  165. content_type='text/yaml',
  166. body=yamldumper(self._spec_for_prefix(request))
  167. )
  168. def add_swagger_ui(self):
  169. """
  170. Adds swagger ui to {base_path}/ui/
  171. """
  172. console_ui_path = self.options.openapi_console_ui_path.strip().rstrip('/')
  173. logger.debug('Adding swagger-ui: %s%s/',
  174. self.base_path,
  175. console_ui_path)
  176. for path in (
  177. console_ui_path + '/',
  178. console_ui_path + '/index.html',
  179. ):
  180. self.subapp.router.add_route(
  181. 'GET',
  182. path,
  183. self._get_swagger_ui_home
  184. )
  185. if self.options.openapi_console_ui_config is not None:
  186. self.subapp.router.add_route(
  187. 'GET',
  188. console_ui_path + '/swagger-ui-config.json',
  189. self._get_swagger_ui_config
  190. )
  191. # we have to add an explicit redirect instead of relying on the
  192. # normalize_path_middleware because we also serve static files
  193. # from this dir (below)
  194. async def redirect(request):
  195. raise web.HTTPMovedPermanently(
  196. location=self.base_path + console_ui_path + '/'
  197. )
  198. self.subapp.router.add_route(
  199. 'GET',
  200. console_ui_path,
  201. redirect
  202. )
  203. # this route will match and get a permission error when trying to
  204. # serve index.html, so we add the redirect above.
  205. self.subapp.router.add_static(
  206. console_ui_path,
  207. path=str(self.options.openapi_console_ui_from_dir),
  208. name='swagger_ui_static'
  209. )
  210. @aiohttp_jinja2.template('index.j2')
  211. async def _get_swagger_ui_home(self, req):
  212. base_path = self._base_path_for_prefix(req)
  213. template_variables = {
  214. 'openapi_spec_url': (base_path + self.options.openapi_spec_path),
  215. **self.options.openapi_console_ui_index_template_variables,
  216. }
  217. if self.options.openapi_console_ui_config is not None:
  218. template_variables['configUrl'] = 'swagger-ui-config.json'
  219. return template_variables
  220. async def _get_swagger_ui_config(self, req):
  221. return web.Response(
  222. status=200,
  223. content_type='text/json',
  224. body=self.jsonifier.dumps(self.options.openapi_console_ui_config)
  225. )
  226. def add_auth_on_not_found(self, security, security_definitions):
  227. """
  228. Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
  229. """
  230. logger.debug('Adding path not found authentication')
  231. not_found_error = AuthErrorHandler(
  232. self, _HttpNotFoundError(),
  233. security=security,
  234. security_definitions=security_definitions
  235. )
  236. endpoint_name = f"{self._api_name}_not_found"
  237. self.subapp.router.add_route(
  238. '*',
  239. '/{not_found_path}',
  240. not_found_error.function,
  241. name=endpoint_name
  242. )
  243. def _add_operation_internal(self, method, path, operation):
  244. method = method.upper()
  245. operation_id = operation.operation_id or path
  246. logger.debug('... Adding %s -> %s', method, operation_id,
  247. extra=vars(operation))
  248. handler = operation.function
  249. endpoint_name = '{}_{}_{}'.format(
  250. self._api_name,
  251. AioHttpApi.normalize_string(path),
  252. method.lower()
  253. )
  254. self.subapp.router.add_route(
  255. method, path, handler, name=endpoint_name
  256. )
  257. if not path.endswith('/'):
  258. self.subapp.router.add_route(
  259. method, path + '/', handler, name=endpoint_name + '_'
  260. )
  261. @classmethod
  262. async def get_request(cls, req):
  263. """Convert aiohttp request to connexion
  264. :param req: instance of aiohttp.web.Request
  265. :return: connexion request instance
  266. :rtype: ConnexionRequest
  267. """
  268. url = str(req.url)
  269. logger.debug(
  270. 'Getting data and status code',
  271. extra={
  272. # has_body | can_read_body report if
  273. # body has been read or not
  274. # body_exists refers to underlying stream of data
  275. 'body_exists': req.body_exists,
  276. 'can_read_body': req.can_read_body,
  277. 'content_type': req.content_type,
  278. 'url': url,
  279. },
  280. )
  281. query = parse_qs(req.rel_url.query_string)
  282. headers = req.headers
  283. body = None
  284. # Note: if request is not 'application/x-www-form-urlencoded' nor 'multipart/form-data',
  285. # then `post_data` will be left an empty dict and the stream will not be consumed.
  286. post_data = await req.post()
  287. files = {}
  288. form = {}
  289. if post_data:
  290. logger.debug('Reading multipart data from request')
  291. for k, v in post_data.items():
  292. if isinstance(v, web.FileField):
  293. if k in files:
  294. # if multiple files arrive under the same name in the
  295. # request, downstream requires that we put them all into
  296. # a list under the same key in the files dict.
  297. if isinstance(files[k], list):
  298. files[k].append(v)
  299. else:
  300. files[k] = [files[k], v]
  301. else:
  302. files[k] = v
  303. else:
  304. # put normal fields as an array, that's how werkzeug does that for Flask
  305. # and that's what Connexion expects in its processing functions
  306. form[k] = [v]
  307. body = b''
  308. else:
  309. logger.debug('Reading data from request')
  310. body = await req.read()
  311. return ConnexionRequest(url=url,
  312. method=req.method.lower(),
  313. path_params=dict(req.match_info),
  314. query=query,
  315. headers=headers,
  316. body=body,
  317. json_getter=lambda: cls.jsonifier.loads(body),
  318. form=form,
  319. files=files,
  320. context=req,
  321. cookies=req.cookies)
  322. @classmethod
  323. async def get_response(cls, response, mimetype=None, request=None):
  324. """Get response.
  325. This method is used in the lifecycle decorators
  326. :type response: aiohttp.web.StreamResponse | (Any,) | (Any, int) | (Any, dict) | (Any, int, dict)
  327. :rtype: aiohttp.web.Response
  328. """
  329. while asyncio.iscoroutine(response):
  330. response = await response
  331. url = str(request.url) if request else ''
  332. return cls._get_response(response, mimetype=mimetype, extra_context={"url": url})
  333. @classmethod
  334. def _is_framework_response(cls, response):
  335. """ Return True if `response` is a framework response class """
  336. return isinstance(response, web.StreamResponse)
  337. @classmethod
  338. def _framework_to_connexion_response(cls, response, mimetype):
  339. """ Cast framework response class to ConnexionResponse used for schema validation """
  340. body = None
  341. if hasattr(response, "body"): # StreamResponse and FileResponse don't have body
  342. body = response.body
  343. return ConnexionResponse(
  344. status_code=response.status,
  345. mimetype=mimetype,
  346. content_type=response.content_type,
  347. headers=response.headers,
  348. body=body
  349. )
  350. @classmethod
  351. def _connexion_to_framework_response(cls, response, mimetype, extra_context=None):
  352. """ Cast ConnexionResponse to framework response class """
  353. return cls._build_response(
  354. mimetype=response.mimetype or mimetype,
  355. status_code=response.status_code,
  356. content_type=response.content_type,
  357. headers=response.headers,
  358. data=response.body,
  359. extra_context=extra_context,
  360. )
  361. @classmethod
  362. def _build_response(cls, data, mimetype, content_type=None, headers=None, status_code=None, extra_context=None):
  363. if cls._is_framework_response(data):
  364. raise TypeError("Cannot return web.StreamResponse in tuple. Only raw data can be returned in tuple.")
  365. data, status_code, serialized_mimetype = cls._prepare_body_and_status_code(data=data, mimetype=mimetype, status_code=status_code, extra_context=extra_context)
  366. if isinstance(data, str):
  367. text = data
  368. body = None
  369. else:
  370. text = None
  371. body = data
  372. content_type = content_type or mimetype or serialized_mimetype
  373. return web.Response(body=body, text=text, headers=headers, status=status_code, content_type=content_type)
  374. @classmethod
  375. def _set_jsonifier(cls):
  376. cls.jsonifier = Jsonifier(cls=JSONEncoder)
  377. class _HttpNotFoundError(HTTPNotFound):
  378. def __init__(self):
  379. self.name = 'Not Found'
  380. self.description = (
  381. 'The requested URL was not found on the server. '
  382. 'If you entered the URL manually please check your spelling and '
  383. 'try again.'
  384. )
  385. self.code = type(self).status_code
  386. self.empty_body = True
  387. HTTPNotFound.__init__(self, reason=self.name)