_trio.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334
  1. from __future__ import annotations
  2. import array
  3. import math
  4. import os
  5. import socket
  6. import sys
  7. import types
  8. import weakref
  9. from collections.abc import (
  10. AsyncGenerator,
  11. AsyncIterator,
  12. Awaitable,
  13. Callable,
  14. Collection,
  15. Coroutine,
  16. Iterable,
  17. Sequence,
  18. )
  19. from concurrent.futures import Future
  20. from contextlib import AbstractContextManager
  21. from dataclasses import dataclass
  22. from functools import partial
  23. from io import IOBase
  24. from os import PathLike
  25. from signal import Signals
  26. from socket import AddressFamily, SocketKind
  27. from types import TracebackType
  28. from typing import (
  29. IO,
  30. TYPE_CHECKING,
  31. Any,
  32. Generic,
  33. NoReturn,
  34. TypeVar,
  35. cast,
  36. overload,
  37. )
  38. import trio.from_thread
  39. import trio.lowlevel
  40. from outcome import Error, Outcome, Value
  41. from trio.lowlevel import (
  42. current_root_task,
  43. current_task,
  44. wait_readable,
  45. wait_writable,
  46. )
  47. from trio.socket import SocketType as TrioSocketType
  48. from trio.to_thread import run_sync
  49. from .. import (
  50. CapacityLimiterStatistics,
  51. EventStatistics,
  52. LockStatistics,
  53. TaskInfo,
  54. WouldBlock,
  55. abc,
  56. )
  57. from .._core._eventloop import claim_worker_thread
  58. from .._core._exceptions import (
  59. BrokenResourceError,
  60. BusyResourceError,
  61. ClosedResourceError,
  62. EndOfStream,
  63. )
  64. from .._core._sockets import convert_ipv6_sockaddr
  65. from .._core._streams import create_memory_object_stream
  66. from .._core._synchronization import (
  67. CapacityLimiter as BaseCapacityLimiter,
  68. )
  69. from .._core._synchronization import Event as BaseEvent
  70. from .._core._synchronization import Lock as BaseLock
  71. from .._core._synchronization import (
  72. ResourceGuard,
  73. SemaphoreStatistics,
  74. )
  75. from .._core._synchronization import Semaphore as BaseSemaphore
  76. from .._core._tasks import CancelScope as BaseCancelScope
  77. from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType
  78. from ..abc._eventloop import AsyncBackend, StrOrBytesPath
  79. from ..streams.memory import MemoryObjectSendStream
  80. if TYPE_CHECKING:
  81. from _typeshed import HasFileno
  82. if sys.version_info >= (3, 10):
  83. from typing import ParamSpec
  84. else:
  85. from typing_extensions import ParamSpec
  86. if sys.version_info >= (3, 11):
  87. from typing import TypeVarTuple, Unpack
  88. else:
  89. from exceptiongroup import BaseExceptionGroup
  90. from typing_extensions import TypeVarTuple, Unpack
  91. T = TypeVar("T")
  92. T_Retval = TypeVar("T_Retval")
  93. T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType)
  94. PosArgsT = TypeVarTuple("PosArgsT")
  95. P = ParamSpec("P")
  96. #
  97. # Event loop
  98. #
  99. RunVar = trio.lowlevel.RunVar
  100. #
  101. # Timeouts and cancellation
  102. #
  103. class CancelScope(BaseCancelScope):
  104. def __new__(
  105. cls, original: trio.CancelScope | None = None, **kwargs: object
  106. ) -> CancelScope:
  107. return object.__new__(cls)
  108. def __init__(self, original: trio.CancelScope | None = None, **kwargs: Any) -> None:
  109. self.__original = original or trio.CancelScope(**kwargs)
  110. def __enter__(self) -> CancelScope:
  111. self.__original.__enter__()
  112. return self
  113. def __exit__(
  114. self,
  115. exc_type: type[BaseException] | None,
  116. exc_val: BaseException | None,
  117. exc_tb: TracebackType | None,
  118. ) -> bool:
  119. return self.__original.__exit__(exc_type, exc_val, exc_tb)
  120. def cancel(self) -> None:
  121. self.__original.cancel()
  122. @property
  123. def deadline(self) -> float:
  124. return self.__original.deadline
  125. @deadline.setter
  126. def deadline(self, value: float) -> None:
  127. self.__original.deadline = value
  128. @property
  129. def cancel_called(self) -> bool:
  130. return self.__original.cancel_called
  131. @property
  132. def cancelled_caught(self) -> bool:
  133. return self.__original.cancelled_caught
  134. @property
  135. def shield(self) -> bool:
  136. return self.__original.shield
  137. @shield.setter
  138. def shield(self, value: bool) -> None:
  139. self.__original.shield = value
  140. #
  141. # Task groups
  142. #
  143. class TaskGroup(abc.TaskGroup):
  144. def __init__(self) -> None:
  145. self._active = False
  146. self._nursery_manager = trio.open_nursery(strict_exception_groups=True)
  147. self.cancel_scope = None # type: ignore[assignment]
  148. async def __aenter__(self) -> TaskGroup:
  149. self._active = True
  150. self._nursery = await self._nursery_manager.__aenter__()
  151. self.cancel_scope = CancelScope(self._nursery.cancel_scope)
  152. return self
  153. async def __aexit__(
  154. self,
  155. exc_type: type[BaseException] | None,
  156. exc_val: BaseException | None,
  157. exc_tb: TracebackType | None,
  158. ) -> bool:
  159. try:
  160. # trio.Nursery.__exit__ returns bool; .open_nursery has wrong type
  161. return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) # type: ignore[return-value]
  162. except BaseExceptionGroup as exc:
  163. if not exc.split(trio.Cancelled)[1]:
  164. raise trio.Cancelled._create() from exc
  165. raise
  166. finally:
  167. del exc_val, exc_tb
  168. self._active = False
  169. def start_soon(
  170. self,
  171. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  172. *args: Unpack[PosArgsT],
  173. name: object = None,
  174. ) -> None:
  175. if not self._active:
  176. raise RuntimeError(
  177. "This task group is not active; no new tasks can be started."
  178. )
  179. self._nursery.start_soon(func, *args, name=name)
  180. async def start(
  181. self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
  182. ) -> Any:
  183. if not self._active:
  184. raise RuntimeError(
  185. "This task group is not active; no new tasks can be started."
  186. )
  187. return await self._nursery.start(func, *args, name=name)
  188. #
  189. # Threads
  190. #
  191. class BlockingPortal(abc.BlockingPortal):
  192. def __new__(cls) -> BlockingPortal:
  193. return object.__new__(cls)
  194. def __init__(self) -> None:
  195. super().__init__()
  196. self._token = trio.lowlevel.current_trio_token()
  197. def _spawn_task_from_thread(
  198. self,
  199. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  200. args: tuple[Unpack[PosArgsT]],
  201. kwargs: dict[str, Any],
  202. name: object,
  203. future: Future[T_Retval],
  204. ) -> None:
  205. trio.from_thread.run_sync(
  206. partial(self._task_group.start_soon, name=name),
  207. self._call_func,
  208. func,
  209. args,
  210. kwargs,
  211. future,
  212. trio_token=self._token,
  213. )
  214. #
  215. # Subprocesses
  216. #
  217. @dataclass(eq=False)
  218. class ReceiveStreamWrapper(abc.ByteReceiveStream):
  219. _stream: trio.abc.ReceiveStream
  220. async def receive(self, max_bytes: int | None = None) -> bytes:
  221. try:
  222. data = await self._stream.receive_some(max_bytes)
  223. except trio.ClosedResourceError as exc:
  224. raise ClosedResourceError from exc.__cause__
  225. except trio.BrokenResourceError as exc:
  226. raise BrokenResourceError from exc.__cause__
  227. if data:
  228. return data
  229. else:
  230. raise EndOfStream
  231. async def aclose(self) -> None:
  232. await self._stream.aclose()
  233. @dataclass(eq=False)
  234. class SendStreamWrapper(abc.ByteSendStream):
  235. _stream: trio.abc.SendStream
  236. async def send(self, item: bytes) -> None:
  237. try:
  238. await self._stream.send_all(item)
  239. except trio.ClosedResourceError as exc:
  240. raise ClosedResourceError from exc.__cause__
  241. except trio.BrokenResourceError as exc:
  242. raise BrokenResourceError from exc.__cause__
  243. async def aclose(self) -> None:
  244. await self._stream.aclose()
  245. @dataclass(eq=False)
  246. class Process(abc.Process):
  247. _process: trio.Process
  248. _stdin: abc.ByteSendStream | None
  249. _stdout: abc.ByteReceiveStream | None
  250. _stderr: abc.ByteReceiveStream | None
  251. async def aclose(self) -> None:
  252. with CancelScope(shield=True):
  253. if self._stdin:
  254. await self._stdin.aclose()
  255. if self._stdout:
  256. await self._stdout.aclose()
  257. if self._stderr:
  258. await self._stderr.aclose()
  259. try:
  260. await self.wait()
  261. except BaseException:
  262. self.kill()
  263. with CancelScope(shield=True):
  264. await self.wait()
  265. raise
  266. async def wait(self) -> int:
  267. return await self._process.wait()
  268. def terminate(self) -> None:
  269. self._process.terminate()
  270. def kill(self) -> None:
  271. self._process.kill()
  272. def send_signal(self, signal: Signals) -> None:
  273. self._process.send_signal(signal)
  274. @property
  275. def pid(self) -> int:
  276. return self._process.pid
  277. @property
  278. def returncode(self) -> int | None:
  279. return self._process.returncode
  280. @property
  281. def stdin(self) -> abc.ByteSendStream | None:
  282. return self._stdin
  283. @property
  284. def stdout(self) -> abc.ByteReceiveStream | None:
  285. return self._stdout
  286. @property
  287. def stderr(self) -> abc.ByteReceiveStream | None:
  288. return self._stderr
  289. class _ProcessPoolShutdownInstrument(trio.abc.Instrument):
  290. def after_run(self) -> None:
  291. super().after_run()
  292. current_default_worker_process_limiter: trio.lowlevel.RunVar = RunVar(
  293. "current_default_worker_process_limiter"
  294. )
  295. async def _shutdown_process_pool(workers: set[abc.Process]) -> None:
  296. try:
  297. await trio.sleep(math.inf)
  298. except trio.Cancelled:
  299. for process in workers:
  300. if process.returncode is None:
  301. process.kill()
  302. with CancelScope(shield=True):
  303. for process in workers:
  304. await process.aclose()
  305. #
  306. # Sockets and networking
  307. #
  308. class _TrioSocketMixin(Generic[T_SockAddr]):
  309. def __init__(self, trio_socket: TrioSocketType) -> None:
  310. self._trio_socket = trio_socket
  311. self._closed = False
  312. def _check_closed(self) -> None:
  313. if self._closed:
  314. raise ClosedResourceError
  315. if self._trio_socket.fileno() < 0:
  316. raise BrokenResourceError
  317. @property
  318. def _raw_socket(self) -> socket.socket:
  319. return self._trio_socket._sock # type: ignore[attr-defined]
  320. async def aclose(self) -> None:
  321. if self._trio_socket.fileno() >= 0:
  322. self._closed = True
  323. self._trio_socket.close()
  324. def _convert_socket_error(self, exc: BaseException) -> NoReturn:
  325. if isinstance(exc, trio.ClosedResourceError):
  326. raise ClosedResourceError from exc
  327. elif self._trio_socket.fileno() < 0 and self._closed:
  328. raise ClosedResourceError from None
  329. elif isinstance(exc, OSError):
  330. raise BrokenResourceError from exc
  331. else:
  332. raise exc
  333. class SocketStream(_TrioSocketMixin, abc.SocketStream):
  334. def __init__(self, trio_socket: TrioSocketType) -> None:
  335. super().__init__(trio_socket)
  336. self._receive_guard = ResourceGuard("reading from")
  337. self._send_guard = ResourceGuard("writing to")
  338. async def receive(self, max_bytes: int = 65536) -> bytes:
  339. with self._receive_guard:
  340. try:
  341. data = await self._trio_socket.recv(max_bytes)
  342. except BaseException as exc:
  343. self._convert_socket_error(exc)
  344. if data:
  345. return data
  346. else:
  347. raise EndOfStream
  348. async def send(self, item: bytes) -> None:
  349. with self._send_guard:
  350. view = memoryview(item)
  351. while view:
  352. try:
  353. bytes_sent = await self._trio_socket.send(view)
  354. except BaseException as exc:
  355. self._convert_socket_error(exc)
  356. view = view[bytes_sent:]
  357. async def send_eof(self) -> None:
  358. self._trio_socket.shutdown(socket.SHUT_WR)
  359. class UNIXSocketStream(SocketStream, abc.UNIXSocketStream):
  360. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  361. if not isinstance(msglen, int) or msglen < 0:
  362. raise ValueError("msglen must be a non-negative integer")
  363. if not isinstance(maxfds, int) or maxfds < 1:
  364. raise ValueError("maxfds must be a positive integer")
  365. fds = array.array("i")
  366. await trio.lowlevel.checkpoint()
  367. with self._receive_guard:
  368. while True:
  369. try:
  370. message, ancdata, flags, addr = await self._trio_socket.recvmsg(
  371. msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
  372. )
  373. except BaseException as exc:
  374. self._convert_socket_error(exc)
  375. else:
  376. if not message and not ancdata:
  377. raise EndOfStream
  378. break
  379. for cmsg_level, cmsg_type, cmsg_data in ancdata:
  380. if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
  381. raise RuntimeError(
  382. f"Received unexpected ancillary data; message = {message!r}, "
  383. f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
  384. )
  385. fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
  386. return message, list(fds)
  387. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  388. if not message:
  389. raise ValueError("message must not be empty")
  390. if not fds:
  391. raise ValueError("fds must not be empty")
  392. filenos: list[int] = []
  393. for fd in fds:
  394. if isinstance(fd, int):
  395. filenos.append(fd)
  396. elif isinstance(fd, IOBase):
  397. filenos.append(fd.fileno())
  398. fdarray = array.array("i", filenos)
  399. await trio.lowlevel.checkpoint()
  400. with self._send_guard:
  401. while True:
  402. try:
  403. await self._trio_socket.sendmsg(
  404. [message],
  405. [
  406. (
  407. socket.SOL_SOCKET,
  408. socket.SCM_RIGHTS,
  409. fdarray,
  410. )
  411. ],
  412. )
  413. break
  414. except BaseException as exc:
  415. self._convert_socket_error(exc)
  416. class TCPSocketListener(_TrioSocketMixin, abc.SocketListener):
  417. def __init__(self, raw_socket: socket.socket):
  418. super().__init__(trio.socket.from_stdlib_socket(raw_socket))
  419. self._accept_guard = ResourceGuard("accepting connections from")
  420. async def accept(self) -> SocketStream:
  421. with self._accept_guard:
  422. try:
  423. trio_socket, _addr = await self._trio_socket.accept()
  424. except BaseException as exc:
  425. self._convert_socket_error(exc)
  426. trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  427. return SocketStream(trio_socket)
  428. class UNIXSocketListener(_TrioSocketMixin, abc.SocketListener):
  429. def __init__(self, raw_socket: socket.socket):
  430. super().__init__(trio.socket.from_stdlib_socket(raw_socket))
  431. self._accept_guard = ResourceGuard("accepting connections from")
  432. async def accept(self) -> UNIXSocketStream:
  433. with self._accept_guard:
  434. try:
  435. trio_socket, _addr = await self._trio_socket.accept()
  436. except BaseException as exc:
  437. self._convert_socket_error(exc)
  438. return UNIXSocketStream(trio_socket)
  439. class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket):
  440. def __init__(self, trio_socket: TrioSocketType) -> None:
  441. super().__init__(trio_socket)
  442. self._receive_guard = ResourceGuard("reading from")
  443. self._send_guard = ResourceGuard("writing to")
  444. async def receive(self) -> tuple[bytes, IPSockAddrType]:
  445. with self._receive_guard:
  446. try:
  447. data, addr = await self._trio_socket.recvfrom(65536)
  448. return data, convert_ipv6_sockaddr(addr)
  449. except BaseException as exc:
  450. self._convert_socket_error(exc)
  451. async def send(self, item: UDPPacketType) -> None:
  452. with self._send_guard:
  453. try:
  454. await self._trio_socket.sendto(*item)
  455. except BaseException as exc:
  456. self._convert_socket_error(exc)
  457. class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket):
  458. def __init__(self, trio_socket: TrioSocketType) -> None:
  459. super().__init__(trio_socket)
  460. self._receive_guard = ResourceGuard("reading from")
  461. self._send_guard = ResourceGuard("writing to")
  462. async def receive(self) -> bytes:
  463. with self._receive_guard:
  464. try:
  465. return await self._trio_socket.recv(65536)
  466. except BaseException as exc:
  467. self._convert_socket_error(exc)
  468. async def send(self, item: bytes) -> None:
  469. with self._send_guard:
  470. try:
  471. await self._trio_socket.send(item)
  472. except BaseException as exc:
  473. self._convert_socket_error(exc)
  474. class UNIXDatagramSocket(_TrioSocketMixin[str], abc.UNIXDatagramSocket):
  475. def __init__(self, trio_socket: TrioSocketType) -> None:
  476. super().__init__(trio_socket)
  477. self._receive_guard = ResourceGuard("reading from")
  478. self._send_guard = ResourceGuard("writing to")
  479. async def receive(self) -> UNIXDatagramPacketType:
  480. with self._receive_guard:
  481. try:
  482. data, addr = await self._trio_socket.recvfrom(65536)
  483. return data, addr
  484. except BaseException as exc:
  485. self._convert_socket_error(exc)
  486. async def send(self, item: UNIXDatagramPacketType) -> None:
  487. with self._send_guard:
  488. try:
  489. await self._trio_socket.sendto(*item)
  490. except BaseException as exc:
  491. self._convert_socket_error(exc)
  492. class ConnectedUNIXDatagramSocket(
  493. _TrioSocketMixin[str], abc.ConnectedUNIXDatagramSocket
  494. ):
  495. def __init__(self, trio_socket: TrioSocketType) -> None:
  496. super().__init__(trio_socket)
  497. self._receive_guard = ResourceGuard("reading from")
  498. self._send_guard = ResourceGuard("writing to")
  499. async def receive(self) -> bytes:
  500. with self._receive_guard:
  501. try:
  502. return await self._trio_socket.recv(65536)
  503. except BaseException as exc:
  504. self._convert_socket_error(exc)
  505. async def send(self, item: bytes) -> None:
  506. with self._send_guard:
  507. try:
  508. await self._trio_socket.send(item)
  509. except BaseException as exc:
  510. self._convert_socket_error(exc)
  511. #
  512. # Synchronization
  513. #
  514. class Event(BaseEvent):
  515. def __new__(cls) -> Event:
  516. return object.__new__(cls)
  517. def __init__(self) -> None:
  518. self.__original = trio.Event()
  519. def is_set(self) -> bool:
  520. return self.__original.is_set()
  521. async def wait(self) -> None:
  522. return await self.__original.wait()
  523. def statistics(self) -> EventStatistics:
  524. orig_statistics = self.__original.statistics()
  525. return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting)
  526. def set(self) -> None:
  527. self.__original.set()
  528. class Lock(BaseLock):
  529. def __new__(cls, *, fast_acquire: bool = False) -> Lock:
  530. return object.__new__(cls)
  531. def __init__(self, *, fast_acquire: bool = False) -> None:
  532. self._fast_acquire = fast_acquire
  533. self.__original = trio.Lock()
  534. @staticmethod
  535. def _convert_runtime_error_msg(exc: RuntimeError) -> None:
  536. if exc.args == ("attempt to re-acquire an already held Lock",):
  537. exc.args = ("Attempted to acquire an already held Lock",)
  538. async def acquire(self) -> None:
  539. if not self._fast_acquire:
  540. try:
  541. await self.__original.acquire()
  542. except RuntimeError as exc:
  543. self._convert_runtime_error_msg(exc)
  544. raise
  545. return
  546. # This is the "fast path" where we don't let other tasks run
  547. await trio.lowlevel.checkpoint_if_cancelled()
  548. try:
  549. self.__original.acquire_nowait()
  550. except trio.WouldBlock:
  551. await self.__original._lot.park()
  552. except RuntimeError as exc:
  553. self._convert_runtime_error_msg(exc)
  554. raise
  555. def acquire_nowait(self) -> None:
  556. try:
  557. self.__original.acquire_nowait()
  558. except trio.WouldBlock:
  559. raise WouldBlock from None
  560. except RuntimeError as exc:
  561. self._convert_runtime_error_msg(exc)
  562. raise
  563. def locked(self) -> bool:
  564. return self.__original.locked()
  565. def release(self) -> None:
  566. self.__original.release()
  567. def statistics(self) -> LockStatistics:
  568. orig_statistics = self.__original.statistics()
  569. owner = TrioTaskInfo(orig_statistics.owner) if orig_statistics.owner else None
  570. return LockStatistics(
  571. orig_statistics.locked, owner, orig_statistics.tasks_waiting
  572. )
  573. class Semaphore(BaseSemaphore):
  574. def __new__(
  575. cls,
  576. initial_value: int,
  577. *,
  578. max_value: int | None = None,
  579. fast_acquire: bool = False,
  580. ) -> Semaphore:
  581. return object.__new__(cls)
  582. def __init__(
  583. self,
  584. initial_value: int,
  585. *,
  586. max_value: int | None = None,
  587. fast_acquire: bool = False,
  588. ) -> None:
  589. super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  590. self.__original = trio.Semaphore(initial_value, max_value=max_value)
  591. async def acquire(self) -> None:
  592. if not self._fast_acquire:
  593. await self.__original.acquire()
  594. return
  595. # This is the "fast path" where we don't let other tasks run
  596. await trio.lowlevel.checkpoint_if_cancelled()
  597. try:
  598. self.__original.acquire_nowait()
  599. except trio.WouldBlock:
  600. await self.__original._lot.park()
  601. def acquire_nowait(self) -> None:
  602. try:
  603. self.__original.acquire_nowait()
  604. except trio.WouldBlock:
  605. raise WouldBlock from None
  606. @property
  607. def max_value(self) -> int | None:
  608. return self.__original.max_value
  609. @property
  610. def value(self) -> int:
  611. return self.__original.value
  612. def release(self) -> None:
  613. self.__original.release()
  614. def statistics(self) -> SemaphoreStatistics:
  615. orig_statistics = self.__original.statistics()
  616. return SemaphoreStatistics(orig_statistics.tasks_waiting)
  617. class CapacityLimiter(BaseCapacityLimiter):
  618. def __new__(
  619. cls,
  620. total_tokens: float | None = None,
  621. *,
  622. original: trio.CapacityLimiter | None = None,
  623. ) -> CapacityLimiter:
  624. return object.__new__(cls)
  625. def __init__(
  626. self,
  627. total_tokens: float | None = None,
  628. *,
  629. original: trio.CapacityLimiter | None = None,
  630. ) -> None:
  631. if original is not None:
  632. self.__original = original
  633. else:
  634. assert total_tokens is not None
  635. self.__original = trio.CapacityLimiter(total_tokens)
  636. async def __aenter__(self) -> None:
  637. return await self.__original.__aenter__()
  638. async def __aexit__(
  639. self,
  640. exc_type: type[BaseException] | None,
  641. exc_val: BaseException | None,
  642. exc_tb: TracebackType | None,
  643. ) -> None:
  644. await self.__original.__aexit__(exc_type, exc_val, exc_tb)
  645. @property
  646. def total_tokens(self) -> float:
  647. return self.__original.total_tokens
  648. @total_tokens.setter
  649. def total_tokens(self, value: float) -> None:
  650. self.__original.total_tokens = value
  651. @property
  652. def borrowed_tokens(self) -> int:
  653. return self.__original.borrowed_tokens
  654. @property
  655. def available_tokens(self) -> float:
  656. return self.__original.available_tokens
  657. def acquire_nowait(self) -> None:
  658. self.__original.acquire_nowait()
  659. def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
  660. self.__original.acquire_on_behalf_of_nowait(borrower)
  661. async def acquire(self) -> None:
  662. await self.__original.acquire()
  663. async def acquire_on_behalf_of(self, borrower: object) -> None:
  664. await self.__original.acquire_on_behalf_of(borrower)
  665. def release(self) -> None:
  666. return self.__original.release()
  667. def release_on_behalf_of(self, borrower: object) -> None:
  668. return self.__original.release_on_behalf_of(borrower)
  669. def statistics(self) -> CapacityLimiterStatistics:
  670. orig = self.__original.statistics()
  671. return CapacityLimiterStatistics(
  672. borrowed_tokens=orig.borrowed_tokens,
  673. total_tokens=orig.total_tokens,
  674. borrowers=tuple(orig.borrowers),
  675. tasks_waiting=orig.tasks_waiting,
  676. )
  677. _capacity_limiter_wrapper: trio.lowlevel.RunVar = RunVar("_capacity_limiter_wrapper")
  678. #
  679. # Signal handling
  680. #
  681. class _SignalReceiver:
  682. _iterator: AsyncIterator[int]
  683. def __init__(self, signals: tuple[Signals, ...]):
  684. self._signals = signals
  685. def __enter__(self) -> _SignalReceiver:
  686. self._cm = trio.open_signal_receiver(*self._signals)
  687. self._iterator = self._cm.__enter__()
  688. return self
  689. def __exit__(
  690. self,
  691. exc_type: type[BaseException] | None,
  692. exc_val: BaseException | None,
  693. exc_tb: TracebackType | None,
  694. ) -> bool | None:
  695. return self._cm.__exit__(exc_type, exc_val, exc_tb)
  696. def __aiter__(self) -> _SignalReceiver:
  697. return self
  698. async def __anext__(self) -> Signals:
  699. signum = await self._iterator.__anext__()
  700. return Signals(signum)
  701. #
  702. # Testing and debugging
  703. #
  704. class TestRunner(abc.TestRunner):
  705. def __init__(self, **options: Any) -> None:
  706. from queue import Queue
  707. self._call_queue: Queue[Callable[[], object]] = Queue()
  708. self._send_stream: MemoryObjectSendStream | None = None
  709. self._options = options
  710. def __exit__(
  711. self,
  712. exc_type: type[BaseException] | None,
  713. exc_val: BaseException | None,
  714. exc_tb: types.TracebackType | None,
  715. ) -> None:
  716. if self._send_stream:
  717. self._send_stream.close()
  718. while self._send_stream is not None:
  719. self._call_queue.get()()
  720. async def _run_tests_and_fixtures(self) -> None:
  721. self._send_stream, receive_stream = create_memory_object_stream(1)
  722. with receive_stream:
  723. async for coro, outcome_holder in receive_stream:
  724. try:
  725. retval = await coro
  726. except BaseException as exc:
  727. outcome_holder.append(Error(exc))
  728. else:
  729. outcome_holder.append(Value(retval))
  730. def _main_task_finished(self, outcome: object) -> None:
  731. self._send_stream = None
  732. def _call_in_runner_task(
  733. self,
  734. func: Callable[P, Awaitable[T_Retval]],
  735. *args: P.args,
  736. **kwargs: P.kwargs,
  737. ) -> T_Retval:
  738. if self._send_stream is None:
  739. trio.lowlevel.start_guest_run(
  740. self._run_tests_and_fixtures,
  741. run_sync_soon_threadsafe=self._call_queue.put,
  742. done_callback=self._main_task_finished,
  743. **self._options,
  744. )
  745. while self._send_stream is None:
  746. self._call_queue.get()()
  747. outcome_holder: list[Outcome] = []
  748. self._send_stream.send_nowait((func(*args, **kwargs), outcome_holder))
  749. while not outcome_holder:
  750. self._call_queue.get()()
  751. return outcome_holder[0].unwrap()
  752. def run_asyncgen_fixture(
  753. self,
  754. fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
  755. kwargs: dict[str, Any],
  756. ) -> Iterable[T_Retval]:
  757. asyncgen = fixture_func(**kwargs)
  758. fixturevalue: T_Retval = self._call_in_runner_task(asyncgen.asend, None)
  759. yield fixturevalue
  760. try:
  761. self._call_in_runner_task(asyncgen.asend, None)
  762. except StopAsyncIteration:
  763. pass
  764. else:
  765. self._call_in_runner_task(asyncgen.aclose)
  766. raise RuntimeError("Async generator fixture did not stop")
  767. def run_fixture(
  768. self,
  769. fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
  770. kwargs: dict[str, Any],
  771. ) -> T_Retval:
  772. return self._call_in_runner_task(fixture_func, **kwargs)
  773. def run_test(
  774. self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
  775. ) -> None:
  776. self._call_in_runner_task(test_func, **kwargs)
  777. class TrioTaskInfo(TaskInfo):
  778. def __init__(self, task: trio.lowlevel.Task):
  779. parent_id = None
  780. if task.parent_nursery and task.parent_nursery.parent_task:
  781. parent_id = id(task.parent_nursery.parent_task)
  782. super().__init__(id(task), parent_id, task.name, task.coro)
  783. self._task = weakref.proxy(task)
  784. def has_pending_cancellation(self) -> bool:
  785. try:
  786. return self._task._cancel_status.effectively_cancelled
  787. except ReferenceError:
  788. # If the task is no longer around, it surely doesn't have a cancellation
  789. # pending
  790. return False
  791. class TrioBackend(AsyncBackend):
  792. @classmethod
  793. def run(
  794. cls,
  795. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  796. args: tuple[Unpack[PosArgsT]],
  797. kwargs: dict[str, Any],
  798. options: dict[str, Any],
  799. ) -> T_Retval:
  800. return trio.run(func, *args)
  801. @classmethod
  802. def current_token(cls) -> object:
  803. return trio.lowlevel.current_trio_token()
  804. @classmethod
  805. def current_time(cls) -> float:
  806. return trio.current_time()
  807. @classmethod
  808. def cancelled_exception_class(cls) -> type[BaseException]:
  809. return trio.Cancelled
  810. @classmethod
  811. async def checkpoint(cls) -> None:
  812. await trio.lowlevel.checkpoint()
  813. @classmethod
  814. async def checkpoint_if_cancelled(cls) -> None:
  815. await trio.lowlevel.checkpoint_if_cancelled()
  816. @classmethod
  817. async def cancel_shielded_checkpoint(cls) -> None:
  818. await trio.lowlevel.cancel_shielded_checkpoint()
  819. @classmethod
  820. async def sleep(cls, delay: float) -> None:
  821. await trio.sleep(delay)
  822. @classmethod
  823. def create_cancel_scope(
  824. cls, *, deadline: float = math.inf, shield: bool = False
  825. ) -> abc.CancelScope:
  826. return CancelScope(deadline=deadline, shield=shield)
  827. @classmethod
  828. def current_effective_deadline(cls) -> float:
  829. return trio.current_effective_deadline()
  830. @classmethod
  831. def create_task_group(cls) -> abc.TaskGroup:
  832. return TaskGroup()
  833. @classmethod
  834. def create_event(cls) -> abc.Event:
  835. return Event()
  836. @classmethod
  837. def create_lock(cls, *, fast_acquire: bool) -> Lock:
  838. return Lock(fast_acquire=fast_acquire)
  839. @classmethod
  840. def create_semaphore(
  841. cls,
  842. initial_value: int,
  843. *,
  844. max_value: int | None = None,
  845. fast_acquire: bool = False,
  846. ) -> abc.Semaphore:
  847. return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  848. @classmethod
  849. def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
  850. return CapacityLimiter(total_tokens)
  851. @classmethod
  852. async def run_sync_in_worker_thread(
  853. cls,
  854. func: Callable[[Unpack[PosArgsT]], T_Retval],
  855. args: tuple[Unpack[PosArgsT]],
  856. abandon_on_cancel: bool = False,
  857. limiter: abc.CapacityLimiter | None = None,
  858. ) -> T_Retval:
  859. def wrapper() -> T_Retval:
  860. with claim_worker_thread(TrioBackend, token):
  861. return func(*args)
  862. token = TrioBackend.current_token()
  863. return await run_sync(
  864. wrapper,
  865. abandon_on_cancel=abandon_on_cancel,
  866. limiter=cast(trio.CapacityLimiter, limiter),
  867. )
  868. @classmethod
  869. def check_cancelled(cls) -> None:
  870. trio.from_thread.check_cancelled()
  871. @classmethod
  872. def run_async_from_thread(
  873. cls,
  874. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  875. args: tuple[Unpack[PosArgsT]],
  876. token: object,
  877. ) -> T_Retval:
  878. return trio.from_thread.run(func, *args)
  879. @classmethod
  880. def run_sync_from_thread(
  881. cls,
  882. func: Callable[[Unpack[PosArgsT]], T_Retval],
  883. args: tuple[Unpack[PosArgsT]],
  884. token: object,
  885. ) -> T_Retval:
  886. return trio.from_thread.run_sync(func, *args)
  887. @classmethod
  888. def create_blocking_portal(cls) -> abc.BlockingPortal:
  889. return BlockingPortal()
  890. @classmethod
  891. async def open_process(
  892. cls,
  893. command: StrOrBytesPath | Sequence[StrOrBytesPath],
  894. *,
  895. stdin: int | IO[Any] | None,
  896. stdout: int | IO[Any] | None,
  897. stderr: int | IO[Any] | None,
  898. **kwargs: Any,
  899. ) -> Process:
  900. def convert_item(item: StrOrBytesPath) -> str:
  901. str_or_bytes = os.fspath(item)
  902. if isinstance(str_or_bytes, str):
  903. return str_or_bytes
  904. else:
  905. return os.fsdecode(str_or_bytes)
  906. if isinstance(command, (str, bytes, PathLike)):
  907. process = await trio.lowlevel.open_process(
  908. convert_item(command),
  909. stdin=stdin,
  910. stdout=stdout,
  911. stderr=stderr,
  912. shell=True,
  913. **kwargs,
  914. )
  915. else:
  916. process = await trio.lowlevel.open_process(
  917. [convert_item(item) for item in command],
  918. stdin=stdin,
  919. stdout=stdout,
  920. stderr=stderr,
  921. shell=False,
  922. **kwargs,
  923. )
  924. stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None
  925. stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None
  926. stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None
  927. return Process(process, stdin_stream, stdout_stream, stderr_stream)
  928. @classmethod
  929. def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
  930. trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers)
  931. @classmethod
  932. async def connect_tcp(
  933. cls, host: str, port: int, local_address: IPSockAddrType | None = None
  934. ) -> SocketStream:
  935. family = socket.AF_INET6 if ":" in host else socket.AF_INET
  936. trio_socket = trio.socket.socket(family)
  937. trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  938. if local_address:
  939. await trio_socket.bind(local_address)
  940. try:
  941. await trio_socket.connect((host, port))
  942. except BaseException:
  943. trio_socket.close()
  944. raise
  945. return SocketStream(trio_socket)
  946. @classmethod
  947. async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
  948. trio_socket = trio.socket.socket(socket.AF_UNIX)
  949. try:
  950. await trio_socket.connect(path)
  951. except BaseException:
  952. trio_socket.close()
  953. raise
  954. return UNIXSocketStream(trio_socket)
  955. @classmethod
  956. def create_tcp_listener(cls, sock: socket.socket) -> abc.SocketListener:
  957. return TCPSocketListener(sock)
  958. @classmethod
  959. def create_unix_listener(cls, sock: socket.socket) -> abc.SocketListener:
  960. return UNIXSocketListener(sock)
  961. @classmethod
  962. async def create_udp_socket(
  963. cls,
  964. family: socket.AddressFamily,
  965. local_address: IPSockAddrType | None,
  966. remote_address: IPSockAddrType | None,
  967. reuse_port: bool,
  968. ) -> UDPSocket | ConnectedUDPSocket:
  969. trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM)
  970. if reuse_port:
  971. trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
  972. if local_address:
  973. await trio_socket.bind(local_address)
  974. if remote_address:
  975. await trio_socket.connect(remote_address)
  976. return ConnectedUDPSocket(trio_socket)
  977. else:
  978. return UDPSocket(trio_socket)
  979. @classmethod
  980. @overload
  981. async def create_unix_datagram_socket(
  982. cls, raw_socket: socket.socket, remote_path: None
  983. ) -> abc.UNIXDatagramSocket: ...
  984. @classmethod
  985. @overload
  986. async def create_unix_datagram_socket(
  987. cls, raw_socket: socket.socket, remote_path: str | bytes
  988. ) -> abc.ConnectedUNIXDatagramSocket: ...
  989. @classmethod
  990. async def create_unix_datagram_socket(
  991. cls, raw_socket: socket.socket, remote_path: str | bytes | None
  992. ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
  993. trio_socket = trio.socket.from_stdlib_socket(raw_socket)
  994. if remote_path:
  995. await trio_socket.connect(remote_path)
  996. return ConnectedUNIXDatagramSocket(trio_socket)
  997. else:
  998. return UNIXDatagramSocket(trio_socket)
  999. @classmethod
  1000. async def getaddrinfo(
  1001. cls,
  1002. host: bytes | str | None,
  1003. port: str | int | None,
  1004. *,
  1005. family: int | AddressFamily = 0,
  1006. type: int | SocketKind = 0,
  1007. proto: int = 0,
  1008. flags: int = 0,
  1009. ) -> Sequence[
  1010. tuple[
  1011. AddressFamily,
  1012. SocketKind,
  1013. int,
  1014. str,
  1015. tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
  1016. ]
  1017. ]:
  1018. return await trio.socket.getaddrinfo(host, port, family, type, proto, flags)
  1019. @classmethod
  1020. async def getnameinfo(
  1021. cls, sockaddr: IPSockAddrType, flags: int = 0
  1022. ) -> tuple[str, str]:
  1023. return await trio.socket.getnameinfo(sockaddr, flags)
  1024. @classmethod
  1025. async def wait_readable(cls, obj: HasFileno | int) -> None:
  1026. try:
  1027. await wait_readable(obj)
  1028. except trio.ClosedResourceError as exc:
  1029. raise ClosedResourceError().with_traceback(exc.__traceback__) from None
  1030. except trio.BusyResourceError:
  1031. raise BusyResourceError("reading from") from None
  1032. @classmethod
  1033. async def wait_writable(cls, obj: HasFileno | int) -> None:
  1034. try:
  1035. await wait_writable(obj)
  1036. except trio.ClosedResourceError as exc:
  1037. raise ClosedResourceError().with_traceback(exc.__traceback__) from None
  1038. except trio.BusyResourceError:
  1039. raise BusyResourceError("writing to") from None
  1040. @classmethod
  1041. def current_default_thread_limiter(cls) -> CapacityLimiter:
  1042. try:
  1043. return _capacity_limiter_wrapper.get()
  1044. except LookupError:
  1045. limiter = CapacityLimiter(
  1046. original=trio.to_thread.current_default_thread_limiter()
  1047. )
  1048. _capacity_limiter_wrapper.set(limiter)
  1049. return limiter
  1050. @classmethod
  1051. def open_signal_receiver(
  1052. cls, *signals: Signals
  1053. ) -> AbstractContextManager[AsyncIterator[Signals]]:
  1054. return _SignalReceiver(signals)
  1055. @classmethod
  1056. def get_current_task(cls) -> TaskInfo:
  1057. task = current_task()
  1058. return TrioTaskInfo(task)
  1059. @classmethod
  1060. def get_running_tasks(cls) -> Sequence[TaskInfo]:
  1061. root_task = current_root_task()
  1062. assert root_task
  1063. task_infos = [TrioTaskInfo(root_task)]
  1064. nurseries = root_task.child_nurseries
  1065. while nurseries:
  1066. new_nurseries: list[trio.Nursery] = []
  1067. for nursery in nurseries:
  1068. for task in nursery.child_tasks:
  1069. task_infos.append(TrioTaskInfo(task))
  1070. new_nurseries.extend(task.child_nurseries)
  1071. nurseries = new_nurseries
  1072. return task_infos
  1073. @classmethod
  1074. async def wait_all_tasks_blocked(cls) -> None:
  1075. from trio.testing import wait_all_tasks_blocked
  1076. await wait_all_tasks_blocked()
  1077. @classmethod
  1078. def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
  1079. return TestRunner(**options)
  1080. backend_class = TrioBackend