Source code for nannyml.runner

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#
#  License: Apache Software License 2.0

"""Used as an access point to start using NannyML in its most simple form."""

import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import pandas as pd
from rich.console import Console

from nannyml._typing import Result
from nannyml.config import Config, InputDataConfig, StoreConfig, WriterConfig
from nannyml.drift.multivariate.data_reconstruction import DataReconstructionDriftCalculator
from nannyml.drift.univariate import UnivariateDriftCalculator
from nannyml.exceptions import InvalidArgumentsException
from nannyml.io import FileReader, FilesystemStore, Writer, WriterFactory
from nannyml.io.store import Store
from nannyml.performance_calculation import PerformanceCalculator
from nannyml.performance_estimation.confidence_based import CBPE
from nannyml.performance_estimation.direct_loss_estimation import DLE


[docs]@dataclass class RunContext: STEPS_PER_CALCULATOR = 3 current_step: int total_steps: int current_calculator: str current_calculator_config: Optional[Dict[str, Any]] = None current_calculator_success: bool = True run_success: bool = True result: Optional[Result] = None
[docs] def increase_step(self): self.current_step += 1
[docs]@contextmanager def run_context(config: Config): yield RunContext( current_step=0, total_steps=len([c for c in config.calculators if c.enabled]) * RunContext.STEPS_PER_CALCULATOR, current_calculator='', )
_registry: Dict[str, Type] = { 'univariate_drift': UnivariateDriftCalculator, 'multivariate_drift': DataReconstructionDriftCalculator, 'performance': PerformanceCalculator, 'cbpe': CBPE, 'dle': DLE, } _logger = logging.getLogger(__name__)
[docs]class RunnerLogger: def __init__(self, logger: logging.Logger, console: Optional[Console] = None): self.logger = logger self.console = console
[docs] def log(self, message: object, log_level: int = logging.INFO): if self.logger: self.logger.log(level=log_level, msg=message) if self.console and log_level == logging.INFO: self.console.log(message)
[docs]def run( # noqa: C901 config: Config, logger: logging.Logger = logging.getLogger(__name__), console: Optional[Console] = None, on_fit: Optional[Callable[[RunContext], Any]] = None, on_calculate: Optional[Callable[[RunContext], Any]] = None, on_write: Optional[Callable[[RunContext], Any]] = None, on_calculator_complete: Optional[Callable[[RunContext], Any]] = None, on_run_complete: Optional[Callable[[RunContext], Any]] = None, on_fail: Optional[Callable[[RunContext, Optional[Exception]], Any]] = None, ): run_logger = RunnerLogger(logger, console) try: with run_context(config) as context: run_logger.log("reading reference data", log_level=logging.DEBUG) reference_data = read_data(config.input.reference_data, run_logger) # read analysis data run_logger.log("reading analysis data", log_level=logging.DEBUG) analysis_data = read_data(config.input.analysis_data, run_logger) if config.input.target_data: run_logger.log("reading target data", log_level=logging.DEBUG) target_data = read_data(config.input.target_data, run_logger) if config.input.target_data.join_column: analysis_data = analysis_data.merge(target_data, on=config.input.target_data.join_column) else: analysis_data = analysis_data.join(target_data) for calculator_config in config.calculators: try: context.current_calculator_config = calculator_config.dict() context.current_calculator = calculator_config.name or calculator_config.type if not calculator_config.enabled: continue writers = get_output_writers(calculator_config.outputs, run_logger) store = get_store(calculator_config.store, run_logger) if calculator_config.type not in _registry: raise InvalidArgumentsException(f"unknown calculator type '{calculator_config.type}'") # first step: load or (create + fit) calculator context.increase_step() calc_cls = _registry[calculator_config.type] if store and calculator_config.store: run_logger.log( f"[{context.current_step}/{context.total_steps}] '{context.current_calculator}': " f"loading calculator from store" ) calc = store.load(filename=calculator_config.store.filename, as_type=calc_cls) if calc is None: run_logger.log( f"calculator '{context.current_calculator}' not found in store. " f"Creating, fitting and storing new instance", log_level=logging.DEBUG, ) calc = calc_cls(**calculator_config.params) if on_fit: on_fit(context) calc.fit(reference_data) store.store(obj=calc, filename=calculator_config.store.filename) else: run_logger.log( f"[{context.current_step}/{context.total_steps}] '{context.current_calculator}': " f"creating and fitting calculator" ) calc = calc_cls(**calculator_config.params) if on_fit: on_fit(context) calc.fit(reference_data) # second step: perform calculations context.increase_step() run_logger.log( f"[{context.current_step}/{context.total_steps}] '{context.current_calculator}': " f"running calculator" ) if on_calculate: on_calculate(context) result = ( calc.calculate(analysis_data) if hasattr(calc, 'calculate') else calc.estimate(analysis_data) ) context.result = result # third step: write results context.increase_step() run_logger.log( f"[{context.current_step}/{context.total_steps}] '{context.current_calculator}': " f"writing out results" ) if on_write: on_write(context) for writer, write_args in writers: run_logger.log(f"writing results with {writer} and args {write_args}", log_level=logging.DEBUG) writer.write(result, **write_args) if on_calculator_complete: on_calculator_complete(context) except Exception as exc: context.current_calculator_success = False context.run_success = False if on_fail: on_fail(context, exc) run_logger.log(f"an unexpected exception occurred running '{calculator_config.type}': {exc}") if on_run_complete: on_run_complete(context) except Exception as exc: context.run_success = False if on_fail: on_fail(context, exc) raise exc
[docs]def read_data(input_config: InputDataConfig, logger: Optional[RunnerLogger] = None) -> pd.DataFrame: data = FileReader( filepath=input_config.path, credentials=input_config.credentials, read_args=input_config.read_args ).read() if logger: logger.log(f"read {data.size} rows from {input_config.path}") return data
[docs]def get_output_writers( outputs_config: Optional[List[WriterConfig]], logger: Optional[RunnerLogger] = None ) -> List[Tuple[Writer, Dict[str, Any]]]: if not outputs_config: return [] writers: List[Tuple[Writer, Dict[str, Any]]] = [] for writer_config in outputs_config: writer = WriterFactory.create(writer_config.type, writer_config.params) writers.append((writer, writer_config.write_args or {})) return writers
[docs]def get_store(store_config: Optional[StoreConfig], logger: Optional[RunnerLogger] = None) -> Optional[Store]: if store_config: if logger: logger.log(f"using file system store with path '{store_config.path}'", log_level=logging.DEBUG) store = FilesystemStore( root_path=store_config.path, credentials=store_config.credentials or {}, ) return store return None
def _get_ignore_errors(ignore_errors: bool, config: Config) -> bool: if ignore_errors is None: if config.ignore_errors is None: return False else: return config.ignore_errors else: return ignore_errors