test_session.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #!/usr/bin/env python
  2. # Author: Leonardo Gama (@leogama)
  3. # Copyright (c) 2022-2024 The Uncertainty Quantification Foundation.
  4. # License: 3-clause BSD. The full license text is available at:
  5. # - https://github.com/uqfoundation/dill/blob/master/LICENSE
  6. import atexit
  7. import os
  8. import sys
  9. import __main__
  10. from contextlib import suppress
  11. from io import BytesIO
  12. import dill
  13. session_file = os.path.join(os.path.dirname(__file__), 'session-refimported-%s.pkl')
  14. ###################
  15. # Child process #
  16. ###################
  17. def _error_line(error, obj, refimported):
  18. import traceback
  19. line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']')
  20. return "while testing (with refimported=%s): %s" % (refimported, line.lstrip())
  21. if __name__ == '__main__' and len(sys.argv) >= 3 and sys.argv[1] == '--child':
  22. # Test session loading in a fresh interpreter session.
  23. refimported = (sys.argv[2] == 'True')
  24. dill.load_module(session_file % refimported, module='__main__')
  25. def test_modules(refimported):
  26. # FIXME: In this test setting with CPython 3.7, 'calendar' is not included
  27. # in sys.modules, independent of the value of refimported. Tried to
  28. # run garbage collection just before loading the session with no luck. It
  29. # fails even when preceding them with 'import calendar'. Needed to run
  30. # these kinds of tests in a supbrocess. Failing test sample:
  31. # assert globals()['day_name'] is sys.modules['calendar'].__dict__['day_name']
  32. try:
  33. for obj in ('json', 'url', 'local_mod', 'sax', 'dom'):
  34. assert globals()[obj].__name__ in sys.modules
  35. assert 'calendar' in sys.modules and 'cmath' in sys.modules
  36. import calendar, cmath
  37. for obj in ('Calendar', 'isleap'):
  38. assert globals()[obj] is sys.modules['calendar'].__dict__[obj]
  39. assert __main__.day_name.__module__ == 'calendar'
  40. if refimported:
  41. assert __main__.day_name is calendar.day_name
  42. assert __main__.complex_log is cmath.log
  43. except AssertionError as error:
  44. error.args = (_error_line(error, obj, refimported),)
  45. raise
  46. test_modules(refimported)
  47. sys.exit()
  48. ####################
  49. # Parent process #
  50. ####################
  51. # Create various kinds of objects to test different internal logics.
  52. ## Modules.
  53. import json # top-level module
  54. import urllib as url # top-level module under alias
  55. from xml import sax # submodule
  56. import xml.dom.minidom as dom # submodule under alias
  57. import test_dictviews as local_mod # non-builtin top-level module
  58. ## Imported objects.
  59. from calendar import Calendar, isleap, day_name # class, function, other object
  60. from cmath import log as complex_log # imported with alias
  61. ## Local objects.
  62. x = 17
  63. empty = None
  64. names = ['Alice', 'Bob', 'Carol']
  65. def squared(x): return x**2
  66. cubed = lambda x: x**3
  67. class Person:
  68. def __init__(self, name, age):
  69. self.name = name
  70. self.age = age
  71. person = Person(names[0], x)
  72. class CalendarSubclass(Calendar):
  73. def weekdays(self):
  74. return [day_name[i] for i in self.iterweekdays()]
  75. cal = CalendarSubclass()
  76. selfref = __main__
  77. # Setup global namespace for session saving tests.
  78. class TestNamespace:
  79. test_globals = globals().copy()
  80. def __init__(self, **extra):
  81. self.extra = extra
  82. def __enter__(self):
  83. self.backup = globals().copy()
  84. globals().clear()
  85. globals().update(self.test_globals)
  86. globals().update(self.extra)
  87. return self
  88. def __exit__(self, *exc_info):
  89. globals().clear()
  90. globals().update(self.backup)
  91. def _clean_up_cache(module):
  92. cached = module.__file__.split('.', 1)[0] + '.pyc'
  93. cached = module.__cached__ if hasattr(module, '__cached__') else cached
  94. pycache = os.path.join(os.path.dirname(module.__file__), '__pycache__')
  95. for remove, file in [(os.remove, cached), (os.removedirs, pycache)]:
  96. with suppress(OSError):
  97. remove(file)
  98. atexit.register(_clean_up_cache, local_mod)
  99. def _test_objects(main, globals_copy, refimported):
  100. try:
  101. main_dict = __main__.__dict__
  102. global Person, person, Calendar, CalendarSubclass, cal, selfref
  103. for obj in ('json', 'url', 'local_mod', 'sax', 'dom'):
  104. assert globals()[obj].__name__ == globals_copy[obj].__name__
  105. for obj in ('x', 'empty', 'names'):
  106. assert main_dict[obj] == globals_copy[obj]
  107. for obj in ['squared', 'cubed']:
  108. assert main_dict[obj].__globals__ is main_dict
  109. assert main_dict[obj](3) == globals_copy[obj](3)
  110. assert Person.__module__ == __main__.__name__
  111. assert isinstance(person, Person)
  112. assert person.age == globals_copy['person'].age
  113. assert issubclass(CalendarSubclass, Calendar)
  114. assert isinstance(cal, CalendarSubclass)
  115. assert cal.weekdays() == globals_copy['cal'].weekdays()
  116. assert selfref is __main__
  117. except AssertionError as error:
  118. error.args = (_error_line(error, obj, refimported),)
  119. raise
  120. def test_session_main(refimported):
  121. """test dump/load_module() for __main__, both in this process and in a subprocess"""
  122. extra_objects = {}
  123. if refimported:
  124. # Test unpickleable imported object in main.
  125. from sys import flags
  126. extra_objects['flags'] = flags
  127. with TestNamespace(**extra_objects) as ns:
  128. try:
  129. # Test session loading in a new session.
  130. dill.dump_module(session_file % refimported, refimported=refimported)
  131. from dill.tests.__main__ import python, shell, sp
  132. error = sp.call([python, __file__, '--child', str(refimported)], shell=shell)
  133. if error: sys.exit(error)
  134. finally:
  135. with suppress(OSError):
  136. os.remove(session_file % refimported)
  137. # Test session loading in the same session.
  138. session_buffer = BytesIO()
  139. dill.dump_module(session_buffer, refimported=refimported)
  140. session_buffer.seek(0)
  141. dill.load_module(session_buffer, module='__main__')
  142. ns.backup['_test_objects'](__main__, ns.backup, refimported)
  143. def test_session_other():
  144. """test dump/load_module() for a module other than __main__"""
  145. import test_classdef as module
  146. atexit.register(_clean_up_cache, module)
  147. module.selfref = module
  148. dict_objects = [obj for obj in module.__dict__.keys() if not obj.startswith('__')]
  149. session_buffer = BytesIO()
  150. dill.dump_module(session_buffer, module)
  151. for obj in dict_objects:
  152. del module.__dict__[obj]
  153. session_buffer.seek(0)
  154. dill.load_module(session_buffer, module)
  155. assert all(obj in module.__dict__ for obj in dict_objects)
  156. assert module.selfref is module
  157. def test_runtime_module():
  158. from types import ModuleType
  159. modname = '__runtime__'
  160. runtime = ModuleType(modname)
  161. runtime.x = 42
  162. mod = dill.session._stash_modules(runtime)
  163. if mod is not runtime:
  164. print("There are objects to save by referenece that shouldn't be:",
  165. mod.__dill_imported, mod.__dill_imported_as, mod.__dill_imported_top_level,
  166. file=sys.stderr)
  167. # This is also for code coverage, tests the use case of dump_module(refimported=True)
  168. # without imported objects in the namespace. It's a contrived example because
  169. # even dill can't be in it. This should work after fixing #462.
  170. session_buffer = BytesIO()
  171. dill.dump_module(session_buffer, module=runtime, refimported=True)
  172. session_dump = session_buffer.getvalue()
  173. # Pass a new runtime created module with the same name.
  174. runtime = ModuleType(modname) # empty
  175. return_val = dill.load_module(BytesIO(session_dump), module=runtime)
  176. assert return_val is None
  177. assert runtime.__name__ == modname
  178. assert runtime.x == 42
  179. assert runtime not in sys.modules.values()
  180. # Pass nothing as main. load_module() must create it.
  181. session_buffer.seek(0)
  182. runtime = dill.load_module(BytesIO(session_dump))
  183. assert runtime.__name__ == modname
  184. assert runtime.x == 42
  185. assert runtime not in sys.modules.values()
  186. def test_refimported_imported_as():
  187. import collections
  188. import concurrent.futures
  189. import types
  190. import typing
  191. mod = sys.modules['__test__'] = types.ModuleType('__test__')
  192. dill.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
  193. mod.Dict = collections.UserDict # select by type
  194. mod.AsyncCM = typing.AsyncContextManager # select by __module__
  195. mod.thread_exec = dill.executor # select by __module__ with regex
  196. session_buffer = BytesIO()
  197. dill.dump_module(session_buffer, mod, refimported=True)
  198. session_buffer.seek(0)
  199. mod = dill.load(session_buffer)
  200. del sys.modules['__test__']
  201. assert set(mod.__dill_imported_as) == {
  202. ('collections', 'UserDict', 'Dict'),
  203. ('typing', 'AsyncContextManager', 'AsyncCM'),
  204. ('dill', 'executor', 'thread_exec'),
  205. }
  206. def test_load_module_asdict():
  207. with TestNamespace():
  208. session_buffer = BytesIO()
  209. dill.dump_module(session_buffer)
  210. global empty, names, x, y
  211. x = y = 0 # change x and create y
  212. del empty
  213. globals_state = globals().copy()
  214. session_buffer.seek(0)
  215. main_vars = dill.load_module_asdict(session_buffer)
  216. assert main_vars is not globals()
  217. assert globals() == globals_state
  218. assert main_vars['__name__'] == '__main__'
  219. assert main_vars['names'] == names
  220. assert main_vars['names'] is not names
  221. assert main_vars['x'] != x
  222. assert 'y' not in main_vars
  223. assert 'empty' in main_vars
  224. if __name__ == '__main__':
  225. test_session_main(refimported=False)
  226. test_session_main(refimported=True)
  227. test_session_other()
  228. test_runtime_module()
  229. test_refimported_imported_as()
  230. test_load_module_asdict()