_asyncio.py 91 KB


  1. from __future__ import annotations
  2. import array
  3. import asyncio
  4. import concurrent.futures
  5. import contextvars
  6. import math
  7. import os
  8. import socket
  9. import sys
  10. import threading
  11. import weakref
  12. from asyncio import (
  13. AbstractEventLoop,
  14. CancelledError,
  15. all_tasks,
  16. create_task,
  17. current_task,
  18. get_running_loop,
  19. sleep,
  20. )
  21. from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
  22. from collections import OrderedDict, deque
  23. from collections.abc import (
  24. AsyncGenerator,
  25. AsyncIterator,
  26. Awaitable,
  27. Callable,
  28. Collection,
  29. Coroutine,
  30. Iterable,
  31. Sequence,
  32. )
  33. from concurrent.futures import Future
  34. from contextlib import AbstractContextManager, suppress
  35. from contextvars import Context, copy_context
  36. from dataclasses import dataclass
  37. from functools import partial, wraps
  38. from inspect import (
  39. CORO_RUNNING,
  40. CORO_SUSPENDED,
  41. getcoroutinestate,
  42. iscoroutine,
  43. )
  44. from io import IOBase
  45. from os import PathLike
  46. from queue import Queue
  47. from signal import Signals
  48. from socket import AddressFamily, SocketKind
  49. from threading import Thread
  50. from types import CodeType, TracebackType
  51. from typing import (
  52. IO,
  53. TYPE_CHECKING,
  54. Any,
  55. Optional,
  56. TypeVar,
  57. cast,
  58. )
  59. from weakref import WeakKeyDictionary
  60. import sniffio
  61. from .. import (
  62. CapacityLimiterStatistics,
  63. EventStatistics,
  64. LockStatistics,
  65. TaskInfo,
  66. abc,
  67. )
  68. from .._core._eventloop import claim_worker_thread, threadlocals
  69. from .._core._exceptions import (
  70. BrokenResourceError,
  71. BusyResourceError,
  72. ClosedResourceError,
  73. EndOfStream,
  74. WouldBlock,
  75. iterate_exceptions,
  76. )
  77. from .._core._sockets import convert_ipv6_sockaddr
  78. from .._core._streams import create_memory_object_stream
  79. from .._core._synchronization import (
  80. CapacityLimiter as BaseCapacityLimiter,
  81. )
  82. from .._core._synchronization import Event as BaseEvent
  83. from .._core._synchronization import Lock as BaseLock
  84. from .._core._synchronization import (
  85. ResourceGuard,
  86. SemaphoreStatistics,
  87. )
  88. from .._core._synchronization import Semaphore as BaseSemaphore
  89. from .._core._tasks import CancelScope as BaseCancelScope
  90. from ..abc import (
  91. AsyncBackend,
  92. IPSockAddrType,
  93. SocketListener,
  94. UDPPacketType,
  95. UNIXDatagramPacketType,
  96. )
  97. from ..abc._eventloop import StrOrBytesPath
  98. from ..lowlevel import RunVar
  99. from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
  100. if TYPE_CHECKING:
  101. from _typeshed import FileDescriptorLike
  102. else:
  103. FileDescriptorLike = object
  104. if sys.version_info >= (3, 10):
  105. from typing import ParamSpec
  106. else:
  107. from typing_extensions import ParamSpec
  108. if sys.version_info >= (3, 11):
  109. from asyncio import Runner
  110. from typing import TypeVarTuple, Unpack
  111. else:
  112. import contextvars
  113. import enum
  114. import signal
  115. from asyncio import coroutines, events, exceptions, tasks
  116. from exceptiongroup import BaseExceptionGroup
  117. from typing_extensions import TypeVarTuple, Unpack
  118. class _State(enum.Enum):
  119. CREATED = "created"
  120. INITIALIZED = "initialized"
  121. CLOSED = "closed"
  122. class Runner:
  123. # Copied from CPython 3.11
  124. def __init__(
  125. self,
  126. *,
  127. debug: bool | None = None,
  128. loop_factory: Callable[[], AbstractEventLoop] | None = None,
  129. ):
  130. self._state = _State.CREATED
  131. self._debug = debug
  132. self._loop_factory = loop_factory
  133. self._loop: AbstractEventLoop | None = None
  134. self._context = None
  135. self._interrupt_count = 0
  136. self._set_event_loop = False
  137. def __enter__(self) -> Runner:
  138. self._lazy_init()
  139. return self
  140. def __exit__(
  141. self,
  142. exc_type: type[BaseException],
  143. exc_val: BaseException,
  144. exc_tb: TracebackType,
  145. ) -> None:
  146. self.close()
  147. def close(self) -> None:
  148. """Shutdown and close event loop."""
  149. if self._state is not _State.INITIALIZED:
  150. return
  151. try:
  152. loop = self._loop
  153. _cancel_all_tasks(loop)
  154. loop.run_until_complete(loop.shutdown_asyncgens())
  155. if hasattr(loop, "shutdown_default_executor"):
  156. loop.run_until_complete(loop.shutdown_default_executor())
  157. else:
  158. loop.run_until_complete(_shutdown_default_executor(loop))
  159. finally:
  160. if self._set_event_loop:
  161. events.set_event_loop(None)
  162. loop.close()
  163. self._loop = None
  164. self._state = _State.CLOSED
  165. def get_loop(self) -> AbstractEventLoop:
  166. """Return embedded event loop."""
  167. self._lazy_init()
  168. return self._loop
  169. def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
  170. """Run a coroutine inside the embedded event loop."""
  171. if not coroutines.iscoroutine(coro):
  172. raise ValueError(f"a coroutine was expected, got {coro!r}")
  173. if events._get_running_loop() is not None:
  174. # fail fast with short traceback
  175. raise RuntimeError(
  176. "Runner.run() cannot be called from a running event loop"
  177. )
  178. self._lazy_init()
  179. if context is None:
  180. context = self._context
  181. task = context.run(self._loop.create_task, coro)
  182. if (
  183. threading.current_thread() is threading.main_thread()
  184. and signal.getsignal(signal.SIGINT) is signal.default_int_handler
  185. ):
  186. sigint_handler = partial(self._on_sigint, main_task=task)
  187. try:
  188. signal.signal(signal.SIGINT, sigint_handler)
  189. except ValueError:
  190. # `signal.signal` may throw if `threading.main_thread` does
  191. # not support signals (e.g. embedded interpreter with signals
  192. # not registered - see gh-91880)
  193. sigint_handler = None
  194. else:
  195. sigint_handler = None
  196. self._interrupt_count = 0
  197. try:
  198. return self._loop.run_until_complete(task)
  199. except exceptions.CancelledError:
  200. if self._interrupt_count > 0:
  201. uncancel = getattr(task, "uncancel", None)
  202. if uncancel is not None and uncancel() == 0:
  203. raise KeyboardInterrupt()
  204. raise # CancelledError
  205. finally:
  206. if (
  207. sigint_handler is not None
  208. and signal.getsignal(signal.SIGINT) is sigint_handler
  209. ):
  210. signal.signal(signal.SIGINT, signal.default_int_handler)
  211. def _lazy_init(self) -> None:
  212. if self._state is _State.CLOSED:
  213. raise RuntimeError("Runner is closed")
  214. if self._state is _State.INITIALIZED:
  215. return
  216. if self._loop_factory is None:
  217. self._loop = events.new_event_loop()
  218. if not self._set_event_loop:
  219. # Call set_event_loop only once to avoid calling
  220. # attach_loop multiple times on child watchers
  221. events.set_event_loop(self._loop)
  222. self._set_event_loop = True
  223. else:
  224. self._loop = self._loop_factory()
  225. if self._debug is not None:
  226. self._loop.set_debug(self._debug)
  227. self._context = contextvars.copy_context()
  228. self._state = _State.INITIALIZED
  229. def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
  230. self._interrupt_count += 1
  231. if self._interrupt_count == 1 and not main_task.done():
  232. main_task.cancel()
  233. # wakeup loop if it is blocked by select() with long timeout
  234. self._loop.call_soon_threadsafe(lambda: None)
  235. return
  236. raise KeyboardInterrupt()
  237. def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
  238. to_cancel = tasks.all_tasks(loop)
  239. if not to_cancel:
  240. return
  241. for task in to_cancel:
  242. task.cancel()
  243. loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
  244. for task in to_cancel:
  245. if task.cancelled():
  246. continue
  247. if task.exception() is not None:
  248. loop.call_exception_handler(
  249. {
  250. "message": "unhandled exception during asyncio.run() shutdown",
  251. "exception": task.exception(),
  252. "task": task,
  253. }
  254. )
  255. async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
  256. """Schedule the shutdown of the default executor."""
  257. def _do_shutdown(future: asyncio.futures.Future) -> None:
  258. try:
  259. loop._default_executor.shutdown(wait=True) # type: ignore[attr-defined]
  260. loop.call_soon_threadsafe(future.set_result, None)
  261. except Exception as ex:
  262. loop.call_soon_threadsafe(future.set_exception, ex)
  263. loop._executor_shutdown_called = True
  264. if loop._default_executor is None:
  265. return
  266. future = loop.create_future()
  267. thread = threading.Thread(target=_do_shutdown, args=(future,))
  268. thread.start()
  269. try:
  270. await future
  271. finally:
  272. thread.join()
  273. T_Retval = TypeVar("T_Retval")
  274. T_contra = TypeVar("T_contra", contravariant=True)
  275. PosArgsT = TypeVarTuple("PosArgsT")
  276. P = ParamSpec("P")
  277. _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
  278. def find_root_task() -> asyncio.Task:
  279. root_task = _root_task.get(None)
  280. if root_task is not None and not root_task.done():
  281. return root_task
  282. # Look for a task that has been started via run_until_complete()
  283. for task in all_tasks():
  284. if task._callbacks and not task.done():
  285. callbacks = [cb for cb, context in task._callbacks]
  286. for cb in callbacks:
  287. if (
  288. cb is _run_until_complete_cb
  289. or getattr(cb, "__module__", None) == "uvloop.loop"
  290. ):
  291. _root_task.set(task)
  292. return task
  293. # Look up the topmost task in the AnyIO task tree, if possible
  294. task = cast(asyncio.Task, current_task())
  295. state = _task_states.get(task)
  296. if state:
  297. cancel_scope = state.cancel_scope
  298. while cancel_scope and cancel_scope._parent_scope is not None:
  299. cancel_scope = cancel_scope._parent_scope
  300. if cancel_scope is not None:
  301. return cast(asyncio.Task, cancel_scope._host_task)
  302. return task
  303. def get_callable_name(func: Callable) -> str:
  304. module = getattr(func, "__module__", None)
  305. qualname = getattr(func, "__qualname__", None)
  306. return ".".join([x for x in (module, qualname) if x])
  307. #
  308. # Event loop
  309. #
  310. _run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
  311. def _task_started(task: asyncio.Task) -> bool:
  312. """Return ``True`` if the task has been started and has not finished."""
  313. # The task coro should never be None here, as we never add finished tasks to the
  314. # task list
  315. coro = task.get_coro()
  316. assert coro is not None
  317. try:
  318. return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
  319. except AttributeError:
  320. # task coro is async_genenerator_asend https://bugs.python.org/issue37771
  321. raise Exception(f"Cannot determine if task {task} has started or not") from None
  322. #
  323. # Timeouts and cancellation
  324. #
  325. def is_anyio_cancellation(exc: CancelledError) -> bool:
  326. # Sometimes third party frameworks catch a CancelledError and raise a new one, so as
  327. # a workaround we have to look at the previous ones in __context__ too for a
  328. # matching cancel message
  329. while True:
  330. if (
  331. exc.args
  332. and isinstance(exc.args[0], str)
  333. and exc.args[0].startswith("Cancelled by cancel scope ")
  334. ):
  335. return True
  336. if isinstance(exc.__context__, CancelledError):
  337. exc = exc.__context__
  338. continue
  339. return False
  340. class CancelScope(BaseCancelScope):
  341. def __new__(
  342. cls, *, deadline: float = math.inf, shield: bool = False
  343. ) -> CancelScope:
  344. return object.__new__(cls)
  345. def __init__(self, deadline: float = math.inf, shield: bool = False):
  346. self._deadline = deadline
  347. self._shield = shield
  348. self._parent_scope: CancelScope | None = None
  349. self._child_scopes: set[CancelScope] = set()
  350. self._cancel_called = False
  351. self._cancelled_caught = False
  352. self._active = False
  353. self._timeout_handle: asyncio.TimerHandle | None = None
  354. self._cancel_handle: asyncio.Handle | None = None
  355. self._tasks: set[asyncio.Task] = set()
  356. self._host_task: asyncio.Task | None = None
  357. if sys.version_info >= (3, 11):
  358. self._pending_uncancellations: int | None = 0
  359. else:
  360. self._pending_uncancellations = None
  361. def __enter__(self) -> CancelScope:
  362. if self._active:
  363. raise RuntimeError(
  364. "Each CancelScope may only be used for a single 'with' block"
  365. )
  366. self._host_task = host_task = cast(asyncio.Task, current_task())
  367. self._tasks.add(host_task)
  368. try:
  369. task_state = _task_states[host_task]
  370. except KeyError:
  371. task_state = TaskState(None, self)
  372. _task_states[host_task] = task_state
  373. else:
  374. self._parent_scope = task_state.cancel_scope
  375. task_state.cancel_scope = self
  376. if self._parent_scope is not None:
  377. # If using an eager task factory, the parent scope may not even contain
  378. # the host task
  379. self._parent_scope._child_scopes.add(self)
  380. self._parent_scope._tasks.discard(host_task)
  381. self._timeout()
  382. self._active = True
  383. # Start cancelling the host task if the scope was cancelled before entering
  384. if self._cancel_called:
  385. self._deliver_cancellation(self)
  386. return self
  387. def __exit__(
  388. self,
  389. exc_type: type[BaseException] | None,
  390. exc_val: BaseException | None,
  391. exc_tb: TracebackType | None,
  392. ) -> bool:
  393. del exc_tb
  394. if not self._active:
  395. raise RuntimeError("This cancel scope is not active")
  396. if current_task() is not self._host_task:
  397. raise RuntimeError(
  398. "Attempted to exit cancel scope in a different task than it was "
  399. "entered in"
  400. )
  401. assert self._host_task is not None
  402. host_task_state = _task_states.get(self._host_task)
  403. if host_task_state is None or host_task_state.cancel_scope is not self:
  404. raise RuntimeError(
  405. "Attempted to exit a cancel scope that isn't the current tasks's "
  406. "current cancel scope"
  407. )
  408. try:
  409. self._active = False
  410. if self._timeout_handle:
  411. self._timeout_handle.cancel()
  412. self._timeout_handle = None
  413. self._tasks.remove(self._host_task)
  414. if self._parent_scope is not None:
  415. self._parent_scope._child_scopes.remove(self)
  416. self._parent_scope._tasks.add(self._host_task)
  417. host_task_state.cancel_scope = self._parent_scope
  418. # Restart the cancellation effort in the closest visible, cancelled parent
  419. # scope if necessary
  420. self._restart_cancellation_in_parent()
  421. # We only swallow the exception iff it was an AnyIO CancelledError, either
  422. # directly as exc_val or inside an exception group and there are no cancelled
  423. # parent cancel scopes visible to us here
  424. if self._cancel_called and not self._parent_cancellation_is_visible_to_us:
  425. # For each level-cancel() call made on the host task, call uncancel()
  426. while self._pending_uncancellations:
  427. self._host_task.uncancel()
  428. self._pending_uncancellations -= 1
  429. # Update cancelled_caught and check for exceptions we must not swallow
  430. cannot_swallow_exc_val = False
  431. if exc_val is not None:
  432. for exc in iterate_exceptions(exc_val):
  433. if isinstance(exc, CancelledError) and is_anyio_cancellation(
  434. exc
  435. ):
  436. self._cancelled_caught = True
  437. else:
  438. cannot_swallow_exc_val = True
  439. return self._cancelled_caught and not cannot_swallow_exc_val
  440. else:
  441. if self._pending_uncancellations:
  442. assert self._parent_scope is not None
  443. assert self._parent_scope._pending_uncancellations is not None
  444. self._parent_scope._pending_uncancellations += (
  445. self._pending_uncancellations
  446. )
  447. self._pending_uncancellations = 0
  448. return False
  449. finally:
  450. self._host_task = None
  451. del exc_val
  452. @property
  453. def _effectively_cancelled(self) -> bool:
  454. cancel_scope: CancelScope | None = self
  455. while cancel_scope is not None:
  456. if cancel_scope._cancel_called:
  457. return True
  458. if cancel_scope.shield:
  459. return False
  460. cancel_scope = cancel_scope._parent_scope
  461. return False
  462. @property
  463. def _parent_cancellation_is_visible_to_us(self) -> bool:
  464. return (
  465. self._parent_scope is not None
  466. and not self.shield
  467. and self._parent_scope._effectively_cancelled
  468. )
  469. def _timeout(self) -> None:
  470. if self._deadline != math.inf:
  471. loop = get_running_loop()
  472. if loop.time() >= self._deadline:
  473. self.cancel()
  474. else:
  475. self._timeout_handle = loop.call_at(self._deadline, self._timeout)
  476. def _deliver_cancellation(self, origin: CancelScope) -> bool:
  477. """
  478. Deliver cancellation to directly contained tasks and nested cancel scopes.
  479. Schedule another run at the end if we still have tasks eligible for
  480. cancellation.
  481. :param origin: the cancel scope that originated the cancellation
  482. :return: ``True`` if the delivery needs to be retried on the next cycle
  483. """
  484. should_retry = False
  485. current = current_task()
  486. for task in self._tasks:
  487. should_retry = True
  488. if task._must_cancel: # type: ignore[attr-defined]
  489. continue
  490. # The task is eligible for cancellation if it has started
  491. if task is not current and (task is self._host_task or _task_started(task)):
  492. waiter = task._fut_waiter # type: ignore[attr-defined]
  493. if not isinstance(waiter, asyncio.Future) or not waiter.done():
  494. task.cancel(f"Cancelled by cancel scope {id(origin):x}")
  495. if (
  496. task is origin._host_task
  497. and origin._pending_uncancellations is not None
  498. ):
  499. origin._pending_uncancellations += 1
  500. # Deliver cancellation to child scopes that aren't shielded or running their own
  501. # cancellation callbacks
  502. for scope in self._child_scopes:
  503. if not scope._shield and not scope.cancel_called:
  504. should_retry = scope._deliver_cancellation(origin) or should_retry
  505. # Schedule another callback if there are still tasks left
  506. if origin is self:
  507. if should_retry:
  508. self._cancel_handle = get_running_loop().call_soon(
  509. self._deliver_cancellation, origin
  510. )
  511. else:
  512. self._cancel_handle = None
  513. return should_retry
  514. def _restart_cancellation_in_parent(self) -> None:
  515. """
  516. Restart the cancellation effort in the closest directly cancelled parent scope.
  517. """
  518. scope = self._parent_scope
  519. while scope is not None:
  520. if scope._cancel_called:
  521. if scope._cancel_handle is None:
  522. scope._deliver_cancellation(scope)
  523. break
  524. # No point in looking beyond any shielded scope
  525. if scope._shield:
  526. break
  527. scope = scope._parent_scope
  528. def cancel(self) -> None:
  529. if not self._cancel_called:
  530. if self._timeout_handle:
  531. self._timeout_handle.cancel()
  532. self._timeout_handle = None
  533. self._cancel_called = True
  534. if self._host_task is not None:
  535. self._deliver_cancellation(self)
  536. @property
  537. def deadline(self) -> float:
  538. return self._deadline
  539. @deadline.setter
  540. def deadline(self, value: float) -> None:
  541. self._deadline = float(value)
  542. if self._timeout_handle is not None:
  543. self._timeout_handle.cancel()
  544. self._timeout_handle = None
  545. if self._active and not self._cancel_called:
  546. self._timeout()
  547. @property
  548. def cancel_called(self) -> bool:
  549. return self._cancel_called
  550. @property
  551. def cancelled_caught(self) -> bool:
  552. return self._cancelled_caught
  553. @property
  554. def shield(self) -> bool:
  555. return self._shield
  556. @shield.setter
  557. def shield(self, value: bool) -> None:
  558. if self._shield != value:
  559. self._shield = value
  560. if not value:
  561. self._restart_cancellation_in_parent()
  562. #
  563. # Task states
  564. #
  565. class TaskState:
  566. """
  567. Encapsulates auxiliary task information that cannot be added to the Task instance
  568. itself because there are no guarantees about its implementation.
  569. """
  570. __slots__ = "parent_id", "cancel_scope", "__weakref__"
  571. def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
  572. self.parent_id = parent_id
  573. self.cancel_scope = cancel_scope
  574. _task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
  575. #
  576. # Task groups
  577. #
  578. class _AsyncioTaskStatus(abc.TaskStatus):
  579. def __init__(self, future: asyncio.Future, parent_id: int):
  580. self._future = future
  581. self._parent_id = parent_id
  582. def started(self, value: T_contra | None = None) -> None:
  583. try:
  584. self._future.set_result(value)
  585. except asyncio.InvalidStateError:
  586. if not self._future.cancelled():
  587. raise RuntimeError(
  588. "called 'started' twice on the same task status"
  589. ) from None
  590. task = cast(asyncio.Task, current_task())
  591. _task_states[task].parent_id = self._parent_id
  592. if sys.version_info >= (3, 12):
  593. _eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__
  594. else:
  595. _eager_task_factory_code = None
  596. class TaskGroup(abc.TaskGroup):
  597. def __init__(self) -> None:
  598. self.cancel_scope: CancelScope = CancelScope()
  599. self._active = False
  600. self._exceptions: list[BaseException] = []
  601. self._tasks: set[asyncio.Task] = set()
  602. self._on_completed_fut: asyncio.Future[None] | None = None
  603. async def __aenter__(self) -> TaskGroup:
  604. self.cancel_scope.__enter__()
  605. self._active = True
  606. return self
  607. async def __aexit__(
  608. self,
  609. exc_type: type[BaseException] | None,
  610. exc_val: BaseException | None,
  611. exc_tb: TracebackType | None,
  612. ) -> bool | None:
  613. try:
  614. if exc_val is not None:
  615. self.cancel_scope.cancel()
  616. if not isinstance(exc_val, CancelledError):
  617. self._exceptions.append(exc_val)
  618. loop = get_running_loop()
  619. try:
  620. if self._tasks:
  621. with CancelScope() as wait_scope:
  622. while self._tasks:
  623. self._on_completed_fut = loop.create_future()
  624. try:
  625. await self._on_completed_fut
  626. except CancelledError as exc:
  627. # Shield the scope against further cancellation attempts,
  628. # as they're not productive (#695)
  629. wait_scope.shield = True
  630. self.cancel_scope.cancel()
  631. # Set exc_val from the cancellation exception if it was
  632. # previously unset. However, we should not replace a native
  633. # cancellation exception with one raise by a cancel scope.
  634. if exc_val is None or (
  635. isinstance(exc_val, CancelledError)
  636. and not is_anyio_cancellation(exc)
  637. ):
  638. exc_val = exc
  639. self._on_completed_fut = None
  640. else:
  641. # If there are no child tasks to wait on, run at least one checkpoint
  642. # anyway
  643. await AsyncIOBackend.cancel_shielded_checkpoint()
  644. self._active = False
  645. if self._exceptions:
  646. # The exception that got us here should already have been
  647. # added to self._exceptions so it's ok to break exception
  648. # chaining and avoid adding a "During handling of above..."
  649. # for each nesting level.
  650. raise BaseExceptionGroup(
  651. "unhandled errors in a TaskGroup", self._exceptions
  652. ) from None
  653. elif exc_val:
  654. raise exc_val
  655. except BaseException as exc:
  656. if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
  657. return True
  658. raise
  659. return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
  660. finally:
  661. del exc_val, exc_tb, self._exceptions
  662. def _spawn(
  663. self,
  664. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  665. args: tuple[Unpack[PosArgsT]],
  666. name: object,
  667. task_status_future: asyncio.Future | None = None,
  668. ) -> asyncio.Task:
  669. def task_done(_task: asyncio.Task) -> None:
  670. task_state = _task_states[_task]
  671. assert task_state.cancel_scope is not None
  672. assert _task in task_state.cancel_scope._tasks
  673. task_state.cancel_scope._tasks.remove(_task)
  674. self._tasks.remove(task)
  675. del _task_states[_task]
  676. if self._on_completed_fut is not None and not self._tasks:
  677. try:
  678. self._on_completed_fut.set_result(None)
  679. except asyncio.InvalidStateError:
  680. pass
  681. try:
  682. exc = _task.exception()
  683. except CancelledError as e:
  684. while isinstance(e.__context__, CancelledError):
  685. e = e.__context__
  686. exc = e
  687. if exc is not None:
  688. # The future can only be in the cancelled state if the host task was
  689. # cancelled, so return immediately instead of adding one more
  690. # CancelledError to the exceptions list
  691. if task_status_future is not None and task_status_future.cancelled():
  692. return
  693. if task_status_future is None or task_status_future.done():
  694. if not isinstance(exc, CancelledError):
  695. self._exceptions.append(exc)
  696. if not self.cancel_scope._effectively_cancelled:
  697. self.cancel_scope.cancel()
  698. else:
  699. task_status_future.set_exception(exc)
  700. elif task_status_future is not None and not task_status_future.done():
  701. task_status_future.set_exception(
  702. RuntimeError("Child exited without calling task_status.started()")
  703. )
  704. if not self._active:
  705. raise RuntimeError(
  706. "This task group is not active; no new tasks can be started."
  707. )
  708. kwargs = {}
  709. if task_status_future:
  710. parent_id = id(current_task())
  711. kwargs["task_status"] = _AsyncioTaskStatus(
  712. task_status_future, id(self.cancel_scope._host_task)
  713. )
  714. else:
  715. parent_id = id(self.cancel_scope._host_task)
  716. coro = func(*args, **kwargs)
  717. if not iscoroutine(coro):
  718. prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
  719. raise TypeError(
  720. f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
  721. f"the return value ({coro!r}) is not a coroutine object"
  722. )
  723. name = get_callable_name(func) if name is None else str(name)
  724. loop = asyncio.get_running_loop()
  725. if (
  726. (factory := loop.get_task_factory())
  727. and getattr(factory, "__code__", None) is _eager_task_factory_code
  728. and (closure := getattr(factory, "__closure__", None))
  729. ):
  730. custom_task_constructor = closure[0].cell_contents
  731. task = custom_task_constructor(coro, loop=loop, name=name)
  732. else:
  733. task = create_task(coro, name=name)
  734. # Make the spawned task inherit the task group's cancel scope
  735. _task_states[task] = TaskState(
  736. parent_id=parent_id, cancel_scope=self.cancel_scope
  737. )
  738. self.cancel_scope._tasks.add(task)
  739. self._tasks.add(task)
  740. task.add_done_callback(task_done)
  741. return task
  742. def start_soon(
  743. self,
  744. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  745. *args: Unpack[PosArgsT],
  746. name: object = None,
  747. ) -> None:
  748. self._spawn(func, args, name)
  749. async def start(
  750. self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
  751. ) -> Any:
  752. future: asyncio.Future = asyncio.Future()
  753. task = self._spawn(func, args, name, future)
  754. # If the task raises an exception after sending a start value without a switch
  755. # point between, the task group is cancelled and this method never proceeds to
  756. # process the completed future. That's why we have to have a shielded cancel
  757. # scope here.
  758. try:
  759. return await future
  760. except CancelledError:
  761. # Cancel the task and wait for it to exit before returning
  762. task.cancel()
  763. with CancelScope(shield=True), suppress(CancelledError):
  764. await task
  765. raise
  766. #
  767. # Threads
  768. #
  769. _Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]]
  770. class WorkerThread(Thread):
  771. MAX_IDLE_TIME = 10 # seconds
  772. def __init__(
  773. self,
  774. root_task: asyncio.Task,
  775. workers: set[WorkerThread],
  776. idle_workers: deque[WorkerThread],
  777. ):
  778. super().__init__(name="AnyIO worker thread")
  779. self.root_task = root_task
  780. self.workers = workers
  781. self.idle_workers = idle_workers
  782. self.loop = root_task._loop
  783. self.queue: Queue[
  784. tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
  785. ] = Queue(2)
  786. self.idle_since = AsyncIOBackend.current_time()
  787. self.stopping = False
  788. def _report_result(
  789. self, future: asyncio.Future, result: Any, exc: BaseException | None
  790. ) -> None:
  791. self.idle_since = AsyncIOBackend.current_time()
  792. if not self.stopping:
  793. self.idle_workers.append(self)
  794. if not future.cancelled():
  795. if exc is not None:
  796. if isinstance(exc, StopIteration):
  797. new_exc = RuntimeError("coroutine raised StopIteration")
  798. new_exc.__cause__ = exc
  799. exc = new_exc
  800. future.set_exception(exc)
  801. else:
  802. future.set_result(result)
  803. def run(self) -> None:
  804. with claim_worker_thread(AsyncIOBackend, self.loop):
  805. while True:
  806. item = self.queue.get()
  807. if item is None:
  808. # Shutdown command received
  809. return
  810. context, func, args, future, cancel_scope = item
  811. if not future.cancelled():
  812. result = None
  813. exception: BaseException | None = None
  814. threadlocals.current_cancel_scope = cancel_scope
  815. try:
  816. result = context.run(func, *args)
  817. except BaseException as exc:
  818. exception = exc
  819. finally:
  820. del threadlocals.current_cancel_scope
  821. if not self.loop.is_closed():
  822. self.loop.call_soon_threadsafe(
  823. self._report_result, future, result, exception
  824. )
  825. del result, exception
  826. self.queue.task_done()
  827. del item, context, func, args, future, cancel_scope
  828. def stop(self, f: asyncio.Task | None = None) -> None:
  829. self.stopping = True
  830. self.queue.put_nowait(None)
  831. self.workers.discard(self)
  832. try:
  833. self.idle_workers.remove(self)
  834. except ValueError:
  835. pass
  836. _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
  837. "_threadpool_idle_workers"
  838. )
  839. _threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
  840. class BlockingPortal(abc.BlockingPortal):
  841. def __new__(cls) -> BlockingPortal:
  842. return object.__new__(cls)
  843. def __init__(self) -> None:
  844. super().__init__()
  845. self._loop = get_running_loop()
  846. def _spawn_task_from_thread(
  847. self,
  848. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  849. args: tuple[Unpack[PosArgsT]],
  850. kwargs: dict[str, Any],
  851. name: object,
  852. future: Future[T_Retval],
  853. ) -> None:
  854. AsyncIOBackend.run_sync_from_thread(
  855. partial(self._task_group.start_soon, name=name),
  856. (self._call_func, func, args, kwargs, future),
  857. self._loop,
  858. )
  859. #
  860. # Subprocesses
  861. #
  862. @dataclass(eq=False)
  863. class StreamReaderWrapper(abc.ByteReceiveStream):
  864. _stream: asyncio.StreamReader
  865. async def receive(self, max_bytes: int = 65536) -> bytes:
  866. data = await self._stream.read(max_bytes)
  867. if data:
  868. return data
  869. else:
  870. raise EndOfStream
  871. async def aclose(self) -> None:
  872. self._stream.set_exception(ClosedResourceError())
  873. await AsyncIOBackend.checkpoint()
  874. @dataclass(eq=False)
  875. class StreamWriterWrapper(abc.ByteSendStream):
  876. _stream: asyncio.StreamWriter
  877. async def send(self, item: bytes) -> None:
  878. self._stream.write(item)
  879. await self._stream.drain()
  880. async def aclose(self) -> None:
  881. self._stream.close()
  882. await AsyncIOBackend.checkpoint()
  883. @dataclass(eq=False)
  884. class Process(abc.Process):
  885. _process: asyncio.subprocess.Process
  886. _stdin: StreamWriterWrapper | None
  887. _stdout: StreamReaderWrapper | None
  888. _stderr: StreamReaderWrapper | None
  889. async def aclose(self) -> None:
  890. with CancelScope(shield=True) as scope:
  891. if self._stdin:
  892. await self._stdin.aclose()
  893. if self._stdout:
  894. await self._stdout.aclose()
  895. if self._stderr:
  896. await self._stderr.aclose()
  897. scope.shield = False
  898. try:
  899. await self.wait()
  900. except BaseException:
  901. scope.shield = True
  902. self.kill()
  903. await self.wait()
  904. raise
  905. async def wait(self) -> int:
  906. return await self._process.wait()
  907. def terminate(self) -> None:
  908. self._process.terminate()
  909. def kill(self) -> None:
  910. self._process.kill()
  911. def send_signal(self, signal: int) -> None:
  912. self._process.send_signal(signal)
  913. @property
  914. def pid(self) -> int:
  915. return self._process.pid
  916. @property
  917. def returncode(self) -> int | None:
  918. return self._process.returncode
  919. @property
  920. def stdin(self) -> abc.ByteSendStream | None:
  921. return self._stdin
  922. @property
  923. def stdout(self) -> abc.ByteReceiveStream | None:
  924. return self._stdout
  925. @property
  926. def stderr(self) -> abc.ByteReceiveStream | None:
  927. return self._stderr
  928. def _forcibly_shutdown_process_pool_on_exit(
  929. workers: set[Process], _task: object
  930. ) -> None:
  931. """
  932. Forcibly shuts down worker processes belonging to this event loop."""
  933. child_watcher: asyncio.AbstractChildWatcher | None = None
  934. if sys.version_info < (3, 12):
  935. try:
  936. child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
  937. except NotImplementedError:
  938. pass
  939. # Close as much as possible (w/o async/await) to avoid warnings
  940. for process in workers:
  941. if process.returncode is None:
  942. continue
  943. process._stdin._stream._transport.close() # type: ignore[union-attr]
  944. process._stdout._stream._transport.close() # type: ignore[union-attr]
  945. process._stderr._stream._transport.close() # type: ignore[union-attr]
  946. process.kill()
  947. if child_watcher:
  948. child_watcher.remove_child_handler(process.pid)
  949. async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
  950. """
  951. Shuts down worker processes belonging to this event loop.
  952. NOTE: this only works when the event loop was started using asyncio.run() or
  953. anyio.run().
  954. """
  955. process: abc.Process
  956. try:
  957. await sleep(math.inf)
  958. except asyncio.CancelledError:
  959. for process in workers:
  960. if process.returncode is None:
  961. process.kill()
  962. for process in workers:
  963. await process.aclose()
  964. #
  965. # Sockets and networking
  966. #
  967. class StreamProtocol(asyncio.Protocol):
  968. read_queue: deque[bytes]
  969. read_event: asyncio.Event
  970. write_event: asyncio.Event
  971. exception: Exception | None = None
  972. is_at_eof: bool = False
  973. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  974. self.read_queue = deque()
  975. self.read_event = asyncio.Event()
  976. self.write_event = asyncio.Event()
  977. self.write_event.set()
  978. cast(asyncio.Transport, transport).set_write_buffer_limits(0)
  979. def connection_lost(self, exc: Exception | None) -> None:
  980. if exc:
  981. self.exception = BrokenResourceError()
  982. self.exception.__cause__ = exc
  983. self.read_event.set()
  984. self.write_event.set()
  985. def data_received(self, data: bytes) -> None:
  986. # ProactorEventloop sometimes sends bytearray instead of bytes
  987. self.read_queue.append(bytes(data))
  988. self.read_event.set()
  989. def eof_received(self) -> bool | None:
  990. self.is_at_eof = True
  991. self.read_event.set()
  992. return True
  993. def pause_writing(self) -> None:
  994. self.write_event = asyncio.Event()
  995. def resume_writing(self) -> None:
  996. self.write_event.set()
  997. class DatagramProtocol(asyncio.DatagramProtocol):
  998. read_queue: deque[tuple[bytes, IPSockAddrType]]
  999. read_event: asyncio.Event
  1000. write_event: asyncio.Event
  1001. exception: Exception | None = None
  1002. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  1003. self.read_queue = deque(maxlen=100) # arbitrary value
  1004. self.read_event = asyncio.Event()
  1005. self.write_event = asyncio.Event()
  1006. self.write_event.set()
  1007. def connection_lost(self, exc: Exception | None) -> None:
  1008. self.read_event.set()
  1009. self.write_event.set()
  1010. def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
  1011. addr = convert_ipv6_sockaddr(addr)
  1012. self.read_queue.append((data, addr))
  1013. self.read_event.set()
  1014. def error_received(self, exc: Exception) -> None:
  1015. self.exception = exc
  1016. def pause_writing(self) -> None:
  1017. self.write_event.clear()
  1018. def resume_writing(self) -> None:
  1019. self.write_event.set()
  1020. class SocketStream(abc.SocketStream):
  1021. def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
  1022. self._transport = transport
  1023. self._protocol = protocol
  1024. self._receive_guard = ResourceGuard("reading from")
  1025. self._send_guard = ResourceGuard("writing to")
  1026. self._closed = False
  1027. @property
  1028. def _raw_socket(self) -> socket.socket:
  1029. return self._transport.get_extra_info("socket")
  1030. async def receive(self, max_bytes: int = 65536) -> bytes:
  1031. with self._receive_guard:
  1032. if (
  1033. not self._protocol.read_event.is_set()
  1034. and not self._transport.is_closing()
  1035. and not self._protocol.is_at_eof
  1036. ):
  1037. self._transport.resume_reading()
  1038. await self._protocol.read_event.wait()
  1039. self._transport.pause_reading()
  1040. else:
  1041. await AsyncIOBackend.checkpoint()
  1042. try:
  1043. chunk = self._protocol.read_queue.popleft()
  1044. except IndexError:
  1045. if self._closed:
  1046. raise ClosedResourceError from None
  1047. elif self._protocol.exception:
  1048. raise self._protocol.exception from None
  1049. else:
  1050. raise EndOfStream from None
  1051. if len(chunk) > max_bytes:
  1052. # Split the oversized chunk
  1053. chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
  1054. self._protocol.read_queue.appendleft(leftover)
  1055. # If the read queue is empty, clear the flag so that the next call will
  1056. # block until data is available
  1057. if not self._protocol.read_queue:
  1058. self._protocol.read_event.clear()
  1059. return chunk
  1060. async def send(self, item: bytes) -> None:
  1061. with self._send_guard:
  1062. await AsyncIOBackend.checkpoint()
  1063. if self._closed:
  1064. raise ClosedResourceError
  1065. elif self._protocol.exception is not None:
  1066. raise self._protocol.exception
  1067. try:
  1068. self._transport.write(item)
  1069. except RuntimeError as exc:
  1070. if self._transport.is_closing():
  1071. raise BrokenResourceError from exc
  1072. else:
  1073. raise
  1074. await self._protocol.write_event.wait()
  1075. async def send_eof(self) -> None:
  1076. try:
  1077. self._transport.write_eof()
  1078. except OSError:
  1079. pass
  1080. async def aclose(self) -> None:
  1081. if not self._transport.is_closing():
  1082. self._closed = True
  1083. try:
  1084. self._transport.write_eof()
  1085. except OSError:
  1086. pass
  1087. self._transport.close()
  1088. await sleep(0)
  1089. self._transport.abort()
  1090. class _RawSocketMixin:
  1091. _receive_future: asyncio.Future | None = None
  1092. _send_future: asyncio.Future | None = None
  1093. _closing = False
  1094. def __init__(self, raw_socket: socket.socket):
  1095. self.__raw_socket = raw_socket
  1096. self._receive_guard = ResourceGuard("reading from")
  1097. self._send_guard = ResourceGuard("writing to")
  1098. @property
  1099. def _raw_socket(self) -> socket.socket:
  1100. return self.__raw_socket
  1101. def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
  1102. def callback(f: object) -> None:
  1103. del self._receive_future
  1104. loop.remove_reader(self.__raw_socket)
  1105. f = self._receive_future = asyncio.Future()
  1106. loop.add_reader(self.__raw_socket, f.set_result, None)
  1107. f.add_done_callback(callback)
  1108. return f
  1109. def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
  1110. def callback(f: object) -> None:
  1111. del self._send_future
  1112. loop.remove_writer(self.__raw_socket)
  1113. f = self._send_future = asyncio.Future()
  1114. loop.add_writer(self.__raw_socket, f.set_result, None)
  1115. f.add_done_callback(callback)
  1116. return f
  1117. async def aclose(self) -> None:
  1118. if not self._closing:
  1119. self._closing = True
  1120. if self.__raw_socket.fileno() != -1:
  1121. self.__raw_socket.close()
  1122. if self._receive_future:
  1123. self._receive_future.set_result(None)
  1124. if self._send_future:
  1125. self._send_future.set_result(None)
  1126. class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
  1127. async def send_eof(self) -> None:
  1128. with self._send_guard:
  1129. self._raw_socket.shutdown(socket.SHUT_WR)
  1130. async def receive(self, max_bytes: int = 65536) -> bytes:
  1131. loop = get_running_loop()
  1132. await AsyncIOBackend.checkpoint()
  1133. with self._receive_guard:
  1134. while True:
  1135. try:
  1136. data = self._raw_socket.recv(max_bytes)
  1137. except BlockingIOError:
  1138. await self._wait_until_readable(loop)
  1139. except OSError as exc:
  1140. if self._closing:
  1141. raise ClosedResourceError from None
  1142. else:
  1143. raise BrokenResourceError from exc
  1144. else:
  1145. if not data:
  1146. raise EndOfStream
  1147. return data
  1148. async def send(self, item: bytes) -> None:
  1149. loop = get_running_loop()
  1150. await AsyncIOBackend.checkpoint()
  1151. with self._send_guard:
  1152. view = memoryview(item)
  1153. while view:
  1154. try:
  1155. bytes_sent = self._raw_socket.send(view)
  1156. except BlockingIOError:
  1157. await self._wait_until_writable(loop)
  1158. except OSError as exc:
  1159. if self._closing:
  1160. raise ClosedResourceError from None
  1161. else:
  1162. raise BrokenResourceError from exc
  1163. else:
  1164. view = view[bytes_sent:]
  1165. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  1166. if not isinstance(msglen, int) or msglen < 0:
  1167. raise ValueError("msglen must be a non-negative integer")
  1168. if not isinstance(maxfds, int) or maxfds < 1:
  1169. raise ValueError("maxfds must be a positive integer")
  1170. loop = get_running_loop()
  1171. fds = array.array("i")
  1172. await AsyncIOBackend.checkpoint()
  1173. with self._receive_guard:
  1174. while True:
  1175. try:
  1176. message, ancdata, flags, addr = self._raw_socket.recvmsg(
  1177. msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
  1178. )
  1179. except BlockingIOError:
  1180. await self._wait_until_readable(loop)
  1181. except OSError as exc:
  1182. if self._closing:
  1183. raise ClosedResourceError from None
  1184. else:
  1185. raise BrokenResourceError from exc
  1186. else:
  1187. if not message and not ancdata:
  1188. raise EndOfStream
  1189. break
  1190. for cmsg_level, cmsg_type, cmsg_data in ancdata:
  1191. if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
  1192. raise RuntimeError(
  1193. f"Received unexpected ancillary data; message = {message!r}, "
  1194. f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
  1195. )
  1196. fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
  1197. return message, list(fds)
  1198. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  1199. if not message:
  1200. raise ValueError("message must not be empty")
  1201. if not fds:
  1202. raise ValueError("fds must not be empty")
  1203. loop = get_running_loop()
  1204. filenos: list[int] = []
  1205. for fd in fds:
  1206. if isinstance(fd, int):
  1207. filenos.append(fd)
  1208. elif isinstance(fd, IOBase):
  1209. filenos.append(fd.fileno())
  1210. fdarray = array.array("i", filenos)
  1211. await AsyncIOBackend.checkpoint()
  1212. with self._send_guard:
  1213. while True:
  1214. try:
  1215. # The ignore can be removed after mypy picks up
  1216. # https://github.com/python/typeshed/pull/5545
  1217. self._raw_socket.sendmsg(
  1218. [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
  1219. )
  1220. break
  1221. except BlockingIOError:
  1222. await self._wait_until_writable(loop)
  1223. except OSError as exc:
  1224. if self._closing:
  1225. raise ClosedResourceError from None
  1226. else:
  1227. raise BrokenResourceError from exc
  1228. class TCPSocketListener(abc.SocketListener):
  1229. _accept_scope: CancelScope | None = None
  1230. _closed = False
  1231. def __init__(self, raw_socket: socket.socket):
  1232. self.__raw_socket = raw_socket
  1233. self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
  1234. self._accept_guard = ResourceGuard("accepting connections from")
  1235. @property
  1236. def _raw_socket(self) -> socket.socket:
  1237. return self.__raw_socket
  1238. async def accept(self) -> abc.SocketStream:
  1239. if self._closed:
  1240. raise ClosedResourceError
  1241. with self._accept_guard:
  1242. await AsyncIOBackend.checkpoint()
  1243. with CancelScope() as self._accept_scope:
  1244. try:
  1245. client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
  1246. except asyncio.CancelledError:
  1247. # Workaround for https://bugs.python.org/issue41317
  1248. try:
  1249. self._loop.remove_reader(self._raw_socket)
  1250. except (ValueError, NotImplementedError):
  1251. pass
  1252. if self._closed:
  1253. raise ClosedResourceError from None
  1254. raise
  1255. finally:
  1256. self._accept_scope = None
  1257. client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  1258. transport, protocol = await self._loop.connect_accepted_socket(
  1259. StreamProtocol, client_sock
  1260. )
  1261. return SocketStream(transport, protocol)
  1262. async def aclose(self) -> None:
  1263. if self._closed:
  1264. return
  1265. self._closed = True
  1266. if self._accept_scope:
  1267. # Workaround for https://bugs.python.org/issue41317
  1268. try:
  1269. self._loop.remove_reader(self._raw_socket)
  1270. except (ValueError, NotImplementedError):
  1271. pass
  1272. self._accept_scope.cancel()
  1273. await sleep(0)
  1274. self._raw_socket.close()
  1275. class UNIXSocketListener(abc.SocketListener):
  1276. def __init__(self, raw_socket: socket.socket):
  1277. self.__raw_socket = raw_socket
  1278. self._loop = get_running_loop()
  1279. self._accept_guard = ResourceGuard("accepting connections from")
  1280. self._closed = False
  1281. async def accept(self) -> abc.SocketStream:
  1282. await AsyncIOBackend.checkpoint()
  1283. with self._accept_guard:
  1284. while True:
  1285. try:
  1286. client_sock, _ = self.__raw_socket.accept()
  1287. client_sock.setblocking(False)
  1288. return UNIXSocketStream(client_sock)
  1289. except BlockingIOError:
  1290. f: asyncio.Future = asyncio.Future()
  1291. self._loop.add_reader(self.__raw_socket, f.set_result, None)
  1292. f.add_done_callback(
  1293. lambda _: self._loop.remove_reader(self.__raw_socket)
  1294. )
  1295. await f
  1296. except OSError as exc:
  1297. if self._closed:
  1298. raise ClosedResourceError from None
  1299. else:
  1300. raise BrokenResourceError from exc
  1301. async def aclose(self) -> None:
  1302. self._closed = True
  1303. self.__raw_socket.close()
  1304. @property
  1305. def _raw_socket(self) -> socket.socket:
  1306. return self.__raw_socket
  1307. class UDPSocket(abc.UDPSocket):
  1308. def __init__(
  1309. self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
  1310. ):
  1311. self._transport = transport
  1312. self._protocol = protocol
  1313. self._receive_guard = ResourceGuard("reading from")
  1314. self._send_guard = ResourceGuard("writing to")
  1315. self._closed = False
  1316. @property
  1317. def _raw_socket(self) -> socket.socket:
  1318. return self._transport.get_extra_info("socket")
  1319. async def aclose(self) -> None:
  1320. if not self._transport.is_closing():
  1321. self._closed = True
  1322. self._transport.close()
  1323. async def receive(self) -> tuple[bytes, IPSockAddrType]:
  1324. with self._receive_guard:
  1325. await AsyncIOBackend.checkpoint()
  1326. # If the buffer is empty, ask for more data
  1327. if not self._protocol.read_queue and not self._transport.is_closing():
  1328. self._protocol.read_event.clear()
  1329. await self._protocol.read_event.wait()
  1330. try:
  1331. return self._protocol.read_queue.popleft()
  1332. except IndexError:
  1333. if self._closed:
  1334. raise ClosedResourceError from None
  1335. else:
  1336. raise BrokenResourceError from None
  1337. async def send(self, item: UDPPacketType) -> None:
  1338. with self._send_guard:
  1339. await AsyncIOBackend.checkpoint()
  1340. await self._protocol.write_event.wait()
  1341. if self._closed:
  1342. raise ClosedResourceError
  1343. elif self._transport.is_closing():
  1344. raise BrokenResourceError
  1345. else:
  1346. self._transport.sendto(*item)
  1347. class ConnectedUDPSocket(abc.ConnectedUDPSocket):
  1348. def __init__(
  1349. self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
  1350. ):
  1351. self._transport = transport
  1352. self._protocol = protocol
  1353. self._receive_guard = ResourceGuard("reading from")
  1354. self._send_guard = ResourceGuard("writing to")
  1355. self._closed = False
  1356. @property
  1357. def _raw_socket(self) -> socket.socket:
  1358. return self._transport.get_extra_info("socket")
  1359. async def aclose(self) -> None:
  1360. if not self._transport.is_closing():
  1361. self._closed = True
  1362. self._transport.close()
  1363. async def receive(self) -> bytes:
  1364. with self._receive_guard:
  1365. await AsyncIOBackend.checkpoint()
  1366. # If the buffer is empty, ask for more data
  1367. if not self._protocol.read_queue and not self._transport.is_closing():
  1368. self._protocol.read_event.clear()
  1369. await self._protocol.read_event.wait()
  1370. try:
  1371. packet = self._protocol.read_queue.popleft()
  1372. except IndexError:
  1373. if self._closed:
  1374. raise ClosedResourceError from None
  1375. else:
  1376. raise BrokenResourceError from None
  1377. return packet[0]
  1378. async def send(self, item: bytes) -> None:
  1379. with self._send_guard:
  1380. await AsyncIOBackend.checkpoint()
  1381. await self._protocol.write_event.wait()
  1382. if self._closed:
  1383. raise ClosedResourceError
  1384. elif self._transport.is_closing():
  1385. raise BrokenResourceError
  1386. else:
  1387. self._transport.sendto(item)
  1388. class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
  1389. async def receive(self) -> UNIXDatagramPacketType:
  1390. loop = get_running_loop()
  1391. await AsyncIOBackend.checkpoint()
  1392. with self._receive_guard:
  1393. while True:
  1394. try:
  1395. data = self._raw_socket.recvfrom(65536)
  1396. except BlockingIOError:
  1397. await self._wait_until_readable(loop)
  1398. except OSError as exc:
  1399. if self._closing:
  1400. raise ClosedResourceError from None
  1401. else:
  1402. raise BrokenResourceError from exc
  1403. else:
  1404. return data
  1405. async def send(self, item: UNIXDatagramPacketType) -> None:
  1406. loop = get_running_loop()
  1407. await AsyncIOBackend.checkpoint()
  1408. with self._send_guard:
  1409. while True:
  1410. try:
  1411. self._raw_socket.sendto(*item)
  1412. except BlockingIOError:
  1413. await self._wait_until_writable(loop)
  1414. except OSError as exc:
  1415. if self._closing:
  1416. raise ClosedResourceError from None
  1417. else:
  1418. raise BrokenResourceError from exc
  1419. else:
  1420. return
  1421. class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
  1422. async def receive(self) -> bytes:
  1423. loop = get_running_loop()
  1424. await AsyncIOBackend.checkpoint()
  1425. with self._receive_guard:
  1426. while True:
  1427. try:
  1428. data = self._raw_socket.recv(65536)
  1429. except BlockingIOError:
  1430. await self._wait_until_readable(loop)
  1431. except OSError as exc:
  1432. if self._closing:
  1433. raise ClosedResourceError from None
  1434. else:
  1435. raise BrokenResourceError from exc
  1436. else:
  1437. return data
  1438. async def send(self, item: bytes) -> None:
  1439. loop = get_running_loop()
  1440. await AsyncIOBackend.checkpoint()
  1441. with self._send_guard:
  1442. while True:
  1443. try:
  1444. self._raw_socket.send(item)
  1445. except BlockingIOError:
  1446. await self._wait_until_writable(loop)
  1447. except OSError as exc:
  1448. if self._closing:
  1449. raise ClosedResourceError from None
  1450. else:
  1451. raise BrokenResourceError from exc
  1452. else:
  1453. return
  1454. _read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events")
  1455. _write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events")
  1456. #
  1457. # Synchronization
  1458. #
  1459. class Event(BaseEvent):
  1460. def __new__(cls) -> Event:
  1461. return object.__new__(cls)
  1462. def __init__(self) -> None:
  1463. self._event = asyncio.Event()
  1464. def set(self) -> None:
  1465. self._event.set()
  1466. def is_set(self) -> bool:
  1467. return self._event.is_set()
  1468. async def wait(self) -> None:
  1469. if self.is_set():
  1470. await AsyncIOBackend.checkpoint()
  1471. else:
  1472. await self._event.wait()
  1473. def statistics(self) -> EventStatistics:
  1474. return EventStatistics(len(self._event._waiters))
  1475. class Lock(BaseLock):
  1476. def __new__(cls, *, fast_acquire: bool = False) -> Lock:
  1477. return object.__new__(cls)
  1478. def __init__(self, *, fast_acquire: bool = False) -> None:
  1479. self._fast_acquire = fast_acquire
  1480. self._owner_task: asyncio.Task | None = None
  1481. self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
  1482. async def acquire(self) -> None:
  1483. task = cast(asyncio.Task, current_task())
  1484. if self._owner_task is None and not self._waiters:
  1485. await AsyncIOBackend.checkpoint_if_cancelled()
  1486. self._owner_task = task
  1487. # Unless on the "fast path", yield control of the event loop so that other
  1488. # tasks can run too
  1489. if not self._fast_acquire:
  1490. try:
  1491. await AsyncIOBackend.cancel_shielded_checkpoint()
  1492. except CancelledError:
  1493. self.release()
  1494. raise
  1495. return
  1496. if self._owner_task == task:
  1497. raise RuntimeError("Attempted to acquire an already held Lock")
  1498. fut: asyncio.Future[None] = asyncio.Future()
  1499. item = task, fut
  1500. self._waiters.append(item)
  1501. try:
  1502. await fut
  1503. except CancelledError:
  1504. self._waiters.remove(item)
  1505. if self._owner_task is task:
  1506. self.release()
  1507. raise
  1508. self._waiters.remove(item)
  1509. def acquire_nowait(self) -> None:
  1510. task = cast(asyncio.Task, current_task())
  1511. if self._owner_task is None and not self._waiters:
  1512. self._owner_task = task
  1513. return
  1514. if self._owner_task is task:
  1515. raise RuntimeError("Attempted to acquire an already held Lock")
  1516. raise WouldBlock
  1517. def locked(self) -> bool:
  1518. return self._owner_task is not None
  1519. def release(self) -> None:
  1520. if self._owner_task != current_task():
  1521. raise RuntimeError("The current task is not holding this lock")
  1522. for task, fut in self._waiters:
  1523. if not fut.cancelled():
  1524. self._owner_task = task
  1525. fut.set_result(None)
  1526. return
  1527. self._owner_task = None
  1528. def statistics(self) -> LockStatistics:
  1529. task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
  1530. return LockStatistics(self.locked(), task_info, len(self._waiters))
  1531. class Semaphore(BaseSemaphore):
  1532. def __new__(
  1533. cls,
  1534. initial_value: int,
  1535. *,
  1536. max_value: int | None = None,
  1537. fast_acquire: bool = False,
  1538. ) -> Semaphore:
  1539. return object.__new__(cls)
  1540. def __init__(
  1541. self,
  1542. initial_value: int,
  1543. *,
  1544. max_value: int | None = None,
  1545. fast_acquire: bool = False,
  1546. ):
  1547. super().__init__(initial_value, max_value=max_value)
  1548. self._value = initial_value
  1549. self._max_value = max_value
  1550. self._fast_acquire = fast_acquire
  1551. self._waiters: deque[asyncio.Future[None]] = deque()
  1552. async def acquire(self) -> None:
  1553. if self._value > 0 and not self._waiters:
  1554. await AsyncIOBackend.checkpoint_if_cancelled()
  1555. self._value -= 1
  1556. # Unless on the "fast path", yield control of the event loop so that other
  1557. # tasks can run too
  1558. if not self._fast_acquire:
  1559. try:
  1560. await AsyncIOBackend.cancel_shielded_checkpoint()
  1561. except CancelledError:
  1562. self.release()
  1563. raise
  1564. return
  1565. fut: asyncio.Future[None] = asyncio.Future()
  1566. self._waiters.append(fut)
  1567. try:
  1568. await fut
  1569. except CancelledError:
  1570. try:
  1571. self._waiters.remove(fut)
  1572. except ValueError:
  1573. self.release()
  1574. raise
  1575. def acquire_nowait(self) -> None:
  1576. if self._value == 0:
  1577. raise WouldBlock
  1578. self._value -= 1
  1579. def release(self) -> None:
  1580. if self._max_value is not None and self._value == self._max_value:
  1581. raise ValueError("semaphore released too many times")
  1582. for fut in self._waiters:
  1583. if not fut.cancelled():
  1584. fut.set_result(None)
  1585. self._waiters.remove(fut)
  1586. return
  1587. self._value += 1
  1588. @property
  1589. def value(self) -> int:
  1590. return self._value
  1591. @property
  1592. def max_value(self) -> int | None:
  1593. return self._max_value
  1594. def statistics(self) -> SemaphoreStatistics:
  1595. return SemaphoreStatistics(len(self._waiters))
  1596. class CapacityLimiter(BaseCapacityLimiter):
  1597. _total_tokens: float = 0
  1598. def __new__(cls, total_tokens: float) -> CapacityLimiter:
  1599. return object.__new__(cls)
  1600. def __init__(self, total_tokens: float):
  1601. self._borrowers: set[Any] = set()
  1602. self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
  1603. self.total_tokens = total_tokens
  1604. async def __aenter__(self) -> None:
  1605. await self.acquire()
  1606. async def __aexit__(
  1607. self,
  1608. exc_type: type[BaseException] | None,
  1609. exc_val: BaseException | None,
  1610. exc_tb: TracebackType | None,
  1611. ) -> None:
  1612. self.release()
  1613. @property
  1614. def total_tokens(self) -> float:
  1615. return self._total_tokens
  1616. @total_tokens.setter
  1617. def total_tokens(self, value: float) -> None:
  1618. if not isinstance(value, int) and not math.isinf(value):
  1619. raise TypeError("total_tokens must be an int or math.inf")
  1620. if value < 1:
  1621. raise ValueError("total_tokens must be >= 1")
  1622. waiters_to_notify = max(value - self._total_tokens, 0)
  1623. self._total_tokens = value
  1624. # Notify waiting tasks that they have acquired the limiter
  1625. while self._wait_queue and waiters_to_notify:
  1626. event = self._wait_queue.popitem(last=False)[1]
  1627. event.set()
  1628. waiters_to_notify -= 1
  1629. @property
  1630. def borrowed_tokens(self) -> int:
  1631. return len(self._borrowers)
  1632. @property
  1633. def available_tokens(self) -> float:
  1634. return self._total_tokens - len(self._borrowers)
  1635. def acquire_nowait(self) -> None:
  1636. self.acquire_on_behalf_of_nowait(current_task())
  1637. def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
  1638. if borrower in self._borrowers:
  1639. raise RuntimeError(
  1640. "this borrower is already holding one of this CapacityLimiter's tokens"
  1641. )
  1642. if self._wait_queue or len(self._borrowers) >= self._total_tokens:
  1643. raise WouldBlock
  1644. self._borrowers.add(borrower)
  1645. async def acquire(self) -> None:
  1646. return await self.acquire_on_behalf_of(current_task())
  1647. async def acquire_on_behalf_of(self, borrower: object) -> None:
  1648. await AsyncIOBackend.checkpoint_if_cancelled()
  1649. try:
  1650. self.acquire_on_behalf_of_nowait(borrower)
  1651. except WouldBlock:
  1652. event = asyncio.Event()
  1653. self._wait_queue[borrower] = event
  1654. try:
  1655. await event.wait()
  1656. except BaseException:
  1657. self._wait_queue.pop(borrower, None)
  1658. raise
  1659. self._borrowers.add(borrower)
  1660. else:
  1661. try:
  1662. await AsyncIOBackend.cancel_shielded_checkpoint()
  1663. except BaseException:
  1664. self.release()
  1665. raise
  1666. def release(self) -> None:
  1667. self.release_on_behalf_of(current_task())
  1668. def release_on_behalf_of(self, borrower: object) -> None:
  1669. try:
  1670. self._borrowers.remove(borrower)
  1671. except KeyError:
  1672. raise RuntimeError(
  1673. "this borrower isn't holding any of this CapacityLimiter's tokens"
  1674. ) from None
  1675. # Notify the next task in line if this limiter has free capacity now
  1676. if self._wait_queue and len(self._borrowers) < self._total_tokens:
  1677. event = self._wait_queue.popitem(last=False)[1]
  1678. event.set()
  1679. def statistics(self) -> CapacityLimiterStatistics:
  1680. return CapacityLimiterStatistics(
  1681. self.borrowed_tokens,
  1682. self.total_tokens,
  1683. tuple(self._borrowers),
  1684. len(self._wait_queue),
  1685. )
  1686. _default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
  1687. #
  1688. # Operating system signals
  1689. #
  1690. class _SignalReceiver:
  1691. def __init__(self, signals: tuple[Signals, ...]):
  1692. self._signals = signals
  1693. self._loop = get_running_loop()
  1694. self._signal_queue: deque[Signals] = deque()
  1695. self._future: asyncio.Future = asyncio.Future()
  1696. self._handled_signals: set[Signals] = set()
  1697. def _deliver(self, signum: Signals) -> None:
  1698. self._signal_queue.append(signum)
  1699. if not self._future.done():
  1700. self._future.set_result(None)
  1701. def __enter__(self) -> _SignalReceiver:
  1702. for sig in set(self._signals):
  1703. self._loop.add_signal_handler(sig, self._deliver, sig)
  1704. self._handled_signals.add(sig)
  1705. return self
  1706. def __exit__(
  1707. self,
  1708. exc_type: type[BaseException] | None,
  1709. exc_val: BaseException | None,
  1710. exc_tb: TracebackType | None,
  1711. ) -> None:
  1712. for sig in self._handled_signals:
  1713. self._loop.remove_signal_handler(sig)
  1714. def __aiter__(self) -> _SignalReceiver:
  1715. return self
  1716. async def __anext__(self) -> Signals:
  1717. await AsyncIOBackend.checkpoint()
  1718. if not self._signal_queue:
  1719. self._future = asyncio.Future()
  1720. await self._future
  1721. return self._signal_queue.popleft()
  1722. #
  1723. # Testing and debugging
  1724. #
  1725. class AsyncIOTaskInfo(TaskInfo):
  1726. def __init__(self, task: asyncio.Task):
  1727. task_state = _task_states.get(task)
  1728. if task_state is None:
  1729. parent_id = None
  1730. else:
  1731. parent_id = task_state.parent_id
  1732. coro = task.get_coro()
  1733. assert coro is not None, "created TaskInfo from a completed Task"
  1734. super().__init__(id(task), parent_id, task.get_name(), coro)
  1735. self._task = weakref.ref(task)
  1736. def has_pending_cancellation(self) -> bool:
  1737. if not (task := self._task()):
  1738. # If the task isn't around anymore, it won't have a pending cancellation
  1739. return False
  1740. if task._must_cancel: # type: ignore[attr-defined]
  1741. return True
  1742. elif (
  1743. isinstance(task._fut_waiter, asyncio.Future) # type: ignore[attr-defined]
  1744. and task._fut_waiter.cancelled() # type: ignore[attr-defined]
  1745. ):
  1746. return True
  1747. if task_state := _task_states.get(task):
  1748. if cancel_scope := task_state.cancel_scope:
  1749. return cancel_scope._effectively_cancelled
  1750. return False
  1751. class TestRunner(abc.TestRunner):
  1752. _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
  1753. def __init__(
  1754. self,
  1755. *,
  1756. debug: bool | None = None,
  1757. use_uvloop: bool = False,
  1758. loop_factory: Callable[[], AbstractEventLoop] | None = None,
  1759. ) -> None:
  1760. if use_uvloop and loop_factory is None:
  1761. import uvloop
  1762. loop_factory = uvloop.new_event_loop
  1763. self._runner = Runner(debug=debug, loop_factory=loop_factory)
  1764. self._exceptions: list[BaseException] = []
  1765. self._runner_task: asyncio.Task | None = None
  1766. def __enter__(self) -> TestRunner:
  1767. self._runner.__enter__()
  1768. self.get_loop().set_exception_handler(self._exception_handler)
  1769. return self
  1770. def __exit__(
  1771. self,
  1772. exc_type: type[BaseException] | None,
  1773. exc_val: BaseException | None,
  1774. exc_tb: TracebackType | None,
  1775. ) -> None:
  1776. self._runner.__exit__(exc_type, exc_val, exc_tb)
  1777. def get_loop(self) -> AbstractEventLoop:
  1778. return self._runner.get_loop()
  1779. def _exception_handler(
  1780. self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
  1781. ) -> None:
  1782. if isinstance(context.get("exception"), Exception):
  1783. self._exceptions.append(context["exception"])
  1784. else:
  1785. loop.default_exception_handler(context)
  1786. def _raise_async_exceptions(self) -> None:
  1787. # Re-raise any exceptions raised in asynchronous callbacks
  1788. if self._exceptions:
  1789. exceptions, self._exceptions = self._exceptions, []
  1790. if len(exceptions) == 1:
  1791. raise exceptions[0]
  1792. elif exceptions:
  1793. raise BaseExceptionGroup(
  1794. "Multiple exceptions occurred in asynchronous callbacks", exceptions
  1795. )
  1796. async def _run_tests_and_fixtures(
  1797. self,
  1798. receive_stream: MemoryObjectReceiveStream[
  1799. tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
  1800. ],
  1801. ) -> None:
  1802. from _pytest.outcomes import OutcomeException
  1803. with receive_stream, self._send_stream:
  1804. async for coro, future in receive_stream:
  1805. try:
  1806. retval = await coro
  1807. except CancelledError as exc:
  1808. if not future.cancelled():
  1809. future.cancel(*exc.args)
  1810. raise
  1811. except BaseException as exc:
  1812. if not future.cancelled():
  1813. future.set_exception(exc)
  1814. if not isinstance(exc, (Exception, OutcomeException)):
  1815. raise
  1816. else:
  1817. if not future.cancelled():
  1818. future.set_result(retval)
  1819. async def _call_in_runner_task(
  1820. self,
  1821. func: Callable[P, Awaitable[T_Retval]],
  1822. *args: P.args,
  1823. **kwargs: P.kwargs,
  1824. ) -> T_Retval:
  1825. if not self._runner_task:
  1826. self._send_stream, receive_stream = create_memory_object_stream[
  1827. tuple[Awaitable[Any], asyncio.Future]
  1828. ](1)
  1829. self._runner_task = self.get_loop().create_task(
  1830. self._run_tests_and_fixtures(receive_stream)
  1831. )
  1832. coro = func(*args, **kwargs)
  1833. future: asyncio.Future[T_Retval] = self.get_loop().create_future()
  1834. self._send_stream.send_nowait((coro, future))
  1835. return await future
  1836. def run_asyncgen_fixture(
  1837. self,
  1838. fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
  1839. kwargs: dict[str, Any],
  1840. ) -> Iterable[T_Retval]:
  1841. asyncgen = fixture_func(**kwargs)
  1842. fixturevalue: T_Retval = self.get_loop().run_until_complete(
  1843. self._call_in_runner_task(asyncgen.asend, None)
  1844. )
  1845. self._raise_async_exceptions()
  1846. yield fixturevalue
  1847. try:
  1848. self.get_loop().run_until_complete(
  1849. self._call_in_runner_task(asyncgen.asend, None)
  1850. )
  1851. except StopAsyncIteration:
  1852. self._raise_async_exceptions()
  1853. else:
  1854. self.get_loop().run_until_complete(asyncgen.aclose())
  1855. raise RuntimeError("Async generator fixture did not stop")
  1856. def run_fixture(
  1857. self,
  1858. fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
  1859. kwargs: dict[str, Any],
  1860. ) -> T_Retval:
  1861. retval = self.get_loop().run_until_complete(
  1862. self._call_in_runner_task(fixture_func, **kwargs)
  1863. )
  1864. self._raise_async_exceptions()
  1865. return retval
  1866. def run_test(
  1867. self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
  1868. ) -> None:
  1869. try:
  1870. self.get_loop().run_until_complete(
  1871. self._call_in_runner_task(test_func, **kwargs)
  1872. )
  1873. except Exception as exc:
  1874. self._exceptions.append(exc)
  1875. self._raise_async_exceptions()
  1876. class AsyncIOBackend(AsyncBackend):
  1877. @classmethod
  1878. def run(
  1879. cls,
  1880. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  1881. args: tuple[Unpack[PosArgsT]],
  1882. kwargs: dict[str, Any],
  1883. options: dict[str, Any],
  1884. ) -> T_Retval:
  1885. @wraps(func)
  1886. async def wrapper() -> T_Retval:
  1887. task = cast(asyncio.Task, current_task())
  1888. task.set_name(get_callable_name(func))
  1889. _task_states[task] = TaskState(None, None)
  1890. try:
  1891. return await func(*args)
  1892. finally:
  1893. del _task_states[task]
  1894. debug = options.get("debug", None)
  1895. loop_factory = options.get("loop_factory", None)
  1896. if loop_factory is None and options.get("use_uvloop", False):
  1897. import uvloop
  1898. loop_factory = uvloop.new_event_loop
  1899. with Runner(debug=debug, loop_factory=loop_factory) as runner:
  1900. return runner.run(wrapper())
  1901. @classmethod
  1902. def current_token(cls) -> object:
  1903. return get_running_loop()
  1904. @classmethod
  1905. def current_time(cls) -> float:
  1906. return get_running_loop().time()
  1907. @classmethod
  1908. def cancelled_exception_class(cls) -> type[BaseException]:
  1909. return CancelledError
  1910. @classmethod
  1911. async def checkpoint(cls) -> None:
  1912. await sleep(0)
  1913. @classmethod
  1914. async def checkpoint_if_cancelled(cls) -> None:
  1915. task = current_task()
  1916. if task is None:
  1917. return
  1918. try:
  1919. cancel_scope = _task_states[task].cancel_scope
  1920. except KeyError:
  1921. return
  1922. while cancel_scope:
  1923. if cancel_scope.cancel_called:
  1924. await sleep(0)
  1925. elif cancel_scope.shield:
  1926. break
  1927. else:
  1928. cancel_scope = cancel_scope._parent_scope
  1929. @classmethod
  1930. async def cancel_shielded_checkpoint(cls) -> None:
  1931. with CancelScope(shield=True):
  1932. await sleep(0)
  1933. @classmethod
  1934. async def sleep(cls, delay: float) -> None:
  1935. await sleep(delay)
  1936. @classmethod
  1937. def create_cancel_scope(
  1938. cls, *, deadline: float = math.inf, shield: bool = False
  1939. ) -> CancelScope:
  1940. return CancelScope(deadline=deadline, shield=shield)
  1941. @classmethod
  1942. def current_effective_deadline(cls) -> float:
  1943. if (task := current_task()) is None:
  1944. return math.inf
  1945. try:
  1946. cancel_scope = _task_states[task].cancel_scope
  1947. except KeyError:
  1948. return math.inf
  1949. deadline = math.inf
  1950. while cancel_scope:
  1951. deadline = min(deadline, cancel_scope.deadline)
  1952. if cancel_scope._cancel_called:
  1953. deadline = -math.inf
  1954. break
  1955. elif cancel_scope.shield:
  1956. break
  1957. else:
  1958. cancel_scope = cancel_scope._parent_scope
  1959. return deadline
  1960. @classmethod
  1961. def create_task_group(cls) -> abc.TaskGroup:
  1962. return TaskGroup()
  1963. @classmethod
  1964. def create_event(cls) -> abc.Event:
  1965. return Event()
  1966. @classmethod
  1967. def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
  1968. return Lock(fast_acquire=fast_acquire)
  1969. @classmethod
  1970. def create_semaphore(
  1971. cls,
  1972. initial_value: int,
  1973. *,
  1974. max_value: int | None = None,
  1975. fast_acquire: bool = False,
  1976. ) -> abc.Semaphore:
  1977. return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  1978. @classmethod
  1979. def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
  1980. return CapacityLimiter(total_tokens)
  1981. @classmethod
  1982. async def run_sync_in_worker_thread( # type: ignore[return]
  1983. cls,
  1984. func: Callable[[Unpack[PosArgsT]], T_Retval],
  1985. args: tuple[Unpack[PosArgsT]],
  1986. abandon_on_cancel: bool = False,
  1987. limiter: abc.CapacityLimiter | None = None,
  1988. ) -> T_Retval:
  1989. await cls.checkpoint()
  1990. # If this is the first run in this event loop thread, set up the necessary
  1991. # variables
  1992. try:
  1993. idle_workers = _threadpool_idle_workers.get()
  1994. workers = _threadpool_workers.get()
  1995. except LookupError:
  1996. idle_workers = deque()
  1997. workers = set()
  1998. _threadpool_idle_workers.set(idle_workers)
  1999. _threadpool_workers.set(workers)
  2000. async with limiter or cls.current_default_thread_limiter():
  2001. with CancelScope(shield=not abandon_on_cancel) as scope:
  2002. future = asyncio.Future[T_Retval]()
  2003. root_task = find_root_task()
  2004. if not idle_workers:
  2005. worker = WorkerThread(root_task, workers, idle_workers)
  2006. worker.start()
  2007. workers.add(worker)
  2008. root_task.add_done_callback(
  2009. worker.stop, context=contextvars.Context()
  2010. )
  2011. else:
  2012. worker = idle_workers.pop()
  2013. # Prune any other workers that have been idle for MAX_IDLE_TIME
  2014. # seconds or longer
  2015. now = cls.current_time()
  2016. while idle_workers:
  2017. if (
  2018. now - idle_workers[0].idle_since
  2019. < WorkerThread.MAX_IDLE_TIME
  2020. ):
  2021. break
  2022. expired_worker = idle_workers.popleft()
  2023. expired_worker.root_task.remove_done_callback(
  2024. expired_worker.stop
  2025. )
  2026. expired_worker.stop()
  2027. context = copy_context()
  2028. context.run(sniffio.current_async_library_cvar.set, None)
  2029. if abandon_on_cancel or scope._parent_scope is None:
  2030. worker_scope = scope
  2031. else:
  2032. worker_scope = scope._parent_scope
  2033. worker.queue.put_nowait((context, func, args, future, worker_scope))
  2034. return await future
  2035. @classmethod
  2036. def check_cancelled(cls) -> None:
  2037. scope: CancelScope | None = threadlocals.current_cancel_scope
  2038. while scope is not None:
  2039. if scope.cancel_called:
  2040. raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
  2041. if scope.shield:
  2042. return
  2043. scope = scope._parent_scope
  2044. @classmethod
  2045. def run_async_from_thread(
  2046. cls,
  2047. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  2048. args: tuple[Unpack[PosArgsT]],
  2049. token: object,
  2050. ) -> T_Retval:
  2051. async def task_wrapper(scope: CancelScope) -> T_Retval:
  2052. __tracebackhide__ = True
  2053. task = cast(asyncio.Task, current_task())
  2054. _task_states[task] = TaskState(None, scope)
  2055. scope._tasks.add(task)
  2056. try:
  2057. return await func(*args)
  2058. except CancelledError as exc:
  2059. raise concurrent.futures.CancelledError(str(exc)) from None
  2060. finally:
  2061. scope._tasks.discard(task)
  2062. loop = cast(AbstractEventLoop, token)
  2063. context = copy_context()
  2064. context.run(sniffio.current_async_library_cvar.set, "asyncio")
  2065. wrapper = task_wrapper(threadlocals.current_cancel_scope)
  2066. f: concurrent.futures.Future[T_Retval] = context.run(
  2067. asyncio.run_coroutine_threadsafe, wrapper, loop
  2068. )
  2069. return f.result()
  2070. @classmethod
  2071. def run_sync_from_thread(
  2072. cls,
  2073. func: Callable[[Unpack[PosArgsT]], T_Retval],
  2074. args: tuple[Unpack[PosArgsT]],
  2075. token: object,
  2076. ) -> T_Retval:
  2077. @wraps(func)
  2078. def wrapper() -> None:
  2079. try:
  2080. sniffio.current_async_library_cvar.set("asyncio")
  2081. f.set_result(func(*args))
  2082. except BaseException as exc:
  2083. f.set_exception(exc)
  2084. if not isinstance(exc, Exception):
  2085. raise
  2086. f: concurrent.futures.Future[T_Retval] = Future()
  2087. loop = cast(AbstractEventLoop, token)
  2088. loop.call_soon_threadsafe(wrapper)
  2089. return f.result()
  2090. @classmethod
  2091. def create_blocking_portal(cls) -> abc.BlockingPortal:
  2092. return BlockingPortal()
  2093. @classmethod
  2094. async def open_process(
  2095. cls,
  2096. command: StrOrBytesPath | Sequence[StrOrBytesPath],
  2097. *,
  2098. stdin: int | IO[Any] | None,
  2099. stdout: int | IO[Any] | None,
  2100. stderr: int | IO[Any] | None,
  2101. **kwargs: Any,
  2102. ) -> Process:
  2103. await cls.checkpoint()
  2104. if isinstance(command, PathLike):
  2105. command = os.fspath(command)
  2106. if isinstance(command, (str, bytes)):
  2107. process = await asyncio.create_subprocess_shell(
  2108. command,
  2109. stdin=stdin,
  2110. stdout=stdout,
  2111. stderr=stderr,
  2112. **kwargs,
  2113. )
  2114. else:
  2115. process = await asyncio.create_subprocess_exec(
  2116. *command,
  2117. stdin=stdin,
  2118. stdout=stdout,
  2119. stderr=stderr,
  2120. **kwargs,
  2121. )
  2122. stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
  2123. stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
  2124. stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
  2125. return Process(process, stdin_stream, stdout_stream, stderr_stream)
  2126. @classmethod
  2127. def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
  2128. create_task(
  2129. _shutdown_process_pool_on_exit(workers),
  2130. name="AnyIO process pool shutdown task",
  2131. )
  2132. find_root_task().add_done_callback(
  2133. partial(_forcibly_shutdown_process_pool_on_exit, workers) # type:ignore[arg-type]
  2134. )
  2135. @classmethod
  2136. async def connect_tcp(
  2137. cls, host: str, port: int, local_address: IPSockAddrType | None = None
  2138. ) -> abc.SocketStream:
  2139. transport, protocol = cast(
  2140. tuple[asyncio.Transport, StreamProtocol],
  2141. await get_running_loop().create_connection(
  2142. StreamProtocol, host, port, local_addr=local_address
  2143. ),
  2144. )
  2145. transport.pause_reading()
  2146. return SocketStream(transport, protocol)
  2147. @classmethod
  2148. async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
  2149. await cls.checkpoint()
  2150. loop = get_running_loop()
  2151. raw_socket = socket.socket(socket.AF_UNIX)
  2152. raw_socket.setblocking(False)
  2153. while True:
  2154. try:
  2155. raw_socket.connect(path)
  2156. except BlockingIOError:
  2157. f: asyncio.Future = asyncio.Future()
  2158. loop.add_writer(raw_socket, f.set_result, None)
  2159. f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
  2160. await f
  2161. except BaseException:
  2162. raw_socket.close()
  2163. raise
  2164. else:
  2165. return UNIXSocketStream(raw_socket)
  2166. @classmethod
  2167. def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
  2168. return TCPSocketListener(sock)
  2169. @classmethod
  2170. def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
  2171. return UNIXSocketListener(sock)
  2172. @classmethod
  2173. async def create_udp_socket(
  2174. cls,
  2175. family: AddressFamily,
  2176. local_address: IPSockAddrType | None,
  2177. remote_address: IPSockAddrType | None,
  2178. reuse_port: bool,
  2179. ) -> UDPSocket | ConnectedUDPSocket:
  2180. transport, protocol = await get_running_loop().create_datagram_endpoint(
  2181. DatagramProtocol,
  2182. local_addr=local_address,
  2183. remote_addr=remote_address,
  2184. family=family,
  2185. reuse_port=reuse_port,
  2186. )
  2187. if protocol.exception:
  2188. transport.close()
  2189. raise protocol.exception
  2190. if not remote_address:
  2191. return UDPSocket(transport, protocol)
  2192. else:
  2193. return ConnectedUDPSocket(transport, protocol)
  2194. @classmethod
  2195. async def create_unix_datagram_socket( # type: ignore[override]
  2196. cls, raw_socket: socket.socket, remote_path: str | bytes | None
  2197. ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
  2198. await cls.checkpoint()
  2199. loop = get_running_loop()
  2200. if remote_path:
  2201. while True:
  2202. try:
  2203. raw_socket.connect(remote_path)
  2204. except BlockingIOError:
  2205. f: asyncio.Future = asyncio.Future()
  2206. loop.add_writer(raw_socket, f.set_result, None)
  2207. f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
  2208. await f
  2209. except BaseException:
  2210. raw_socket.close()
  2211. raise
  2212. else:
  2213. return ConnectedUNIXDatagramSocket(raw_socket)
  2214. else:
  2215. return UNIXDatagramSocket(raw_socket)
  2216. @classmethod
  2217. async def getaddrinfo(
  2218. cls,
  2219. host: bytes | str | None,
  2220. port: str | int | None,
  2221. *,
  2222. family: int | AddressFamily = 0,
  2223. type: int | SocketKind = 0,
  2224. proto: int = 0,
  2225. flags: int = 0,
  2226. ) -> Sequence[
  2227. tuple[
  2228. AddressFamily,
  2229. SocketKind,
  2230. int,
  2231. str,
  2232. tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
  2233. ]
  2234. ]:
  2235. return await get_running_loop().getaddrinfo(
  2236. host, port, family=family, type=type, proto=proto, flags=flags
  2237. )
  2238. @classmethod
  2239. async def getnameinfo(
  2240. cls, sockaddr: IPSockAddrType, flags: int = 0
  2241. ) -> tuple[str, str]:
  2242. return await get_running_loop().getnameinfo(sockaddr, flags)
  2243. @classmethod
  2244. async def wait_readable(cls, obj: FileDescriptorLike) -> None:
  2245. await cls.checkpoint()
  2246. try:
  2247. read_events = _read_events.get()
  2248. except LookupError:
  2249. read_events = {}
  2250. _read_events.set(read_events)
  2251. if not isinstance(obj, int):
  2252. obj = obj.fileno()
  2253. if read_events.get(obj):
  2254. raise BusyResourceError("reading from")
  2255. loop = get_running_loop()
  2256. event = asyncio.Event()
  2257. try:
  2258. loop.add_reader(obj, event.set)
  2259. except NotImplementedError:
  2260. from anyio._core._asyncio_selector_thread import get_selector
  2261. selector = get_selector()
  2262. selector.add_reader(obj, event.set)
  2263. remove_reader = selector.remove_reader
  2264. else:
  2265. remove_reader = loop.remove_reader
  2266. read_events[obj] = event
  2267. try:
  2268. await event.wait()
  2269. finally:
  2270. remove_reader(obj)
  2271. del read_events[obj]
  2272. @classmethod
  2273. async def wait_writable(cls, obj: FileDescriptorLike) -> None:
  2274. await cls.checkpoint()
  2275. try:
  2276. write_events = _write_events.get()
  2277. except LookupError:
  2278. write_events = {}
  2279. _write_events.set(write_events)
  2280. if not isinstance(obj, int):
  2281. obj = obj.fileno()
  2282. if write_events.get(obj):
  2283. raise BusyResourceError("writing to")
  2284. loop = get_running_loop()
  2285. event = asyncio.Event()
  2286. try:
  2287. loop.add_writer(obj, event.set)
  2288. except NotImplementedError:
  2289. from anyio._core._asyncio_selector_thread import get_selector
  2290. selector = get_selector()
  2291. selector.add_writer(obj, event.set)
  2292. remove_writer = selector.remove_writer
  2293. else:
  2294. remove_writer = loop.remove_writer
  2295. write_events[obj] = event
  2296. try:
  2297. await event.wait()
  2298. finally:
  2299. del write_events[obj]
  2300. remove_writer(obj)
  2301. @classmethod
  2302. def current_default_thread_limiter(cls) -> CapacityLimiter:
  2303. try:
  2304. return _default_thread_limiter.get()
  2305. except LookupError:
  2306. limiter = CapacityLimiter(40)
  2307. _default_thread_limiter.set(limiter)
  2308. return limiter
  2309. @classmethod
  2310. def open_signal_receiver(
  2311. cls, *signals: Signals
  2312. ) -> AbstractContextManager[AsyncIterator[Signals]]:
  2313. return _SignalReceiver(signals)
  2314. @classmethod
  2315. def get_current_task(cls) -> TaskInfo:
  2316. return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type]
  2317. @classmethod
  2318. def get_running_tasks(cls) -> Sequence[TaskInfo]:
  2319. return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
  2320. @classmethod
  2321. async def wait_all_tasks_blocked(cls) -> None:
  2322. await cls.checkpoint()
  2323. this_task = current_task()
  2324. while True:
  2325. for task in all_tasks():
  2326. if task is this_task:
  2327. continue
  2328. waiter = task._fut_waiter # type: ignore[attr-defined]
  2329. if waiter is None or waiter.done():
  2330. await sleep(0.1)
  2331. break
  2332. else:
  2333. return
  2334. @classmethod
  2335. def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
  2336. return TestRunner(**options)
  2337. backend_class = AsyncIOBackend