Source code for nannyml.io.base

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#
#  License: Apache Software License 2.0

import logging
import re
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Type
from urllib.parse import urlsplit

import pandas as pd

from nannyml._typing import Result
from nannyml.exceptions import InvalidArgumentsException, ReaderException, WriterException

HTTP_PROTOCOLS = ['http', 'https']
CLOUD_PROTOCOLS = ['s3', 'gcs', 'gs', 'adl', 'abfs', 'abfss', 'az']


[docs]class Writer(ABC): """Base class for writing ``Result`` instances to an external medium such as disk, database or API.""" @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__)
[docs] def write(self, result: Result, **kwargs) -> Any: if result is None: raise InvalidArgumentsException("Trying to write 'None'") if kwargs is None: kwargs = {} try: self._write(result, **kwargs) except Exception as exc: raise WriterException(f"Failed writing data. \n{str(exc)}")
@abstractmethod def _write(self, result: Result, **kwargs): raise NotImplementedError( f"'{self.__class__.__name__}' is a subclass of Writer and it must implement the _write method" )
[docs]class WriterFactory: """A factory class that produces :class:`~nannyml.io.base.Writer` instances for a given ``key``. The value for this ``key`` is passed along explicitly by the user, either by providing it directly during :class:`~nannyml.io.base.Writer` initialization or passed along in the ``nann.yml`` configuration file. """ registry: Dict[str, Type[Writer]] = {} @classmethod def _logger(cls) -> logging.Logger: return logging.getLogger(__name__)
[docs] @classmethod def create(cls, key, kwargs: Optional[Dict[str, Any]] = None) -> Writer: """Returns a :class:`~nannyml.io.base.Writer` instance for a given string.""" if kwargs is None: kwargs = {} if key not in cls.registry: raise InvalidArgumentsException( f"unknown key '{key}' given. " f"Currently registered keys are: {list(cls.registry.keys())}" ) writer_class = cls.registry[key] return writer_class(**kwargs)
[docs] @classmethod def register(cls, key) -> Callable: def inner_wrapper(wrapped_class: Type[Writer]) -> Type[Writer]: if key in cls.registry: cls._logger().warning(f"re-registering Writer for key='{key}'") cls.registry[key] = wrapped_class return wrapped_class return inner_wrapper
[docs]class Reader(ABC): """Base class for reading data""" @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__)
[docs] def read(self) -> pd.DataFrame: try: return self._read() except Exception as exc: raise ReaderException(f"Failed reading data. \n{str(exc)}")
@abstractmethod def _read(self) -> pd.DataFrame: raise NotImplementedError( f"'{self.__class__.__name__}' is a subclass of Reader and it must implement the _read method" )
def _get_protocol_and_path(filepath: str) -> Tuple[str, str]: if re.match(r"^[a-zA-Z]:[\\/]", filepath) or re.match(r"^[a-zA-Z\d]+://", filepath) is None: return "file", filepath parsed_path = urlsplit(filepath) protocol = parsed_path.scheme or "file" path = parsed_path.path if protocol in HTTP_PROTOCOLS: path = filepath.split("://", 1)[-1] return protocol, path if protocol == "file": windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path) if windows_path: path = ":".join(windows_path.groups()) if parsed_path.netloc: if protocol in CLOUD_PROTOCOLS: host_with_port = parsed_path.netloc.rsplit("@", 1)[-1] host = host_with_port.rsplit(":", 1)[0] path = host + path return protocol, path
[docs]def get_filepath_str(path: str, protocol: str) -> str: if protocol in HTTP_PROTOCOLS: path = "".join((protocol, "://", path)) return path