Estimating Confusion Matrix Elements for Multiclass Classification

This tutorial explains how to use NannyML to estimate the confusion matrix for multiclass classification models in the absence of target data. To find out how CBPE estimates performance, read the explanation of Confidence-based Performance Estimation.

Note

The following example uses timestamps. These are optional but have an impact on the way data is chunked and results are plotted. You can read more about them in the data requirements.

Just The Code

>>> import nannyml as nml
>>> from IPython.display import display

>>> reference_df, analysis_df, _ = nml.load_synthetic_multiclass_classification_dataset()

>>> display(reference_df.head(3))

>>> estimator = nml.CBPE(
...     y_pred_proba={
...         'prepaid_card': 'y_pred_proba_prepaid_card',
...         'highstreet_card': 'y_pred_proba_highstreet_card',
...         'upmarket_card': 'y_pred_proba_upmarket_card'},
...     y_pred='y_pred',
...     y_true='y_true',
...     timestamp_column_name='timestamp',
...     problem_type='classification_multiclass',
...     metrics=['confusion_matrix'],
...     normalize_confusion_matrix='all',
...     chunk_size=6000,
>>> )

>>> estimator.fit(reference_df)

>>> results = estimator.estimate(analysis_df)
>>> display(results.filter(period='analysis').to_df())

>>> metric_fig = results.plot()
>>> metric_fig.show()

Walkthrough

For simplicity this guide is based on a synthetic dataset where the monitored model predicts which type of credit card product new customers should be assigned to. Check out Credit Card Dataset to learn more about this dataset.

In order to monitor a model, NannyML needs to learn about it from a reference dataset. Then it can monitor the data that is subject to actual analysis, provided as the analysis dataset. You can read more about this in our section on data periods.

We start by loading the dataset we’ll be using:

>>> import nannyml as nml
>>> from IPython.display import display

>>> reference_df, analysis_df, _ = nml.load_synthetic_multiclass_classification_dataset()

>>> display(reference_df.head(3))

id

acq_channel

app_behavioral_score

requested_credit_limit

app_channel

credit_bureau_score

stated_income

is_customer

timestamp

y_pred_proba_prepaid_card

y_pred_proba_highstreet_card

y_pred_proba_upmarket_card

y_pred

y_true

0

0

Partner3

1.80823

350

web

309

15000

True

2020-05-02 02:01:30

0.97

0.03

0

prepaid_card

prepaid_card

1

1

Partner2

4.38257

500

mobile

418

23000

True

2020-05-02 02:03:33

0.87

0.13

0

prepaid_card

prepaid_card

2

2

Partner2

-0.787575

400

web

507

24000

False

2020-05-02 02:04:49

0.47

0.35

0.18

prepaid_card

upmarket_card

Next we create the Confidence-based Performance Estimation (CBPE) estimator. To initialize an estimator that estimates the confusion_matrix, we specify the following parameters:

  • y_pred_proba: a dictionary that maps the class names to the name of the column in the reference data that contains the predicted probabilities for that class.

  • y_pred: the name of the column in the reference data that contains the predicted classes.

  • y_true: the name of the column in the reference data that contains the true classes.

  • timestamp_column_name (Optional): the name of the column in the reference data that contains timestamps.

  • metrics: a list of metrics to estimate. In this example we will estimate the confusion_matrix metric.

  • chunk_size (Optional): the number of observations in each chunk of data used to estimate performance. For more information about chunking configurations check out the chunking tutorial.

  • problem_type: the type of problem being monitored. In this example we will monitor a multiclass classification problem.

  • normalize_confusion_matrix (Optional): how to normalize the confusion matrix. The normalization options are:

    • None : returns counts for each cell

    • “true” : normalize over the true class of observations.

    • “pred” : normalize over the predicted class of observations

    • “all” : normalize over all observations

  • thresholds (Optional): the thresholds used to calculate the alert flag. For more information about thresholds, check out the thresholds tutorial.

Note

Since we are estimating the confusion matrix, the count values in each cell of the confusion matrix are estimates. We normalize the estimates just as if they were true counts. This means that when we normalize over the true class, the estimates in each row will sum to 1. When we normalize over the predicted class, the estimates in each column will sum to 1. When we normalize over all observations, the estimates in the entire matrix will sum to 1.

