Source code for nannyml.io.db.mappers

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#
#  License: Apache Software License 2.0
import abc
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional

from nannyml.drift.multivariate.data_reconstruction.result import Result as DataReconstructionDriftResult
from nannyml.drift.univariate import Result as UnivariateDriftResult
from nannyml.exceptions import InvalidArgumentsException
from nannyml.io.db.entities import CBPEPerformanceMetric, DataReconstructionFeatureDriftMetric, DLEPerformanceMetric
from nannyml.io.db.entities import Metric
from nannyml.io.db.entities import Metric as DbMetric
from nannyml.io.db.entities import RealizedPerformanceMetric, UnivariateDriftMetric
from nannyml.performance_calculation.result import Result as RealizedPerformanceResult
from nannyml.performance_estimation.confidence_based.results import Result as CBPEResult
from nannyml.performance_estimation.direct_loss_estimation.result import Result as DLEResult


[docs]class Mapper(abc.ABC): def __init__(self): pass
[docs] @abc.abstractmethod def map_to_entity(self, result, **metric_args) -> List[DbMetric]: """Maps a result to a list of Metric entities."""
def _fully_qualified_class_name(result): return f"{result.__module__}.{result.__name__}"
[docs]class MapperFactory: """A factory class that produces Mapper instances for a given Result subclass.""" registry: Dict[str, Mapper] = {} @classmethod def _logger(cls) -> logging.Logger: return logging.getLogger(__name__)
[docs] @classmethod def create(cls, result, kwargs: Optional[Dict[str, Any]] = None) -> Mapper: """Returns a Mapper instance for a given result class.""" if kwargs is None: kwargs = {} key = _fully_qualified_class_name(result.__class__) if key not in cls.registry: raise InvalidArgumentsException( f"unknown result class '{key}' given. " f"Currently registered result classes are: {list(cls.registry.keys())}" ) mapper_class = cls.registry[key] return mapper_class(**kwargs) # type: ignore
[docs] @classmethod def register(cls, result) -> Callable: key = _fully_qualified_class_name(result) def inner_wrapper(wrapped_class: Mapper) -> Mapper: if key in cls.registry: cls._logger().warning(f"re-registering Metric for result_class='{key}'") cls.registry[key] = wrapped_class return wrapped_class return inner_wrapper
[docs]@MapperFactory.register(UnivariateDriftResult) class UnivariateDriftResultMapper(Mapper):
[docs] def map_to_entity(self, result, **metric_args) -> List[DbMetric]: def _parse( column_name: str, metric_name: str, start_date: datetime, end_date: datetime, value, alert: bool ) -> DbMetric: timestamp = start_date + (end_date - start_date) / 2 return UnivariateDriftMetric( column_name=column_name, metric_name=metric_name, start_timestamp=start_date, end_timestamp=end_date, timestamp=timestamp, value=value, alert=alert, **metric_args, ) if not isinstance(result, UnivariateDriftResult): raise InvalidArgumentsException(f"{self.__class__.__name__} can not deal with '{type(result)}'") if result.timestamp_column_name is None: raise NotImplementedError( 'no timestamp column was specified. Listing metrics currently requires a ' 'timestamp column to be specified and present' ) res: List[Metric] = [] for column_name in result.continuous_column_names: for method in result.continuous_method_names: res += ( result.filter(period='analysis') .to_df()[ [ ('chunk', 'chunk', 'start_date'), ('chunk', 'chunk', 'end_date'), (column_name, method, 'value'), (column_name, method, 'alert'), ] ] .apply(lambda r: _parse(column_name, method, *r), axis=1) .to_list() ) for column_name in result.categorical_column_names: for method in result.categorical_method_names: res += ( result.filter(period='analysis') .to_df()[ [ ('chunk', 'chunk', 'start_date'), ('chunk', 'chunk', 'end_date'), (column_name, method, 'value'), (column_name, method, 'alert'), ] ] .apply(lambda r: _parse(column_name, method, *r), axis=1) .to_list() ) return res
[docs]@MapperFactory.register(DataReconstructionDriftResult) class ReconstructionErrorDriftResultMapper(Mapper):
[docs] def map_to_entity(self, result, **metric_args) -> List[DbMetric]: def _parse( metric_name: str, start_date: datetime, end_date: datetime, value, upper_threshold, lower_threshold, alert, ) -> DbMetric: timestamp = start_date + (end_date - start_date) / 2 return DataReconstructionFeatureDriftMetric( metric_name=metric_name, start_timestamp=start_date, end_timestamp=end_date, timestamp=timestamp, value=value, upper_threshold=upper_threshold, lower_threshold=lower_threshold, alert=alert, **metric_args, ) if not isinstance(result, DataReconstructionDriftResult): raise InvalidArgumentsException(f"{self.__class__.__name__} can not deal with '{type(result)}'") if result.timestamp_column_name is None: raise NotImplementedError( 'no timestamp column was specified. Listing metrics currently requires a ' 'timestamp column to be specified and present' ) res: List[DbMetric] = [] for metric in result.metrics: res += ( result.filter(period='analysis') .to_df()[ [ ('chunk', 'start_date'), ('chunk', 'end_date'), (metric, 'value'), (metric, 'upper_threshold'), (metric, 'lower_threshold'), (metric, 'alert'), ] ] .apply(lambda r: _parse(metric, *r), axis=1) .to_list() ) return res
[docs]@MapperFactory.register(RealizedPerformanceResult) class RealizedPerformanceMapper(Mapper):
[docs] def map_to_entity(self, result, **metric_args) -> List[DbMetric]: def _parse( metric_name: str, start_date: datetime, end_date: datetime, value, upper_threshold, lower_threshold, alert: bool, ) -> RealizedPerformanceMetric: timestamp = start_date + (end_date - start_date) / 2 return RealizedPerformanceMetric( metric_name=metric_name, start_timestamp=start_date, end_timestamp=end_date, timestamp=timestamp, value=value, upper_threshold=upper_threshold, lower_threshold=lower_threshold, alert=alert, **metric_args, ) if result.timestamp_column_name is None: raise NotImplementedError( 'no timestamp column was specified. Listing metrics currently requires a ' 'timestamp column to be specified and present' ) res: List[Metric] = [] for metric in [metric.column_name for metric in result.metrics]: res += ( result.filter(partition='analysis') .to_df()[ [ ('chunk', 'start_date'), ('chunk', 'end_date'), (metric, 'value'), (metric, 'upper_threshold'), (metric, 'lower_threshold'), (metric, 'alert'), ] ] .apply(lambda r: _parse(metric, *r), axis=1) .to_list() ) return res
[docs]@MapperFactory.register(CBPEResult) class CBPEMapper(Mapper):
[docs] def map_to_entity(self, result, **metric_args) -> List[DbMetric]: def _parse( metric_name: str, start_date: datetime, end_date: datetime, value, upper_threshold, lower_threshold, alert: bool, ) -> CBPEPerformanceMetric: timestamp = start_date + (end_date - start_date) / 2 return CBPEPerformanceMetric( metric_name=metric_name, start_timestamp=start_date, end_timestamp=end_date, timestamp=timestamp, value=value, upper_threshold=upper_threshold, lower_threshold=lower_threshold, alert=alert, **metric_args, ) if result.timestamp_column_name is None: raise NotImplementedError( 'no timestamp column was specified. Listing metrics currently requires a ' 'timestamp column to be specified and present' ) res: List[Metric] = [] for metric in [metric.column_name for metric in result.metrics]: res += ( result.filter(period='analysis') .to_df()[ [ ('chunk', 'start_date'), ('chunk', 'end_date'), (metric, 'value'), (metric, 'upper_threshold'), (metric, 'lower_threshold'), (metric, 'alert'), ] ] .apply(lambda r: _parse(metric, *r), axis=1) .to_list() ) return res
[docs]@MapperFactory.register(DLEResult) class DLEMapper(Mapper):
[docs] def map_to_entity(self, result, **metric_args) -> List[DbMetric]: def _parse( metric_name: str, start_date: datetime, end_date: datetime, value, upper_threshold, lower_threshold, alert: bool, ) -> DLEPerformanceMetric: timestamp = start_date + (end_date - start_date) / 2 return DLEPerformanceMetric( metric_name=metric_name, start_timestamp=start_date, end_timestamp=end_date, timestamp=timestamp, value=value, upper_threshold=upper_threshold, lower_threshold=lower_threshold, alert=alert, **metric_args, ) if result.timestamp_column_name is None: raise NotImplementedError( 'no timestamp column was specified. Listing metrics currently requires a ' 'timestamp column to be specified and present' ) res: List[Metric] = [] for metric in [metric.column_name for metric in result.metrics]: res += ( result.filter(period='analysis') .to_df()[ [ ('chunk', 'start_date'), ('chunk', 'end_date'), (metric, 'value'), (metric, 'upper_threshold'), (metric, 'lower_threshold'), (metric, 'alert'), ] ] .apply(lambda r: _parse(metric, *r), axis=1) .to_list() ) return res