numpy.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. from typing import TYPE_CHECKING, Any
  20. from airflow.utils.module_loading import import_string, qualname
  21. # lazy loading for performance reasons
  22. serializers = [
  23. "numpy.int8",
  24. "numpy.int16",
  25. "numpy.int32",
  26. "numpy.int64",
  27. "numpy.uint8",
  28. "numpy.uint16",
  29. "numpy.uint32",
  30. "numpy.uint64",
  31. "numpy.bool_",
  32. "numpy.float64",
  33. "numpy.float16",
  34. "numpy.complex128",
  35. "numpy.complex64",
  36. ]
  37. if TYPE_CHECKING:
  38. from airflow.serialization.serde import U
  39. deserializers = serializers
  40. __version__ = 1
  41. def serialize(o: object) -> tuple[U, str, int, bool]:
  42. import numpy as np
  43. if np is None:
  44. return "", "", 0, False
  45. name = qualname(o)
  46. if isinstance(
  47. o,
  48. (
  49. np.int_,
  50. np.intc,
  51. np.intp,
  52. np.int8,
  53. np.int16,
  54. np.int32,
  55. np.int64,
  56. np.uint8,
  57. np.uint16,
  58. np.uint32,
  59. np.uint64,
  60. ),
  61. ):
  62. return int(o), name, __version__, True
  63. if isinstance(o, np.bool_):
  64. return bool(np), name, __version__, True
  65. if isinstance(
  66. o, (np.float_, np.float16, np.float32, np.float64, np.complex_, np.complex64, np.complex128)
  67. ):
  68. return float(o), name, __version__, True
  69. return "", "", 0, False
  70. def deserialize(classname: str, version: int, data: str) -> Any:
  71. if version > __version__:
  72. raise TypeError("serialized version is newer than class version")
  73. if classname not in deserializers:
  74. raise TypeError(f"unsupported {classname} found for numpy deserialization")
  75. return import_string(classname)(data)