123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514 |
- from __future__ import annotations
- import datetime as dt
- import functools
- import inspect
- import os
- import sys
- import time as time_module
- import uuid
- from collections.abc import Awaitable
- from collections.abc import Generator
- from collections.abc import Generator as TypingGenerator
- from time import gmtime as orig_gmtime
- from time import struct_time
- from types import TracebackType
- from typing import Any
- from typing import Callable
- from typing import TypeVar
- from typing import Union
- from typing import cast
- from typing import overload
- from unittest import TestCase
- from unittest import mock
- from zoneinfo import ZoneInfo
- import _time_machine
- from dateutil.parser import parse as parse_datetime
- # time.clock_gettime and time.CLOCK_REALTIME not always available
- # e.g. on builds against old macOS = official Python.org installer
- try:
- from time import CLOCK_REALTIME
- except ImportError:
- # Dummy value that won't compare equal to any value
- CLOCK_REALTIME = sys.maxsize
- try:
- from time import tzset
- HAVE_TZSET = True
- except ImportError: # pragma: no cover
- # Windows
- HAVE_TZSET = False
- try:
- import pytest
- except ImportError: # pragma: no cover
- HAVE_PYTEST = False
- else:
- HAVE_PYTEST = True
- NANOSECONDS_PER_SECOND = 1_000_000_000
- # Windows' time epoch is not unix epoch but in 1601. This constant helps us
- # translate to it.
- _system_epoch = orig_gmtime(0)
- SYSTEM_EPOCH_TIMESTAMP_NS = int(
- dt.datetime(
- _system_epoch.tm_year,
- _system_epoch.tm_mon,
- _system_epoch.tm_mday,
- _system_epoch.tm_hour,
- _system_epoch.tm_min,
- _system_epoch.tm_sec,
- tzinfo=dt.timezone.utc,
- ).timestamp()
- * NANOSECONDS_PER_SECOND
- )
- DestinationBaseType = Union[
- int,
- float,
- dt.datetime,
- dt.timedelta,
- dt.date,
- str,
- ]
- DestinationType = Union[
- DestinationBaseType,
- Callable[[], DestinationBaseType],
- TypingGenerator[DestinationBaseType, None, None],
- ]
- _F = TypeVar("_F", bound=Callable[..., Any])
- _AF = TypeVar("_AF", bound=Callable[..., Awaitable[Any]])
- TestCaseType = TypeVar("TestCaseType", bound=type[TestCase])
- # copied from typeshed:
- _TimeTuple = tuple[int, int, int, int, int, int, int, int, int]
- def extract_timestamp_tzname(
- destination: DestinationType,
- ) -> tuple[float, str | None]:
- dest: DestinationBaseType
- if isinstance(destination, Generator):
- dest = next(destination)
- elif callable(destination):
- dest = destination()
- else:
- dest = destination
- timestamp: float
- tzname: str | None = None
- if isinstance(dest, int):
- timestamp = float(dest)
- elif isinstance(dest, float):
- timestamp = dest
- elif isinstance(dest, dt.datetime):
- if isinstance(dest.tzinfo, ZoneInfo):
- tzname = dest.tzinfo.key
- if dest.tzinfo is None:
- dest = dest.replace(tzinfo=dt.timezone.utc)
- timestamp = dest.timestamp()
- elif isinstance(dest, dt.timedelta):
- timestamp = time_module.time() + dest.total_seconds()
- elif isinstance(dest, dt.date):
- timestamp = dt.datetime.combine(
- dest, dt.time(0, 0), tzinfo=dt.timezone.utc
- ).timestamp()
- elif isinstance(dest, str):
- timestamp = parse_datetime(dest).timestamp()
- else:
- raise TypeError(f"Unsupported destination {dest!r}")
- return timestamp, tzname
- class Coordinates:
- def __init__(
- self,
- destination_timestamp: float,
- destination_tzname: str | None,
- tick: bool,
- ) -> None:
- self._destination_timestamp_ns = int(
- destination_timestamp * NANOSECONDS_PER_SECOND
- )
- self._destination_tzname = destination_tzname
- self._tick = tick
- self._requested = False
- def time(self) -> float:
- return self.time_ns() / NANOSECONDS_PER_SECOND
- def time_ns(self) -> int:
- if not self._tick:
- return self._destination_timestamp_ns
- base = SYSTEM_EPOCH_TIMESTAMP_NS + self._destination_timestamp_ns
- now_ns: int = _time_machine.original_time_ns()
- if not self._requested:
- self._requested = True
- self._real_start_timestamp_ns = now_ns
- return base
- return base + (now_ns - self._real_start_timestamp_ns)
- def shift(self, delta: dt.timedelta | int | float) -> None:
- if isinstance(delta, dt.timedelta):
- total_seconds = delta.total_seconds()
- elif isinstance(delta, (int, float)):
- total_seconds = delta
- else:
- raise TypeError(f"Unsupported type for delta argument: {delta!r}")
- self._destination_timestamp_ns += int(total_seconds * NANOSECONDS_PER_SECOND)
- def move_to(
- self,
- destination: DestinationType,
- tick: bool | None = None,
- ) -> None:
- self._stop()
- timestamp, self._destination_tzname = extract_timestamp_tzname(destination)
- self._destination_timestamp_ns = int(timestamp * NANOSECONDS_PER_SECOND)
- self._requested = False
- self._start()
- if tick is not None:
- self._tick = tick
- def _start(self) -> None:
- if HAVE_TZSET and self._destination_tzname is not None:
- self._orig_tz = os.environ.get("TZ")
- os.environ["TZ"] = self._destination_tzname
- tzset()
- def _stop(self) -> None:
- if HAVE_TZSET and self._destination_tzname is not None:
- if self._orig_tz is None:
- del os.environ["TZ"]
- else:
- os.environ["TZ"] = self._orig_tz
- tzset()
- coordinates_stack: list[Coordinates] = []
- # During time travel, patch the uuid module's time-based generation function to
- # None, which makes it use time.time(). Otherwise it makes a system call to
- # find the current datetime. The time it finds is stored in generated UUID1
- # values.
- uuid_generate_time_attr = "_generate_time_safe"
- uuid_generate_time_patcher = mock.patch.object(uuid, uuid_generate_time_attr, new=None)
- uuid_uuid_create_patcher = mock.patch.object(uuid, "_UuidCreate", new=None)
- class travel:
- def __init__(self, destination: DestinationType, *, tick: bool = True) -> None:
- self.destination_timestamp, self.destination_tzname = extract_timestamp_tzname(
- destination
- )
- self.tick = tick
- def start(self) -> Coordinates:
- global coordinates_stack
- _time_machine.patch_if_needed()
- if not coordinates_stack:
- uuid_generate_time_patcher.start()
- uuid_uuid_create_patcher.start()
- coordinates = Coordinates(
- destination_timestamp=self.destination_timestamp,
- destination_tzname=self.destination_tzname,
- tick=self.tick,
- )
- coordinates_stack.append(coordinates)
- coordinates._start()
- return coordinates
- def stop(self) -> None:
- global coordinates_stack
- coordinates_stack.pop()._stop()
- if not coordinates_stack:
- uuid_generate_time_patcher.stop()
- uuid_uuid_create_patcher.stop()
- def __enter__(self) -> Coordinates:
- return self.start()
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- self.stop()
- @overload
- def __call__(self, wrapped: TestCaseType) -> TestCaseType: # pragma: no cover
- ...
- @overload
- def __call__(self, wrapped: _AF) -> _AF: # pragma: no cover
- ...
- @overload
- def __call__(self, wrapped: _F) -> _F: # pragma: no cover
- ...
- # 'Any' below is workaround for Mypy error:
- # Overloaded function implementation does not accept all possible arguments
- # of signature
- def __call__(
- self, wrapped: TestCaseType | _AF | _F | Any
- ) -> TestCaseType | _AF | _F | Any:
- if isinstance(wrapped, type):
- # Class decorator
- if not issubclass(wrapped, TestCase):
- raise TypeError("Can only decorate unittest.TestCase subclasses.")
- # Modify the setUpClass method
- orig_setUpClass = wrapped.setUpClass.__func__ # type: ignore[attr-defined]
- @functools.wraps(orig_setUpClass)
- def setUpClass(cls: type[TestCase]) -> None:
- self.__enter__()
- try:
- orig_setUpClass(cls)
- except Exception:
- self.__exit__(*sys.exc_info())
- raise
- wrapped.setUpClass = classmethod(setUpClass) # type: ignore[assignment]
- orig_tearDownClass = (
- wrapped.tearDownClass.__func__ # type: ignore[attr-defined]
- )
- @functools.wraps(orig_tearDownClass)
- def tearDownClass(cls: type[TestCase]) -> None:
- orig_tearDownClass(cls)
- self.__exit__(None, None, None)
- wrapped.tearDownClass = classmethod( # type: ignore[assignment]
- tearDownClass
- )
- return cast(TestCaseType, wrapped)
- elif inspect.iscoroutinefunction(wrapped):
- @functools.wraps(wrapped)
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
- with self:
- return await wrapped(*args, **kwargs)
- return cast(_AF, wrapper)
- else:
- assert callable(wrapped)
- @functools.wraps(wrapped)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- with self:
- return wrapped(*args, **kwargs)
- return cast(_F, wrapper)
- # datetime module
- def now(tz: dt.tzinfo | None = None) -> dt.datetime:
- if not coordinates_stack:
- result: dt.datetime = _time_machine.original_now(tz)
- return result
- return dt.datetime.fromtimestamp(time(), tz)
- def utcnow() -> dt.datetime:
- if not coordinates_stack:
- result: dt.datetime = _time_machine.original_utcnow()
- return result
- return dt.datetime.fromtimestamp(time(), dt.timezone.utc).replace(tzinfo=None)
- # time module
- def clock_gettime(clk_id: int) -> float:
- if not coordinates_stack or clk_id != CLOCK_REALTIME:
- result: float = _time_machine.original_clock_gettime(clk_id)
- return result
- return time()
- def clock_gettime_ns(clk_id: int) -> int:
- if not coordinates_stack or clk_id != CLOCK_REALTIME:
- result: int = _time_machine.original_clock_gettime_ns(clk_id)
- return result
- return time_ns()
- def gmtime(secs: float | None = None) -> struct_time:
- result: struct_time
- if not coordinates_stack or secs is not None:
- result = _time_machine.original_gmtime(secs)
- else:
- result = _time_machine.original_gmtime(coordinates_stack[-1].time())
- return result
- def localtime(secs: float | None = None) -> struct_time:
- result: struct_time
- if not coordinates_stack or secs is not None:
- result = _time_machine.original_localtime(secs)
- else:
- result = _time_machine.original_localtime(coordinates_stack[-1].time())
- return result
- def strftime(format: str, t: _TimeTuple | struct_time | None = None) -> str:
- result: str
- if t is not None:
- result = _time_machine.original_strftime(format, t)
- elif not coordinates_stack:
- result = _time_machine.original_strftime(format)
- else:
- result = _time_machine.original_strftime(format, localtime())
- return result
- def time() -> float:
- if not coordinates_stack:
- result: float = _time_machine.original_time()
- return result
- return coordinates_stack[-1].time()
- def time_ns() -> int:
- if not coordinates_stack:
- result: int = _time_machine.original_time_ns()
- return result
- return coordinates_stack[-1].time_ns()
- # pytest plugin
- if HAVE_PYTEST: # pragma: no branch
- class TimeMachineFixture:
- traveller: travel | None
- coordinates: Coordinates | None
- def __init__(self) -> None:
- self.traveller = None
- self.coordinates = None
- def move_to(
- self,
- destination: DestinationType,
- tick: bool | None = None,
- ) -> None:
- if self.traveller is None:
- if tick is None:
- tick = True
- self.traveller = travel(destination, tick=tick)
- self.coordinates = self.traveller.start()
- else:
- assert self.coordinates is not None
- self.coordinates.move_to(destination, tick=tick)
- def shift(self, delta: dt.timedelta | int | float) -> None:
- if self.traveller is None:
- raise RuntimeError(
- "Initialize time_machine with move_to() before using shift()."
- )
- assert self.coordinates is not None
- self.coordinates.shift(delta=delta)
- def stop(self) -> None:
- if self.traveller is not None:
- self.traveller.stop()
- @pytest.fixture(name="time_machine")
- def time_machine_fixture() -> TypingGenerator[TimeMachineFixture, None, None]:
- fixture = TimeMachineFixture()
- yield fixture
- fixture.stop()
- # escape hatch
- class _EscapeHatchDatetimeDatetime:
- def now(self, tz: dt.tzinfo | None = None) -> dt.datetime:
- result: dt.datetime = _time_machine.original_now(tz)
- return result
- def utcnow(self) -> dt.datetime:
- result: dt.datetime = _time_machine.original_utcnow()
- return result
- class _EscapeHatchDatetime:
- def __init__(self) -> None:
- self.datetime = _EscapeHatchDatetimeDatetime()
- class _EscapeHatchTime:
- def clock_gettime(self, clk_id: int) -> float:
- result: float = _time_machine.original_clock_gettime(clk_id)
- return result
- def clock_gettime_ns(self, clk_id: int) -> int:
- result: int = _time_machine.original_clock_gettime_ns(clk_id)
- return result
- def gmtime(self, secs: float | None = None) -> struct_time:
- result: struct_time = _time_machine.original_gmtime(secs)
- return result
- def localtime(self, secs: float | None = None) -> struct_time:
- result: struct_time = _time_machine.original_localtime(secs)
- return result
- def monotonic(self) -> float:
- result: float = _time_machine.original_monotonic()
- return result
- def monotonic_ns(self) -> int:
- result: int = _time_machine.original_monotonic_ns()
- return result
- def strftime(self, format: str, t: _TimeTuple | struct_time | None = None) -> str:
- result: str
- if t is not None:
- result = _time_machine.original_strftime(format, t)
- else:
- result = _time_machine.original_strftime(format)
- return result
- def time(self) -> float:
- result: float = _time_machine.original_time()
- return result
- def time_ns(self) -> int:
- result: int = _time_machine.original_time_ns()
- return result
- class _EscapeHatch:
- def __init__(self) -> None:
- self.datetime = _EscapeHatchDatetime()
- self.time = _EscapeHatchTime()
- def is_travelling(self) -> bool:
- return bool(coordinates_stack)
- escape_hatch = _EscapeHatch()
|