# Author: Niels Nuyttens <niels@nannyml.com>
#
# License: Apache Software License 2.0
import logging
import re
from abc import ABC, abstractmethod
from pathlib import PurePath, PurePosixPath
from typing import Any, Dict, Tuple
from urllib.parse import urlsplit
import pandas as pd
from plotly.graph_objs import Figure
from nannyml.exceptions import InvalidArgumentsException, ReaderException, WriterException
HTTP_PROTOCOLS = ['http', 'https']
CLOUD_PROTOCOLS = ['s3', 'gcs', 'gs', 'adl', 'abfs', 'abfss']
[docs]class Writer(ABC):
"""Base class for writing out results"""
def __init__(
self,
filepath: PurePosixPath,
):
self.filepath = filepath
@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)
[docs] def write(self, data: pd.DataFrame, plots: Dict[str, Figure] = None, **kwargs) -> Any:
if data is None:
raise InvalidArgumentsException("Trying to write 'None'")
if plots is None:
plots = {}
if kwargs is None:
kwargs = {}
try:
self._write(data=data, plots=plots, **kwargs)
except Exception as exc:
raise WriterException(f"Failed writing data. \n{str(exc)}")
@abstractmethod
def _write(self, data: pd.DataFrame, plots: Dict[str, Figure], **kwargs):
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of Writer and it must implement the _write method"
)
[docs]class Reader(ABC):
"""Base class for reading data"""
@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)
[docs] def read(self) -> pd.DataFrame:
try:
return self._read()
except Exception as exc:
raise ReaderException(f"Failed reading data. \n{str(exc)}")
@abstractmethod
def _read(self) -> pd.DataFrame:
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of Reader and it must implement the _read method"
)
[docs]def get_protocol_and_path(filepath: str) -> Tuple[str, str]:
if re.match(r"^[a-zA-Z]:[\\/]", filepath) or re.match(r"^[a-zA-Z\d]+://", filepath) is None:
return "file", filepath
parsed_path = urlsplit(filepath)
protocol = parsed_path.scheme or "file"
path = parsed_path.path
if protocol in HTTP_PROTOCOLS:
return protocol, path
if protocol == "file":
windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path)
if windows_path:
path = ":".join(windows_path.groups())
if parsed_path.netloc:
if protocol in CLOUD_PROTOCOLS:
host_with_port = parsed_path.netloc.rsplit("@", 1)[-1]
host = host_with_port.rsplit(":", 1)[0]
path = host + path
return protocol, path
[docs]def get_filepath_str(path: PurePath, protocol: str) -> str:
path_str = path.as_posix()
if protocol in HTTP_PROTOCOLS:
path_str = "".join((protocol, "://", path_str))
return path_str