Source code for nannyml.drift.model_inputs.multivariate.data_reconstruction.results

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#
#  License: Apache Software License 2.0

"""Implementation of the Data Reconstruction Drift Calculator."""

import pandas as pd
import plotly.graph_objects as go

from nannyml.drift.base import DriftResult
from nannyml.exceptions import InvalidArgumentsException
from nannyml.plots import CHUNK_KEY_COLUMN_NAME
from nannyml.plots._step_plot import _step_plot


[docs]class DataReconstructionDriftCalculatorResult(DriftResult): """Contains the results of the data reconstruction drift calculation and adds functionality like plotting."""
[docs] def plot(self, kind: str = 'drift', *args, **kwargs) -> go.Figure: """Renders a line plot of the ``reconstruction_error`` of the data reconstruction drift calculation results. Chunks are set on a time-based X-axis by using the period containing their observations. Chunks of different partitions (``reference`` and ``analysis``) are represented using different colors and a vertical separation if the drift results contain multiple partitions. The different plot kinds that are available: - ``drift``: plots drift per :class:`~nannyml.chunk.Chunk` for a chunked data set. Returns ------- fig: plotly.graph_objects.Figure A ``Figure`` object containing the requested drift plot. Can be saved to disk or shown rendered on screen using ``fig.show()``. Examples -------- >>> import nannyml as nml >>> ref_df, ana_df, _ = nml.load_synthetic_binary_classification_dataset() >>> metadata = nml.extract_metadata(ref_df, model_type=nml.ModelType.CLASSIFICATION_BINARY) >>> drift_calc = nml.DataReconstructionDriftCalculator(model_metadata=metadata, chunk_period='W') >>> drift_calc.fit(ref_df) >>> drifts = drift_calc.calculate(ana_df) >>> # create the data reconstruction drift plot and display it >>> drifts.plot(kind='drift').show() """ if kind == 'drift': return _plot_drift(self.data, args, kwargs) else: raise InvalidArgumentsException( f"unknown plot kind '{kind}'. " f"Please provide on of: ['feature_drift', 'feature_distribution', " f"'prediction_drift', 'prediction_distribution']." )
def _plot_drift(data: pd.DataFrame, *args, **kwargs) -> go.Figure: plot_partition_separator = len(data.value_counts()) > 1 data['thresholds'] = list(zip(data.lower_threshold, data.upper_threshold)) fig = _step_plot( table=data, metric_column_name='reconstruction_error', chunk_column_name=CHUNK_KEY_COLUMN_NAME, drift_column_name='alert', threshold_column_name='thresholds', hover_labels=['Chunk', 'Reconstruction error', 'Target data'], title='Data Reconstruction Drift', y_axis_title='Reconstruction Error', v_line_separating_analysis_period=plot_partition_separator, ) return fig