>>> estimator = nml.CBPE(
...     y_pred_proba={
...         'prepaid_card': 'y_pred_proba_prepaid_card',
...         'highstreet_card': 'y_pred_proba_highstreet_card',
...         'upmarket_card': 'y_pred_proba_upmarket_card'},
...     y_pred='y_pred',
...     y_true='y_true',
...     timestamp_column_name='timestamp',
...     problem_type='classification_multiclass',
...     metrics=['confusion_matrix'],
...     normalize_confusion_matrix='all',
...     chunk_size=6000,
>>> )

The CBPE estimator is then fitted using the fit() method on the reference data.

>>> estimator.fit(reference_df)

The fitted estimator can be used to estimate performance on other data, for which performance cannot be calculated. Typically, this would be used on the latest production data where target is missing. In our example this is the analysis_df data.

NannyML can then output a dataframe that contains all the results. Let’s have a look at the results for analysis period only.

>>> results = estimator.estimate(analysis_df)
>>> display(results.filter(period='analysis').to_df())

chunk
key
chunk_index
start_index
end_index
start_date
end_date
period
true_prepaid_card_pred_prepaid_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert
true_prepaid_card_pred_highstreet_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert
true_prepaid_card_pred_upmarket_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert
true_highstreet_card_pred_prepaid_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert
true_highstreet_card_pred_highstreet_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert
true_highstreet_card_pred_upmarket_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert
true_upmarket_card_pred_prepaid_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert
true_upmarket_card_pred_highstreet_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert
true_upmarket_card_pred_upmarket_card
value
sampling_error
realized
upper_confidence_boundary
lower_confidence_boundary
upper_threshold
lower_threshold
alert

0

[0:5999]

0

0

5999

2020-09-01 03:10:01

2020-09-13 16:15:10

analysis

0.261674

0.00555482

nan

0.278339

0.24501

0.265912

0.243455

False

0.0415311

0.00265159

nan

0.0494858

0.0335763

0.0526699

0.0369635

False

0.0379462

0.00262565

nan

0.0458231

0.0300692

0.0421872

0.0289128

False

0.0453291

0.00267108

nan

0.0533423

0.0373158

0.0486211

0.0396456

False

0.247291

0.00562464

nan

0.264165

0.230417

0.262292

0.228341

False

0.0416287

0.00239047

nan

0.0488002

0.0344573

0.0470817

0.039385

False

0.0396632

0.00254241

nan

0.0472904

0.0320359

0.0441763

0.0331904

False

0.0405114

0.00248954

nan

0.04798

0.0330428

0.046008

0.0348254

False

0.244425

0.00561357

nan

0.261266

0.227584

0.270106

0.236227

False

1

[6000:11999]

1

6000

11999

2020-09-13 16:15:32

2020-09-25 19:48:42

analysis

0.252942

0.00555482

nan

0.269606

0.236277

0.265912

0.243455

False

0.039106

0.00265159

nan

0.0470608

0.0311512

0.0526699

0.0369635

False

0.0380279

0.00262565

nan

0.0459048

0.0301509

0.0421872

0.0289128

False

0.0438881

0.00267108

nan

0.0519014

0.0358749

0.0486211

0.0396456

False

0.256042

0.00562464

nan

0.272916

0.239169

0.262292

0.228341

False

0.0415422

0.00239047

nan

0.0487137

0.0343708

0.0470817

0.039385

False

0.0386701

0.00254241

nan

0.0462974

0.0310429

0.0441763

0.0331904

False

0.0423516

0.00248954

nan

0.0498202

0.0348829

0.046008

0.0348254

False

0.24743

0.00561357

nan

0.264271

0.230589

0.270106

0.236227

False

2

[12000:17999]

2

12000

17999

2020-09-25 19:50:04

2020-10-08 02:53:47

analysis

0.256846

0.00555482

nan

0.27351

0.240181

0.265912

0.243455

False

0.0402242

0.00265159

nan

0.0481789

0.0322694

0.0526699

0.0369635

False

0.0373997

0.00262565

nan

0.0452766

0.0295227

0.0421872

0.0289128

False

0.0425611

0.00267108

nan

0.0505744

0.0345479

0.0486211

0.0396456

False

0.247692

0.00562464

nan

0.264566

0.230818

0.262292

0.228341

False

0.0427649

0.00239047

nan

0.0499363

0.0355935

0.0470817

0.039385

False

0.0390932

0.00254241

nan

0.0467204

0.0314659

0.0441763

0.0331904

False

0.0397506

0.00248954

nan

0.0472193

0.032282

0.046008

0.0348254

False

0.253669

0.00561357

nan

0.270509

0.236828

0.270106

0.236227

False

3

[18000:23999]

3

