Source code for nannyml.performance_calculation.metrics.binary_classification

#  Author:   Niels Nuyttens  <niels@nannyml.com>
#
#  License: Apache Software License 2.0
from typing import Optional, Tuple

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score

from nannyml._typing import ProblemType
from nannyml.base import _list_missing
from nannyml.exceptions import InvalidArgumentsException
from nannyml.performance_calculation.metrics.base import Metric, MetricFactory, _common_data_cleaning
from nannyml.sampling_error.binary_classification import (
    accuracy_sampling_error,
    accuracy_sampling_error_components,
    auroc_sampling_error,
    auroc_sampling_error_components,
    f1_sampling_error,
    f1_sampling_error_components,
    precision_sampling_error,
    precision_sampling_error_components,
    recall_sampling_error,
    recall_sampling_error_components,
    specificity_sampling_error,
    specificity_sampling_error_components,
)


[docs]@MetricFactory.register(metric='roc_auc', use_case=ProblemType.CLASSIFICATION_BINARY) class BinaryClassificationAUROC(Metric): """Area under Receiver Operating Curve metric.""" def __init__(self, y_true: str, y_pred: str, y_pred_proba: Optional[str] = None): """Creates a new AUROC instance.""" super().__init__( display_name='ROC AUC', column_name='roc_auc', y_true=y_true, y_pred=y_pred, y_pred_proba=y_pred_proba, lower_threshold_limit=0, upper_threshold_limit=1, ) # sampling error self._sampling_error_components: Tuple = () def __str__(self): return "roc_auc" def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true, self.y_pred_proba], list(reference_data.columns)) self._sampling_error_components = auroc_sampling_error_components( y_true_reference=reference_data[self.y_true], y_pred_proba_reference=reference_data[self.y_pred_proba], ) def _calculate(self, data: pd.DataFrame): """Redefine to handle NaNs and edge cases.""" _list_missing([self.y_true, self.y_pred_proba], list(data.columns)) y_true = data[self.y_true] y_pred = data[self.y_pred_proba] y_true, y_pred = _common_data_cleaning(y_true, y_pred) if y_true.nunique() <= 1: return np.nan else: return roc_auc_score(y_true, y_pred) def _sampling_error(self, data: pd.DataFrame) -> float: return auroc_sampling_error(self._sampling_error_components, data)
[docs]@MetricFactory.register(metric='f1', use_case=ProblemType.CLASSIFICATION_BINARY) class BinaryClassificationF1(Metric): """F1 score metric.""" def __init__(self, y_true: str, y_pred: str, y_pred_proba: Optional[str] = None): """Creates a new F1 instance.""" super().__init__( display_name='F1', column_name='f1', y_true=y_true, y_pred=y_pred, y_pred_proba=y_pred_proba, lower_threshold_limit=0, upper_threshold_limit=1, ) # sampling error self._sampling_error_components: Tuple = () def __str__(self): return "f1" def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(reference_data.columns)) self._sampling_error_components = f1_sampling_error_components( y_true_reference=reference_data[self.y_true], y_pred_reference=reference_data[self.y_pred], ) def _calculate(self, data: pd.DataFrame): """Redefine to handle NaNs and edge cases.""" _list_missing([self.y_true, self.y_pred], list(data.columns)) y_true = data[self.y_true] y_pred = data[self.y_pred] y_true, y_pred = _common_data_cleaning(y_true, y_pred) if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1): return np.nan else: return f1_score(y_true, y_pred) def _sampling_error(self, data: pd.DataFrame) -> float: return f1_sampling_error(self._sampling_error_components, data)
[docs]@MetricFactory.register(metric='precision', use_case=ProblemType.CLASSIFICATION_BINARY) class BinaryClassificationPrecision(Metric): """Precision metric.""" def __init__(self, y_true: str, y_pred: str, y_pred_proba: Optional[str] = None): """Creates a new Precision instance.""" super().__init__( display_name='Precision', column_name='precision', y_true=y_true, y_pred=y_pred, y_pred_proba=y_pred_proba, lower_threshold_limit=0, upper_threshold_limit=1, ) # sampling error self._sampling_error_components: Tuple = () def __str__(self): return "precision" def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(reference_data.columns)) self._sampling_error_components = precision_sampling_error_components( y_true_reference=reference_data[self.y_true], y_pred_reference=reference_data[self.y_pred], ) def _calculate(self, data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(data.columns)) y_true = data[self.y_true] y_pred = data[self.y_pred] y_true, y_pred = _common_data_cleaning(y_true, y_pred) if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1): return np.nan else: return precision_score(y_true, y_pred) def _sampling_error(self, data: pd.DataFrame): return precision_sampling_error(self._sampling_error_components, data)
[docs]@MetricFactory.register(metric='recall', use_case=ProblemType.CLASSIFICATION_BINARY) class BinaryClassificationRecall(Metric): """Recall metric, also known as 'sensitivity'.""" def __init__(self, y_true: str, y_pred: str, y_pred_proba: Optional[str] = None): """Creates a new Recall instance.""" super().__init__( display_name='Recall', column_name='recall', y_true=y_true, y_pred=y_pred, y_pred_proba=y_pred_proba, lower_threshold_limit=0, upper_threshold_limit=1, ) # sampling error self._sampling_error_components: Tuple = () def __str__(self): return "recall" def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(reference_data.columns)) self._sampling_error_components = recall_sampling_error_components( y_true_reference=reference_data[self.y_true], y_pred_reference=reference_data[self.y_pred], ) def _calculate(self, data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(data.columns)) y_true = data[self.y_true] y_pred = data[self.y_pred] y_true, y_pred = _common_data_cleaning(y_true, y_pred) if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1): return np.nan else: return recall_score(y_true, y_pred) def _sampling_error(self, data: pd.DataFrame): return recall_sampling_error(self._sampling_error_components, data)
[docs]@MetricFactory.register(metric='specificity', use_case=ProblemType.CLASSIFICATION_BINARY) class BinaryClassificationSpecificity(Metric): """Specificity metric.""" def __init__(self, y_true: str, y_pred: str, y_pred_proba: Optional[str] = None): """Creates a new F1 instance.""" super().__init__( display_name='Specificity', column_name='specificity', y_true=y_true, y_pred=y_pred, y_pred_proba=y_pred_proba, lower_threshold_limit=0, upper_threshold_limit=1, ) # sampling error self._sampling_error_components: Tuple = () def __str__(self): return "specificity" def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(reference_data.columns)) self._sampling_error_components = specificity_sampling_error_components( y_true_reference=reference_data[self.y_true], y_pred_reference=reference_data[self.y_pred], ) def _calculate(self, data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(data.columns)) y_true = data[self.y_true] y_pred = data[self.y_pred] if y_pred.isna().all(): raise InvalidArgumentsException( f"could not calculate metric {self.display_name}: " "prediction column contains no data" ) y_true, y_pred = _common_data_cleaning(y_true, y_pred) if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1): return np.nan else: tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() return tn / (tn + fp) def _sampling_error(self, data: pd.DataFrame): return specificity_sampling_error(self._sampling_error_components, data)
[docs]@MetricFactory.register(metric='accuracy', use_case=ProblemType.CLASSIFICATION_BINARY) class BinaryClassificationAccuracy(Metric): """Accuracy metric.""" def __init__(self, y_true: str, y_pred: str, y_pred_proba: Optional[str] = None): """Creates a new Accuracy instance.""" super().__init__( display_name='Accuracy', column_name='accuracy', y_true=y_true, y_pred=y_pred, y_pred_proba=y_pred_proba, lower_threshold_limit=0, upper_threshold_limit=1, ) # sampling error self._sampling_error_components: Tuple = () def __str__(self): return "accuracy" def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(reference_data.columns)) self._sampling_error_components = accuracy_sampling_error_components( y_true_reference=reference_data[self.y_true], y_pred_reference=reference_data[self.y_pred], ) def _calculate(self, data: pd.DataFrame): _list_missing([self.y_true, self.y_pred], list(data.columns)) y_true = data[self.y_true] y_pred = data[self.y_pred] if y_pred.isna().all(): raise InvalidArgumentsException( f"could not calculate metric '{self.display_name}': " "prediction column contains no data" ) y_true, y_pred = _common_data_cleaning(y_true, y_pred) if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1): return np.nan else: tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() return (tp + tn) / (tp + tn + fp + fn) def _sampling_error(self, data: pd.DataFrame): return accuracy_sampling_error(self._sampling_error_components, data)