Source code for nannyml.distribution.categorical.result

import copy
import math
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd
import plotly.graph_objs as go
from typing_extensions import Self

from nannyml import Chunker
from nannyml._typing import Key
from nannyml.base import AbstractResult
from nannyml.drift.univariate.result import Result as DriftResult
from nannyml.exceptions import InvalidArgumentsException
from nannyml.plots import Colors, Figure, is_time_based_x_axis
from nannyml.plots.components.stacked_bar_plot import alert as stacked_bar_alert
from nannyml.plots.components.stacked_bar_plot import stacked_bar


[docs]class Result(AbstractResult): def __init__( self, results_data: pd.DataFrame, column_names: List[str], timestamp_column_name: Optional[str], chunker: Chunker, ): super().__init__(results_data, column_names) self.timestamp_column_name = timestamp_column_name self.chunker = chunker self.column_names = column_names
[docs] def to_df(self, multilevel: bool = True) -> pd.DataFrame: return self.data
def _filter( self, period: str, metrics: Optional[List[str]] = None, column_names: Optional[Union[str, List[str]]] = None, *args, **kwargs, ) -> Self: data = self.data if period != 'all': data = data.loc[data['period'] == period, :] data = data.reset_index(drop=True) if isinstance(column_names, str): column_names = [column_names] if column_names: data = data.loc[data['column_name'].isin(column_names), :] res = copy.deepcopy(self) res.data = data return res @property def chunk_keys(self) -> pd.Series: return self.data['key'] @property def chunk_start_dates(self) -> pd.Series: return self.data['start_datetime'] # def chunk_start_dates_for_key(self, key: Key) -> Optional[pd.Series]: # return self._get_property_for_key(key, 'start_datetime') @property def chunk_end_dates(self) -> pd.Series: return self.data['end_datetime'] # def chunk_end_dates_for_key(self, key: Key) -> Optional[pd.Series]: # return self._get_property_for_key(key, 'end_datetime') @property def chunk_start_indices(self) -> pd.Series: return self.data['start_index'] # def chunk_start_indices_for_key(self, key: Key) -> Optional[pd.Series]: # return self._get_property_for_key(key, 'start_index') @property def chunk_end_indices(self) -> pd.Series: return self.data['end_index'] # def chunk_end_indices_for_key(self, key: Key) -> Optional[pd.Series]: # return self._get_property_for_key(key, 'end_index') @property def chunk_indices(self) -> pd.Series: return self.data['chunk_index'] # def chunk_indices_for_key(self, key: Key) -> Optional[pd.Series]: # return self._get_property_for_key(key, 'chunk_index') @property def chunk_periods(self) -> pd.Series: return self.data['period'] # def chunk_periods_for_key(self, key: Key) -> Optional[pd.Series]: # return self._get_property_for_key(key, 'period')
[docs] def value_counts(self, key: Optional[Key] = None, column_name: Optional[str] = None) -> pd.DataFrame: if not key and not column_name: raise InvalidArgumentsException( "cannot retrieve value counts when key and column_name are both not set. " "Please provide either a key or a column." ) if key: (column_name,) = key.properties data = self.filter(column_names=[column_name]).data res = data[ [ 'value', 'key', 'start_datetime', 'end_datetime', 'start_index', 'end_index', 'chunk_index', 'value_counts', 'value_counts_total', 'value_counts_normalised', ] ].rename( columns={'value': column_name, 'key': 'chunk_key', 'chunk_index': 'chunk_indices'}, ) res[column_name] = res[column_name].astype('category') return res
def _get_property_for_key(self, key: Key, property_name: str) -> Optional[pd.Series]: (column_name,) = key.properties return ( self.data.loc[self.data['column_name'] == column_name, property_name] if property_name in self.data.columns else None )
[docs] def keys(self) -> List[Key]: return [Key(properties=(c,), display_names=(c,)) for c in self.column_names]
[docs] def plot(self, drift_result: Optional[DriftResult] = None, *args, **kwargs) -> go.Figure: """ Creates a "joyplot over time" visualization to illustrate continuous distribution changes over time. Parameters ---------- drift_result: Optional[nannyml.drift.univariate.Result] The result of a univariate drift calculation. When set it will be used to lookup alerts that occurred for each column / drift method combination in the drift calculation result. For each of these combinations a distribution plot of the column will be rendered showing the alerts for each drift method. When the `drift_result` parameter is not set no alerts will be rendered on the distribution plots. """ if drift_result and not isinstance(drift_result, DriftResult): raise InvalidArgumentsException( 'currently the alerts_from parameter only supports results of the ' 'UnivariateDriftCalculator.' ) if drift_result: self.check_is_compatible_with(drift_result) return ( _plot_categorical_distribution_with_alerts(self, drift_result) if drift_result else _plot_categorical_distribution(self) )
[docs] def check_is_compatible_with(self, drift_result: DriftResult): # Check if all distribution columns are present in the drift result drift_column_names = set([col for tup in drift_result.keys() for col, _ in tup]) distribution_column_names = set(self.column_names) missing_columns = distribution_column_names.difference(drift_column_names) if len(missing_columns) > 0: raise InvalidArgumentsException( "cannot render distribution plots with warnings. Following columns are not " f"in the drift results: {list(missing_columns)}" ) # Check if both results use the same X-axis drift_result_is_time_based = is_time_based_x_axis(drift_result.chunk_start_dates, drift_result.chunk_end_dates) distr_result_is_time_based = is_time_based_x_axis(self.chunk_start_dates, self.chunk_end_dates) if drift_result_is_time_based != distr_result_is_time_based: raise InvalidArgumentsException( "cannot render distribution plots with warnings. Drift results are" f"{'' if drift_result_is_time_based else ' not'} time-based, distribution results are" f"{'' if distr_result_is_time_based else ' not'} time-based. Drift and distribution results should " f"both be time-based (have a timestamp column) or not." )
def _plot_categorical_distribution( result: Result, title: Optional[str] = 'Column distributions', figure: Optional[go.Figure] = None, x_axis_time_title: str = 'Time', x_axis_chunk_title: str = 'Chunk', y_axis_title: str = 'Values', figure_args: Optional[Dict[str, Any]] = None, subplot_title_format: str = '<b>{display_names[0]}</b> distribution', number_of_columns: Optional[int] = None, ) -> go.Figure: number_of_plots = len(result.keys()) if number_of_columns is None: number_of_columns = min(number_of_plots, 1) number_of_rows = math.ceil(number_of_plots / number_of_columns) if figure_args is None: figure_args = {} if figure is None: figure = Figure( **dict( title=title, x_axis_title=x_axis_time_title if is_time_based_x_axis(result.chunk_start_dates, result.chunk_end_dates) else x_axis_chunk_title, y_axis_title=y_axis_title, legend=dict(traceorder="grouped", itemclick=False, itemdoubleclick=False), height=number_of_plots * 500 / number_of_columns, subplot_args=dict( cols=number_of_columns, rows=number_of_rows, subplot_titles=[ subplot_title_format.format(display_names=key.display_names) for key in result.keys() ], ), **figure_args, ) ) for idx, key in enumerate(result.keys()): row = (idx // number_of_columns) + 1 col = (idx % number_of_columns) + 1 (column_name,) = key.properties reference_result = result.filter(period='reference', column_names=[column_name]) analysis_result = result.filter(period='analysis', column_names=[column_name]) figure = _plot_stacked_bar( figure=figure, row=row, col=col, column_name=column_name, reference_value_counts=reference_result.value_counts(key), reference_alerts=None, reference_chunk_keys=reference_result.chunk_keys, reference_chunk_periods=reference_result.chunk_periods, reference_chunk_indices=reference_result.chunk_indices, reference_chunk_start_dates=reference_result.chunk_start_dates, reference_chunk_end_dates=reference_result.chunk_end_dates, analysis_value_counts=analysis_result.value_counts(key), analysis_alerts=None, analysis_chunk_keys=analysis_result.chunk_keys, analysis_chunk_periods=analysis_result.chunk_periods, analysis_chunk_indices=analysis_result.chunk_indices, analysis_chunk_start_dates=analysis_result.chunk_start_dates, analysis_chunk_end_dates=analysis_result.chunk_end_dates, ) return figure def _plot_categorical_distribution_with_alerts( result: Result, drift_result: DriftResult, title: Optional[str] = 'Column distributions', figure: Optional[go.Figure] = None, x_axis_time_title: str = 'Time', x_axis_chunk_title: str = 'Chunk', y_axis_title: str = 'Values', figure_args: Optional[Dict[str, Any]] = None, subplot_title_format: str = '<b>{display_names[0]}</b> distribution (alerts for {display_names[1]})', number_of_columns: Optional[int] = None, ) -> go.Figure: number_of_plots = len(drift_result.keys()) if number_of_columns is None: number_of_columns = min(number_of_plots, 1) number_of_rows = math.ceil(number_of_plots / number_of_columns) if figure_args is None: figure_args = {} if figure is None: figure = Figure( **dict( title=title, x_axis_title=x_axis_time_title if is_time_based_x_axis(result.chunk_start_dates, result.chunk_end_dates) else x_axis_chunk_title, y_axis_title=y_axis_title, legend=dict(traceorder="grouped", itemclick=False, itemdoubleclick=False), height=number_of_plots * 500 / number_of_columns, subplot_args=dict( cols=number_of_columns, rows=number_of_rows, subplot_titles=[ subplot_title_format.format(display_names=key.display_names) for key in drift_result.keys() ], ), **figure_args, ) ) for idx, drift_key in enumerate(drift_result.keys()): row = (idx // number_of_columns) + 1 col = (idx % number_of_columns) + 1 (column_name, method_name) = drift_key.properties reference_result = result.filter(period='reference', column_names=[column_name]) reference_result.data.sort_index(inplace=True) analysis_result = result.filter(period='analysis', column_names=[column_name]) analysis_result.data.sort_index(inplace=True) # reference_alerts = drift_result.filter(period='reference').alerts(drift_key) analysis_alerts = drift_result.filter(period='analysis').alerts(drift_key) figure = _plot_stacked_bar( figure=figure, row=row, col=col, column_name=column_name, reference_value_counts=reference_result.value_counts(column_name=column_name), reference_alerts=None, reference_chunk_keys=reference_result.chunk_keys, reference_chunk_periods=reference_result.chunk_periods, reference_chunk_indices=reference_result.chunk_indices, reference_chunk_start_dates=reference_result.chunk_start_dates, reference_chunk_end_dates=reference_result.chunk_end_dates, analysis_value_counts=analysis_result.value_counts(column_name=column_name), analysis_alerts=analysis_alerts, analysis_chunk_keys=analysis_result.chunk_keys, analysis_chunk_periods=analysis_result.chunk_periods, analysis_chunk_indices=analysis_result.chunk_indices, analysis_chunk_start_dates=analysis_result.chunk_start_dates, analysis_chunk_end_dates=analysis_result.chunk_end_dates, ) return figure def _plot_stacked_bar( figure: Figure, column_name: str, reference_value_counts: pd.DataFrame, analysis_value_counts: pd.DataFrame, reference_alerts: Optional[Union[np.ndarray, pd.Series]] = None, reference_chunk_keys: Optional[Union[np.ndarray, pd.Series]] = None, reference_chunk_periods: Optional[Union[np.ndarray, pd.Series]] = None, reference_chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None, reference_chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None, reference_chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None, analysis_alerts: Optional[Union[np.ndarray, pd.Series]] = None, analysis_chunk_keys: Optional[Union[np.ndarray, pd.Series]] = None, analysis_chunk_periods: Optional[Union[np.ndarray, pd.Series]] = None, analysis_chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None, analysis_chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None, analysis_chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None, row: Optional[int] = None, col: Optional[int] = None, ) -> Figure: is_subplot = row is not None and col is not None subplot_args = dict(row=row, col=col) if is_subplot else None has_reference_results = reference_chunk_indices is not None and len(reference_chunk_indices) > 0 if figure is None: figure = Figure(title='continuous distribution', x_axis_title='time', y_axis_title='value', height=500) figure.update_xaxes( dict(mirror=False, showline=False), overwrite=True, matches='x', title=figure.layout.xaxis.title, row=row, col=col, ) figure.update_yaxes( dict(mirror=False, showline=False), overwrite=True, title=figure.layout.yaxis.title, row=row, col=col ) if has_reference_results: figure = stacked_bar( figure=figure, stacked_bar_table=reference_value_counts, color=Colors.BLUE_SKY_CRAYOLA, chunk_indices=reference_chunk_indices, chunk_start_dates=reference_chunk_start_dates, chunk_end_dates=reference_chunk_end_dates, annotation='Reference', showlegend=True, legendgrouptitle_text=f'<b>{column_name}</b>', legendgroup=column_name, subplot_args=subplot_args, ) assert reference_chunk_indices is not None analysis_chunk_indices = (analysis_chunk_indices + (max(reference_chunk_indices) + 1)).reset_index(drop=True) analysis_value_counts['chunk_indices'] += max(reference_chunk_indices) + 1 if analysis_chunk_start_dates is not None: analysis_chunk_start_dates = analysis_chunk_start_dates.reset_index(drop=True) figure = stacked_bar( figure=figure, stacked_bar_table=analysis_value_counts, color=Colors.INDIGO_PERSIAN, chunk_indices=analysis_chunk_indices, chunk_start_dates=analysis_chunk_start_dates, chunk_end_dates=analysis_chunk_end_dates, annotation='Analysis', showlegend=False, legendgroup=column_name, subplot_args=subplot_args, ) if analysis_alerts is not None: figure = stacked_bar_alert( figure=figure, alerts=analysis_alerts, stacked_bar_table=analysis_value_counts, color=Colors.RED_IMPERIAL, chunk_indices=analysis_chunk_indices, chunk_start_dates=analysis_chunk_start_dates, chunk_end_dates=analysis_chunk_end_dates, showlegend=True, legendgroup=column_name, subplot_args=subplot_args, ) return figure