# Author: Niels Nuyttens <niels@nannyml.com>
#
# License: Apache Software License 2.0
"""The classes representing the results of a target distribution calculation."""
import pandas as pd
import plotly.graph_objects as go
from nannyml.exceptions import InvalidArgumentsException
from nannyml.metadata.base import ModelMetadata
from nannyml.plots import CHUNK_KEY_COLUMN_NAME
from nannyml.plots._step_plot import _step_plot
[docs]class TargetDistributionResult:
"""Contains target distribution data and utilities to plot it."""
def __init__(self, target_distribution: pd.DataFrame, model_metadata: ModelMetadata):
"""Creates a new instance of the TargetDistributionResults."""
self.data = target_distribution
self.metadata = model_metadata
[docs] def plot(self, kind: str = 'distribution', distribution: str = 'metric', *args, **kwargs) -> go.Figure:
"""Renders a line plot of the target distribution.
Chunks are set on a time-based X-axis by using the period containing their observations.
Chunks of different partitions (``reference`` and ``analysis``) are represented using different colors and
a vertical separation if the drift results contain multiple partitions.
Parameters
----------
kind: str
The kind of plot to show. Restricted to the value 'distribution'.
distribution: str, default='metric'
The kind of distribution to plot. Restricted to the values 'metric' or 'statistical'
Returns
-------
fig: plotly.graph_objects.Figure
A ``Figure`` object containing the requested drift plot. Can be saved to disk or shown rendered on screen
using ``fig.show()``.
Examples
--------
>>> import nannyml as nml
>>> ref_df, ana_df, _ = nml.load_synthetic_binary_classification_dataset()
>>> metadata = nml.extract_metadata(ref_df, model_type=nml.ModelType.CLASSIFICATION_BINARY)
>>> target_distribution_calc = nml.TargetDistributionCalculator(model_metadata=metadata, chunk_period='W')
>>> target_distribution_calc.fit(ref_df)
>>> target_distribution = target_distribution_calc.calculate(ana_df)
>>> # plot the distribution of the mean
>>> target_distribution.plot(kind='metric').show()
>>> # plot the Chi square statistic
>>> target_distribution.plot(kind='statistical').show()
"""
if kind == 'distribution':
return self._plot_distribution(distribution)
else:
raise InvalidArgumentsException(f"unknown plot kind '{kind}'. " f"Please provide one of: ['distribution'].")
def _plot_distribution(self, distribution: str) -> go.Figure:
plot_partition_separator = len(self.data.value_counts()) > 1
if distribution == 'metric':
fig = _step_plot(
table=self.data,
metric_column_name='metric_target_drift',
chunk_column_name=CHUNK_KEY_COLUMN_NAME,
drift_column_name='alert',
hover_labels=['Chunk', 'Rate', 'Target data'],
title=f'Target distribution over time for {self.metadata.target_column_name}',
y_axis_title='Rate of positive occurrences',
v_line_separating_analysis_period=plot_partition_separator,
partial_target_column_name='targets_missing_rate',
statistically_significant_column_name='significant',
)
return fig
elif distribution == 'statistical':
fig = _step_plot(
table=self.data,
metric_column_name='statistical_target_drift',
chunk_column_name=CHUNK_KEY_COLUMN_NAME,
drift_column_name='alert',
hover_labels=['Chunk', 'Chi-square statistic', 'Target data'],
title=f'Chi-square statistic over time for {self.metadata.target_column_name} ',
y_axis_title='Chi-square statistic',
v_line_separating_analysis_period=plot_partition_separator,
partial_target_column_name='targets_missing_rate',
statistically_significant_column_name='significant',
)
return fig