18000

23999

2020-10-08 02:57:34

2020-10-20 15:48:19

analysis

0.257472

0.00555482

nan

0.274136

0.240807

0.265912

0.243455

False

0.0406657

0.00265159

nan

0.0486205

0.032711

0.0526699

0.0369635

False

0.0394442

0.00262565

nan

0.0473212

0.0315673

0.0421872

0.0289128

False

0.0415748

0.00267108

nan

0.0495881

0.0335616

0.0486211

0.0396456

False

0.242519

0.00562464

nan

0.259393

0.225645

0.262292

0.228341

False

0.0413573

0.00239047

nan

0.0485288

0.0341859

0.0470817

0.039385

False

0.03862

0.00254241

nan

0.0462473

0.0309928

0.0441763

0.0331904

False

0.041649

0.00248954

nan

0.0491176

0.0341804

0.046008

0.0348254

False

0.256698

0.00561357

nan

0.273539

0.239858

0.270106

0.236227

False

4

[24000:29999]

4

24000

29999

2020-10-20 15:49:06

2020-11-01 22:04:40

analysis

0.246073

0.00555482

nan

0.262738

0.229409

0.265912

0.243455

False

0.0418344

0.00265159

nan

0.0497891

0.0338796

0.0526699

0.0369635

False

0.0380074

0.00262565

nan

0.0458843

0.0301304

0.0421872

0.0289128

False

0.0427521

0.00267108

nan

0.0507653

0.0347388

0.0486211

0.0396456

False

0.250714

0.00562464

nan

0.267588

0.23384

0.262292

0.228341

False

0.0448168

0.00239047

nan

0.0519882

0.0376454

0.0470817

0.039385

False

0.0381747

0.00254241

nan

0.0458019

0.0305475

0.0441763

0.0331904

False

0.040785

0.00248954

nan

0.0482537

0.0333164

0.046008

0.0348254

False

0.256843

0.00561357

nan

0.273683

0.240002

0.270106

0.236227

False

5

[30000:35999]

5

30000

35999

2020-11-01 22:04:59

2020-11-14 03:55:33

analysis

0.166777

0.00555482

nan

0.183441

0.150112

0.265912

0.243455

True

0.0710278

0.00265159

nan

0.0789826

0.0630731

0.0526699

0.0369635

True

0.0623

0.00262565

nan

0.070177

0.0544231

0.0421872

0.0289128

True

0.0621998

0.00267108

nan

0.070213

0.0541865

0.0486211

0.0396456

True

0.263583

0.00562464

nan

0.280457

0.24671

0.262292

0.228341

True

0.0560677

0.00239047

nan

0.0632392

0.0488963

0.0470817

0.039385

True

0.0495235

0.00254241

nan

0.0571508

0.0418963

0.0441763

0.0331904

True

0.0630554

0.00248954

nan

0.070524

0.0555868

0.046008

0.0348254

True

0.205466

0.00561357

nan

0.222306

0.188625

0.270106

0.236227

True

6

[36000:41999]

6

36000

41999

2020-11-14 03:55:49

2020-11-26 09:19:06

analysis

0.169317

0.00555482

nan

0.185981

0.152652

0.265912

0.243455

True

0.0740581

0.00265159

nan

0.0820128

0.0661033

0.0526699

0.0369635

True

0.0585906

0.00262565

nan

0.0664675

0.0507136

0.0421872

0.0289128

True

0.0627503

0.00267108

nan

0.0707635

0.054737

0.0486211

0.0396456

True

0.269904

0.00562464

nan

0.286778

0.25303

0.262292

0.228341

True

0.0535273

0.00239047

nan

0.0606988

0.0463559

0.0470817

0.039385

True

0.0502663

0.00254241

nan

0.0578936

0.0426391

0.0441763

0.0331904

True

0.0642049

0.00248954

nan

0.0716735

0.0567363

0.046008

0.0348254

True

0.197382

0.00561357

nan

0.214223

0.180541

0.270106

0.236227

True

7

[42000:47999]

7

42000

47999

2020-11-26 09:19:22

2020-12-08 14:33:56

analysis

0.165042

0.00555482

nan

0.181707

0.148378

0.265912

0.243455

True

0.0729706

0.00265159

nan

0.0809254

0.0650159

0.0526699

0.0369635

True

0.0598608

0.00262565

nan

0.0677378

0.0519839

0.0421872

0.0289128

True

0.06332

0.00267108

nan

0.0713333

0.0553068

0.0486211

0.0396456

True

0.275653

