from_thread.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. from __future__ import annotations
  2. import sys
  3. from collections.abc import Awaitable, Callable, Generator
  4. from concurrent.futures import Future
  5. from contextlib import (
  6. AbstractAsyncContextManager,
  7. AbstractContextManager,
  8. contextmanager,
  9. )
  10. from dataclasses import dataclass, field
  11. from inspect import isawaitable
  12. from threading import Lock, Thread, get_ident
  13. from types import TracebackType
  14. from typing import (
  15. Any,
  16. Generic,
  17. TypeVar,
  18. cast,
  19. overload,
  20. )
  21. from ._core import _eventloop
  22. from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals
  23. from ._core._synchronization import Event
  24. from ._core._tasks import CancelScope, create_task_group
  25. from .abc import AsyncBackend
  26. from .abc._tasks import TaskStatus
  27. if sys.version_info >= (3, 11):
  28. from typing import TypeVarTuple, Unpack
  29. else:
  30. from typing_extensions import TypeVarTuple, Unpack
  31. T_Retval = TypeVar("T_Retval")
  32. T_co = TypeVar("T_co", covariant=True)
  33. PosArgsT = TypeVarTuple("PosArgsT")
  34. def run(
  35. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT]
  36. ) -> T_Retval:
  37. """
  38. Call a coroutine function from a worker thread.
  39. :param func: a coroutine function
  40. :param args: positional arguments for the callable
  41. :return: the return value of the coroutine function
  42. """
  43. try:
  44. async_backend = threadlocals.current_async_backend
  45. token = threadlocals.current_token
  46. except AttributeError:
  47. raise RuntimeError(
  48. "This function can only be run from an AnyIO worker thread"
  49. ) from None
  50. return async_backend.run_async_from_thread(func, args, token=token)
  51. def run_sync(
  52. func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  53. ) -> T_Retval:
  54. """
  55. Call a function in the event loop thread from a worker thread.
  56. :param func: a callable
  57. :param args: positional arguments for the callable
  58. :return: the return value of the callable
  59. """
  60. try:
  61. async_backend = threadlocals.current_async_backend
  62. token = threadlocals.current_token
  63. except AttributeError:
  64. raise RuntimeError(
  65. "This function can only be run from an AnyIO worker thread"
  66. ) from None
  67. return async_backend.run_sync_from_thread(func, args, token=token)
  68. class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
  69. _enter_future: Future[T_co]
  70. _exit_future: Future[bool | None]
  71. _exit_event: Event
  72. _exit_exc_info: tuple[
  73. type[BaseException] | None, BaseException | None, TracebackType | None
  74. ] = (None, None, None)
  75. def __init__(
  76. self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
  77. ):
  78. self._async_cm = async_cm
  79. self._portal = portal
  80. async def run_async_cm(self) -> bool | None:
  81. try:
  82. self._exit_event = Event()
  83. value = await self._async_cm.__aenter__()
  84. except BaseException as exc:
  85. self._enter_future.set_exception(exc)
  86. raise
  87. else:
  88. self._enter_future.set_result(value)
  89. try:
  90. # Wait for the sync context manager to exit.
  91. # This next statement can raise `get_cancelled_exc_class()` if
  92. # something went wrong in a task group in this async context
  93. # manager.
  94. await self._exit_event.wait()
  95. finally:
  96. # In case of cancellation, it could be that we end up here before
  97. # `_BlockingAsyncContextManager.__exit__` is called, and an
  98. # `_exit_exc_info` has been set.
  99. result = await self._async_cm.__aexit__(*self._exit_exc_info)
  100. return result
  101. def __enter__(self) -> T_co:
  102. self._enter_future = Future()
  103. self._exit_future = self._portal.start_task_soon(self.run_async_cm)
  104. return self._enter_future.result()
  105. def __exit__(
  106. self,
  107. __exc_type: type[BaseException] | None,
  108. __exc_value: BaseException | None,
  109. __traceback: TracebackType | None,
  110. ) -> bool | None:
  111. self._exit_exc_info = __exc_type, __exc_value, __traceback
  112. self._portal.call(self._exit_event.set)
  113. return self._exit_future.result()
  114. class _BlockingPortalTaskStatus(TaskStatus):
  115. def __init__(self, future: Future):
  116. self._future = future
  117. def started(self, value: object = None) -> None:
  118. self._future.set_result(value)
  119. class BlockingPortal:
  120. """An object that lets external threads run code in an asynchronous event loop."""
  121. def __new__(cls) -> BlockingPortal:
  122. return get_async_backend().create_blocking_portal()
  123. def __init__(self) -> None:
  124. self._event_loop_thread_id: int | None = get_ident()
  125. self._stop_event = Event()
  126. self._task_group = create_task_group()
  127. self._cancelled_exc_class = get_cancelled_exc_class()
  128. async def __aenter__(self) -> BlockingPortal:
  129. await self._task_group.__aenter__()
  130. return self
  131. async def __aexit__(
  132. self,
  133. exc_type: type[BaseException] | None,
  134. exc_val: BaseException | None,
  135. exc_tb: TracebackType | None,
  136. ) -> bool | None:
  137. await self.stop()
  138. return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
  139. def _check_running(self) -> None:
  140. if self._event_loop_thread_id is None:
  141. raise RuntimeError("This portal is not running")
  142. if self._event_loop_thread_id == get_ident():
  143. raise RuntimeError(
  144. "This method cannot be called from the event loop thread"
  145. )
  146. async def sleep_until_stopped(self) -> None:
  147. """Sleep until :meth:`stop` is called."""
  148. await self._stop_event.wait()
  149. async def stop(self, cancel_remaining: bool = False) -> None:
  150. """
  151. Signal the portal to shut down.
  152. This marks the portal as no longer accepting new calls and exits from
  153. :meth:`sleep_until_stopped`.
  154. :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
  155. to let them finish before returning
  156. """
  157. self._event_loop_thread_id = None
  158. self._stop_event.set()
  159. if cancel_remaining:
  160. self._task_group.cancel_scope.cancel()
  161. async def _call_func(
  162. self,
  163. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  164. args: tuple[Unpack[PosArgsT]],
  165. kwargs: dict[str, Any],
  166. future: Future[T_Retval],
  167. ) -> None:
  168. def callback(f: Future[T_Retval]) -> None:
  169. if f.cancelled() and self._event_loop_thread_id not in (
  170. None,
  171. get_ident(),
  172. ):
  173. self.call(scope.cancel)
  174. try:
  175. retval_or_awaitable = func(*args, **kwargs)
  176. if isawaitable(retval_or_awaitable):
  177. with CancelScope() as scope:
  178. if future.cancelled():
  179. scope.cancel()
  180. else:
  181. future.add_done_callback(callback)
  182. retval = await retval_or_awaitable
  183. else:
  184. retval = retval_or_awaitable
  185. except self._cancelled_exc_class:
  186. future.cancel()
  187. future.set_running_or_notify_cancel()
  188. except BaseException as exc:
  189. if not future.cancelled():
  190. future.set_exception(exc)
  191. # Let base exceptions fall through
  192. if not isinstance(exc, Exception):
  193. raise
  194. else:
  195. if not future.cancelled():
  196. future.set_result(retval)
  197. finally:
  198. scope = None # type: ignore[assignment]
  199. def _spawn_task_from_thread(
  200. self,
  201. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  202. args: tuple[Unpack[PosArgsT]],
  203. kwargs: dict[str, Any],
  204. name: object,
  205. future: Future[T_Retval],
  206. ) -> None:
  207. """
  208. Spawn a new task using the given callable.
  209. Implementers must ensure that the future is resolved when the task finishes.
  210. :param func: a callable
  211. :param args: positional arguments to be passed to the callable
  212. :param kwargs: keyword arguments to be passed to the callable
  213. :param name: name of the task (will be coerced to a string if not ``None``)
  214. :param future: a future that will resolve to the return value of the callable,
  215. or the exception raised during its execution
  216. """
  217. raise NotImplementedError
  218. @overload
  219. def call(
  220. self,
  221. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  222. *args: Unpack[PosArgsT],
  223. ) -> T_Retval: ...
  224. @overload
  225. def call(
  226. self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  227. ) -> T_Retval: ...
  228. def call(
  229. self,
  230. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  231. *args: Unpack[PosArgsT],
  232. ) -> T_Retval:
  233. """
  234. Call the given function in the event loop thread.
  235. If the callable returns a coroutine object, it is awaited on.
  236. :param func: any callable
  237. :raises RuntimeError: if the portal is not running or if this method is called
  238. from within the event loop thread
  239. """
  240. return cast(T_Retval, self.start_task_soon(func, *args).result())
  241. @overload
  242. def start_task_soon(
  243. self,
  244. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  245. *args: Unpack[PosArgsT],
  246. name: object = None,
  247. ) -> Future[T_Retval]: ...
  248. @overload
  249. def start_task_soon(
  250. self,
  251. func: Callable[[Unpack[PosArgsT]], T_Retval],
  252. *args: Unpack[PosArgsT],
  253. name: object = None,
  254. ) -> Future[T_Retval]: ...
  255. def start_task_soon(
  256. self,
  257. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  258. *args: Unpack[PosArgsT],
  259. name: object = None,
  260. ) -> Future[T_Retval]:
  261. """
  262. Start a task in the portal's task group.
  263. The task will be run inside a cancel scope which can be cancelled by cancelling
  264. the returned future.
  265. :param func: the target function
  266. :param args: positional arguments passed to ``func``
  267. :param name: name of the task (will be coerced to a string if not ``None``)
  268. :return: a future that resolves with the return value of the callable if the
  269. task completes successfully, or with the exception raised in the task
  270. :raises RuntimeError: if the portal is not running or if this method is called
  271. from within the event loop thread
  272. :rtype: concurrent.futures.Future[T_Retval]
  273. .. versionadded:: 3.0
  274. """
  275. self._check_running()
  276. f: Future[T_Retval] = Future()
  277. self._spawn_task_from_thread(func, args, {}, name, f)
  278. return f
  279. def start_task(
  280. self,
  281. func: Callable[..., Awaitable[T_Retval]],
  282. *args: object,
  283. name: object = None,
  284. ) -> tuple[Future[T_Retval], Any]:
  285. """
  286. Start a task in the portal's task group and wait until it signals for readiness.
  287. This method works the same way as :meth:`.abc.TaskGroup.start`.
  288. :param func: the target function
  289. :param args: positional arguments passed to ``func``
  290. :param name: name of the task (will be coerced to a string if not ``None``)
  291. :return: a tuple of (future, task_status_value) where the ``task_status_value``
  292. is the value passed to ``task_status.started()`` from within the target
  293. function
  294. :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
  295. .. versionadded:: 3.0
  296. """
  297. def task_done(future: Future[T_Retval]) -> None:
  298. if not task_status_future.done():
  299. if future.cancelled():
  300. task_status_future.cancel()
  301. elif future.exception():
  302. task_status_future.set_exception(future.exception())
  303. else:
  304. exc = RuntimeError(
  305. "Task exited without calling task_status.started()"
  306. )
  307. task_status_future.set_exception(exc)
  308. self._check_running()
  309. task_status_future: Future = Future()
  310. task_status = _BlockingPortalTaskStatus(task_status_future)
  311. f: Future = Future()
  312. f.add_done_callback(task_done)
  313. self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
  314. return f, task_status_future.result()
  315. def wrap_async_context_manager(
  316. self, cm: AbstractAsyncContextManager[T_co]
  317. ) -> AbstractContextManager[T_co]:
  318. """
  319. Wrap an async context manager as a synchronous context manager via this portal.
  320. Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
  321. in the middle until the synchronous context manager exits.
  322. :param cm: an asynchronous context manager
  323. :return: a synchronous context manager
  324. .. versionadded:: 2.1
  325. """
  326. return _BlockingAsyncContextManager(cm, self)
  327. @dataclass
  328. class BlockingPortalProvider:
  329. """
  330. A manager for a blocking portal. Used as a context manager. The first thread to
  331. enter this context manager causes a blocking portal to be started with the specific
  332. parameters, and the last thread to exit causes the portal to be shut down. Thus,
  333. there will be exactly one blocking portal running in this context as long as at
  334. least one thread has entered this context manager.
  335. The parameters are the same as for :func:`~anyio.run`.
  336. :param backend: name of the backend
  337. :param backend_options: backend options
  338. .. versionadded:: 4.4
  339. """
  340. backend: str = "asyncio"
  341. backend_options: dict[str, Any] | None = None
  342. _lock: Lock = field(init=False, default_factory=Lock)
  343. _leases: int = field(init=False, default=0)
  344. _portal: BlockingPortal = field(init=False)
  345. _portal_cm: AbstractContextManager[BlockingPortal] | None = field(
  346. init=False, default=None
  347. )
  348. def __enter__(self) -> BlockingPortal:
  349. with self._lock:
  350. if self._portal_cm is None:
  351. self._portal_cm = start_blocking_portal(
  352. self.backend, self.backend_options
  353. )
  354. self._portal = self._portal_cm.__enter__()
  355. self._leases += 1
  356. return self._portal
  357. def __exit__(
  358. self,
  359. exc_type: type[BaseException] | None,
  360. exc_val: BaseException | None,
  361. exc_tb: TracebackType | None,
  362. ) -> None:
  363. portal_cm: AbstractContextManager[BlockingPortal] | None = None
  364. with self._lock:
  365. assert self._portal_cm
  366. assert self._leases > 0
  367. self._leases -= 1
  368. if not self._leases:
  369. portal_cm = self._portal_cm
  370. self._portal_cm = None
  371. del self._portal
  372. if portal_cm:
  373. portal_cm.__exit__(None, None, None)
  374. @contextmanager
  375. def start_blocking_portal(
  376. backend: str = "asyncio", backend_options: dict[str, Any] | None = None
  377. ) -> Generator[BlockingPortal, Any, None]:
  378. """
  379. Start a new event loop in a new thread and run a blocking portal in its main task.
  380. The parameters are the same as for :func:`~anyio.run`.
  381. :param backend: name of the backend
  382. :param backend_options: backend options
  383. :return: a context manager that yields a blocking portal
  384. .. versionchanged:: 3.0
  385. Usage as a context manager is now required.
  386. """
  387. async def run_portal() -> None:
  388. async with BlockingPortal() as portal_:
  389. future.set_result(portal_)
  390. await portal_.sleep_until_stopped()
  391. def run_blocking_portal() -> None:
  392. if future.set_running_or_notify_cancel():
  393. try:
  394. _eventloop.run(
  395. run_portal, backend=backend, backend_options=backend_options
  396. )
  397. except BaseException as exc:
  398. if not future.done():
  399. future.set_exception(exc)
  400. future: Future[BlockingPortal] = Future()
  401. thread = Thread(target=run_blocking_portal, daemon=True)
  402. thread.start()
  403. try:
  404. cancel_remaining_tasks = False
  405. portal = future.result()
  406. try:
  407. yield portal
  408. except BaseException:
  409. cancel_remaining_tasks = True
  410. raise
  411. finally:
  412. try:
  413. portal.call(portal.stop, cancel_remaining_tasks)
  414. except RuntimeError:
  415. pass
  416. finally:
  417. thread.join()
  418. def check_cancelled() -> None:
  419. """
  420. Check if the cancel scope of the host task's running the current worker thread has
  421. been cancelled.
  422. If the host task's current cancel scope has indeed been cancelled, the
  423. backend-specific cancellation exception will be raised.
  424. :raises RuntimeError: if the current thread was not spawned by
  425. :func:`.to_thread.run_sync`
  426. """
  427. try:
  428. async_backend: AsyncBackend = threadlocals.current_async_backend
  429. except AttributeError:
  430. raise RuntimeError(
  431. "This function can only be run from an AnyIO worker thread"
  432. ) from None
  433. async_backend.check_cancelled()