to_process.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. from __future__ import annotations
  2. import os
  3. import pickle
  4. import subprocess
  5. import sys
  6. from collections import deque
  7. from collections.abc import Callable
  8. from importlib.util import module_from_spec, spec_from_file_location
  9. from typing import TypeVar, cast
  10. from ._core._eventloop import current_time, get_async_backend, get_cancelled_exc_class
  11. from ._core._exceptions import BrokenWorkerProcess
  12. from ._core._subprocesses import open_process
  13. from ._core._synchronization import CapacityLimiter
  14. from ._core._tasks import CancelScope, fail_after
  15. from .abc import ByteReceiveStream, ByteSendStream, Process
  16. from .lowlevel import RunVar, checkpoint_if_cancelled
  17. from .streams.buffered import BufferedByteReceiveStream
  18. if sys.version_info >= (3, 11):
  19. from typing import TypeVarTuple, Unpack
  20. else:
  21. from typing_extensions import TypeVarTuple, Unpack
  22. WORKER_MAX_IDLE_TIME = 300 # 5 minutes
  23. T_Retval = TypeVar("T_Retval")
  24. PosArgsT = TypeVarTuple("PosArgsT")
  25. _process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers")
  26. _process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar(
  27. "_process_pool_idle_workers"
  28. )
  29. _default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_limiter")
  30. async def run_sync( # type: ignore[return]
  31. func: Callable[[Unpack[PosArgsT]], T_Retval],
  32. *args: Unpack[PosArgsT],
  33. cancellable: bool = False,
  34. limiter: CapacityLimiter | None = None,
  35. ) -> T_Retval:
  36. """
  37. Call the given function with the given arguments in a worker process.
  38. If the ``cancellable`` option is enabled and the task waiting for its completion is
  39. cancelled, the worker process running it will be abruptly terminated using SIGKILL
  40. (or ``terminateProcess()`` on Windows).
  41. :param func: a callable
  42. :param args: positional arguments for the callable
  43. :param cancellable: ``True`` to allow cancellation of the operation while it's
  44. running
  45. :param limiter: capacity limiter to use to limit the total amount of processes
  46. running (if omitted, the default limiter is used)
  47. :return: an awaitable that yields the return value of the function.
  48. """
  49. async def send_raw_command(pickled_cmd: bytes) -> object:
  50. try:
  51. await stdin.send(pickled_cmd)
  52. response = await buffered.receive_until(b"\n", 50)
  53. status, length = response.split(b" ")
  54. if status not in (b"RETURN", b"EXCEPTION"):
  55. raise RuntimeError(
  56. f"Worker process returned unexpected response: {response!r}"
  57. )
  58. pickled_response = await buffered.receive_exactly(int(length))
  59. except BaseException as exc:
  60. workers.discard(process)
  61. try:
  62. process.kill()
  63. with CancelScope(shield=True):
  64. await process.aclose()
  65. except ProcessLookupError:
  66. pass
  67. if isinstance(exc, get_cancelled_exc_class()):
  68. raise
  69. else:
  70. raise BrokenWorkerProcess from exc
  71. retval = pickle.loads(pickled_response)
  72. if status == b"EXCEPTION":
  73. assert isinstance(retval, BaseException)
  74. raise retval
  75. else:
  76. return retval
  77. # First pickle the request before trying to reserve a worker process
  78. await checkpoint_if_cancelled()
  79. request = pickle.dumps(("run", func, args), protocol=pickle.HIGHEST_PROTOCOL)
  80. # If this is the first run in this event loop thread, set up the necessary variables
  81. try:
  82. workers = _process_pool_workers.get()
  83. idle_workers = _process_pool_idle_workers.get()
  84. except LookupError:
  85. workers = set()
  86. idle_workers = deque()
  87. _process_pool_workers.set(workers)
  88. _process_pool_idle_workers.set(idle_workers)
  89. get_async_backend().setup_process_pool_exit_at_shutdown(workers)
  90. async with limiter or current_default_process_limiter():
  91. # Pop processes from the pool (starting from the most recently used) until we
  92. # find one that hasn't exited yet
  93. process: Process
  94. while idle_workers:
  95. process, idle_since = idle_workers.pop()
  96. if process.returncode is None:
  97. stdin = cast(ByteSendStream, process.stdin)
  98. buffered = BufferedByteReceiveStream(
  99. cast(ByteReceiveStream, process.stdout)
  100. )
  101. # Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME
  102. # seconds or longer
  103. now = current_time()
  104. killed_processes: list[Process] = []
  105. while idle_workers:
  106. if now - idle_workers[0][1] < WORKER_MAX_IDLE_TIME:
  107. break
  108. process_to_kill, idle_since = idle_workers.popleft()
  109. process_to_kill.kill()
  110. workers.remove(process_to_kill)
  111. killed_processes.append(process_to_kill)
  112. with CancelScope(shield=True):
  113. for killed_process in killed_processes:
  114. await killed_process.aclose()
  115. break
  116. workers.remove(process)
  117. else:
  118. command = [sys.executable, "-u", "-m", __name__]
  119. process = await open_process(
  120. command, stdin=subprocess.PIPE, stdout=subprocess.PIPE
  121. )
  122. try:
  123. stdin = cast(ByteSendStream, process.stdin)
  124. buffered = BufferedByteReceiveStream(
  125. cast(ByteReceiveStream, process.stdout)
  126. )
  127. with fail_after(20):
  128. message = await buffered.receive(6)
  129. if message != b"READY\n":
  130. raise BrokenWorkerProcess(
  131. f"Worker process returned unexpected response: {message!r}"
  132. )
  133. main_module_path = getattr(sys.modules["__main__"], "__file__", None)
  134. pickled = pickle.dumps(
  135. ("init", sys.path, main_module_path),
  136. protocol=pickle.HIGHEST_PROTOCOL,
  137. )
  138. await send_raw_command(pickled)
  139. except (BrokenWorkerProcess, get_cancelled_exc_class()):
  140. raise
  141. except BaseException as exc:
  142. process.kill()
  143. raise BrokenWorkerProcess(
  144. "Error during worker process initialization"
  145. ) from exc
  146. workers.add(process)
  147. with CancelScope(shield=not cancellable):
  148. try:
  149. return cast(T_Retval, await send_raw_command(request))
  150. finally:
  151. if process in workers:
  152. idle_workers.append((process, current_time()))
  153. def current_default_process_limiter() -> CapacityLimiter:
  154. """
  155. Return the capacity limiter that is used by default to limit the number of worker
  156. processes.
  157. :return: a capacity limiter object
  158. """
  159. try:
  160. return _default_process_limiter.get()
  161. except LookupError:
  162. limiter = CapacityLimiter(os.cpu_count() or 2)
  163. _default_process_limiter.set(limiter)
  164. return limiter
  165. def process_worker() -> None:
  166. # Redirect standard streams to os.devnull so that user code won't interfere with the
  167. # parent-worker communication
  168. stdin = sys.stdin
  169. stdout = sys.stdout
  170. sys.stdin = open(os.devnull)
  171. sys.stdout = open(os.devnull, "w")
  172. stdout.buffer.write(b"READY\n")
  173. while True:
  174. retval = exception = None
  175. try:
  176. command, *args = pickle.load(stdin.buffer)
  177. except EOFError:
  178. return
  179. except BaseException as exc:
  180. exception = exc
  181. else:
  182. if command == "run":
  183. func, args = args
  184. try:
  185. retval = func(*args)
  186. except BaseException as exc:
  187. exception = exc
  188. elif command == "init":
  189. main_module_path: str | None
  190. sys.path, main_module_path = args
  191. del sys.modules["__main__"]
  192. if main_module_path and os.path.isfile(main_module_path):
  193. # Load the parent's main module but as __mp_main__ instead of
  194. # __main__ (like multiprocessing does) to avoid infinite recursion
  195. try:
  196. spec = spec_from_file_location("__mp_main__", main_module_path)
  197. if spec and spec.loader:
  198. main = module_from_spec(spec)
  199. spec.loader.exec_module(main)
  200. sys.modules["__main__"] = main
  201. except BaseException as exc:
  202. exception = exc
  203. try:
  204. if exception is not None:
  205. status = b"EXCEPTION"
  206. pickled = pickle.dumps(exception, pickle.HIGHEST_PROTOCOL)
  207. else:
  208. status = b"RETURN"
  209. pickled = pickle.dumps(retval, pickle.HIGHEST_PROTOCOL)
  210. except BaseException as exc:
  211. exception = exc
  212. status = b"EXCEPTION"
  213. pickled = pickle.dumps(exc, pickle.HIGHEST_PROTOCOL)
  214. stdout.buffer.write(b"%s %d\n" % (status, len(pickled)))
  215. stdout.buffer.write(pickled)
  216. # Respect SIGTERM
  217. if isinstance(exception, SystemExit):
  218. raise exception
  219. if __name__ == "__main__":
  220. process_worker()