test_recursive.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #!/usr/bin/env python
  2. #
  3. # Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
  4. # Copyright (c) 2019-2024 The Uncertainty Quantification Foundation.
  5. # License: 3-clause BSD. The full license text is available at:
  6. # - https://github.com/uqfoundation/dill/blob/master/LICENSE
  7. import dill
  8. from functools import partial
  9. import warnings
  10. def copy(obj, byref=False, recurse=False):
  11. if byref:
  12. try:
  13. return dill.copy(obj, byref=byref, recurse=recurse)
  14. except Exception:
  15. pass
  16. else:
  17. raise AssertionError('Copy of %s with byref=True should have given a warning!' % (obj,))
  18. warnings.simplefilter('ignore')
  19. val = dill.copy(obj, byref=byref, recurse=recurse)
  20. warnings.simplefilter('error')
  21. return val
  22. else:
  23. return dill.copy(obj, byref=byref, recurse=recurse)
  24. class obj1(object):
  25. def __init__(self):
  26. super(obj1, self).__init__()
  27. class obj2(object):
  28. def __init__(self):
  29. super(obj2, self).__init__()
  30. class obj3(object):
  31. super_ = super
  32. def __init__(self):
  33. obj3.super_(obj3, self).__init__()
  34. def test_super():
  35. assert copy(obj1(), byref=True)
  36. assert copy(obj1(), byref=True, recurse=True)
  37. assert copy(obj1(), recurse=True)
  38. assert copy(obj1())
  39. assert copy(obj2(), byref=True)
  40. assert copy(obj2(), byref=True, recurse=True)
  41. assert copy(obj2(), recurse=True)
  42. assert copy(obj2())
  43. assert copy(obj3(), byref=True)
  44. assert copy(obj3(), byref=True, recurse=True)
  45. assert copy(obj3(), recurse=True)
  46. assert copy(obj3())
  47. def get_trigger(model):
  48. pass
  49. class Machine(object):
  50. def __init__(self):
  51. self.child = Model()
  52. self.trigger = partial(get_trigger, self)
  53. self.child.trigger = partial(get_trigger, self.child)
  54. class Model(object):
  55. pass
  56. def test_partial():
  57. assert copy(Machine(), byref=True)
  58. assert copy(Machine(), byref=True, recurse=True)
  59. assert copy(Machine(), recurse=True)
  60. assert copy(Machine())
  61. class Machine2(object):
  62. def __init__(self):
  63. self.go = partial(self.member, self)
  64. def member(self, model):
  65. pass
  66. class SubMachine(Machine2):
  67. def __init__(self):
  68. super(SubMachine, self).__init__()
  69. def test_partials():
  70. assert copy(SubMachine(), byref=True)
  71. assert copy(SubMachine(), byref=True, recurse=True)
  72. assert copy(SubMachine(), recurse=True)
  73. assert copy(SubMachine())
  74. class obj4(object):
  75. def __init__(self):
  76. super(obj4, self).__init__()
  77. a = self
  78. class obj5(object):
  79. def __init__(self):
  80. super(obj5, self).__init__()
  81. self.a = a
  82. self.b = obj5()
  83. def test_circular_reference():
  84. assert copy(obj4())
  85. obj4_copy = dill.loads(dill.dumps(obj4()))
  86. assert type(obj4_copy) is type(obj4_copy).__init__.__closure__[0].cell_contents
  87. assert type(obj4_copy.b) is type(obj4_copy.b).__init__.__closure__[0].cell_contents
  88. def f():
  89. def g():
  90. return g
  91. return g
  92. def test_function_cells():
  93. assert copy(f())
  94. def fib(n):
  95. assert n >= 0
  96. if n <= 1:
  97. return n
  98. else:
  99. return fib(n-1) + fib(n-2)
  100. def test_recursive_function():
  101. global fib
  102. fib2 = copy(fib, recurse=True)
  103. fib3 = copy(fib)
  104. fib4 = fib
  105. del fib
  106. assert fib2(5) == 5
  107. for _fib in (fib3, fib4):
  108. try:
  109. _fib(5)
  110. except Exception:
  111. # This is expected to fail because fib no longer exists
  112. pass
  113. else:
  114. raise AssertionError("Function fib shouldn't have been found")
  115. fib = fib4
  116. def collection_function_recursion():
  117. d = {}
  118. def g():
  119. return d
  120. d['g'] = g
  121. return g
  122. def test_collection_function_recursion():
  123. g = copy(collection_function_recursion())
  124. assert g()['g'] is g
  125. if __name__ == '__main__':
  126. with warnings.catch_warnings():
  127. warnings.simplefilter('error')
  128. test_super()
  129. test_partial()
  130. test_partials()
  131. test_circular_reference()
  132. test_function_cells()
  133. test_recursive_function()
  134. test_collection_function_recursion()