_server.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528
  1. # Copyright 2016 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Service-side implementation of gRPC Python."""
  15. from __future__ import annotations
  16. import abc
  17. import collections
  18. from concurrent import futures
  19. import contextvars
  20. import enum
  21. import logging
  22. import threading
  23. import time
  24. import traceback
  25. from typing import (
  26. Any,
  27. Callable,
  28. Dict,
  29. Iterable,
  30. Iterator,
  31. List,
  32. Mapping,
  33. Optional,
  34. Sequence,
  35. Set,
  36. Tuple,
  37. Union,
  38. )
  39. import grpc # pytype: disable=pyi-error
  40. from grpc import _common # pytype: disable=pyi-error
  41. from grpc import _compression # pytype: disable=pyi-error
  42. from grpc import _interceptor # pytype: disable=pyi-error
  43. from grpc import _observability # pytype: disable=pyi-error
  44. from grpc._cython import cygrpc
  45. from grpc._typing import ArityAgnosticMethodHandler
  46. from grpc._typing import ChannelArgumentType
  47. from grpc._typing import DeserializingFunction
  48. from grpc._typing import MetadataType
  49. from grpc._typing import NullaryCallbackType
  50. from grpc._typing import ResponseType
  51. from grpc._typing import SerializingFunction
  52. from grpc._typing import ServerCallbackTag
  53. from grpc._typing import ServerTagCallbackType
  54. _LOGGER = logging.getLogger(__name__)
  55. _SHUTDOWN_TAG = "shutdown"
  56. _REQUEST_CALL_TAG = "request_call"
  57. _RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server"
  58. _SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata"
  59. _RECEIVE_MESSAGE_TOKEN = "receive_message"
  60. _SEND_MESSAGE_TOKEN = "send_message"
  61. _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
  62. "send_initial_metadata * send_message"
  63. )
  64. _SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server"
  65. _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
  66. "send_initial_metadata * send_status_from_server"
  67. )
  68. _OPEN = "open"
  69. _CLOSED = "closed"
  70. _CANCELLED = "cancelled"
  71. _EMPTY_FLAGS = 0
  72. _DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
  73. _INF_TIMEOUT = 1e9
  74. def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes:
  75. return request_event.batch_operations[0].message()
  76. def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode:
  77. cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
  78. return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
  79. def _completion_code(state: _RPCState) -> cygrpc.StatusCode:
  80. if state.code is None:
  81. return cygrpc.StatusCode.ok
  82. else:
  83. return _application_code(state.code)
  84. def _abortion_code(
  85. state: _RPCState, code: cygrpc.StatusCode
  86. ) -> cygrpc.StatusCode:
  87. if state.code is None:
  88. return code
  89. else:
  90. return _application_code(state.code)
  91. def _details(state: _RPCState) -> bytes:
  92. return b"" if state.details is None else state.details
  93. class _HandlerCallDetails(
  94. collections.namedtuple(
  95. "_HandlerCallDetails",
  96. (
  97. "method",
  98. "invocation_metadata",
  99. ),
  100. ),
  101. grpc.HandlerCallDetails,
  102. ):
  103. pass
  104. class _Method(abc.ABC):
  105. @abc.abstractmethod
  106. def name(self) -> Optional[str]:
  107. raise NotImplementedError()
  108. @abc.abstractmethod
  109. def handler(
  110. self, handler_call_details: _HandlerCallDetails
  111. ) -> Optional[grpc.RpcMethodHandler]:
  112. raise NotImplementedError()
  113. class _RegisteredMethod(_Method):
  114. def __init__(
  115. self,
  116. name: str,
  117. registered_handler: Optional[grpc.RpcMethodHandler],
  118. ):
  119. self._name = name
  120. self._registered_handler = registered_handler
  121. def name(self) -> Optional[str]:
  122. return self._name
  123. def handler(
  124. self, handler_call_details: _HandlerCallDetails
  125. ) -> Optional[grpc.RpcMethodHandler]:
  126. return self._registered_handler
  127. class _GenericMethod(_Method):
  128. def __init__(
  129. self,
  130. generic_handlers: List[grpc.GenericRpcHandler],
  131. ):
  132. self._generic_handlers = generic_handlers
  133. def name(self) -> Optional[str]:
  134. return None
  135. def handler(
  136. self, handler_call_details: _HandlerCallDetails
  137. ) -> Optional[grpc.RpcMethodHandler]:
  138. # If the same method have both generic and registered handler,
  139. # registered handler will take precedence.
  140. for generic_handler in self._generic_handlers:
  141. method_handler = generic_handler.service(handler_call_details)
  142. if method_handler is not None:
  143. return method_handler
  144. return None
  145. class _RPCState(object):
  146. context: contextvars.Context
  147. condition: threading.Condition
  148. due = Set[str]
  149. request: Any
  150. client: str
  151. initial_metadata_allowed: bool
  152. compression_algorithm: Optional[grpc.Compression]
  153. disable_next_compression: bool
  154. trailing_metadata: Optional[MetadataType]
  155. code: Optional[grpc.StatusCode]
  156. details: Optional[bytes]
  157. statused: bool
  158. rpc_errors: List[Exception]
  159. callbacks: Optional[List[NullaryCallbackType]]
  160. aborted: bool
  161. def __init__(self):
  162. self.context = contextvars.Context()
  163. self.condition = threading.Condition()
  164. self.due = set()
  165. self.request = None
  166. self.client = _OPEN
  167. self.initial_metadata_allowed = True
  168. self.compression_algorithm = None
  169. self.disable_next_compression = False
  170. self.trailing_metadata = None
  171. self.code = None
  172. self.details = None
  173. self.statused = False
  174. self.rpc_errors = []
  175. self.callbacks = []
  176. self.aborted = False
  177. def _raise_rpc_error(state: _RPCState) -> None:
  178. rpc_error = grpc.RpcError()
  179. state.rpc_errors.append(rpc_error)
  180. raise rpc_error
  181. def _possibly_finish_call(
  182. state: _RPCState, token: str
  183. ) -> ServerTagCallbackType:
  184. state.due.remove(token)
  185. if not _is_rpc_state_active(state) and not state.due:
  186. callbacks = state.callbacks
  187. state.callbacks = None
  188. return state, callbacks
  189. else:
  190. return None, ()
  191. def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag:
  192. def send_status_from_server(unused_send_status_from_server_event):
  193. with state.condition:
  194. return _possibly_finish_call(state, token)
  195. return send_status_from_server
  196. def _get_initial_metadata(
  197. state: _RPCState, metadata: Optional[MetadataType]
  198. ) -> Optional[MetadataType]:
  199. with state.condition:
  200. if state.compression_algorithm:
  201. compression_metadata = (
  202. _compression.compression_algorithm_to_metadata(
  203. state.compression_algorithm
  204. ),
  205. )
  206. if metadata is None:
  207. return compression_metadata
  208. else:
  209. return compression_metadata + tuple(metadata)
  210. else:
  211. return metadata
  212. def _get_initial_metadata_operation(
  213. state: _RPCState, metadata: Optional[MetadataType]
  214. ) -> cygrpc.Operation:
  215. operation = cygrpc.SendInitialMetadataOperation(
  216. _get_initial_metadata(state, metadata), _EMPTY_FLAGS
  217. )
  218. return operation
  219. def _abort(
  220. state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes
  221. ) -> None:
  222. if state.client is not _CANCELLED:
  223. effective_code = _abortion_code(state, code)
  224. effective_details = details if state.details is None else state.details
  225. if state.initial_metadata_allowed:
  226. operations = (
  227. _get_initial_metadata_operation(state, None),
  228. cygrpc.SendStatusFromServerOperation(
  229. state.trailing_metadata,
  230. effective_code,
  231. effective_details,
  232. _EMPTY_FLAGS,
  233. ),
  234. )
  235. token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
  236. else:
  237. operations = (
  238. cygrpc.SendStatusFromServerOperation(
  239. state.trailing_metadata,
  240. effective_code,
  241. effective_details,
  242. _EMPTY_FLAGS,
  243. ),
  244. )
  245. token = _SEND_STATUS_FROM_SERVER_TOKEN
  246. call.start_server_batch(
  247. operations, _send_status_from_server(state, token)
  248. )
  249. state.statused = True
  250. state.due.add(token)
  251. def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag:
  252. def receive_close_on_server(receive_close_on_server_event):
  253. with state.condition:
  254. if receive_close_on_server_event.batch_operations[0].cancelled():
  255. state.client = _CANCELLED
  256. elif state.client is _OPEN:
  257. state.client = _CLOSED
  258. state.condition.notify_all()
  259. return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
  260. return receive_close_on_server
  261. def _receive_message(
  262. state: _RPCState,
  263. call: cygrpc.Call,
  264. request_deserializer: Optional[DeserializingFunction],
  265. ) -> ServerCallbackTag:
  266. def receive_message(receive_message_event):
  267. serialized_request = _serialized_request(receive_message_event)
  268. if serialized_request is None:
  269. with state.condition:
  270. if state.client is _OPEN:
  271. state.client = _CLOSED
  272. state.condition.notify_all()
  273. return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
  274. else:
  275. request = _common.deserialize(
  276. serialized_request, request_deserializer
  277. )
  278. with state.condition:
  279. if request is None:
  280. _abort(
  281. state,
  282. call,
  283. cygrpc.StatusCode.internal,
  284. b"Exception deserializing request!",
  285. )
  286. else:
  287. state.request = request
  288. state.condition.notify_all()
  289. return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
  290. return receive_message
  291. def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag:
  292. def send_initial_metadata(unused_send_initial_metadata_event):
  293. with state.condition:
  294. return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
  295. return send_initial_metadata
  296. def _send_message(state: _RPCState, token: str) -> ServerCallbackTag:
  297. def send_message(unused_send_message_event):
  298. with state.condition:
  299. state.condition.notify_all()
  300. return _possibly_finish_call(state, token)
  301. return send_message
  302. class _Context(grpc.ServicerContext):
  303. _rpc_event: cygrpc.BaseEvent
  304. _state: _RPCState
  305. request_deserializer: Optional[DeserializingFunction]
  306. def __init__(
  307. self,
  308. rpc_event: cygrpc.BaseEvent,
  309. state: _RPCState,
  310. request_deserializer: Optional[DeserializingFunction],
  311. ):
  312. self._rpc_event = rpc_event
  313. self._state = state
  314. self._request_deserializer = request_deserializer
  315. def is_active(self) -> bool:
  316. with self._state.condition:
  317. return _is_rpc_state_active(self._state)
  318. def time_remaining(self) -> float:
  319. return max(self._rpc_event.call_details.deadline - time.time(), 0)
  320. def cancel(self) -> None:
  321. self._rpc_event.call.cancel()
  322. def add_callback(self, callback: NullaryCallbackType) -> bool:
  323. with self._state.condition:
  324. if self._state.callbacks is None:
  325. return False
  326. else:
  327. self._state.callbacks.append(callback)
  328. return True
  329. def disable_next_message_compression(self) -> None:
  330. with self._state.condition:
  331. self._state.disable_next_compression = True
  332. def invocation_metadata(self) -> Optional[MetadataType]:
  333. return self._rpc_event.invocation_metadata
  334. def peer(self) -> str:
  335. return _common.decode(self._rpc_event.call.peer())
  336. def peer_identities(self) -> Optional[Sequence[bytes]]:
  337. return cygrpc.peer_identities(self._rpc_event.call)
  338. def peer_identity_key(self) -> Optional[str]:
  339. id_key = cygrpc.peer_identity_key(self._rpc_event.call)
  340. return id_key if id_key is None else _common.decode(id_key)
  341. def auth_context(self) -> Mapping[str, Sequence[bytes]]:
  342. auth_context = cygrpc.auth_context(self._rpc_event.call)
  343. auth_context_dict = {} if auth_context is None else auth_context
  344. return {
  345. _common.decode(key): value
  346. for key, value in auth_context_dict.items()
  347. }
  348. def set_compression(self, compression: grpc.Compression) -> None:
  349. with self._state.condition:
  350. self._state.compression_algorithm = compression
  351. def send_initial_metadata(self, initial_metadata: MetadataType) -> None:
  352. with self._state.condition:
  353. if self._state.client is _CANCELLED:
  354. _raise_rpc_error(self._state)
  355. else:
  356. if self._state.initial_metadata_allowed:
  357. operation = _get_initial_metadata_operation(
  358. self._state, initial_metadata
  359. )
  360. self._rpc_event.call.start_server_batch(
  361. (operation,), _send_initial_metadata(self._state)
  362. )
  363. self._state.initial_metadata_allowed = False
  364. self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
  365. else:
  366. raise ValueError("Initial metadata no longer allowed!")
  367. def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None:
  368. with self._state.condition:
  369. self._state.trailing_metadata = trailing_metadata
  370. def trailing_metadata(self) -> Optional[MetadataType]:
  371. return self._state.trailing_metadata
  372. def abort(self, code: grpc.StatusCode, details: str) -> None:
  373. # treat OK like other invalid arguments: fail the RPC
  374. if code == grpc.StatusCode.OK:
  375. _LOGGER.error(
  376. "abort() called with StatusCode.OK; returning UNKNOWN"
  377. )
  378. code = grpc.StatusCode.UNKNOWN
  379. details = ""
  380. with self._state.condition:
  381. self._state.code = code
  382. self._state.details = _common.encode(details)
  383. self._state.aborted = True
  384. raise Exception()
  385. def abort_with_status(self, status: grpc.Status) -> None:
  386. self._state.trailing_metadata = status.trailing_metadata
  387. self.abort(status.code, status.details)
  388. def set_code(self, code: grpc.StatusCode) -> None:
  389. with self._state.condition:
  390. self._state.code = code
  391. def code(self) -> grpc.StatusCode:
  392. return self._state.code
  393. def set_details(self, details: str) -> None:
  394. with self._state.condition:
  395. self._state.details = _common.encode(details)
  396. def details(self) -> bytes:
  397. return self._state.details
  398. def _finalize_state(self) -> None:
  399. pass
  400. class _RequestIterator(object):
  401. _state: _RPCState
  402. _call: cygrpc.Call
  403. _request_deserializer: Optional[DeserializingFunction]
  404. def __init__(
  405. self,
  406. state: _RPCState,
  407. call: cygrpc.Call,
  408. request_deserializer: Optional[DeserializingFunction],
  409. ):
  410. self._state = state
  411. self._call = call
  412. self._request_deserializer = request_deserializer
  413. def _raise_or_start_receive_message(self) -> None:
  414. if self._state.client is _CANCELLED:
  415. _raise_rpc_error(self._state)
  416. elif not _is_rpc_state_active(self._state):
  417. raise StopIteration()
  418. else:
  419. self._call.start_server_batch(
  420. (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
  421. _receive_message(
  422. self._state, self._call, self._request_deserializer
  423. ),
  424. )
  425. self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
  426. def _look_for_request(self) -> Any:
  427. if self._state.client is _CANCELLED:
  428. _raise_rpc_error(self._state)
  429. elif (
  430. self._state.request is None
  431. and _RECEIVE_MESSAGE_TOKEN not in self._state.due
  432. ):
  433. raise StopIteration()
  434. else:
  435. request = self._state.request
  436. self._state.request = None
  437. return request
  438. raise AssertionError() # should never run
  439. def _next(self) -> Any:
  440. with self._state.condition:
  441. self._raise_or_start_receive_message()
  442. while True:
  443. self._state.condition.wait()
  444. request = self._look_for_request()
  445. if request is not None:
  446. return request
  447. def __iter__(self) -> _RequestIterator:
  448. return self
  449. def __next__(self) -> Any:
  450. return self._next()
  451. def next(self) -> Any:
  452. return self._next()
  453. def _unary_request(
  454. rpc_event: cygrpc.BaseEvent,
  455. state: _RPCState,
  456. request_deserializer: Optional[DeserializingFunction],
  457. ) -> Callable[[], Any]:
  458. def unary_request():
  459. with state.condition:
  460. if not _is_rpc_state_active(state):
  461. return None
  462. else:
  463. rpc_event.call.start_server_batch(
  464. (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
  465. _receive_message(
  466. state, rpc_event.call, request_deserializer
  467. ),
  468. )
  469. state.due.add(_RECEIVE_MESSAGE_TOKEN)
  470. while True:
  471. state.condition.wait()
  472. if state.request is None:
  473. if state.client is _CLOSED:
  474. details = '"{}" requires exactly one request message.'.format(
  475. rpc_event.call_details.method
  476. )
  477. _abort(
  478. state,
  479. rpc_event.call,
  480. cygrpc.StatusCode.unimplemented,
  481. _common.encode(details),
  482. )
  483. return None
  484. elif state.client is _CANCELLED:
  485. return None
  486. else:
  487. request = state.request
  488. state.request = None
  489. return request
  490. return unary_request
  491. def _call_behavior(
  492. rpc_event: cygrpc.BaseEvent,
  493. state: _RPCState,
  494. behavior: ArityAgnosticMethodHandler,
  495. argument: Any,
  496. request_deserializer: Optional[DeserializingFunction],
  497. send_response_callback: Optional[Callable[[ResponseType], None]] = None,
  498. ) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]:
  499. from grpc import _create_servicer_context # pytype: disable=pyi-error
  500. with _create_servicer_context(
  501. rpc_event, state, request_deserializer
  502. ) as context:
  503. try:
  504. response_or_iterator = None
  505. if send_response_callback is not None:
  506. response_or_iterator = behavior(
  507. argument, context, send_response_callback
  508. )
  509. else:
  510. response_or_iterator = behavior(argument, context)
  511. return response_or_iterator, True
  512. except Exception as exception: # pylint: disable=broad-except
  513. with state.condition:
  514. if state.aborted:
  515. _abort(
  516. state,
  517. rpc_event.call,
  518. cygrpc.StatusCode.unknown,
  519. b"RPC Aborted",
  520. )
  521. elif exception not in state.rpc_errors:
  522. try:
  523. details = "Exception calling application: {}".format(
  524. exception
  525. )
  526. except Exception: # pylint: disable=broad-except
  527. details = (
  528. "Calling application raised unprintable Exception!"
  529. )
  530. _LOGGER.exception(
  531. traceback.format_exception(
  532. type(exception),
  533. exception,
  534. exception.__traceback__,
  535. )
  536. )
  537. traceback.print_exc()
  538. _LOGGER.exception(details)
  539. _abort(
  540. state,
  541. rpc_event.call,
  542. cygrpc.StatusCode.unknown,
  543. _common.encode(details),
  544. )
  545. return None, False
  546. def _take_response_from_response_iterator(
  547. rpc_event: cygrpc.BaseEvent,
  548. state: _RPCState,
  549. response_iterator: Iterator[ResponseType],
  550. ) -> Tuple[ResponseType, bool]:
  551. try:
  552. return next(response_iterator), True
  553. except StopIteration:
  554. return None, True
  555. except Exception as exception: # pylint: disable=broad-except
  556. with state.condition:
  557. if state.aborted:
  558. _abort(
  559. state,
  560. rpc_event.call,
  561. cygrpc.StatusCode.unknown,
  562. b"RPC Aborted",
  563. )
  564. elif exception not in state.rpc_errors:
  565. details = "Exception iterating responses: {}".format(exception)
  566. _LOGGER.exception(details)
  567. _abort(
  568. state,
  569. rpc_event.call,
  570. cygrpc.StatusCode.unknown,
  571. _common.encode(details),
  572. )
  573. return None, False
  574. def _serialize_response(
  575. rpc_event: cygrpc.BaseEvent,
  576. state: _RPCState,
  577. response: Any,
  578. response_serializer: Optional[SerializingFunction],
  579. ) -> Optional[bytes]:
  580. serialized_response = _common.serialize(response, response_serializer)
  581. if serialized_response is None:
  582. with state.condition:
  583. _abort(
  584. state,
  585. rpc_event.call,
  586. cygrpc.StatusCode.internal,
  587. b"Failed to serialize response!",
  588. )
  589. return None
  590. else:
  591. return serialized_response
  592. def _get_send_message_op_flags_from_state(
  593. state: _RPCState,
  594. ) -> Union[int, cygrpc.WriteFlag]:
  595. if state.disable_next_compression:
  596. return cygrpc.WriteFlag.no_compress
  597. else:
  598. return _EMPTY_FLAGS
  599. def _reset_per_message_state(state: _RPCState) -> None:
  600. with state.condition:
  601. state.disable_next_compression = False
  602. def _send_response(
  603. rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes
  604. ) -> bool:
  605. with state.condition:
  606. if not _is_rpc_state_active(state):
  607. return False
  608. else:
  609. if state.initial_metadata_allowed:
  610. operations = (
  611. _get_initial_metadata_operation(state, None),
  612. cygrpc.SendMessageOperation(
  613. serialized_response,
  614. _get_send_message_op_flags_from_state(state),
  615. ),
  616. )
  617. state.initial_metadata_allowed = False
  618. token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
  619. else:
  620. operations = (
  621. cygrpc.SendMessageOperation(
  622. serialized_response,
  623. _get_send_message_op_flags_from_state(state),
  624. ),
  625. )
  626. token = _SEND_MESSAGE_TOKEN
  627. rpc_event.call.start_server_batch(
  628. operations, _send_message(state, token)
  629. )
  630. state.due.add(token)
  631. _reset_per_message_state(state)
  632. while True:
  633. state.condition.wait()
  634. if token not in state.due:
  635. return _is_rpc_state_active(state)
  636. def _status(
  637. rpc_event: cygrpc.BaseEvent,
  638. state: _RPCState,
  639. serialized_response: Optional[bytes],
  640. ) -> None:
  641. with state.condition:
  642. if state.client is not _CANCELLED:
  643. code = _completion_code(state)
  644. details = _details(state)
  645. operations = [
  646. cygrpc.SendStatusFromServerOperation(
  647. state.trailing_metadata, code, details, _EMPTY_FLAGS
  648. ),
  649. ]
  650. if state.initial_metadata_allowed:
  651. operations.append(_get_initial_metadata_operation(state, None))
  652. if serialized_response is not None:
  653. operations.append(
  654. cygrpc.SendMessageOperation(
  655. serialized_response,
  656. _get_send_message_op_flags_from_state(state),
  657. )
  658. )
  659. rpc_event.call.start_server_batch(
  660. operations,
  661. _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN),
  662. )
  663. state.statused = True
  664. _reset_per_message_state(state)
  665. state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
  666. def _unary_response_in_pool(
  667. rpc_event: cygrpc.BaseEvent,
  668. state: _RPCState,
  669. behavior: ArityAgnosticMethodHandler,
  670. argument_thunk: Callable[[], Any],
  671. request_deserializer: Optional[SerializingFunction],
  672. response_serializer: Optional[SerializingFunction],
  673. ) -> None:
  674. cygrpc.install_context_from_request_call_event(rpc_event)
  675. try:
  676. argument = argument_thunk()
  677. if argument is not None:
  678. response, proceed = _call_behavior(
  679. rpc_event, state, behavior, argument, request_deserializer
  680. )
  681. if proceed:
  682. serialized_response = _serialize_response(
  683. rpc_event, state, response, response_serializer
  684. )
  685. if serialized_response is not None:
  686. _status(rpc_event, state, serialized_response)
  687. except Exception: # pylint: disable=broad-except
  688. traceback.print_exc()
  689. finally:
  690. cygrpc.uninstall_context()
  691. def _stream_response_in_pool(
  692. rpc_event: cygrpc.BaseEvent,
  693. state: _RPCState,
  694. behavior: ArityAgnosticMethodHandler,
  695. argument_thunk: Callable[[], Any],
  696. request_deserializer: Optional[DeserializingFunction],
  697. response_serializer: Optional[SerializingFunction],
  698. ) -> None:
  699. cygrpc.install_context_from_request_call_event(rpc_event)
  700. def send_response(response: Any) -> None:
  701. if response is None:
  702. _status(rpc_event, state, None)
  703. else:
  704. serialized_response = _serialize_response(
  705. rpc_event, state, response, response_serializer
  706. )
  707. if serialized_response is not None:
  708. _send_response(rpc_event, state, serialized_response)
  709. try:
  710. argument = argument_thunk()
  711. if argument is not None:
  712. if (
  713. hasattr(behavior, "experimental_non_blocking")
  714. and behavior.experimental_non_blocking
  715. ):
  716. _call_behavior(
  717. rpc_event,
  718. state,
  719. behavior,
  720. argument,
  721. request_deserializer,
  722. send_response_callback=send_response,
  723. )
  724. else:
  725. response_iterator, proceed = _call_behavior(
  726. rpc_event, state, behavior, argument, request_deserializer
  727. )
  728. if proceed:
  729. _send_message_callback_to_blocking_iterator_adapter(
  730. rpc_event, state, send_response, response_iterator
  731. )
  732. except Exception: # pylint: disable=broad-except
  733. traceback.print_exc()
  734. finally:
  735. cygrpc.uninstall_context()
  736. def _is_rpc_state_active(state: _RPCState) -> bool:
  737. return state.client is not _CANCELLED and not state.statused
  738. def _send_message_callback_to_blocking_iterator_adapter(
  739. rpc_event: cygrpc.BaseEvent,
  740. state: _RPCState,
  741. send_response_callback: Callable[[ResponseType], None],
  742. response_iterator: Iterator[ResponseType],
  743. ) -> None:
  744. while True:
  745. response, proceed = _take_response_from_response_iterator(
  746. rpc_event, state, response_iterator
  747. )
  748. if proceed:
  749. send_response_callback(response)
  750. if not _is_rpc_state_active(state):
  751. break
  752. else:
  753. break
  754. def _select_thread_pool_for_behavior(
  755. behavior: ArityAgnosticMethodHandler,
  756. default_thread_pool: futures.ThreadPoolExecutor,
  757. ) -> futures.ThreadPoolExecutor:
  758. if hasattr(behavior, "experimental_thread_pool") and isinstance(
  759. behavior.experimental_thread_pool, futures.ThreadPoolExecutor
  760. ):
  761. return behavior.experimental_thread_pool
  762. else:
  763. return default_thread_pool
  764. def _handle_unary_unary(
  765. rpc_event: cygrpc.BaseEvent,
  766. state: _RPCState,
  767. method_handler: grpc.RpcMethodHandler,
  768. default_thread_pool: futures.ThreadPoolExecutor,
  769. ) -> futures.Future:
  770. unary_request = _unary_request(
  771. rpc_event, state, method_handler.request_deserializer
  772. )
  773. thread_pool = _select_thread_pool_for_behavior(
  774. method_handler.unary_unary, default_thread_pool
  775. )
  776. return thread_pool.submit(
  777. state.context.run,
  778. _unary_response_in_pool,
  779. rpc_event,
  780. state,
  781. method_handler.unary_unary,
  782. unary_request,
  783. method_handler.request_deserializer,
  784. method_handler.response_serializer,
  785. )
  786. def _handle_unary_stream(
  787. rpc_event: cygrpc.BaseEvent,
  788. state: _RPCState,
  789. method_handler: grpc.RpcMethodHandler,
  790. default_thread_pool: futures.ThreadPoolExecutor,
  791. ) -> futures.Future:
  792. unary_request = _unary_request(
  793. rpc_event, state, method_handler.request_deserializer
  794. )
  795. thread_pool = _select_thread_pool_for_behavior(
  796. method_handler.unary_stream, default_thread_pool
  797. )
  798. return thread_pool.submit(
  799. state.context.run,
  800. _stream_response_in_pool,
  801. rpc_event,
  802. state,
  803. method_handler.unary_stream,
  804. unary_request,
  805. method_handler.request_deserializer,
  806. method_handler.response_serializer,
  807. )
  808. def _handle_stream_unary(
  809. rpc_event: cygrpc.BaseEvent,
  810. state: _RPCState,
  811. method_handler: grpc.RpcMethodHandler,
  812. default_thread_pool: futures.ThreadPoolExecutor,
  813. ) -> futures.Future:
  814. request_iterator = _RequestIterator(
  815. state, rpc_event.call, method_handler.request_deserializer
  816. )
  817. thread_pool = _select_thread_pool_for_behavior(
  818. method_handler.stream_unary, default_thread_pool
  819. )
  820. return thread_pool.submit(
  821. state.context.run,
  822. _unary_response_in_pool,
  823. rpc_event,
  824. state,
  825. method_handler.stream_unary,
  826. lambda: request_iterator,
  827. method_handler.request_deserializer,
  828. method_handler.response_serializer,
  829. )
  830. def _handle_stream_stream(
  831. rpc_event: cygrpc.BaseEvent,
  832. state: _RPCState,
  833. method_handler: grpc.RpcMethodHandler,
  834. default_thread_pool: futures.ThreadPoolExecutor,
  835. ) -> futures.Future:
  836. request_iterator = _RequestIterator(
  837. state, rpc_event.call, method_handler.request_deserializer
  838. )
  839. thread_pool = _select_thread_pool_for_behavior(
  840. method_handler.stream_stream, default_thread_pool
  841. )
  842. return thread_pool.submit(
  843. state.context.run,
  844. _stream_response_in_pool,
  845. rpc_event,
  846. state,
  847. method_handler.stream_stream,
  848. lambda: request_iterator,
  849. method_handler.request_deserializer,
  850. method_handler.response_serializer,
  851. )
  852. def _find_method_handler(
  853. rpc_event: cygrpc.BaseEvent,
  854. state: _RPCState,
  855. method_with_handler: _Method,
  856. interceptor_pipeline: Optional[_interceptor._ServicePipeline],
  857. ) -> Optional[grpc.RpcMethodHandler]:
  858. def query_handlers(
  859. handler_call_details: _HandlerCallDetails,
  860. ) -> Optional[grpc.RpcMethodHandler]:
  861. return method_with_handler.handler(handler_call_details)
  862. method_name = method_with_handler.name()
  863. if not method_name:
  864. method_name = _common.decode(rpc_event.call_details.method)
  865. handler_call_details = _HandlerCallDetails(
  866. method_name,
  867. rpc_event.invocation_metadata,
  868. )
  869. if interceptor_pipeline is not None:
  870. return state.context.run(
  871. interceptor_pipeline.execute, query_handlers, handler_call_details
  872. )
  873. else:
  874. return state.context.run(query_handlers, handler_call_details)
  875. def _reject_rpc(
  876. rpc_event: cygrpc.BaseEvent,
  877. rpc_state: _RPCState,
  878. status: cygrpc.StatusCode,
  879. details: bytes,
  880. ):
  881. operations = (
  882. _get_initial_metadata_operation(rpc_state, None),
  883. cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
  884. cygrpc.SendStatusFromServerOperation(
  885. None, status, details, _EMPTY_FLAGS
  886. ),
  887. )
  888. rpc_event.call.start_server_batch(
  889. operations,
  890. lambda ignored_event: (
  891. rpc_state,
  892. (),
  893. ),
  894. )
  895. def _handle_with_method_handler(
  896. rpc_event: cygrpc.BaseEvent,
  897. state: _RPCState,
  898. method_handler: grpc.RpcMethodHandler,
  899. thread_pool: futures.ThreadPoolExecutor,
  900. ) -> futures.Future:
  901. with state.condition:
  902. rpc_event.call.start_server_batch(
  903. (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
  904. _receive_close_on_server(state),
  905. )
  906. state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
  907. if method_handler.request_streaming:
  908. if method_handler.response_streaming:
  909. return _handle_stream_stream(
  910. rpc_event, state, method_handler, thread_pool
  911. )
  912. else:
  913. return _handle_stream_unary(
  914. rpc_event, state, method_handler, thread_pool
  915. )
  916. else:
  917. if method_handler.response_streaming:
  918. return _handle_unary_stream(
  919. rpc_event, state, method_handler, thread_pool
  920. )
  921. else:
  922. return _handle_unary_unary(
  923. rpc_event, state, method_handler, thread_pool
  924. )
  925. def _handle_call(
  926. rpc_event: cygrpc.BaseEvent,
  927. method_with_handler: _Method,
  928. interceptor_pipeline: Optional[_interceptor._ServicePipeline],
  929. thread_pool: futures.ThreadPoolExecutor,
  930. concurrency_exceeded: bool,
  931. ) -> Tuple[Optional[_RPCState], Optional[futures.Future]]:
  932. """Handles RPC based on provided handlers.
  933. When receiving a call event from Core, registered method will have its
  934. name as tag, we pass the tag as registered_method_name to this method,
  935. then we can find the handler in registered_method_handlers based on
  936. the method name.
  937. For call event with unregistered method, the method name will be included
  938. in rpc_event.call_details.method and we need to query the generics handlers
  939. to find the actual handler.
  940. """
  941. if not rpc_event.success:
  942. return None, None
  943. if rpc_event.call_details.method or method_with_handler.name():
  944. rpc_state = _RPCState()
  945. try:
  946. method_handler = _find_method_handler(
  947. rpc_event,
  948. rpc_state,
  949. method_with_handler,
  950. interceptor_pipeline,
  951. )
  952. except Exception as exception: # pylint: disable=broad-except
  953. details = "Exception servicing handler: {}".format(exception)
  954. _LOGGER.exception(details)
  955. _reject_rpc(
  956. rpc_event,
  957. rpc_state,
  958. cygrpc.StatusCode.unknown,
  959. b"Error in service handler!",
  960. )
  961. return rpc_state, None
  962. if method_handler is None:
  963. _reject_rpc(
  964. rpc_event,
  965. rpc_state,
  966. cygrpc.StatusCode.unimplemented,
  967. b"Method not found!",
  968. )
  969. return rpc_state, None
  970. elif concurrency_exceeded:
  971. _reject_rpc(
  972. rpc_event,
  973. rpc_state,
  974. cygrpc.StatusCode.resource_exhausted,
  975. b"Concurrent RPC limit exceeded!",
  976. )
  977. return rpc_state, None
  978. else:
  979. return (
  980. rpc_state,
  981. _handle_with_method_handler(
  982. rpc_event, rpc_state, method_handler, thread_pool
  983. ),
  984. )
  985. else:
  986. return None, None
  987. @enum.unique
  988. class _ServerStage(enum.Enum):
  989. STOPPED = "stopped"
  990. STARTED = "started"
  991. GRACE = "grace"
  992. class _ServerState(object):
  993. lock: threading.RLock
  994. completion_queue: cygrpc.CompletionQueue
  995. server: cygrpc.Server
  996. generic_handlers: List[grpc.GenericRpcHandler]
  997. registered_method_handlers: Dict[str, grpc.RpcMethodHandler]
  998. interceptor_pipeline: Optional[_interceptor._ServicePipeline]
  999. thread_pool: futures.ThreadPoolExecutor
  1000. stage: _ServerStage
  1001. termination_event: threading.Event
  1002. shutdown_events: List[threading.Event]
  1003. maximum_concurrent_rpcs: Optional[int]
  1004. active_rpc_count: int
  1005. rpc_states: Set[_RPCState]
  1006. due: Set[str]
  1007. server_deallocated: bool
  1008. # pylint: disable=too-many-arguments
  1009. def __init__(
  1010. self,
  1011. completion_queue: cygrpc.CompletionQueue,
  1012. server: cygrpc.Server,
  1013. generic_handlers: Sequence[grpc.GenericRpcHandler],
  1014. interceptor_pipeline: Optional[_interceptor._ServicePipeline],
  1015. thread_pool: futures.ThreadPoolExecutor,
  1016. maximum_concurrent_rpcs: Optional[int],
  1017. ):
  1018. self.lock = threading.RLock()
  1019. self.completion_queue = completion_queue
  1020. self.server = server
  1021. self.generic_handlers = list(generic_handlers)
  1022. self.interceptor_pipeline = interceptor_pipeline
  1023. self.thread_pool = thread_pool
  1024. self.stage = _ServerStage.STOPPED
  1025. self.termination_event = threading.Event()
  1026. self.shutdown_events = [self.termination_event]
  1027. self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
  1028. self.active_rpc_count = 0
  1029. self.registered_method_handlers = {}
  1030. # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
  1031. self.rpc_states = set()
  1032. self.due = set()
  1033. # A "volatile" flag to interrupt the daemon serving thread
  1034. self.server_deallocated = False
  1035. def _add_generic_handlers(
  1036. state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler]
  1037. ) -> None:
  1038. with state.lock:
  1039. state.generic_handlers.extend(generic_handlers)
  1040. def _add_registered_method_handlers(
  1041. state: _ServerState, method_handlers: Dict[str, grpc.RpcMethodHandler]
  1042. ) -> None:
  1043. with state.lock:
  1044. state.registered_method_handlers.update(method_handlers)
  1045. def _add_insecure_port(state: _ServerState, address: bytes) -> int:
  1046. with state.lock:
  1047. return state.server.add_http2_port(address)
  1048. def _add_secure_port(
  1049. state: _ServerState,
  1050. address: bytes,
  1051. server_credentials: grpc.ServerCredentials,
  1052. ) -> int:
  1053. with state.lock:
  1054. return state.server.add_http2_port(
  1055. address, server_credentials._credentials
  1056. )
  1057. def _request_call(state: _ServerState) -> None:
  1058. state.server.request_call(
  1059. state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG
  1060. )
  1061. state.due.add(_REQUEST_CALL_TAG)
  1062. def _request_registered_call(state: _ServerState, method: str) -> None:
  1063. registered_call_tag = method
  1064. state.server.request_registered_call(
  1065. state.completion_queue,
  1066. state.completion_queue,
  1067. method,
  1068. registered_call_tag,
  1069. )
  1070. state.due.add(registered_call_tag)
  1071. # TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
  1072. def _stop_serving(state: _ServerState) -> bool:
  1073. if not state.rpc_states and not state.due:
  1074. state.server.destroy()
  1075. for shutdown_event in state.shutdown_events:
  1076. shutdown_event.set()
  1077. state.stage = _ServerStage.STOPPED
  1078. return True
  1079. else:
  1080. return False
  1081. def _on_call_completed(state: _ServerState) -> None:
  1082. with state.lock:
  1083. state.active_rpc_count -= 1
  1084. # pylint: disable=too-many-branches
  1085. def _process_event_and_continue(
  1086. state: _ServerState, event: cygrpc.BaseEvent
  1087. ) -> bool:
  1088. should_continue = True
  1089. if event.tag is _SHUTDOWN_TAG:
  1090. with state.lock:
  1091. state.due.remove(_SHUTDOWN_TAG)
  1092. if _stop_serving(state):
  1093. should_continue = False
  1094. elif (
  1095. event.tag is _REQUEST_CALL_TAG
  1096. or event.tag in state.registered_method_handlers.keys()
  1097. ):
  1098. registered_method_name = None
  1099. if event.tag in state.registered_method_handlers.keys():
  1100. registered_method_name = event.tag
  1101. method_with_handler = _RegisteredMethod(
  1102. registered_method_name,
  1103. state.registered_method_handlers.get(
  1104. registered_method_name, None
  1105. ),
  1106. )
  1107. else:
  1108. method_with_handler = _GenericMethod(
  1109. state.generic_handlers,
  1110. )
  1111. with state.lock:
  1112. state.due.remove(event.tag)
  1113. concurrency_exceeded = (
  1114. state.maximum_concurrent_rpcs is not None
  1115. and state.active_rpc_count >= state.maximum_concurrent_rpcs
  1116. )
  1117. rpc_state, rpc_future = _handle_call(
  1118. event,
  1119. method_with_handler,
  1120. state.interceptor_pipeline,
  1121. state.thread_pool,
  1122. concurrency_exceeded,
  1123. )
  1124. if rpc_state is not None:
  1125. state.rpc_states.add(rpc_state)
  1126. if rpc_future is not None:
  1127. state.active_rpc_count += 1
  1128. rpc_future.add_done_callback(
  1129. lambda unused_future: _on_call_completed(state)
  1130. )
  1131. if state.stage is _ServerStage.STARTED:
  1132. if (
  1133. registered_method_name
  1134. in state.registered_method_handlers.keys()
  1135. ):
  1136. _request_registered_call(state, registered_method_name)
  1137. else:
  1138. _request_call(state)
  1139. elif _stop_serving(state):
  1140. should_continue = False
  1141. else:
  1142. rpc_state, callbacks = event.tag(event)
  1143. for callback in callbacks:
  1144. try:
  1145. callback()
  1146. except Exception: # pylint: disable=broad-except
  1147. _LOGGER.exception("Exception calling callback!")
  1148. if rpc_state is not None:
  1149. with state.lock:
  1150. state.rpc_states.remove(rpc_state)
  1151. if _stop_serving(state):
  1152. should_continue = False
  1153. return should_continue
  1154. def _serve(state: _ServerState) -> None:
  1155. while True:
  1156. timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S
  1157. event = state.completion_queue.poll(timeout)
  1158. if state.server_deallocated:
  1159. _begin_shutdown_once(state)
  1160. if event.completion_type != cygrpc.CompletionType.queue_timeout:
  1161. if not _process_event_and_continue(state, event):
  1162. return
  1163. # We want to force the deletion of the previous event
  1164. # ~before~ we poll again; if the event has a reference
  1165. # to a shutdown Call object, this can induce spinlock.
  1166. event = None
  1167. def _begin_shutdown_once(state: _ServerState) -> None:
  1168. with state.lock:
  1169. if state.stage is _ServerStage.STARTED:
  1170. state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
  1171. state.stage = _ServerStage.GRACE
  1172. state.due.add(_SHUTDOWN_TAG)
  1173. def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event:
  1174. with state.lock:
  1175. if state.stage is _ServerStage.STOPPED:
  1176. shutdown_event = threading.Event()
  1177. shutdown_event.set()
  1178. return shutdown_event
  1179. else:
  1180. _begin_shutdown_once(state)
  1181. shutdown_event = threading.Event()
  1182. state.shutdown_events.append(shutdown_event)
  1183. if grace is None:
  1184. state.server.cancel_all_calls()
  1185. else:
  1186. def cancel_all_calls_after_grace():
  1187. shutdown_event.wait(timeout=grace)
  1188. with state.lock:
  1189. state.server.cancel_all_calls()
  1190. thread = threading.Thread(target=cancel_all_calls_after_grace)
  1191. thread.start()
  1192. return shutdown_event
  1193. shutdown_event.wait()
  1194. return shutdown_event
  1195. def _start(state: _ServerState) -> None:
  1196. with state.lock:
  1197. if state.stage is not _ServerStage.STOPPED:
  1198. raise ValueError("Cannot start already-started server!")
  1199. state.server.start()
  1200. state.stage = _ServerStage.STARTED
  1201. # Request a call for each registered method so we can handle any of them.
  1202. for method in state.registered_method_handlers.keys():
  1203. _request_registered_call(state, method)
  1204. # Also request a call for non-registered method.
  1205. _request_call(state)
  1206. thread = threading.Thread(target=_serve, args=(state,))
  1207. thread.daemon = True
  1208. thread.start()
  1209. def _validate_generic_rpc_handlers(
  1210. generic_rpc_handlers: Iterable[grpc.GenericRpcHandler],
  1211. ) -> None:
  1212. for generic_rpc_handler in generic_rpc_handlers:
  1213. service_attribute = getattr(generic_rpc_handler, "service", None)
  1214. if service_attribute is None:
  1215. raise AttributeError(
  1216. '"{}" must conform to grpc.GenericRpcHandler type but does '
  1217. 'not have "service" method!'.format(generic_rpc_handler)
  1218. )
  1219. def _augment_options(
  1220. base_options: Sequence[ChannelArgumentType],
  1221. compression: Optional[grpc.Compression],
  1222. xds: bool,
  1223. ) -> Sequence[ChannelArgumentType]:
  1224. compression_option = _compression.create_channel_option(compression)
  1225. maybe_server_call_tracer_factory_option = (
  1226. _observability.create_server_call_tracer_factory_option(xds)
  1227. )
  1228. return (
  1229. tuple(base_options)
  1230. + compression_option
  1231. + maybe_server_call_tracer_factory_option
  1232. )
  1233. class _Server(grpc.Server):
  1234. _state: _ServerState
  1235. # pylint: disable=too-many-arguments
  1236. def __init__(
  1237. self,
  1238. thread_pool: futures.ThreadPoolExecutor,
  1239. generic_handlers: Sequence[grpc.GenericRpcHandler],
  1240. interceptors: Sequence[grpc.ServerInterceptor],
  1241. options: Sequence[ChannelArgumentType],
  1242. maximum_concurrent_rpcs: Optional[int],
  1243. compression: Optional[grpc.Compression],
  1244. xds: bool,
  1245. ):
  1246. completion_queue = cygrpc.CompletionQueue()
  1247. server = cygrpc.Server(_augment_options(options, compression, xds), xds)
  1248. server.register_completion_queue(completion_queue)
  1249. self._state = _ServerState(
  1250. completion_queue,
  1251. server,
  1252. generic_handlers,
  1253. _interceptor.service_pipeline(interceptors),
  1254. thread_pool,
  1255. maximum_concurrent_rpcs,
  1256. )
  1257. self._cy_server = server
  1258. def add_generic_rpc_handlers(
  1259. self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]
  1260. ) -> None:
  1261. _validate_generic_rpc_handlers(generic_rpc_handlers)
  1262. _add_generic_handlers(self._state, generic_rpc_handlers)
  1263. def add_registered_method_handlers(
  1264. self,
  1265. service_name: str,
  1266. method_handlers: Dict[str, grpc.RpcMethodHandler],
  1267. ) -> None:
  1268. # Can't register method once server started.
  1269. with self._state.lock:
  1270. if self._state.stage is _ServerStage.STARTED:
  1271. return
  1272. # TODO(xuanwn): We should validate method_handlers first.
  1273. method_to_handlers = {
  1274. _common.fully_qualified_method(service_name, method): method_handler
  1275. for method, method_handler in method_handlers.items()
  1276. }
  1277. for fully_qualified_method in method_to_handlers.keys():
  1278. self._cy_server.register_method(fully_qualified_method)
  1279. _add_registered_method_handlers(self._state, method_to_handlers)
  1280. def add_insecure_port(self, address: str) -> int:
  1281. return _common.validate_port_binding_result(
  1282. address, _add_insecure_port(self._state, _common.encode(address))
  1283. )
  1284. def add_secure_port(
  1285. self, address: str, server_credentials: grpc.ServerCredentials
  1286. ) -> int:
  1287. return _common.validate_port_binding_result(
  1288. address,
  1289. _add_secure_port(
  1290. self._state, _common.encode(address), server_credentials
  1291. ),
  1292. )
  1293. def start(self) -> None:
  1294. _start(self._state)
  1295. def wait_for_termination(self, timeout: Optional[float] = None) -> bool:
  1296. # NOTE(https://bugs.python.org/issue35935)
  1297. # Remove this workaround once threading.Event.wait() is working with
  1298. # CTRL+C across platforms.
  1299. return _common.wait(
  1300. self._state.termination_event.wait,
  1301. self._state.termination_event.is_set,
  1302. timeout=timeout,
  1303. )
  1304. def stop(self, grace: Optional[float]) -> threading.Event:
  1305. return _stop(self._state, grace)
  1306. def __del__(self):
  1307. if hasattr(self, "_state"):
  1308. # We can not grab a lock in __del__(), so set a flag to signal the
  1309. # serving daemon thread (if it exists) to initiate shutdown.
  1310. self._state.server_deallocated = True
  1311. def create_server(
  1312. thread_pool: futures.ThreadPoolExecutor,
  1313. generic_rpc_handlers: Sequence[grpc.GenericRpcHandler],
  1314. interceptors: Sequence[grpc.ServerInterceptor],
  1315. options: Sequence[ChannelArgumentType],
  1316. maximum_concurrent_rpcs: Optional[int],
  1317. compression: Optional[grpc.Compression],
  1318. xds: bool,
  1319. ) -> _Server:
  1320. _validate_generic_rpc_handlers(generic_rpc_handlers)
  1321. return _Server(
  1322. thread_pool,
  1323. generic_rpc_handlers,
  1324. interceptors,
  1325. options,
  1326. maximum_concurrent_rpcs,
  1327. compression,
  1328. xds,
  1329. )