Source code for nannyml.io.base

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#
#  License: Apache Software License 2.0

import logging
import re
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Tuple
from urllib.parse import urlsplit

import pandas as pd

from nannyml._typing import Result
from nannyml.exceptions import InvalidArgumentsException, ReaderException, WriterException

HTTP_PROTOCOLS = ['http', 'https']
CLOUD_PROTOCOLS = ['s3', 'gcs', 'gs', 'adl', 'abfs', 'abfss']


[docs]class Writer(ABC): """Base class for writing Result instances to an external medium such as disk, database or API.""" @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__)
[docs] def write(self, result: Result, **kwargs) -> Any: if result is None: raise InvalidArgumentsException("Trying to write 'None'") if kwargs is None: kwargs = {} try: self._write(result, **kwargs) except Exception as exc: raise WriterException(f"Failed writing data. \n{str(exc)}")
@abstractmethod def _write(self, result: Result, **kwargs): raise NotImplementedError( f"'{self.__class__.__name__}' is a subclass of Writer and it must implement the _write method" )
[docs]class WriterFactory: """A factory class that produces Writer instances for a given ``key``. The value for this ``key`` is passed along explicitly by the user, either by providing it directly during ``Writer`` initialization or passed along in the ``nann.yml`` configuration file. """ registry: Dict[str, Writer] = {} @classmethod def _logger(cls) -> logging.Logger: return logging.getLogger(__name__)
[docs] @classmethod def create(cls, key, kwargs: Dict[str, Any] = None) -> Writer: """Returns a Writer instance for a given string.""" if kwargs is None: kwargs = {} if key not in cls.registry: raise InvalidArgumentsException( f"unknown key '{key}' given. " f"Currently registered keys are: {list(cls.registry.keys())}" ) writer_class = cls.registry[key] return writer_class(**kwargs) # type: ignore
[docs] @classmethod def register(cls, key) -> Callable: def inner_wrapper(wrapped_class: Writer) -> Writer: if key in cls.registry: cls._logger().warning(f"re-registering Writer for key='{key}'") cls.registry[key] = wrapped_class return wrapped_class return inner_wrapper
[docs]class Reader(ABC): """Base class for reading data""" @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__)
[docs] def read(self) -> pd.DataFrame: try: return self._read() except Exception as exc: raise ReaderException(f"Failed reading data. \n{str(exc)}")
@abstractmethod def _read(self) -> pd.DataFrame: raise NotImplementedError( f"'{self.__class__.__name__}' is a subclass of Reader and it must implement the _read method" )
def _get_protocol_and_path(filepath: str) -> Tuple[str, str]: if re.match(r"^[a-zA-Z]:[\\/]", filepath) or re.match(r"^[a-zA-Z\d]+://", filepath) is None: return "file", filepath parsed_path = urlsplit(filepath) protocol = parsed_path.scheme or "file" path = parsed_path.path if protocol in HTTP_PROTOCOLS: return protocol, path if protocol == "file": windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path) if windows_path: path = ":".join(windows_path.groups()) if parsed_path.netloc: if protocol in CLOUD_PROTOCOLS: host_with_port = parsed_path.netloc.rsplit("@", 1)[-1] host = host_with_port.rsplit(":", 1)[0] path = host + path return protocol, path def _get_filepath_str(path: str, protocol: str) -> str: if protocol in HTTP_PROTOCOLS: path = "".join((protocol, "://", path)) return path