0.00562464

nan

0.292527

0.258779

0.262292

0.228341

True

0.0549509

0.00239047

nan

0.0621223

0.0477794

0.0470817

0.039385

True

0.0471376

0.00254241

nan

0.0547648

0.0395104

0.0441763

0.0331904

True

0.0650432

0.00248954

nan

0.0725119

0.0575746

0.046008

0.0348254

True

0.196022

0.00561357

nan

0.212862

0.179181

0.270106

0.236227

True

8

[48000:53999]

8

48000

53999

2020-12-08 14:34:25

2020-12-20 18:30:30

analysis

0.168032

0.00555482

nan

0.184696

0.151367

0.265912

0.243455

True

0.0727937

0.00265159

nan

0.0807484

0.0648389

0.0526699

0.0369635

True

0.0607362

0.00262565

nan

0.0686131

0.0528592

0.0421872

0.0289128

True

0.0638833

0.00267108

nan

0.0718965

0.05587

0.0486211

0.0396456

True

0.266011

0.00562464

nan

0.282885

0.249137

0.262292

0.228341

True

0.056456

0.00239047

nan

0.0636275

0.0492846

0.0470817

0.039385

True

0.0475851

0.00254241

nan

0.0552123

0.0399578

0.0441763

0.0331904

True

0.0620285

0.00248954

nan

0.0694972

0.0545599

0.046008

0.0348254

True

0.202474

0.00561357

nan

0.219315

0.185634

0.270106

0.236227

True

9

[54000:59999]

9

54000

59999

2020-12-20 18:31:09

2021-01-01 22:57:55

analysis

0.162861

0.00555482

nan

0.179525

0.146196

0.265912

0.243455

True

0.0720836

0.00265159

nan

0.0800383

0.0641288

0.0526699

0.0369635

True

0.0601227

0.00262565

nan

0.0679997

0.0522458

0.0421872

0.0289128

True

0.0596525

0.00267108

nan

0.0676658

0.0516393

0.0486211

0.0396456

True

0.270229

0.00562464

nan

0.287103

0.253355

0.262292

0.228341

True

0.0561397

0.00239047

nan

0.0633112

0.0489683

0.0470817

0.039385

True

0.0491533

0.00254241

nan

0.0567806

0.0415261

0.0441763

0.0331904

True

0.0635208

0.00248954

nan

0.0709894

0.0560522

0.046008

0.0348254

True

0.206238

0.00561357

nan

0.223078

0.189397

0.270106

0.236227

True

Apart from chunk-related data, the results data have the following columns for each metric that was estimated:

  • value - the estimate of a metric for a specific chunk.

  • sampling_error - the estimate of the Sampling Error.

  • realized - when target values are available for a chunk, the realized performance metric will also be calculated and included within the results.

  • upper_confidence_boundary and lower_confidence_boundary - These values show the confidence band of the relevant metric and are equal to estimated value +/- 3 times the estimated sampling error.

  • upper_threshold and lower_threshold - crossing these thresholds will raise an alert on significant performance change. The thresholds are calculated based on the actual performance of the monitored model on chunks in the reference partition. The thresholds are 3 standard deviations away from the mean performance calculated on chunks. The thresholds are calculated during fit phase.

  • alert - flag indicating potentially significant performance change. True if estimated performance crosses upper or lower threshold.

These results can be also plotted. Our plot contains several key elements.

  • The purple step plot shows the estimated performance in each chunk of the analysis period. Thick squared point markers indicate the middle of these chunks.

  • The low-saturated purple area around the estimated performance in the analysis period corresponds to the confidence band which is calculated as the estimated performance +/- 3 times the estimated Sampling Error.

  • The gray vertical line splits the reference and analysis periods.

  • The red horizontal dashed lines show upper and lower thresholds for alerting purposes.

  • The red diamond-shaped point markers in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the estimated performance crossing the upper or lower threshold.

>>> metric_fig = results.plot()
>>> metric_fig.show()
../../../_images/tutorial-confusion-matrix-estimation-multiclass-analysis-with-ref.svg

Additional information such as the chunk index range and chunk date range (if timestamps were provided) is shown in the hover for each chunk (these are interactive plots, though only static views are included here).

Insights

After reviewing the performance estimation results, we should be able to see any indications of performance change that NannyML has detected based upon the model’s inputs and outputs alone.

What’s next

The Data Drift functionality can help us to understand whether data drift is causing the performance problem. When the target values become available we can use realized performance calculation to compare realized and estimated confusion matrix results.