Source code for nannyml.plots.components.stacked_bar_plot

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#
#  License: Apache Software License 2.0
from typing import Any, Dict, Optional, Union, cast

import matplotlib
import numpy as np
import pandas as pd
import seaborn as sns
from plotly.graph_objects import Bar, Figure

from nannyml.chunk import Chunker
from nannyml.exceptions import InvalidArgumentsException
from nannyml.plots.colors import Colors
from nannyml.plots.components.hover import Hover, render_x_coordinate
from nannyml.plots.util import ensure_numpy, is_time_based_x_axis


[docs]def calculate_value_counts( data: Union[np.ndarray, pd.Series], chunker: Chunker, missing_category_label, max_number_of_categories, timestamps: Optional[Union[np.ndarray, pd.Series]] = None, column_name: Optional[str] = None, ): if isinstance(data, np.ndarray): if column_name is None: raise InvalidArgumentsException("'column_name' can not be None when 'data' is of type 'np.ndarray'.") data = pd.Series(data, name=column_name) else: column_name = data.name data = data.astype("category") cat_str = [str(value) for value in data.cat.categories.values] data = data.cat.rename_categories(cat_str) data = data.cat.add_categories([missing_category_label, 'Other']) data = data.fillna(missing_category_label) if max_number_of_categories: top_categories = data.value_counts().index.tolist()[:max_number_of_categories] if data.nunique() > max_number_of_categories + 1: data.loc[~data.isin(top_categories)] = 'Other' data = data.cat.remove_unused_categories() categories_ordered = data.value_counts().index.tolist() categorical_data = pd.Categorical(data, categories_ordered) # TODO: deal with None timestamps if isinstance(timestamps, pd.Series): timestamps = timestamps.reset_index() data_with_chunk_keys = pd.concat( [ chunk.data.assign(chunk_key=chunk.key, chunk_index=chunk.chunk_index) for chunk in chunker.split(pd.concat([pd.Series(categorical_data, name=column_name), timestamps], axis=1)) ] ) value_counts_table = ( data_with_chunk_keys.groupby(['chunk_key', 'chunk_index'])[column_name] .value_counts() .to_frame('value_counts') .sort_values(by=['chunk_index', 'value_counts']) .reset_index() .rename(columns={'level_2': column_name, 'chunk_index': 'chunk_indices'}) ) value_counts_table['value_counts_total'] = value_counts_table['chunk_key'].map( value_counts_table.groupby('chunk_key')['value_counts'].sum() ) value_counts_table['value_counts_normalised'] = ( value_counts_table['value_counts'] / value_counts_table['value_counts_total'] ) return value_counts_table
[docs]def stacked_bar( figure: Figure, stacked_bar_table: pd.DataFrame, color: str, chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None, chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None, chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None, subplot_args: Optional[Dict[str, Any]] = None, annotation: Optional[str] = None, **kwargs, ) -> Figure: if subplot_args is None: subplot_args = {} chunk_indices, chunk_start_dates, chunk_end_dates = ensure_numpy(chunk_indices, chunk_start_dates, chunk_end_dates) column_name = [ col for col in stacked_bar_table.columns if col not in ('chunk_key', 'chunk_indices', 'value_counts') ][0] categories = stacked_bar_table[column_name].cat.categories category_colors = list( sns.blend_palette( [Colors.INDIGO_PERSIAN, Colors.GRAY, Colors.BLUE_SKY_CRAYOLA], n_colors=len(categories) ).as_hex() ) category_colors_transparent = [ 'rgba{}'.format(matplotlib.colors.to_rgba(matplotlib.colors.to_rgb(color), 1)) for color in category_colors ] figure.update_layout( dict(barmode='relative', paper_bgcolor='rgba(255,255,255,1)', plot_bgcolor='rgba(255,255,255,1)') ) # plot bars for i, category in enumerate(categories): data = stacked_bar_table.loc[stacked_bar_table[column_name] == category] if is_time_based_x_axis(chunk_start_dates, chunk_end_dates): x = data.get('start_datetime', chunk_start_dates) else: x = data.get('chunk_indices', chunk_indices) hover = Hover(template="Chunk %{chunk_key}: %{x_coordinate}; (%{value_counts_normalised}, %{value_counts})") hover.add(data['chunk_key'], name='chunk_key') hover.add( render_x_coordinate( data['chunk_indices'], data.get('start_datetime', chunk_start_dates), data.get('end_datetime', chunk_end_dates), ), name='x_coordinate', ) # hover.add(render_x_coordinate(data['chunk_indices'], chunk_start_dates, chunk_end_dates), name='x_coordinate') hover.add(data['value_counts_normalised'], name='value_counts_normalised') hover.add(data['value_counts'], name='value_counts') figure.add_trace( Bar( name=category, x=x, y=data['value_counts_normalised'], orientation='v', marker=dict(line_color=color, color=category_colors_transparent[i], line_width=0), yperiodalignment="start", offset=0, customdata=hover.get_custom_data(), hovertemplate=hover.get_template(), hoverlabel=dict(bgcolor=category_colors_transparent[i], font=dict(color='white')), **kwargs, ), **subplot_args, ) # Shade chunk type x0 = chunk_start_dates.min() if is_time_based_x_axis(chunk_start_dates, chunk_end_dates) else chunk_indices.min() x1 = chunk_end_dates.max() if is_time_based_x_axis(chunk_start_dates, chunk_end_dates) else chunk_indices.max() + 1 figure.add_shape( y0=0, y1=1.05, x0=x0, x1=x1, line_color='rgba{}'.format(matplotlib.colors.to_rgba(matplotlib.colors.to_rgb(color), 0.5)), layer='above', line_width=2, line=dict(dash='dash'), **subplot_args, ), if annotation: figure.add_annotation( x=pd.Series(chunk_start_dates).mean() if is_time_based_x_axis(chunk_start_dates, chunk_end_dates) else chunk_indices.mean(), y=1.025, text=annotation, font=dict(color=color), align="center", textangle=0, showarrow=False, **subplot_args, ) return figure
[docs]def alert( figure: Figure, stacked_bar_table: pd.DataFrame, alerts: Union[np.ndarray, pd.Series], color: str, chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None, chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None, chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None, subplot_args: Optional[Dict[str, Any]] = None, annotation: Optional[str] = None, **kwargs, ) -> Figure: if chunk_start_dates is None and chunk_end_dates is None and chunk_indices is None: raise InvalidArgumentsException( "you must provide either 'chunk_indices' or 'chunk_start_dates' and " "'chunk_end_dates." ) prv_color = figure.data[0].marker.line.color or figure.data[0].line.color marker_line_colors = [color if val else prv_color for val in list(alerts)] marker_line_widths = [2 if val else 0 for val in list(alerts)] if is_time_based_x_axis(chunk_start_dates, chunk_end_dates): anchor = cast(Union[np.ndarray, pd.Series], chunk_start_dates)[0] else: anchor = cast(Union[np.ndarray, pd.Series], chunk_indices)[0] max_x_axis = max([bar.xaxis for bar in figure.data]) for bars in [bar for bar in figure.data if bar.x[0] == anchor and bar.xaxis == max_x_axis]: bars.marker.line.color = marker_line_colors bars.marker.line.width = marker_line_widths return figure