timeout.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import os
  20. import signal
  21. from threading import Timer
  22. from typing import ContextManager
  23. from airflow.exceptions import AirflowTaskTimeout
  24. from airflow.utils.log.logging_mixin import LoggingMixin
  25. from airflow.utils.platform import IS_WINDOWS
  26. _timeout = ContextManager[None]
  27. class TimeoutWindows(_timeout, LoggingMixin):
  28. """Windows timeout version: To be used in a ``with`` block and timeout its content."""
  29. def __init__(self, seconds=1, error_message="Timeout"):
  30. super().__init__()
  31. self._timer: Timer | None = None
  32. self.seconds = seconds
  33. self.error_message = error_message + ", PID: " + str(os.getpid())
  34. def handle_timeout(self, *args):
  35. """Log information and raises AirflowTaskTimeout."""
  36. self.log.error("Process timed out, PID: %s", str(os.getpid()))
  37. raise AirflowTaskTimeout(self.error_message)
  38. def __enter__(self):
  39. if self._timer:
  40. self._timer.cancel()
  41. self._timer = Timer(self.seconds, self.handle_timeout)
  42. self._timer.start()
  43. def __exit__(self, type_, value, traceback):
  44. if self._timer:
  45. self._timer.cancel()
  46. self._timer = None
  47. class TimeoutPosix(_timeout, LoggingMixin):
  48. """POSIX Timeout version: To be used in a ``with`` block and timeout its content."""
  49. def __init__(self, seconds=1, error_message="Timeout"):
  50. super().__init__()
  51. self.seconds = seconds
  52. self.error_message = error_message + ", PID: " + str(os.getpid())
  53. def handle_timeout(self, signum, frame):
  54. """Log information and raises AirflowTaskTimeout."""
  55. self.log.error("Process timed out, PID: %s", str(os.getpid()))
  56. raise AirflowTaskTimeout(self.error_message)
  57. def __enter__(self):
  58. try:
  59. signal.signal(signal.SIGALRM, self.handle_timeout)
  60. signal.setitimer(signal.ITIMER_REAL, self.seconds)
  61. except ValueError:
  62. self.log.warning("timeout can't be used in the current context", exc_info=True)
  63. def __exit__(self, type_, value, traceback):
  64. try:
  65. signal.setitimer(signal.ITIMER_REAL, 0)
  66. except ValueError:
  67. self.log.warning("timeout can't be used in the current context", exc_info=True)
  68. if IS_WINDOWS:
  69. timeout: type[TimeoutWindows | TimeoutPosix] = TimeoutWindows
  70. else:
  71. timeout = TimeoutPosix