test_mixins.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #!/usr/bin/env python
  2. #
  3. # Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
  4. # Copyright (c) 2008-2016 California Institute of Technology.
  5. # Copyright (c) 2016-2024 The Uncertainty Quantification Foundation.
  6. # License: 3-clause BSD. The full license text is available at:
  7. # - https://github.com/uqfoundation/dill/blob/master/LICENSE
  8. import dill
  9. dill.settings['recurse'] = True
  10. def wtf(x,y,z):
  11. def zzz():
  12. return x
  13. def yyy():
  14. return y
  15. def xxx():
  16. return z
  17. return zzz,yyy
  18. def quad(a=1, b=1, c=0):
  19. inverted = [False]
  20. def invert():
  21. inverted[0] = not inverted[0]
  22. def dec(f):
  23. def func(*args, **kwds):
  24. x = f(*args, **kwds)
  25. if inverted[0]: x = -x
  26. return a*x**2 + b*x + c
  27. func.__wrapped__ = f
  28. func.invert = invert
  29. func.inverted = inverted
  30. return func
  31. return dec
  32. @quad(a=0,b=2)
  33. def double_add(*args):
  34. return sum(args)
  35. fx = sum([1,2,3])
  36. ### to make it interesting...
  37. def quad_factory(a=1,b=1,c=0):
  38. def dec(f):
  39. def func(*args,**kwds):
  40. fx = f(*args,**kwds)
  41. return a*fx**2 + b*fx + c
  42. return func
  43. return dec
  44. @quad_factory(a=0,b=4,c=0)
  45. def quadish(x):
  46. return x+1
  47. quadratic = quad_factory()
  48. def doubler(f):
  49. def inner(*args, **kwds):
  50. fx = f(*args, **kwds)
  51. return 2*fx
  52. return inner
  53. @doubler
  54. def quadruple(x):
  55. return 2*x
  56. def test_mixins():
  57. # test mixins
  58. assert double_add(1,2,3) == 2*fx
  59. double_add.invert()
  60. assert double_add(1,2,3) == -2*fx
  61. _d = dill.copy(double_add)
  62. assert _d(1,2,3) == -2*fx
  63. #_d.invert() #FIXME: fails seemingly randomly
  64. #assert _d(1,2,3) == 2*fx
  65. assert _d.__wrapped__(1,2,3) == fx
  66. # XXX: issue or feature? in python3.4, inverted is linked through copy
  67. if not double_add.inverted[0]:
  68. double_add.invert()
  69. # test some stuff from source and pointers
  70. ds = dill.source
  71. dd = dill.detect
  72. assert ds.getsource(dd.freevars(quadish)['f']) == '@quad_factory(a=0,b=4,c=0)\ndef quadish(x):\n return x+1\n'
  73. assert ds.getsource(dd.freevars(quadruple)['f']) == '@doubler\ndef quadruple(x):\n return 2*x\n'
  74. assert ds.importable(quadish, source=False) == 'from %s import quadish\n' % __name__
  75. assert ds.importable(quadruple, source=False) == 'from %s import quadruple\n' % __name__
  76. assert ds.importable(quadratic, source=False) == 'from %s import quadratic\n' % __name__
  77. assert ds.importable(double_add, source=False) == 'from %s import double_add\n' % __name__
  78. assert ds.importable(quadruple, source=True) == 'def doubler(f):\n def inner(*args, **kwds):\n fx = f(*args, **kwds)\n return 2*fx\n return inner\n\n@doubler\ndef quadruple(x):\n return 2*x\n'
  79. #***** #FIXME: this needs work
  80. result = ds.importable(quadish, source=True)
  81. a,b,c,_,result = result.split('\n',4)
  82. assert result == 'def quad_factory(a=1,b=1,c=0):\n def dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n return dec\n\n@quad_factory(a=0,b=4,c=0)\ndef quadish(x):\n return x+1\n'
  83. assert set([a,b,c]) == set(['a = 0', 'c = 0', 'b = 4'])
  84. result = ds.importable(quadratic, source=True)
  85. a,b,c,result = result.split('\n',3)
  86. assert result == '\ndef dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n'
  87. assert set([a,b,c]) == set(['a = 1', 'c = 0', 'b = 1'])
  88. result = ds.importable(double_add, source=True)
  89. a,b,c,d,_,result = result.split('\n',5)
  90. assert result == 'def quad(a=1, b=1, c=0):\n inverted = [False]\n def invert():\n inverted[0] = not inverted[0]\n def dec(f):\n def func(*args, **kwds):\n x = f(*args, **kwds)\n if inverted[0]: x = -x\n return a*x**2 + b*x + c\n func.__wrapped__ = f\n func.invert = invert\n func.inverted = inverted\n return func\n return dec\n\n@quad(a=0,b=2)\ndef double_add(*args):\n return sum(args)\n'
  91. assert set([a,b,c,d]) == set(['a = 0', 'c = 0', 'b = 2', 'inverted = [True]'])
  92. #*****
  93. if __name__ == '__main__':
  94. test_mixins()