bookmark_manager.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) "Neo4j"
  2. # Neo4j Sweden AB [https://neo4j.com]
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # https://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import annotations
  16. import typing as t
  17. from .._async_compat.concurrency import AsyncCooperativeLock
  18. from .._async_compat.util import AsyncUtil
  19. from ..api import (
  20. AsyncBookmarkManager,
  21. Bookmarks,
  22. )
  23. TBmSupplier = t.Callable[[], t.Union[Bookmarks, t.Awaitable[Bookmarks]]]
  24. TBmConsumer = t.Callable[[Bookmarks], t.Union[None, t.Awaitable[None]]]
  25. def _bookmarks_to_set(
  26. bookmarks: Bookmarks | t.Iterable[str],
  27. ) -> set[str]:
  28. if isinstance(bookmarks, Bookmarks):
  29. return set(bookmarks.raw_values)
  30. return set(map(str, bookmarks))
  31. class AsyncNeo4jBookmarkManager(AsyncBookmarkManager):
  32. def __init__(
  33. self,
  34. initial_bookmarks: Bookmarks | t.Iterable[str] | None = None,
  35. bookmarks_supplier: TBmSupplier | None = None,
  36. bookmarks_consumer: TBmConsumer | None = None,
  37. ) -> None:
  38. super().__init__()
  39. self._bookmarks_supplier = bookmarks_supplier
  40. self._bookmarks_consumer = bookmarks_consumer
  41. if not initial_bookmarks:
  42. self._bookmarks = set()
  43. else:
  44. if not hasattr(initial_bookmarks, "raw_values"):
  45. initial_bookmarks = Bookmarks.from_raw_values(
  46. t.cast(t.Iterable[str], initial_bookmarks)
  47. )
  48. self._bookmarks = set(
  49. t.cast(Bookmarks, initial_bookmarks).raw_values
  50. )
  51. self._lock = AsyncCooperativeLock()
  52. async def update_bookmarks(
  53. self,
  54. previous_bookmarks: t.Collection[str],
  55. new_bookmarks: t.Collection[str],
  56. ) -> None:
  57. if not new_bookmarks:
  58. return
  59. with self._lock:
  60. self._bookmarks.difference_update(previous_bookmarks)
  61. self._bookmarks.update(new_bookmarks)
  62. if self._bookmarks_consumer:
  63. curr_bms_snapshot = Bookmarks.from_raw_values(self._bookmarks)
  64. if self._bookmarks_consumer:
  65. await AsyncUtil.callback(
  66. self._bookmarks_consumer, curr_bms_snapshot
  67. )
  68. async def get_bookmarks(self) -> set[str]:
  69. with self._lock:
  70. bms = set(self._bookmarks)
  71. if self._bookmarks_supplier:
  72. extra_bms = await AsyncUtil.callback(self._bookmarks_supplier)
  73. bms.update(extra_bms.raw_values)
  74. return bms