# Author: Niels Nuyttens <niels@nannyml.com>
#
# License: Apache Software License 2.0
"""Module containing CBPE estimation results and plotting implementations."""
import copy
from typing import List, Optional
import pandas as pd
from plotly import graph_objects as go
from nannyml._typing import Key, ModelOutputsType, ProblemType
from nannyml._typing import Result as ResultType
from nannyml.base import Abstract1DResult
from nannyml.chunk import Chunker
from nannyml.exceptions import InvalidArgumentsException
from nannyml.performance_estimation.confidence_based.metrics import Metric
from nannyml.plots.blueprints.comparisons import ResultCompareMixin
from nannyml.plots.blueprints.metrics import plot_metrics
from nannyml.usage_logging import UsageEvent, log_usage
SUPPORTED_METRIC_VALUES = ['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy']
[docs]class Result(Abstract1DResult, ResultCompareMixin):
"""Contains results for CBPE estimation and adds plotting functionality."""
def __init__(
self,
results_data: pd.DataFrame,
metrics: List[Metric],
y_pred: str,
y_pred_proba: ModelOutputsType,
y_true: str,
chunker: Chunker,
problem_type: ProblemType,
timestamp_column_name: Optional[str] = None,
):
super().__init__(results_data)
self.metrics = metrics
self.y_pred = y_pred
self.y_pred_proba = y_pred_proba
self.y_true = y_true
self.timestamp_column_name = timestamp_column_name
self.problem_type = problem_type
self.chunker = chunker
def _filter(self, period: str, metrics: Optional[List[str]] = None, *args, **kwargs) -> ResultType:
if metrics is None:
metrics = [metric.column_name for metric in self.metrics]
data = pd.concat([self.data.loc[:, (['chunk'])], self.data.loc[:, (metrics,)]], axis=1)
if period != 'all':
data = data.loc[data.loc[:, ('chunk', 'period')] == period, :]
data = data.reset_index(drop=True)
res = copy.deepcopy(self)
res.data = data
res.metrics = [m for m in self.metrics if m.column_name in metrics]
return res
[docs] def keys(self) -> List[Key]:
return [
Key(properties=(metric.column_name,), display_names=(f'estimated {metric.display_name}',))
for metric in self.metrics
]
[docs] @log_usage(UsageEvent.CBPE_PLOT, metadata_from_kwargs=['kind'])
def plot(
self,
kind: str = 'performance',
*args,
**kwargs,
) -> go.Figure:
"""Render plots based on CBPE estimation results.
This function will return a :class:`plotly.graph_objects.Figure` object.
The following kinds of plots are available:
- ``performance``: a line plot rendering the estimated performance per :class:`~nannyml.chunk.Chunk` after
applying the :meth:`~nannyml.performance_estimation.confidence_based.CBPE.calculate` method on a chunked
dataset.
Returns
-------
fig: :class:`plotly.graph_objs._figure.Figure`
A :class:`~plotly.graph_objs._figure.Figure` object containing the requested drift plot.
Can be saved to disk using the :meth:`~plotly.graph_objs._figure.Figure.write_image` method
or shown rendered on screen using the :meth:`~plotly.graph_objs._figure.Figure.show` method.
Examples
--------
>>> import nannyml as nml
>>>
>>> reference_df, analysis_df, target_df = nml.load_synthetic_binary_classification_dataset()
>>>
>>> estimator = nml.CBPE(
>>> y_true='work_home_actual',
>>> y_pred='y_pred',
>>> y_pred_proba='y_pred_proba',
>>> timestamp_column_name='timestamp',
>>> metrics=['f1', 'roc_auc']
>>> )
>>>
>>> estimator.fit(reference_df)
>>>
>>> results = estimator.estimate(analysis_df)
>>> results.plot().show()
"""
if kind == 'performance':
return plot_metrics(
self,
title='Estimated performance <b>(CBPE)</b>',
subplot_title_format='Estimated <b>{display_names[0]}</b>',
)
else:
raise InvalidArgumentsException(f"unknown plot kind '{kind}'. " f"Please provide on of: ['performance'].")