__init__.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. from __future__ import annotations
  2. import datetime as dt
  3. import functools
  4. import inspect
  5. import os
  6. import sys
  7. import time as time_module
  8. import uuid
  9. from collections.abc import Awaitable
  10. from collections.abc import Generator
  11. from collections.abc import Generator as TypingGenerator
  12. from time import gmtime as orig_gmtime
  13. from time import struct_time
  14. from types import TracebackType
  15. from typing import Any
  16. from typing import Callable
  17. from typing import TypeVar
  18. from typing import Union
  19. from typing import cast
  20. from typing import overload
  21. from unittest import TestCase
  22. from unittest import mock
  23. from zoneinfo import ZoneInfo
  24. import _time_machine
  25. from dateutil.parser import parse as parse_datetime
  26. # time.clock_gettime and time.CLOCK_REALTIME not always available
  27. # e.g. on builds against old macOS = official Python.org installer
  28. try:
  29. from time import CLOCK_REALTIME
  30. except ImportError:
  31. # Dummy value that won't compare equal to any value
  32. CLOCK_REALTIME = sys.maxsize
  33. try:
  34. from time import tzset
  35. HAVE_TZSET = True
  36. except ImportError: # pragma: no cover
  37. # Windows
  38. HAVE_TZSET = False
  39. try:
  40. import pytest
  41. except ImportError: # pragma: no cover
  42. HAVE_PYTEST = False
  43. else:
  44. HAVE_PYTEST = True
  45. NANOSECONDS_PER_SECOND = 1_000_000_000
  46. # Windows' time epoch is not unix epoch but in 1601. This constant helps us
  47. # translate to it.
  48. _system_epoch = orig_gmtime(0)
  49. SYSTEM_EPOCH_TIMESTAMP_NS = int(
  50. dt.datetime(
  51. _system_epoch.tm_year,
  52. _system_epoch.tm_mon,
  53. _system_epoch.tm_mday,
  54. _system_epoch.tm_hour,
  55. _system_epoch.tm_min,
  56. _system_epoch.tm_sec,
  57. tzinfo=dt.timezone.utc,
  58. ).timestamp()
  59. * NANOSECONDS_PER_SECOND
  60. )
  61. DestinationBaseType = Union[
  62. int,
  63. float,
  64. dt.datetime,
  65. dt.timedelta,
  66. dt.date,
  67. str,
  68. ]
  69. DestinationType = Union[
  70. DestinationBaseType,
  71. Callable[[], DestinationBaseType],
  72. TypingGenerator[DestinationBaseType, None, None],
  73. ]
  74. _F = TypeVar("_F", bound=Callable[..., Any])
  75. _AF = TypeVar("_AF", bound=Callable[..., Awaitable[Any]])
  76. TestCaseType = TypeVar("TestCaseType", bound=type[TestCase])
  77. # copied from typeshed:
  78. _TimeTuple = tuple[int, int, int, int, int, int, int, int, int]
  79. def extract_timestamp_tzname(
  80. destination: DestinationType,
  81. ) -> tuple[float, str | None]:
  82. dest: DestinationBaseType
  83. if isinstance(destination, Generator):
  84. dest = next(destination)
  85. elif callable(destination):
  86. dest = destination()
  87. else:
  88. dest = destination
  89. timestamp: float
  90. tzname: str | None = None
  91. if isinstance(dest, int):
  92. timestamp = float(dest)
  93. elif isinstance(dest, float):
  94. timestamp = dest
  95. elif isinstance(dest, dt.datetime):
  96. if isinstance(dest.tzinfo, ZoneInfo):
  97. tzname = dest.tzinfo.key
  98. if dest.tzinfo is None:
  99. dest = dest.replace(tzinfo=dt.timezone.utc)
  100. timestamp = dest.timestamp()
  101. elif isinstance(dest, dt.timedelta):
  102. timestamp = time_module.time() + dest.total_seconds()
  103. elif isinstance(dest, dt.date):
  104. timestamp = dt.datetime.combine(
  105. dest, dt.time(0, 0), tzinfo=dt.timezone.utc
  106. ).timestamp()
  107. elif isinstance(dest, str):
  108. timestamp = parse_datetime(dest).timestamp()
  109. else:
  110. raise TypeError(f"Unsupported destination {dest!r}")
  111. return timestamp, tzname
  112. class Coordinates:
  113. def __init__(
  114. self,
  115. destination_timestamp: float,
  116. destination_tzname: str | None,
  117. tick: bool,
  118. ) -> None:
  119. self._destination_timestamp_ns = int(
  120. destination_timestamp * NANOSECONDS_PER_SECOND
  121. )
  122. self._destination_tzname = destination_tzname
  123. self._tick = tick
  124. self._requested = False
  125. def time(self) -> float:
  126. return self.time_ns() / NANOSECONDS_PER_SECOND
  127. def time_ns(self) -> int:
  128. if not self._tick:
  129. return self._destination_timestamp_ns
  130. base = SYSTEM_EPOCH_TIMESTAMP_NS + self._destination_timestamp_ns
  131. now_ns: int = _time_machine.original_time_ns()
  132. if not self._requested:
  133. self._requested = True
  134. self._real_start_timestamp_ns = now_ns
  135. return base
  136. return base + (now_ns - self._real_start_timestamp_ns)
  137. def shift(self, delta: dt.timedelta | int | float) -> None:
  138. if isinstance(delta, dt.timedelta):
  139. total_seconds = delta.total_seconds()
  140. elif isinstance(delta, (int, float)):
  141. total_seconds = delta
  142. else:
  143. raise TypeError(f"Unsupported type for delta argument: {delta!r}")
  144. self._destination_timestamp_ns += int(total_seconds * NANOSECONDS_PER_SECOND)
  145. def move_to(
  146. self,
  147. destination: DestinationType,
  148. tick: bool | None = None,
  149. ) -> None:
  150. self._stop()
  151. timestamp, self._destination_tzname = extract_timestamp_tzname(destination)
  152. self._destination_timestamp_ns = int(timestamp * NANOSECONDS_PER_SECOND)
  153. self._requested = False
  154. self._start()
  155. if tick is not None:
  156. self._tick = tick
  157. def _start(self) -> None:
  158. if HAVE_TZSET and self._destination_tzname is not None:
  159. self._orig_tz = os.environ.get("TZ")
  160. os.environ["TZ"] = self._destination_tzname
  161. tzset()
  162. def _stop(self) -> None:
  163. if HAVE_TZSET and self._destination_tzname is not None:
  164. if self._orig_tz is None:
  165. del os.environ["TZ"]
  166. else:
  167. os.environ["TZ"] = self._orig_tz
  168. tzset()
  169. coordinates_stack: list[Coordinates] = []
  170. # During time travel, patch the uuid module's time-based generation function to
  171. # None, which makes it use time.time(). Otherwise it makes a system call to
  172. # find the current datetime. The time it finds is stored in generated UUID1
  173. # values.
  174. uuid_generate_time_attr = "_generate_time_safe"
  175. uuid_generate_time_patcher = mock.patch.object(uuid, uuid_generate_time_attr, new=None)
  176. uuid_uuid_create_patcher = mock.patch.object(uuid, "_UuidCreate", new=None)
  177. class travel:
  178. def __init__(self, destination: DestinationType, *, tick: bool = True) -> None:
  179. self.destination_timestamp, self.destination_tzname = extract_timestamp_tzname(
  180. destination
  181. )
  182. self.tick = tick
  183. def start(self) -> Coordinates:
  184. global coordinates_stack
  185. _time_machine.patch_if_needed()
  186. if not coordinates_stack:
  187. uuid_generate_time_patcher.start()
  188. uuid_uuid_create_patcher.start()
  189. coordinates = Coordinates(
  190. destination_timestamp=self.destination_timestamp,
  191. destination_tzname=self.destination_tzname,
  192. tick=self.tick,
  193. )
  194. coordinates_stack.append(coordinates)
  195. coordinates._start()
  196. return coordinates
  197. def stop(self) -> None:
  198. global coordinates_stack
  199. coordinates_stack.pop()._stop()
  200. if not coordinates_stack:
  201. uuid_generate_time_patcher.stop()
  202. uuid_uuid_create_patcher.stop()
  203. def __enter__(self) -> Coordinates:
  204. return self.start()
  205. def __exit__(
  206. self,
  207. exc_type: type[BaseException] | None,
  208. exc_val: BaseException | None,
  209. exc_tb: TracebackType | None,
  210. ) -> None:
  211. self.stop()
  212. @overload
  213. def __call__(self, wrapped: TestCaseType) -> TestCaseType: # pragma: no cover
  214. ...
  215. @overload
  216. def __call__(self, wrapped: _AF) -> _AF: # pragma: no cover
  217. ...
  218. @overload
  219. def __call__(self, wrapped: _F) -> _F: # pragma: no cover
  220. ...
  221. # 'Any' below is workaround for Mypy error:
  222. # Overloaded function implementation does not accept all possible arguments
  223. # of signature
  224. def __call__(
  225. self, wrapped: TestCaseType | _AF | _F | Any
  226. ) -> TestCaseType | _AF | _F | Any:
  227. if isinstance(wrapped, type):
  228. # Class decorator
  229. if not issubclass(wrapped, TestCase):
  230. raise TypeError("Can only decorate unittest.TestCase subclasses.")
  231. # Modify the setUpClass method
  232. orig_setUpClass = wrapped.setUpClass.__func__ # type: ignore[attr-defined]
  233. @functools.wraps(orig_setUpClass)
  234. def setUpClass(cls: type[TestCase]) -> None:
  235. self.__enter__()
  236. try:
  237. orig_setUpClass(cls)
  238. except Exception:
  239. self.__exit__(*sys.exc_info())
  240. raise
  241. wrapped.setUpClass = classmethod(setUpClass) # type: ignore[assignment]
  242. orig_tearDownClass = (
  243. wrapped.tearDownClass.__func__ # type: ignore[attr-defined]
  244. )
  245. @functools.wraps(orig_tearDownClass)
  246. def tearDownClass(cls: type[TestCase]) -> None:
  247. orig_tearDownClass(cls)
  248. self.__exit__(None, None, None)
  249. wrapped.tearDownClass = classmethod( # type: ignore[assignment]
  250. tearDownClass
  251. )
  252. return cast(TestCaseType, wrapped)
  253. elif inspect.iscoroutinefunction(wrapped):
  254. @functools.wraps(wrapped)
  255. async def wrapper(*args: Any, **kwargs: Any) -> Any:
  256. with self:
  257. return await wrapped(*args, **kwargs)
  258. return cast(_AF, wrapper)
  259. else:
  260. assert callable(wrapped)
  261. @functools.wraps(wrapped)
  262. def wrapper(*args: Any, **kwargs: Any) -> Any:
  263. with self:
  264. return wrapped(*args, **kwargs)
  265. return cast(_F, wrapper)
  266. # datetime module
  267. def now(tz: dt.tzinfo | None = None) -> dt.datetime:
  268. if not coordinates_stack:
  269. result: dt.datetime = _time_machine.original_now(tz)
  270. return result
  271. return dt.datetime.fromtimestamp(time(), tz)
  272. def utcnow() -> dt.datetime:
  273. if not coordinates_stack:
  274. result: dt.datetime = _time_machine.original_utcnow()
  275. return result
  276. return dt.datetime.fromtimestamp(time(), dt.timezone.utc).replace(tzinfo=None)
  277. # time module
  278. def clock_gettime(clk_id: int) -> float:
  279. if not coordinates_stack or clk_id != CLOCK_REALTIME:
  280. result: float = _time_machine.original_clock_gettime(clk_id)
  281. return result
  282. return time()
  283. def clock_gettime_ns(clk_id: int) -> int:
  284. if not coordinates_stack or clk_id != CLOCK_REALTIME:
  285. result: int = _time_machine.original_clock_gettime_ns(clk_id)
  286. return result
  287. return time_ns()
  288. def gmtime(secs: float | None = None) -> struct_time:
  289. result: struct_time
  290. if not coordinates_stack or secs is not None:
  291. result = _time_machine.original_gmtime(secs)
  292. else:
  293. result = _time_machine.original_gmtime(coordinates_stack[-1].time())
  294. return result
  295. def localtime(secs: float | None = None) -> struct_time:
  296. result: struct_time
  297. if not coordinates_stack or secs is not None:
  298. result = _time_machine.original_localtime(secs)
  299. else:
  300. result = _time_machine.original_localtime(coordinates_stack[-1].time())
  301. return result
  302. def strftime(format: str, t: _TimeTuple | struct_time | None = None) -> str:
  303. result: str
  304. if t is not None:
  305. result = _time_machine.original_strftime(format, t)
  306. elif not coordinates_stack:
  307. result = _time_machine.original_strftime(format)
  308. else:
  309. result = _time_machine.original_strftime(format, localtime())
  310. return result
  311. def time() -> float:
  312. if not coordinates_stack:
  313. result: float = _time_machine.original_time()
  314. return result
  315. return coordinates_stack[-1].time()
  316. def time_ns() -> int:
  317. if not coordinates_stack:
  318. result: int = _time_machine.original_time_ns()
  319. return result
  320. return coordinates_stack[-1].time_ns()
  321. # pytest plugin
  322. if HAVE_PYTEST: # pragma: no branch
  323. class TimeMachineFixture:
  324. traveller: travel | None
  325. coordinates: Coordinates | None
  326. def __init__(self) -> None:
  327. self.traveller = None
  328. self.coordinates = None
  329. def move_to(
  330. self,
  331. destination: DestinationType,
  332. tick: bool | None = None,
  333. ) -> None:
  334. if self.traveller is None:
  335. if tick is None:
  336. tick = True
  337. self.traveller = travel(destination, tick=tick)
  338. self.coordinates = self.traveller.start()
  339. else:
  340. assert self.coordinates is not None
  341. self.coordinates.move_to(destination, tick=tick)
  342. def shift(self, delta: dt.timedelta | int | float) -> None:
  343. if self.traveller is None:
  344. raise RuntimeError(
  345. "Initialize time_machine with move_to() before using shift()."
  346. )
  347. assert self.coordinates is not None
  348. self.coordinates.shift(delta=delta)
  349. def stop(self) -> None:
  350. if self.traveller is not None:
  351. self.traveller.stop()
  352. @pytest.fixture(name="time_machine")
  353. def time_machine_fixture() -> TypingGenerator[TimeMachineFixture, None, None]:
  354. fixture = TimeMachineFixture()
  355. yield fixture
  356. fixture.stop()
  357. # escape hatch
  358. class _EscapeHatchDatetimeDatetime:
  359. def now(self, tz: dt.tzinfo | None = None) -> dt.datetime:
  360. result: dt.datetime = _time_machine.original_now(tz)
  361. return result
  362. def utcnow(self) -> dt.datetime:
  363. result: dt.datetime = _time_machine.original_utcnow()
  364. return result
  365. class _EscapeHatchDatetime:
  366. def __init__(self) -> None:
  367. self.datetime = _EscapeHatchDatetimeDatetime()
  368. class _EscapeHatchTime:
  369. def clock_gettime(self, clk_id: int) -> float:
  370. result: float = _time_machine.original_clock_gettime(clk_id)
  371. return result
  372. def clock_gettime_ns(self, clk_id: int) -> int:
  373. result: int = _time_machine.original_clock_gettime_ns(clk_id)
  374. return result
  375. def gmtime(self, secs: float | None = None) -> struct_time:
  376. result: struct_time = _time_machine.original_gmtime(secs)
  377. return result
  378. def localtime(self, secs: float | None = None) -> struct_time:
  379. result: struct_time = _time_machine.original_localtime(secs)
  380. return result
  381. def monotonic(self) -> float:
  382. result: float = _time_machine.original_monotonic()
  383. return result
  384. def monotonic_ns(self) -> int:
  385. result: int = _time_machine.original_monotonic_ns()
  386. return result
  387. def strftime(self, format: str, t: _TimeTuple | struct_time | None = None) -> str:
  388. result: str
  389. if t is not None:
  390. result = _time_machine.original_strftime(format, t)
  391. else:
  392. result = _time_machine.original_strftime(format)
  393. return result
  394. def time(self) -> float:
  395. result: float = _time_machine.original_time()
  396. return result
  397. def time_ns(self) -> int:
  398. result: int = _time_machine.original_time_ns()
  399. return result
  400. class _EscapeHatch:
  401. def __init__(self) -> None:
  402. self.datetime = _EscapeHatchDatetime()
  403. self.time = _EscapeHatchTime()
  404. def is_travelling(self) -> bool:
  405. return bool(coordinates_stack)
  406. escape_hatch = _EscapeHatch()