pandas.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. from typing import TYPE_CHECKING
  20. from airflow.utils.module_loading import qualname
  21. # lazy loading for performance reasons
  22. serializers = [
  23. "pandas.core.frame.DataFrame",
  24. ]
  25. deserializers = serializers
  26. if TYPE_CHECKING:
  27. import pandas as pd
  28. from airflow.serialization.serde import U
  29. __version__ = 1
  30. def serialize(o: object) -> tuple[U, str, int, bool]:
  31. import pandas as pd
  32. import pyarrow as pa
  33. from pyarrow import parquet as pq
  34. if not isinstance(o, pd.DataFrame):
  35. return "", "", 0, False
  36. # for now, we *always* serialize into in memory
  37. # until we have a generic backend that manages
  38. # sinks
  39. table = pa.Table.from_pandas(o)
  40. buf = pa.BufferOutputStream()
  41. pq.write_table(table, buf, compression="snappy")
  42. return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True
  43. def deserialize(classname: str, version: int, data: object) -> pd.DataFrame:
  44. if version > __version__:
  45. raise TypeError(f"serialized {version} of {classname} > {__version__}")
  46. from pyarrow import parquet as pq
  47. if not isinstance(data, str):
  48. raise TypeError(f"serialized {classname} has wrong data type {type(data)}")
  49. from io import BytesIO
  50. with BytesIO(bytes.fromhex(data)) as buf:
  51. df = pq.read_table(buf).to_pandas()
  52. return df