test_functions.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #!/usr/bin/env python
  2. #
  3. # Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
  4. # Copyright (c) 2019-2024 The Uncertainty Quantification Foundation.
  5. # License: 3-clause BSD. The full license text is available at:
  6. # - https://github.com/uqfoundation/dill/blob/master/LICENSE
  7. import functools
  8. import dill
  9. import sys
  10. dill.settings['recurse'] = True
  11. def function_a(a):
  12. return a
  13. def function_b(b, b1):
  14. return b + b1
  15. def function_c(c, c1=1):
  16. return c + c1
  17. def function_d(d, d1, d2=1):
  18. """doc string"""
  19. return d + d1 + d2
  20. function_d.__module__ = 'a module'
  21. exec('''
  22. def function_e(e, *e1, e2=1, e3=2):
  23. return e + sum(e1) + e2 + e3''')
  24. globalvar = 0
  25. @functools.lru_cache(None)
  26. def function_with_cache(x):
  27. global globalvar
  28. globalvar += x
  29. return globalvar
  30. def function_with_unassigned_variable():
  31. if False:
  32. value = None
  33. return (lambda: value)
  34. def test_issue_510():
  35. # A very bizzare use of functions and methods that pickle doesn't get
  36. # correctly for odd reasons.
  37. class Foo:
  38. def __init__(self):
  39. def f2(self):
  40. return self
  41. self.f2 = f2.__get__(self)
  42. import dill, pickletools
  43. f = Foo()
  44. f1 = dill.copy(f)
  45. assert f1.f2() is f1
  46. def test_functions():
  47. dumped_func_a = dill.dumps(function_a)
  48. assert dill.loads(dumped_func_a)(0) == 0
  49. dumped_func_b = dill.dumps(function_b)
  50. assert dill.loads(dumped_func_b)(1,2) == 3
  51. dumped_func_c = dill.dumps(function_c)
  52. assert dill.loads(dumped_func_c)(1) == 2
  53. assert dill.loads(dumped_func_c)(1, 2) == 3
  54. dumped_func_d = dill.dumps(function_d)
  55. assert dill.loads(dumped_func_d).__doc__ == function_d.__doc__
  56. assert dill.loads(dumped_func_d).__module__ == function_d.__module__
  57. assert dill.loads(dumped_func_d)(1, 2) == 4
  58. assert dill.loads(dumped_func_d)(1, 2, 3) == 6
  59. assert dill.loads(dumped_func_d)(1, 2, d2=3) == 6
  60. function_with_cache(1)
  61. globalvar = 0
  62. dumped_func_cache = dill.dumps(function_with_cache)
  63. assert function_with_cache(2) == 3
  64. assert function_with_cache(1) == 1
  65. assert function_with_cache(3) == 6
  66. assert function_with_cache(2) == 3
  67. empty_cell = function_with_unassigned_variable()
  68. cell_copy = dill.loads(dill.dumps(empty_cell))
  69. assert 'empty' in str(cell_copy.__closure__[0])
  70. try:
  71. cell_copy()
  72. except Exception:
  73. # this is good
  74. pass
  75. else:
  76. raise AssertionError('cell_copy() did not read an empty cell')
  77. exec('''
  78. dumped_func_e = dill.dumps(function_e)
  79. assert dill.loads(dumped_func_e)(1, 2) == 6
  80. assert dill.loads(dumped_func_e)(1, 2, 3) == 9
  81. assert dill.loads(dumped_func_e)(1, 2, e2=3) == 8
  82. assert dill.loads(dumped_func_e)(1, 2, e2=3, e3=4) == 10
  83. assert dill.loads(dumped_func_e)(1, 2, 3, e2=4) == 12
  84. assert dill.loads(dumped_func_e)(1, 2, 3, e2=4, e3=5) == 15''')
  85. def test_code_object():
  86. import warnings
  87. from dill._dill import ALL_CODE_PARAMS, CODE_PARAMS, CODE_VERSION, _create_code
  88. code = function_c.__code__
  89. warnings.filterwarnings('ignore', category=DeprecationWarning) # issue 597
  90. LNOTAB = getattr(code, 'co_lnotab', b'')
  91. if warnings.filters: del warnings.filters[0]
  92. fields = {f: getattr(code, 'co_'+f) for f in CODE_PARAMS}
  93. fields.setdefault('posonlyargcount', 0) # python >= 3.8
  94. fields.setdefault('lnotab', LNOTAB) # python <= 3.9
  95. fields.setdefault('linetable', b'') # python >= 3.10
  96. fields.setdefault('qualname', fields['name']) # python >= 3.11
  97. fields.setdefault('exceptiontable', b'') # python >= 3.11
  98. fields.setdefault('endlinetable', None) # python == 3.11a
  99. fields.setdefault('columntable', None) # python == 3.11a
  100. for version, _, params in ALL_CODE_PARAMS:
  101. args = tuple(fields[p] for p in params.split())
  102. try:
  103. _create_code(*args)
  104. if version >= (3,10):
  105. _create_code(fields['lnotab'], *args)
  106. except Exception as error:
  107. raise Exception("failed to construct code object with format version {}".format(version)) from error
  108. if __name__ == '__main__':
  109. test_functions()
  110. test_issue_510()
  111. test_code_object()