Source code for nannyml.io.file_writer

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#  #
#  License: Apache Software License 2.0

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#
#  License: Apache Software License 2.0
import logging
from copy import deepcopy
from io import BytesIO
from pathlib import Path, PurePosixPath
from typing import Any, Dict

import fsspec

from nannyml._typing import Result
from nannyml.exceptions import InvalidArgumentsException
from nannyml.io.base import Writer, get_filepath_str, get_protocol_and_path


[docs]class FileWriter(Writer): _logger = logging.getLogger(__name__) def __init__( self, filepath: str, data_format: str, write_args: Dict[str, Any] = None, credentials: Dict[str, Any] = None, fs_args: Dict[str, Any] = None, ): _fs_args = deepcopy(fs_args) or {} _credentials = deepcopy(credentials) or {} self._data_format = data_format protocol, path = get_protocol_and_path(filepath) if protocol == "file": _fs_args.setdefault("auto_mkdir", True) self._protocol = protocol self._storage_options = {**_credentials, **_fs_args} self._fs = fsspec.filesystem(self._protocol, **self._storage_options) super().__init__(filepath=PurePosixPath(path)) self._write_args = write_args or {} # type: Dict[str, Any] def _write(self, result: Result): write_path = get_filepath_str(self._filepath, self._protocol) images_path = Path(write_path) / result.calculator_name / "images" images_path.mkdir(parents=True, exist_ok=True) plots = result.plots.items() self._logger.debug(f"writing {len(plots)} images to {images_path}") for name, image in plots: _write_bytes_to_filesystem(image.to_image(format='png'), images_path / f'{name}.png', self._fs) data_path = Path(write_path) / result.calculator_name / "data" data_path.mkdir(parents=True, exist_ok=True) self._logger.debug(f"writing data to {data_path}") bytes_buffer = BytesIO() if self._data_format == "parquet": result.data.to_parquet(bytes_buffer, **self._write_args) _write_bytes_to_filesystem(bytes_buffer.getvalue(), data_path / f"{result.calculator_name}.pq", self._fs) elif self._data_format == "csv": result.data.to_csv(bytes_buffer, **self._write_args) _write_bytes_to_filesystem(bytes_buffer.getvalue(), data_path / f"{result.calculator_name}.csv", self._fs) else: raise InvalidArgumentsException(f"unknown value for format '{format}', should be one of 'parquet', 'csv'")
def _write_bytes_to_filesystem(bytez, save_path: Path, fs: fsspec.spec.AbstractFileSystem): with fs.open(str(save_path), mode="wb") as fs_file: fs_file.write(bytez)