Prediction Tasks
This notebook is part of the CaTabRa GitHub repository.
This short example demonstrates the four prediction tasks supported by CaTabRa:
Familiarity with CaTabRa’s main data analysis workflow is assumed. A step-by-step introduction can be found in CaTabRa Workflow.
Prerequisites
[14]:
import numpy as np
from catabra.analysis import analyze
Binary Classification
Analyze data with a binary target, i.e., each sample belongs to one of two classes.
[2]:
# load dataset
from sklearn.datasets import load_breast_cancer
X_binary, y_binary = load_breast_cancer(as_frame=True, return_X_y=True)
[3]:
# add target labels to DataFrame
X_binary['diagnosis'] = y_binary
[4]:
# split into train- and test set by adding column with corresponding values
# the name of the column is arbitrary; CaTabRa tries to "guess" which samples belong to which set based on the column name and -values
X_binary['train'] = X_binary.index <= 0.8 * len(X_binary)
[5]:
analyze(
X_binary, # table to analyze; can also be the path to a CSV/Excel/HDF5 file
classify='diagnosis', # name of column containing classification target
split='train', # name of column containing information about the train-test split (optional)
time=1, # time budget for hyperparameter tuning, in minutes (optional)
out='binary_classification'
)
[CaTabRa] ### Analysis started at 2023-02-07 10:44:35.243270
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Using AutoML-backend auto-sklearn for binary_classification
[CaTabRa] Successfully loaded the following auto-sklearn add-on module(s): xgb
/home/amaletzk/miniconda3/envs/catabra/lib/python3.9/site-packages/autosklearn/metalearning/metalearning/meta_base.py:68: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
self.metafeatures = self.metafeatures.append(metafeatures)
/home/amaletzk/miniconda3/envs/catabra/lib/python3.9/site-packages/autosklearn/metalearning/metalearning/meta_base.py:72: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
self.algorithm_runs[metric].append(runs)
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.980337
n_constituent_models: 1
total_elapsed_time: 00:04
[CaTabRa] New model #1 trained:
val_roc_auc: 0.980337
val_accuracy: 0.927152
val_balanced_accuracy: 0.928416
train_roc_auc: 1.000000
type: random_forest
total_elapsed_time: 00:04
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.994744
n_constituent_models: 1
total_elapsed_time: 00:05
[CaTabRa] New model #2 trained:
val_roc_auc: 0.994744
val_accuracy: 0.947020
val_balanced_accuracy: 0.947717
train_roc_auc: 0.996970
type: passive_aggressive
total_elapsed_time: 00:05
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.994744
n_constituent_models: 1
total_elapsed_time: 00:06
[CaTabRa] New model #3 trained:
val_roc_auc: 0.970098
val_accuracy: 0.920530
val_balanced_accuracy: 0.915458
train_roc_auc: 1.000000
type: gradient_boosting
total_elapsed_time: 00:06
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.994744
n_constituent_models: 1
total_elapsed_time: 00:07
[CaTabRa] New model #4 trained:
val_roc_auc: 0.975535
val_accuracy: 0.933775
val_balanced_accuracy: 0.934034
train_roc_auc: 0.999866
type: random_forest
total_elapsed_time: 00:07
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.994744
n_constituent_models: 1
total_elapsed_time: 00:09
[CaTabRa] New model #5 trained:
val_roc_auc: 0.969192
val_accuracy: 0.913907
val_balanced_accuracy: 0.914734
train_roc_auc: 1.000000
type: mlp
total_elapsed_time: 00:09
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.994744
n_constituent_models: 1
total_elapsed_time: 00:10
[CaTabRa] New model #6 trained:
val_roc_auc: 0.984143
val_accuracy: 0.933775
val_balanced_accuracy: 0.934034
train_roc_auc: 1.000000
type: gradient_boosting
total_elapsed_time: 00:10
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.994926
n_constituent_models: 2
total_elapsed_time: 00:12
[CaTabRa] New model #7 trained:
val_roc_auc: 0.980065
val_accuracy: 0.907285
val_balanced_accuracy: 0.914009
train_roc_auc: 0.995322
type: extra_trees
total_elapsed_time: 00:12
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.995832
n_constituent_models: 2
total_elapsed_time: 00:17
[CaTabRa] New model #8 trained:
val_roc_auc: 0.994201
val_accuracy: 0.973510
val_balanced_accuracy: 0.972635
train_roc_auc: 1.000000
type: gradient_boosting
total_elapsed_time: 00:17
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.995832
n_constituent_models: 2
total_elapsed_time: 00:19
[CaTabRa] New model #9 trained:
val_roc_auc: 0.978434
val_accuracy: 0.933775
val_balanced_accuracy: 0.926694
train_roc_auc: 0.998574
type: mlp
total_elapsed_time: 00:19
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.995832
n_constituent_models: 2
total_elapsed_time: 00:21
[CaTabRa] New model #10 trained:
val_roc_auc: 0.985502
val_accuracy: 0.940397
val_balanced_accuracy: 0.937206
train_roc_auc: 1.000000
type: random_forest
total_elapsed_time: 00:21
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 2
total_elapsed_time: 00:22
[CaTabRa] New model #11 trained:
val_roc_auc: 0.997282
val_accuracy: 0.940397
val_balanced_accuracy: 0.927419
train_roc_auc: 0.995901
type: passive_aggressive
total_elapsed_time: 00:22
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 2
total_elapsed_time: 00:24
[CaTabRa] New model #12 trained:
val_roc_auc: 0.976260
val_accuracy: 0.920530
val_balanced_accuracy: 0.925245
train_roc_auc: 0.998886
type: random_forest
total_elapsed_time: 00:23
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 2
total_elapsed_time: 00:27
[CaTabRa] New model #13 trained:
val_roc_auc: 0.995107
val_accuracy: 0.973510
val_balanced_accuracy: 0.972635
train_roc_auc: 1.000000
type: gradient_boosting
total_elapsed_time: 00:27
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 2
total_elapsed_time: 00:29
[CaTabRa] New model #14 trained:
val_roc_auc: 0.994020
val_accuracy: 0.953642
val_balanced_accuracy: 0.948441
train_roc_auc: 1.000000
type: mlp
total_elapsed_time: 00:29
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 2
total_elapsed_time: 00:30
[CaTabRa] New model #15 trained:
val_roc_auc: 0.987858
val_accuracy: 0.940397
val_balanced_accuracy: 0.942099
train_roc_auc: 1.000000
type: extra_trees
total_elapsed_time: 00:30
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 3
total_elapsed_time: 00:31
[CaTabRa] New model #16 trained:
val_roc_auc: 0.994020
val_accuracy: 0.947020
val_balanced_accuracy: 0.950163
train_roc_auc: 0.996569
type: sgd
total_elapsed_time: 00:31
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 3
total_elapsed_time: 00:32
[CaTabRa] New model #17 trained:
val_roc_auc: 0.996557
val_accuracy: 0.966887
val_balanced_accuracy: 0.964570
train_roc_auc: 1.000000
type: mlp
total_elapsed_time: 00:32
[CaTabRa] New model #18 trained:
val_roc_auc: 0.974175
val_accuracy: 0.920530
val_balanced_accuracy: 0.920352
train_roc_auc: 1.000000
type: random_forest
total_elapsed_time: 00:33
[CaTabRa] New model #19 trained:
val_roc_auc: 0.981877
val_accuracy: 0.933775
val_balanced_accuracy: 0.934034
train_roc_auc: 0.999955
type: random_forest
total_elapsed_time: 00:35
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 3
total_elapsed_time: 00:36
[CaTabRa] New model #20 trained:
val_roc_auc: 0.987767
val_accuracy: 0.960265
val_balanced_accuracy: 0.956506
train_roc_auc: 1.000000
type: mlp
total_elapsed_time: 00:36
[CaTabRa] New model #21 trained:
val_roc_auc: 0.876858
val_accuracy: 0.880795
val_balanced_accuracy: 0.876858
train_roc_auc: 1.000000
type: k_nearest_neighbors
total_elapsed_time: 00:37
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 3
total_elapsed_time: 00:38
[CaTabRa] New model #22 trained:
val_roc_auc: 0.988039
val_accuracy: 0.960265
val_balanced_accuracy: 0.956506
train_roc_auc: 1.000000
type: gradient_boosting
total_elapsed_time: 00:38
[CaTabRa] New model #23 trained:
val_roc_auc: 0.982331
val_accuracy: 0.913907
val_balanced_accuracy: 0.914734
train_roc_auc: 1.000000
type: random_forest
total_elapsed_time: 00:39
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 3
total_elapsed_time: 00:40
[CaTabRa] New model #24 trained:
val_roc_auc: 0.990395
val_accuracy: 0.960265
val_balanced_accuracy: 0.956506
train_roc_auc: 1.000000
type: gradient_boosting
total_elapsed_time: 00:40
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 3
total_elapsed_time: 00:43
[CaTabRa] New model #25 trained:
val_roc_auc: 0.993838
val_accuracy: 0.933775
val_balanced_accuracy: 0.941374
train_roc_auc: 0.997995
type: passive_aggressive
total_elapsed_time: 00:43
[CaTabRa] New ensemble fitted:
ensemble_val_roc_auc: 0.998007
n_constituent_models: 3
total_elapsed_time: 00:47
[CaTabRa] New model #26 trained:
val_roc_auc: 0.993295
val_accuracy: 0.947020
val_balanced_accuracy: 0.947717
train_roc_auc: 0.994386
type: passive_aggressive
total_elapsed_time: 00:47
[CaTabRa] New model #27 trained:
val_roc_auc: 0.942008
val_accuracy: 0.741722
val_balanced_accuracy: 0.776006
train_roc_auc: 0.971039
type: mlp
total_elapsed_time: 00:50
[CaTabRa] Final training statistics:
n_models_trained: 27
ensemble_val_roc_auc: 0.9980065241029358
[CaTabRa] Creating shap explainer
[CaTabRa] Initialized out-of-distribution detector of type Autoencoder
[CaTabRa] Fitting out-of-distribution detector...
Iteration 1, loss = 0.06377027
Iteration 2, loss = 0.03359140
Iteration 3, loss = 0.02194009
Iteration 4, loss = 0.01638331
Iteration 5, loss = 0.01398731
Iteration 6, loss = 0.01247938
Iteration 7, loss = 0.01109453
Iteration 8, loss = 0.01022823
Iteration 9, loss = 0.00959181
Iteration 10, loss = 0.00878776
Iteration 11, loss = 0.00890786
Iteration 12, loss = 0.00813948
Iteration 13, loss = 0.00792689
Iteration 14, loss = 0.00700066
Iteration 15, loss = 0.00632195
Iteration 16, loss = 0.00601902
Iteration 17, loss = 0.00576786
Iteration 18, loss = 0.00579422
Iteration 19, loss = 0.00653422
Iteration 20, loss = 0.00673960
Iteration 21, loss = 0.00594721
Iteration 22, loss = 0.00580508
Iteration 23, loss = 0.00573597
Iteration 24, loss = 0.00554276
Iteration 25, loss = 0.00560443
Iteration 26, loss = 0.00545064
Iteration 27, loss = 0.00537846
Iteration 28, loss = 0.00532150
Iteration 29, loss = 0.00532396
Iteration 30, loss = 0.00528569
Iteration 31, loss = 0.00528335
Iteration 32, loss = 0.00525824
Iteration 33, loss = 0.00527302
Iteration 34, loss = 0.00526580
Iteration 35, loss = 0.00523524
Iteration 36, loss = 0.00525581
Iteration 37, loss = 0.00522755
Iteration 38, loss = 0.00522458
Iteration 39, loss = 0.00522368
Iteration 40, loss = 0.00521142
Iteration 41, loss = 0.00521995
Iteration 42, loss = 0.00521802
Iteration 43, loss = 0.00521297
Iteration 44, loss = 0.00521383
Iteration 45, loss = 0.00520886
Iteration 46, loss = 0.00521189
Iteration 47, loss = 0.00523315
Iteration 48, loss = 0.00520962
Iteration 49, loss = 0.00522331
Iteration 50, loss = 0.00520331
Iteration 51, loss = 0.00519715
Iteration 52, loss = 0.00521778
Iteration 53, loss = 0.00519735
Iteration 54, loss = 0.00518522
Iteration 55, loss = 0.00518577
Iteration 56, loss = 0.00520093
Iteration 57, loss = 0.00519632
Iteration 58, loss = 0.00517685
Iteration 59, loss = 0.00518872
Iteration 60, loss = 0.00522378
Iteration 61, loss = 0.00520782
Iteration 62, loss = 0.00519619
Iteration 63, loss = 0.00522216
Iteration 64, loss = 0.00528045
Iteration 65, loss = 0.00522250
Iteration 66, loss = 0.00521136
Iteration 67, loss = 0.00520125
Iteration 68, loss = 0.00517974
Iteration 69, loss = 0.00516680
Iteration 70, loss = 0.00516050
Iteration 71, loss = 0.00516956
Iteration 72, loss = 0.00516102
Iteration 73, loss = 0.00524152
Iteration 74, loss = 0.00517434
Iteration 75, loss = 0.00529929
Training loss did not improve more than tol=0.000100 for 50 consecutive epochs. Stopping.
[CaTabRa] Out-of-distribution detector fitted.
[CaTabRa] ### Analysis finished at 2023-02-07 10:45:35.313517
[CaTabRa] ### Elapsed time: 0 days 00:01:00.070247
[CaTabRa] ### Output saved in /mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/binary_classification
[CaTabRa] ### Evaluation started at 2023-02-07 10:45:35.376897
[CaTabRa] Predicting out-of-distribution samples.
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Evaluation results for train:
roc_auc: 0.9973715651135006
accuracy @ 0.5: 0.9736842105263158
balanced_accuracy @ 0.5: 0.96857825567503
[CaTabRa] Evaluation results for not_train:
roc_auc: 0.9986737400530503
accuracy @ 0.5: 0.9734513274336283
balanced_accuracy @ 0.5: 0.95579133510168
[CaTabRa] ### Evaluation finished at 2023-02-07 10:45:40.180348
[CaTabRa] ### Elapsed time: 0 days 00:00:04.803451
[CaTabRa] ### Output saved in /mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/binary_classification/eval
Multiclass Classification
Analyze data with a multiclass target, i.e., each sample belongs to one of n > 2 classes.
[19]:
# load dataset
from sklearn.datasets import load_iris
X_multiclass, y_multiclass = load_iris(as_frame=True, return_X_y=True)
[20]:
# add target labels to DataFrame
X_multiclass['species'] = y_multiclass
[21]:
# split into train- and test set by adding column with corresponding values
# the name of the column is arbitrary; CaTabRa tries to "guess" which samples belong to which set based on the column name and -values
X_multiclass['test'] = np.random.choice([False, True], size=len(X_multiclass), p=[0.8, 0.2])
[22]:
X_multiclass.head()
[22]:
| sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | species | test | |
|---|---|---|---|---|---|---|
| 0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 | True |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 | False |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 | False |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 | False |
| 4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 | True |
[23]:
X_multiclass['species'].value_counts()
[23]:
0 50
1 50
2 50
Name: species, dtype: int64
[24]:
X_multiclass.groupby('test')['species'].value_counts().unstack()
[24]:
| species | 0 | 1 | 2 |
|---|---|---|---|
| test | |||
| False | 36 | 39 | 40 |
| True | 14 | 11 | 10 |
Function analyze() is called just as before. CaTabRa automatically treats the given data as multiclass, because the target column specified by classify contains more than two unique values.
[26]:
analyze(
X_multiclass, # table to analyze; can also be the path to a CSV/Excel/HDF5 file
classify='species', # name of column containing classification target
split='test', # name of column containing information about the train-test split (optional)
time=1, # time budget for hyperparameter tuning, in minutes (optional)
out='multiclass_classification'
)
Output folder "/mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/multiclass_classification" already exists. Delete? [y/n] y
[CaTabRa] ### Analysis started at 2023-02-07 10:52:44.468626
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Using AutoML-backend auto-sklearn for multiclass_classification
[CaTabRa] Successfully loaded the following auto-sklearn add-on module(s): xgb
The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 0.973684
n_constituent_models: 1
total_elapsed_time: 00:05
[CaTabRa] New model #1 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: random_forest
total_elapsed_time: 00:05
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 0.973684
n_constituent_models: 2
total_elapsed_time: 00:07
[CaTabRa] New model #2 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.987013
type: random_forest
total_elapsed_time: 00:07
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 0.973684
n_constituent_models: 2
total_elapsed_time: 00:07
[CaTabRa] New model #3 trained:
val_accuracy: 0.868421
val_balanced_accuracy: 0.871795
train_accuracy: 0.883117
type: passive_aggressive
total_elapsed_time: 00:07
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 0.973684
n_constituent_models: 3
total_elapsed_time: 00:08
[CaTabRa] New model #4 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.974026
type: random_forest
total_elapsed_time: 00:08
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 0.973684
n_constituent_models: 4
total_elapsed_time: 00:10
[CaTabRa] New model #5 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: random_forest
total_elapsed_time: 00:10
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 1.000000
n_constituent_models: 5
total_elapsed_time: 00:10
[CaTabRa] New model #6 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: libsvm_svc
total_elapsed_time: 00:10
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 1.000000
n_constituent_models: 5
total_elapsed_time: 00:12
[CaTabRa] New model #7 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.987013
type: random_forest
total_elapsed_time: 00:12
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 1.000000
n_constituent_models: 6
total_elapsed_time: 00:13
[CaTabRa] New model #8 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.974026
type: mlp
total_elapsed_time: 00:13
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 1.000000
n_constituent_models: 6
total_elapsed_time: 00:14
[CaTabRa] New model #9 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: gradient_boosting
total_elapsed_time: 00:14
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 1.000000
n_constituent_models: 6
total_elapsed_time: 00:14
[CaTabRa] New model #10 trained:
val_accuracy: 1.000000
val_balanced_accuracy: 1.000000
train_accuracy: 0.987013
type: decision_tree
total_elapsed_time: 00:14
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 1.000000
n_constituent_models: 6
total_elapsed_time: 00:15
[CaTabRa] New model #11 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: mlp
total_elapsed_time: 00:15
[CaTabRa] New model #12 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: liblinear_svc
total_elapsed_time: 00:16
[CaTabRa] New model #13 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.974026
type: random_forest
total_elapsed_time: 00:17
[CaTabRa] New model #14 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: gradient_boosting
total_elapsed_time: 00:18
[CaTabRa] New model #15 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.974026
type: random_forest
total_elapsed_time: 00:19
[CaTabRa] New model #16 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.987013
type: adaboost
total_elapsed_time: 00:21
[CaTabRa] New model #17 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: gradient_boosting
total_elapsed_time: 00:22
[CaTabRa] New model #18 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.987013
type: gradient_boosting
total_elapsed_time: 00:23
[CaTabRa] New model #19 trained:
val_accuracy: 0.842105
val_balanced_accuracy: 0.846154
train_accuracy: 0.909091
type: passive_aggressive
total_elapsed_time: 00:23
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 1.000000
n_constituent_models: 7
total_elapsed_time: 00:25
[CaTabRa] New model #20 trained:
val_accuracy: 1.000000
val_balanced_accuracy: 1.000000
train_accuracy: 0.987013
type: extra_trees
total_elapsed_time: 00:25
[CaTabRa] New model #21 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 1.000000
type: mlp
total_elapsed_time: 00:26
[CaTabRa] New model #22 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.987013
type: gradient_boosting
total_elapsed_time: 00:27
[CaTabRa] New model #23 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.974026
type: extra_trees
total_elapsed_time: 00:28
[CaTabRa] New model #24 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.974026
type: random_forest
total_elapsed_time: 00:30
[CaTabRa] New model #25 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.987013
type: mlp
total_elapsed_time: 00:31
[WARNING] [2023-02-07 10:53:17,557:smac.runhistory.runhistory2epm.RunHistory2EPM4LogCost] Got cost of smaller/equal to 0. Replace by 0.000010 since we use log cost.
[CaTabRa] New ensemble fitted:
ensemble_val_accuracy: 1.000000
n_constituent_models: 7
total_elapsed_time: 00:35
[CaTabRa] New model #26 trained:
val_accuracy: 1.000000
val_balanced_accuracy: 1.000000
train_accuracy: 1.000000
type: extra_trees
total_elapsed_time: 00:35
[WARNING] [2023-02-07 10:53:21,478:smac.runhistory.runhistory2epm.RunHistory2EPM4LogCost] Got cost of smaller/equal to 0. Replace by 0.000010 since we use log cost.
[CaTabRa] New model #27 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.961039
type: extra_trees
total_elapsed_time: 00:39
[WARNING] [2023-02-07 10:53:25,091:smac.runhistory.runhistory2epm.RunHistory2EPM4LogCost] Got cost of smaller/equal to 0. Replace by 0.000010 since we use log cost.
[CaTabRa] New model #28 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.987013
type: decision_tree
total_elapsed_time: 00:42
[WARNING] [2023-02-07 10:53:28,507:smac.runhistory.runhistory2epm.RunHistory2EPM4LogCost] Got cost of smaller/equal to 0. Replace by 0.000010 since we use log cost.
[CaTabRa] New model #29 trained:
val_accuracy: 0.657895
val_balanced_accuracy: 0.666667
train_accuracy: 0.649351
type: decision_tree
total_elapsed_time: 00:45
[WARNING] [2023-02-07 10:53:31,791:smac.runhistory.runhistory2epm.RunHistory2EPM4LogCost] Got cost of smaller/equal to 0. Replace by 0.000010 since we use log cost.
[CaTabRa] New model #30 trained:
val_accuracy: 0.684211
val_balanced_accuracy: 0.685897
train_accuracy: 0.727273
type: qda
total_elapsed_time: 00:46
[WARNING] [2023-02-07 10:53:32,555:smac.runhistory.runhistory2epm.RunHistory2EPM4LogCost] Got cost of smaller/equal to 0. Replace by 0.000010 since we use log cost.
[CaTabRa] New model #31 trained:
val_accuracy: 0.973684
val_balanced_accuracy: 0.974359
train_accuracy: 0.935065
type: extra_trees
total_elapsed_time: 00:50
[WARNING] [2023-02-07 10:53:36,301:smac.runhistory.runhistory2epm.RunHistory2EPM4LogCost] Got cost of smaller/equal to 0. Replace by 0.000010 since we use log cost.
[CaTabRa] Final training statistics:
n_models_trained: 31
ensemble_val_accuracy: 1.0
[CaTabRa] Creating shap explainer
[CaTabRa] Initialized out-of-distribution detector of type Autoencoder
[CaTabRa] Fitting out-of-distribution detector...
Iteration 1, loss = 1.61104467
Iteration 2, loss = 1.43654067
Iteration 3, loss = 1.27854978
Iteration 4, loss = 1.13614818
Iteration 5, loss = 1.00836412
Iteration 6, loss = 0.89419353
Iteration 7, loss = 0.79261599
Iteration 8, loss = 0.70261024
Iteration 9, loss = 0.62316871
Iteration 10, loss = 0.55331055
Iteration 11, loss = 0.49209276
Iteration 12, loss = 0.43861932
Iteration 13, loss = 0.39204793
Iteration 14, loss = 0.35159479
Iteration 15, loss = 0.31653725
Iteration 16, loss = 0.28621472
Iteration 17, loss = 0.26002813
Iteration 18, loss = 0.23743812
Iteration 19, loss = 0.21796231
Iteration 20, loss = 0.20117202
Iteration 21, loss = 0.18668841
Iteration 22, loss = 0.17417852
Iteration 23, loss = 0.16349008
Iteration 24, loss = 0.15395270
Iteration 25, loss = 0.14576417
Iteration 26, loss = 0.13859644
Iteration 27, loss = 0.13228710
Iteration 28, loss = 0.12669742
Iteration 29, loss = 0.12170941
Iteration 30, loss = 0.11722327
Iteration 31, loss = 0.11315493
Iteration 32, loss = 0.10943403
Iteration 33, loss = 0.10600198
Iteration 34, loss = 0.10281034
Iteration 35, loss = 0.09981932
Iteration 36, loss = 0.09699653
Iteration 37, loss = 0.09431581
Iteration 38, loss = 0.09175631
Iteration 39, loss = 0.08930155
Iteration 40, loss = 0.08693878
Iteration 41, loss = 0.08465825
Iteration 42, loss = 0.08245272
Iteration 43, loss = 0.08031693
Iteration 44, loss = 0.07824726
Iteration 45, loss = 0.07624131
Iteration 46, loss = 0.07429766
Iteration 47, loss = 0.07241559
Iteration 48, loss = 0.07059486
Iteration 49, loss = 0.06883559
Iteration 50, loss = 0.06713804
Iteration 51, loss = 0.06550254
Iteration 52, loss = 0.06392941
Iteration 53, loss = 0.06241882
Iteration 54, loss = 0.06097080
Iteration 55, loss = 0.05958518
Iteration 56, loss = 0.05826153
Iteration 57, loss = 0.05699920
Iteration 58, loss = 0.05579728
Iteration 59, loss = 0.05465461
Iteration 60, loss = 0.05356982
Iteration 61, loss = 0.05254128
Iteration 62, loss = 0.05156719
Iteration 63, loss = 0.05064558
Iteration 64, loss = 0.04977432
Iteration 65, loss = 0.04895119
Iteration 66, loss = 0.04817387
Iteration 67, loss = 0.04743998
Iteration 68, loss = 0.04674713
Iteration 69, loss = 0.04609293
Iteration 70, loss = 0.04547500
Iteration 71, loss = 0.04489103
Iteration 72, loss = 0.04433879
Iteration 73, loss = 0.04381610
Iteration 74, loss = 0.04332094
Iteration 75, loss = 0.04285135
Iteration 76, loss = 0.04240554
Iteration 77, loss = 0.04198182
Iteration 78, loss = 0.04157865
Iteration 79, loss = 0.04119461
Iteration 80, loss = 0.04082844
Iteration 81, loss = 0.04047897
Iteration 82, loss = 0.04014516
Iteration 83, loss = 0.03982611
Iteration 84, loss = 0.03952099
Iteration 85, loss = 0.03922909
Iteration 86, loss = 0.03894978
Iteration 87, loss = 0.03868249
Iteration 88, loss = 0.03842674
Iteration 89, loss = 0.03818209
Iteration 90, loss = 0.03794816
Iteration 91, loss = 0.03772459
Iteration 92, loss = 0.03751107
Iteration 93, loss = 0.03730730
Iteration 94, loss = 0.03711299
Iteration 95, loss = 0.03692789
Iteration 96, loss = 0.03675174
Iteration 97, loss = 0.03658426
Iteration 98, loss = 0.03642521
Iteration 99, loss = 0.03627433
Iteration 100, loss = 0.03613136
Iteration 101, loss = 0.03599602
Iteration 102, loss = 0.03586806
Iteration 103, loss = 0.03574719
Iteration 104, loss = 0.03563315
Iteration 105, loss = 0.03552564
Iteration 106, loss = 0.03542441
Iteration 107, loss = 0.03532915
Iteration 108, loss = 0.03523961
Iteration 109, loss = 0.03515549
Iteration 110, loss = 0.03507654
Iteration 111, loss = 0.03500249
Iteration 112, loss = 0.03493308
Iteration 113, loss = 0.03486805
Iteration 114, loss = 0.03480716
Iteration 115, loss = 0.03475019
Iteration 116, loss = 0.03469690
Iteration 117, loss = 0.03464708
Iteration 118, loss = 0.03460052
Iteration 119, loss = 0.03455703
Iteration 120, loss = 0.03451642
Iteration 121, loss = 0.03447853
Iteration 122, loss = 0.03444317
Iteration 123, loss = 0.03441021
Iteration 124, loss = 0.03437948
Iteration 125, loss = 0.03435085
Iteration 126, loss = 0.03432420
Iteration 127, loss = 0.03429939
Iteration 128, loss = 0.03427632
Iteration 129, loss = 0.03425488
Iteration 130, loss = 0.03423495
Iteration 131, loss = 0.03421646
Iteration 132, loss = 0.03419930
Iteration 133, loss = 0.03418339
Iteration 134, loss = 0.03416866
Iteration 135, loss = 0.03415502
Iteration 136, loss = 0.03414240
Iteration 137, loss = 0.03413074
Iteration 138, loss = 0.03411998
Iteration 139, loss = 0.03411005
Iteration 140, loss = 0.03410089
Iteration 141, loss = 0.03409246
Iteration 142, loss = 0.03408469
Iteration 143, loss = 0.03407756
Iteration 144, loss = 0.03407100
Iteration 145, loss = 0.03406497
Iteration 146, loss = 0.03405945
Iteration 147, loss = 0.03405439
Iteration 148, loss = 0.03404975
Iteration 149, loss = 0.03404550
Iteration 150, loss = 0.03404162
Iteration 151, loss = 0.03403807
Iteration 152, loss = 0.03403483
Iteration 153, loss = 0.03403188
Iteration 154, loss = 0.03402918
Iteration 155, loss = 0.03402672
Iteration 156, loss = 0.03402448
Iteration 157, loss = 0.03402245
Training loss did not improve more than tol=0.000100 for 50 consecutive epochs. Stopping.
[CaTabRa] Out-of-distribution detector fitted.
[CaTabRa] ### Analysis finished at 2023-02-07 10:53:49.727518
[CaTabRa] ### Elapsed time: 0 days 00:01:05.258892
[CaTabRa] ### Output saved in /mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/multiclass_classification
[CaTabRa] ### Evaluation started at 2023-02-07 10:53:49.730334
[CaTabRa] Predicting out-of-distribution samples.
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Evaluation results for not_test:
accuracy: 0.991304347826087
balanced_accuracy: 0.9914529914529915
[CaTabRa] Evaluation results for test:
accuracy: 0.9428571428571428
balanced_accuracy: 0.9393939393939394
[CaTabRa] ### Evaluation finished at 2023-02-07 10:53:52.687656
[CaTabRa] ### Elapsed time: 0 days 00:00:02.957322
[CaTabRa] ### Output saved in /mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/multiclass_classification/eval
Note how the performance metrics reported during training differ from those in binary classification. This is also reflected in the detailed performance reports in subdirectory eval/.
Multilabel Classification
Analyze data with a multilabel target, i.e., each sample belongs to an arbitrary subset of n >= 2 classes.
[28]:
# load dataset
from sklearn.datasets import fetch_openml
X_multilabel, y_multilabel = fetch_openml(data_id=40595, as_frame=True, return_X_y=True)
[37]:
# add target labels to DataFrame
X_multilabel = X_multilabel.join(y_multilabel == 'TRUE')
[38]:
# split into train- and test set by adding column with corresponding values
# the name of the column is arbitrary; CaTabRa tries to "guess" which samples belong to which set based on the column name and -values
X_multilabel['train'] = X_multilabel.index <= 0.8 * len(X_multilabel)
[39]:
X_multilabel.head()
[39]:
| Att1 | Att2 | Att3 | Att4 | Att5 | Att6 | Att7 | Att8 | Att9 | Att10 | Att11 | Att12 | Att13 | Att14 | Att15 | Att16 | Att17 | Att18 | Att19 | Att20 | Att21 | Att22 | Att23 | Att24 | Att25 | Att26 | Att27 | Att28 | Att29 | Att30 | Att31 | Att32 | Att33 | Att34 | Att35 | Att36 | Att37 | Att38 | Att39 | Att40 | ... | Att262 | Att263 | Att264 | Att265 | Att266 | Att267 | Att268 | Att269 | Att270 | Att271 | Att272 | Att273 | Att274 | Att275 | Att276 | Att277 | Att278 | Att279 | Att280 | Att281 | Att282 | Att283 | Att284 | Att285 | Att286 | Att287 | Att288 | Att289 | Att290 | Att291 | Att292 | Att293 | Att294 | Beach | Sunset | FallFoliage | Field | Mountain | Urban | train | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.646467 | 0.666435 | 0.685047 | 0.699053 | 0.652746 | 0.407864 | 0.150309 | 0.535193 | 0.555689 | 0.580782 | 0.577094 | 0.390455 | 0.242458 | 0.170217 | 0.421797 | 0.428206 | 0.428277 | 0.490017 | 0.459252 | 0.350897 | 0.255987 | 0.310109 | 0.375018 | 0.437369 | 0.451752 | 0.508325 | 0.467347 | 0.567068 | 0.546262 | 0.566969 | 0.612951 | 0.621101 | 0.653561 | 0.694546 | 0.574777 | 0.710196 | 0.614510 | 0.590450 | 0.508313 | 0.645884 | ... | 0.136652 | 0.136285 | 0.127585 | 0.249868 | 0.545665 | 0.252143 | 0.261571 | 0.203095 | 0.172747 | 0.239030 | 0.309251 | 0.090241 | 0.048767 | 0.085062 | 0.072274 | 0.167601 | 0.094636 | 0.258751 | 0.092845 | 0.477150 | 0.224848 | 0.102568 | 0.329816 | 0.061538 | 0.049615 | 0.068962 | 0.653879 | 0.354982 | 0.124074 | 0.157332 | 0.247298 | 0.014025 | 0.029709 | True | False | False | False | True | False | True |
| 1 | 0.770156 | 0.767255 | 0.761053 | 0.745630 | 0.742231 | 0.688086 | 0.708416 | 0.757351 | 0.760633 | 0.740314 | 0.513377 | 0.600421 | 0.542340 | 0.439594 | 0.604272 | 0.624697 | 0.642823 | 0.424883 | 0.448578 | 0.318076 | 0.209851 | 0.570696 | 0.599071 | 0.556610 | 0.556215 | 0.653352 | 0.559962 | 0.473784 | 0.636677 | 0.653249 | 0.621813 | 0.613890 | 0.596795 | 0.596297 | 0.692224 | 0.634007 | 0.605896 | 0.594992 | 0.650470 | 0.582844 | ... | 0.097577 | 0.167246 | 0.193839 | 0.283507 | 0.190554 | 0.072342 | 0.111906 | 0.175488 | 0.178064 | 0.249890 | 0.085085 | 0.073259 | 0.133331 | 0.090761 | 0.138334 | 0.102932 | 0.406639 | 0.126982 | 0.046562 | 0.354085 | 0.199359 | 0.157326 | 0.051859 | 0.114123 | 0.160008 | 0.414088 | 0.361843 | 0.303399 | 0.176387 | 0.251454 | 0.137833 | 0.082672 | 0.036320 | True | False | False | False | False | True | True |
| 2 | 0.793984 | 0.772096 | 0.761820 | 0.762213 | 0.740569 | 0.734361 | 0.722677 | 0.849128 | 0.839607 | 0.812746 | 0.785767 | 0.760288 | 0.751835 | 0.754508 | 0.853808 | 0.857499 | 0.858505 | 0.864827 | 0.865957 | 0.867185 | 0.872483 | 0.955915 | 0.966291 | 0.968941 | 0.879657 | 0.716114 | 0.479571 | 0.402155 | 0.754620 | 0.775176 | 0.723823 | 0.676656 | 0.633313 | 0.552341 | 0.417900 | 0.622198 | 0.652387 | 0.648123 | 0.680452 | 0.662322 | ... | 0.060296 | 0.058945 | 0.052964 | 0.062245 | 0.075563 | 0.006149 | 0.004046 | 0.006033 | 0.181837 | 0.213608 | 0.122532 | 0.035184 | 0.025505 | 0.027821 | 0.353377 | 0.073733 | 0.048943 | 0.080248 | 0.074113 | 0.051372 | 0.024035 | 0.015971 | 0.028559 | 0.047596 | 0.038082 | 0.079977 | 0.004901 | 0.003460 | 0.006049 | 0.017166 | 0.051125 | 0.112506 | 0.083924 | True | False | False | False | False | False | True |
| 3 | 0.938563 | 0.949260 | 0.955621 | 0.966743 | 0.968649 | 0.869619 | 0.696925 | 0.953460 | 0.959631 | 0.966320 | 0.972766 | 0.916497 | 0.622508 | 0.530428 | 0.963539 | 0.972303 | 0.972980 | 0.945388 | 0.609497 | 0.514073 | 0.360757 | 0.804240 | 0.827367 | 0.813407 | 0.796413 | 0.753638 | 0.696435 | 0.520342 | 0.782931 | 0.774347 | 0.750613 | 0.706845 | 0.612971 | 0.647101 | 0.645833 | 0.736683 | 0.719352 | 0.643989 | 0.705878 | 0.773725 | ... | 0.008393 | 0.093743 | 0.105665 | 0.060825 | 0.025972 | 0.045153 | 0.039900 | 0.030980 | 0.448542 | 0.024508 | 0.024751 | 0.045848 | 0.020989 | 0.015197 | 0.209978 | 0.138788 | 0.031173 | 0.032565 | 0.034237 | 0.018757 | 0.082271 | 0.201563 | 0.043669 | 0.027527 | 0.016922 | 0.024174 | 0.036799 | 0.007694 | 0.009735 | 0.019267 | 0.031290 | 0.049780 | 0.090959 | True | False | False | False | False | False | True |
| 4 | 0.512130 | 0.524684 | 0.520020 | 0.504467 | 0.471209 | 0.417654 | 0.364292 | 0.562266 | 0.588592 | 0.584449 | 0.570074 | 0.551043 | 0.503925 | 0.447526 | 0.500117 | 0.539517 | 0.588721 | 0.600226 | 0.588937 | 0.562027 | 0.510786 | 0.465298 | 0.626580 | 0.649661 | 0.629969 | 0.574756 | 0.519651 | 0.445292 | 0.450048 | 0.742275 | 0.784539 | 0.903786 | 0.834243 | 0.766266 | 0.657113 | 0.276264 | 0.394086 | 0.610411 | 0.698119 | 0.743710 | ... | 0.595821 | 0.207690 | 0.028206 | 0.010644 | 0.010589 | 0.138157 | 0.094097 | 0.044848 | 0.036629 | 0.046537 | 0.090652 | 0.086531 | 0.293732 | 0.221770 | 0.094467 | 0.143500 | 0.186763 | 0.074600 | 0.043375 | 0.208570 | 0.188324 | 0.413413 | 0.387559 | 0.158730 | 0.023177 | 0.129994 | 0.167709 | 0.226580 | 0.218534 | 0.198151 | 0.238796 | 0.164270 | 0.184290 | True | False | False | False | False | False | True |
5 rows × 301 columns
Function analyze() is called just as before, with more than one column passed to classify. This tells CaTabRa to treat the data as multilabel.
[40]:
analyze(
X_multilabel, # table to analyze; can also be the path to a CSV/Excel/HDF5 file
classify=['Beach', 'Sunset', 'FallFoliage', 'Field', 'Mountain', 'Urban'],
split='train', # name of column containing information about the train-test split (optional)
time=3, # time budget for hyperparameter tuning, in minutes (optional)
out='multilabel_classification'
)
[CaTabRa] ### Analysis started at 2023-02-07 11:16:46.333073
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Using AutoML-backend auto-sklearn for multilabel_classification
[CaTabRa] Successfully loaded the following auto-sklearn add-on module(s): xgb
[WARNING] [2023-02-07 11:17:00,572:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 28 not found
[WARNING] [2023-02-07 11:17:00,572:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 265 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 7 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 220 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 253 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 355 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 694 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 40 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 668 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 579 not found
[WARNING] [2023-02-07 11:17:00,573:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 422 not found
[WARNING] [2023-02-07 11:17:00,574:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 546 not found
[WARNING] [2023-02-07 11:17:00,574:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 702 not found
[WARNING] [2023-02-07 11:17:00,574:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 270 not found
[WARNING] [2023-02-07 11:17:00,574:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 657 not found
[WARNING] [2023-02-07 11:17:00,574:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 238 not found
[WARNING] [2023-02-07 11:17:00,574:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 206 not found
[WARNING] [2023-02-07 11:17:00,574:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 377 not found
[WARNING] [2023-02-07 11:17:00,574:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 341 not found
[WARNING] [2023-02-07 11:17:00,575:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 235 not found
[WARNING] [2023-02-07 11:17:00,575:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 332 not found
[WARNING] [2023-02-07 11:17:00,575:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 61 not found
[WARNING] [2023-02-07 11:17:00,575:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 690 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 608 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 666 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 367 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 426 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 262 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 302 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 282 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 211 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 701 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 444 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 506 not found
[WARNING] [2023-02-07 11:17:00,576:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 108 not found
[WARNING] [2023-02-07 11:17:00,577:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 460 not found
[WARNING] [2023-02-07 11:17:00,577:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 329 not found
[WARNING] [2023-02-07 11:17:00,577:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 617 not found
[WARNING] [2023-02-07 11:17:00,577:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 68 not found
[WARNING] [2023-02-07 11:17:00,577:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 143 not found
[WARNING] [2023-02-07 11:17:00,578:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 532 not found
[WARNING] [2023-02-07 11:17:00,578:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 659 not found
[WARNING] [2023-02-07 11:17:00,578:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 230 not found
[WARNING] [2023-02-07 11:17:00,579:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 138 not found
[WARNING] [2023-02-07 11:17:00,579:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 75 not found
[WARNING] [2023-02-07 11:17:00,579:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 521 not found
[WARNING] [2023-02-07 11:17:00,579:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 89 not found
[WARNING] [2023-02-07 11:17:00,579:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 315 not found
[WARNING] [2023-02-07 11:17:00,580:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 345 not found
[WARNING] [2023-02-07 11:17:00,580:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 194 not found
[WARNING] [2023-02-07 11:17:00,580:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 14 not found
[WARNING] [2023-02-07 11:17:00,580:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 358 not found
[WARNING] [2023-02-07 11:17:00,580:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 594 not found
[WARNING] [2023-02-07 11:17:00,581:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 129 not found
[WARNING] [2023-02-07 11:17:00,581:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 65 not found
[WARNING] [2023-02-07 11:17:00,581:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 83 not found
[WARNING] [2023-02-07 11:17:00,581:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 227 not found
[WARNING] [2023-02-07 11:17:00,581:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 605 not found
[WARNING] [2023-02-07 11:17:00,581:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 191 not found
[WARNING] [2023-02-07 11:17:00,582:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 277 not found
[WARNING] [2023-02-07 11:17:00,583:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 563 not found
[WARNING] [2023-02-07 11:17:00,583:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 124 not found
[WARNING] [2023-02-07 11:17:00,584:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 121 not found
[WARNING] [2023-02-07 11:17:00,584:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 683 not found
[WARNING] [2023-02-07 11:17:00,584:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 352 not found
[WARNING] [2023-02-07 11:17:00,584:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 434 not found
[WARNING] [2023-02-07 11:17:00,584:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 199 not found
[WARNING] [2023-02-07 11:17:00,584:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 445 not found
[WARNING] [2023-02-07 11:17:00,585:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 318 not found
The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
[WARNING] [2023-02-07 11:17:00,585:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 363 not found
[WARNING] [2023-02-07 11:17:00,585:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 173 not found
[WARNING] [2023-02-07 11:17:00,585:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 647 not found
[WARNING] [2023-02-07 11:17:00,585:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 412 not found
[WARNING] [2023-02-07 11:17:00,585:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 131 not found
[WARNING] [2023-02-07 11:17:00,585:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 494 not found
[WARNING] [2023-02-07 11:17:00,585:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 500 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 37 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 17 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 480 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 524 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 529 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 403 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 628 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 288 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 298 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 171 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 375 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 59 not found
[WARNING] [2023-02-07 11:17:00,586:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 553 not found
[WARNING] [2023-02-07 11:17:00,587:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 621 not found
[WARNING] [2023-02-07 11:17:00,587:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 85 not found
[WARNING] [2023-02-07 11:17:00,588:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 291 not found
[WARNING] [2023-02-07 11:17:00,588:Client-AutoMLSMBO(1)::8d1e93df-a6d0-11ed-8083-00155da455e5] Configuration 273 not found
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.641285
n_constituent_models: 1
total_elapsed_time: 00:11
[CaTabRa] New model #1 trained:
val_f1_macro: 0.641285
train_f1_macro: 1.000000
type: random_forest
total_elapsed_time: 00:11
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.698102
n_constituent_models: 2
total_elapsed_time: 00:37
[CaTabRa] New model #2 trained:
val_f1_macro: 0.698102
train_f1_macro: 1.000000
type: mlp
total_elapsed_time: 00:36
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 2
total_elapsed_time: 01:21
[CaTabRa] New model #3 trained:
val_f1_macro: 0.761772
train_f1_macro: 0.917106
type: mlp
total_elapsed_time: 01:21
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 3
total_elapsed_time: 01:23
[CaTabRa] New model #4 trained:
val_f1_macro: 0.319453
train_f1_macro: 0.904445
type: liblinear_svc
total_elapsed_time: 01:22
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 4
total_elapsed_time: 01:26
[CaTabRa] New model #5 trained:
val_f1_macro: 0.700602
train_f1_macro: 0.869776
type: mlp
total_elapsed_time: 01:25
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 5
total_elapsed_time: 01:36
[CaTabRa] New model #6 trained:
val_f1_macro: 0.593701
train_f1_macro: 0.990705
type: random_forest
total_elapsed_time: 01:36
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 5
total_elapsed_time: 01:54
[CaTabRa] New model #7 trained:
val_f1_macro: 0.020042
train_f1_macro: 1.000000
type: random_forest
total_elapsed_time: 01:53
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 5
total_elapsed_time: 01:59
[CaTabRa] New model #8 trained:
val_f1_macro: 0.673939
train_f1_macro: 1.000000
type: k_nearest_neighbors
total_elapsed_time: 01:59
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 3
total_elapsed_time: 02:08
[CaTabRa] New model #9 trained:
val_f1_macro: 0.640255
train_f1_macro: 0.907360
type: extra_trees
total_elapsed_time: 02:07
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 4
total_elapsed_time: 02:17
[CaTabRa] New model #10 trained:
val_f1_macro: 0.594273
train_f1_macro: 0.984806
type: extra_trees
total_elapsed_time: 02:16
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 4
total_elapsed_time: 02:27
[CaTabRa] New model #11 trained:
val_f1_macro: 0.603230
train_f1_macro: 1.000000
type: random_forest
total_elapsed_time: 02:27
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 4
total_elapsed_time: 02:33
[CaTabRa] New model #12 trained:
val_f1_macro: 0.724957
train_f1_macro: 0.926656
type: mlp
total_elapsed_time: 02:33
[CaTabRa] New model #13 trained:
val_f1_macro: 0.215699
train_f1_macro: 0.662193
type: extra_trees
total_elapsed_time: 02:42
[CaTabRa] New model #14 trained:
val_f1_macro: 0.593483
train_f1_macro: 0.984792
type: random_forest
total_elapsed_time: 02:46
[CaTabRa] New ensemble fitted:
ensemble_val_f1_macro: 0.761772
n_constituent_models: 4
total_elapsed_time: 02:52
[CaTabRa] New model #15 trained:
val_f1_macro: 0.690480
train_f1_macro: 0.715259
type: mlp
total_elapsed_time: 02:52
[CaTabRa] Final training statistics:
n_models_trained: 15
ensemble_val_f1_macro: 0.7617723444629978
[CaTabRa] Creating shap explainer
[CaTabRa] Initialized out-of-distribution detector of type Autoencoder
[CaTabRa] Fitting out-of-distribution detector...
Iteration 1, loss = 0.05982325
Iteration 2, loss = 0.03488524
Iteration 3, loss = 0.02117650
Iteration 4, loss = 0.01949050
Iteration 5, loss = 0.01914386
Iteration 6, loss = 0.01906342
Iteration 7, loss = 0.01904390
Iteration 8, loss = 0.01904470
Iteration 9, loss = 0.01901837
Iteration 10, loss = 0.01901736
Iteration 11, loss = 0.01899080
Iteration 12, loss = 0.01898463
Iteration 13, loss = 0.01898274
Iteration 14, loss = 0.01896644
Iteration 15, loss = 0.01896694
Iteration 16, loss = 0.01898501
Iteration 17, loss = 0.01896577
Iteration 18, loss = 0.01895877
Iteration 19, loss = 0.01894492
Iteration 20, loss = 0.01894275
Iteration 21, loss = 0.01893161
Iteration 22, loss = 0.01893052
Iteration 23, loss = 0.01892016
Iteration 24, loss = 0.01892147
Iteration 25, loss = 0.01891416
Iteration 26, loss = 0.01892668
Iteration 27, loss = 0.01892885
Iteration 28, loss = 0.01891607
Iteration 29, loss = 0.01890401
Iteration 30, loss = 0.01889755
Iteration 31, loss = 0.01891291
Iteration 32, loss = 0.01894366
Iteration 33, loss = 0.01892826
Iteration 34, loss = 0.01890775
Iteration 35, loss = 0.01891401
Iteration 36, loss = 0.01889082
Iteration 37, loss = 0.01889566
Iteration 38, loss = 0.01888600
Iteration 39, loss = 0.01888517
Iteration 40, loss = 0.01887266
Iteration 41, loss = 0.01887094
Iteration 42, loss = 0.01887194
Iteration 43, loss = 0.01888870
Iteration 44, loss = 0.01891678
Iteration 45, loss = 0.01890816
Iteration 46, loss = 0.01889278
Iteration 47, loss = 0.01886188
Iteration 48, loss = 0.01885999
Iteration 49, loss = 0.01885912
Iteration 50, loss = 0.01886987
Iteration 51, loss = 0.01887337
Iteration 52, loss = 0.01885332
Iteration 53, loss = 0.01883346
Iteration 54, loss = 0.01887909
Iteration 55, loss = 0.01889231
Iteration 56, loss = 0.01883884
Training loss did not improve more than tol=0.000100 for 50 consecutive epochs. Stopping.
[CaTabRa] Out-of-distribution detector fitted.
[CaTabRa] ### Analysis finished at 2023-02-07 11:20:40.127400
[CaTabRa] ### Elapsed time: 0 days 00:03:53.794327
[CaTabRa] ### Output saved in /mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/multilabel_classification
[CaTabRa] ### Evaluation started at 2023-02-07 11:20:40.130001
[CaTabRa] Predicting out-of-distribution samples.
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Evaluation results for train:
f1_macro @ 0.5: 0.9725546309829336
[CaTabRa] Evaluation results for not_train:
f1_macro @ 0.5: 0.3461939178627995
invalid value encountered in true_divide
invalid value encountered in true_divide
invalid value encountered in true_divide
invalid value encountered in true_divide
invalid value encountered in true_divide
invalid value encountered in true_divide
invalid value encountered in true_divide
No positive samples in y_true, true positive value should be meaningless and recall is set to 1 for all thresholds
invalid value encountered in true_divide
No positive samples in y_true, true positive value should be meaningless and recall is set to 1 for all thresholds
invalid value encountered in true_divide
No positive samples in y_true, true positive value should be meaningless and recall is set to 1 for all thresholds
[CaTabRa] ### Evaluation finished at 2023-02-07 11:21:24.398684
[CaTabRa] ### Elapsed time: 0 days 00:00:44.268683
[CaTabRa] ### Output saved in /mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/multilabel_classification/eval
Note again how the performance metrics reported during training differ from those in binary- and multiclass classification. This is also reflected in the detailed performance reports in subdirectory eval/.
Regression
Analyze data with one or more regression targets.
[41]:
# load dataset
from sklearn.datasets import load_diabetes
X_regression, y_regression = load_diabetes(as_frame=True, return_X_y=True)
[43]:
# add target labels to DataFrame
X_regression['disease_progression'] = y_regression
[44]:
# split into train- and test set by adding column with corresponding values
# the name of the column is arbitrary; CaTabRa tries to "guess" which samples belong to which set based on the column name and -values
X_regression['train'] = X_regression.index <= 0.8 * len(X_regression)
[45]:
X_regression.head()
[45]:
| age | sex | bmi | bp | s1 | s2 | s3 | s4 | s5 | s6 | disease_progression | train | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.038076 | 0.050680 | 0.061696 | 0.021872 | -0.044223 | -0.034821 | -0.043401 | -0.002592 | 0.019908 | -0.017646 | 151.0 | True |
| 1 | -0.001882 | -0.044642 | -0.051474 | -0.026328 | -0.008449 | -0.019163 | 0.074412 | -0.039493 | -0.068330 | -0.092204 | 75.0 | True |
| 2 | 0.085299 | 0.050680 | 0.044451 | -0.005671 | -0.045599 | -0.034194 | -0.032356 | -0.002592 | 0.002864 | -0.025930 | 141.0 | True |
| 3 | -0.089063 | -0.044642 | -0.011595 | -0.036656 | 0.012191 | 0.024991 | -0.036038 | 0.034309 | 0.022692 | -0.009362 | 206.0 | True |
| 4 | 0.005383 | -0.044642 | -0.036385 | 0.021872 | 0.003935 | 0.015596 | 0.008142 | -0.002592 | -0.031991 | -0.046641 | 135.0 | True |
Function analyze() is again called as before, the only difference being that keyword argument classify is replaced by regress.
[46]:
analyze(
X_regression, # table to analyze; can also be the path to a CSV/Excel/HDF5 file
regress=['disease_progression'], # name(s) of target column(s)
split='train', # name of column containing information about the train-test split (optional)
time=1, # time budget for hyperparameter tuning, in minutes (optional)
out='regression'
)
[CaTabRa] ### Analysis started at 2023-02-07 11:26:30.909830
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Using AutoML-backend auto-sklearn for regression
[CaTabRa] Successfully loaded the following auto-sklearn add-on module(s): xgb
The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.346560
n_constituent_models: 1
total_elapsed_time: 00:14
[CaTabRa] New model #1 trained:
val_r2: 0.346560
val_mean_absolute_error: 47.753839
val_mean_squared_error: 3699.549517
train_r2: 0.924566
type: random_forest
total_elapsed_time: 00:14
[CaTabRa] New model #2 trained:
val_r2: -0.000265
val_mean_absolute_error: 62.965668
val_mean_squared_error: 5663.151483
train_r2: 0.998678
type: gaussian_process
total_elapsed_time: 00:15
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.347943
n_constituent_models: 2
total_elapsed_time: 00:19
[CaTabRa] New model #3 trained:
val_r2: 0.065712
val_mean_absolute_error: 60.908616
val_mean_squared_error: 5289.613941
train_r2: 0.997329
type: gaussian_process
total_elapsed_time: 00:19
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.388678
n_constituent_models: 2
total_elapsed_time: 00:21
[CaTabRa] New model #4 trained:
val_r2: 0.378612
val_mean_absolute_error: 47.679539
val_mean_squared_error: 3518.081929
train_r2: 0.573299
type: ard_regression
total_elapsed_time: 00:21
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.397726
n_constituent_models: 2
total_elapsed_time: 00:22
[CaTabRa] New model #5 trained:
val_r2: 0.368827
val_mean_absolute_error: 47.818353
val_mean_squared_error: 3573.482370
train_r2: 0.918655
type: gradient_boosting
total_elapsed_time: 00:22
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.399933
n_constituent_models: 3
total_elapsed_time: 00:23
[CaTabRa] New model #6 trained:
val_r2: 0.348328
val_mean_absolute_error: 47.907515
val_mean_squared_error: 3689.537293
train_r2: 0.922852
type: gradient_boosting
total_elapsed_time: 00:23
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.425789
n_constituent_models: 3
total_elapsed_time: 00:24
[CaTabRa] New model #7 trained:
val_r2: 0.415240
val_mean_absolute_error: 46.700819
val_mean_squared_error: 3310.708121
train_r2: 0.523863
type: sgd
total_elapsed_time: 00:24
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.425789
n_constituent_models: 3
total_elapsed_time: 00:25
[CaTabRa] New model #8 trained:
val_r2: 0.401955
val_mean_absolute_error: 47.683143
val_mean_squared_error: 3385.920689
train_r2: 0.547418
type: gaussian_process
total_elapsed_time: 00:25
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.425789
n_constituent_models: 3
total_elapsed_time: 00:34
[CaTabRa] New model #9 trained:
val_r2: 0.330392
val_mean_absolute_error: 48.786826
val_mean_squared_error: 3791.084115
train_r2: 1.000000
type: extra_trees
total_elapsed_time: 00:34
[CaTabRa] New model #10 trained:
val_r2: -0.000265
val_mean_absolute_error: 62.965668
val_mean_squared_error: 5663.151483
train_r2: 0.999580
type: gaussian_process
total_elapsed_time: 00:35
[CaTabRa] New model #11 trained:
val_r2: -0.000265
val_mean_absolute_error: 62.965668
val_mean_squared_error: 5663.151483
train_r2: 0.992597
type: gaussian_process
total_elapsed_time: 00:36
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.425789
n_constituent_models: 3
total_elapsed_time: 00:37
[CaTabRa] New model #12 trained:
val_r2: 0.354427
val_mean_absolute_error: 47.265859
val_mean_squared_error: 3655.010389
train_r2: 1.000000
type: extra_trees
total_elapsed_time: 00:37
[CaTabRa] New model #13 trained:
val_r2: -0.000265
val_mean_absolute_error: 62.965668
val_mean_squared_error: 5663.151483
train_r2: 0.999705
type: gaussian_process
total_elapsed_time: 00:38
[CaTabRa] New model #14 trained:
val_r2: -0.000265
val_mean_absolute_error: 62.965668
val_mean_squared_error: 5663.151483
train_r2: 0.999234
type: gaussian_process
total_elapsed_time: 00:38
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.425789
n_constituent_models: 3
total_elapsed_time: 00:40
[CaTabRa] New model #15 trained:
val_r2: 0.374720
val_mean_absolute_error: 46.745134
val_mean_squared_error: 3540.116141
train_r2: 0.977903
type: extra_trees
total_elapsed_time: 00:39
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.437003
n_constituent_models: 5
total_elapsed_time: 00:41
[CaTabRa] New model #16 trained:
val_r2: 0.125783
val_mean_absolute_error: 57.840272
val_mean_squared_error: 4949.511448
train_r2: 0.999744
type: gaussian_process
total_elapsed_time: 00:41
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.425789
n_constituent_models: 3
total_elapsed_time: 00:42
[CaTabRa] New model #17 trained:
val_r2: 0.387407
val_mean_absolute_error: 46.685701
val_mean_squared_error: 3468.285643
train_r2: 0.846457
type: extra_trees
total_elapsed_time: 00:42
[CaTabRa] New model #18 trained:
val_r2: 0.196697
val_mean_absolute_error: 50.261620
val_mean_squared_error: 4548.019091
train_r2: 0.645442
type: libsvm_svr
total_elapsed_time: 00:43
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.428171
n_constituent_models: 4
total_elapsed_time: 00:43
[CaTabRa] New model #19 trained:
val_r2: 0.373981
val_mean_absolute_error: 46.977856
val_mean_squared_error: 3544.300674
train_r2: 0.596925
type: libsvm_svr
total_elapsed_time: 00:43
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.428171
n_constituent_models: 4
total_elapsed_time: 00:44
[CaTabRa] New model #20 trained:
val_r2: 0.384921
val_mean_absolute_error: 48.585789
val_mean_squared_error: 3482.360456
train_r2: 0.528625
type: liblinear_svr
total_elapsed_time: 00:44
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.427500
n_constituent_models: 4
total_elapsed_time: 00:45
[CaTabRa] New model #21 trained:
val_r2: 0.398428
val_mean_absolute_error: 46.894167
val_mean_squared_error: 3405.887542
train_r2: 0.560726
type: sgd
total_elapsed_time: 00:45
[CaTabRa] New model #22 trained:
val_r2: 0.152500
val_mean_absolute_error: 56.274715
val_mean_squared_error: 4798.245744
train_r2: 0.188230
type: libsvm_svr
total_elapsed_time: 00:45
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.427500
n_constituent_models: 4
total_elapsed_time: 00:46
[CaTabRa] New model #23 trained:
val_r2: 0.392710
val_mean_absolute_error: 47.980951
val_mean_squared_error: 3438.262290
train_r2: 0.521770
type: ard_regression
total_elapsed_time: 00:46
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.426474
n_constituent_models: 4
total_elapsed_time: 00:47
[CaTabRa] New model #24 trained:
val_r2: 0.412140
val_mean_absolute_error: 46.539725
val_mean_squared_error: 3328.259648
train_r2: 0.536904
type: ard_regression
total_elapsed_time: 00:47
[CaTabRa] New model #25 trained:
val_r2: 0.226133
val_mean_absolute_error: 50.241453
val_mean_squared_error: 4381.364850
train_r2: 0.648969
type: k_nearest_neighbors
total_elapsed_time: 00:50
[CaTabRa] New model #26 trained:
val_r2: 0.270252
val_mean_absolute_error: 52.848647
val_mean_squared_error: 4131.579680
train_r2: 0.338830
type: extra_trees
total_elapsed_time: 00:51
[CaTabRa] New ensemble fitted:
ensemble_val_r2: 0.426118
n_constituent_models: 4
total_elapsed_time: 00:52
[CaTabRa] New model #27 trained:
val_r2: 0.374823
val_mean_absolute_error: 48.787595
val_mean_squared_error: 3539.532322
train_r2: 0.537110
type: libsvm_svr
total_elapsed_time: 00:52
[CaTabRa] New model #28 trained:
val_r2: 0.292769
val_mean_absolute_error: 49.361707
val_mean_squared_error: 4004.093698
train_r2: 0.991042
type: adaboost
total_elapsed_time: 00:57
[CaTabRa] Final training statistics:
n_models_trained: 28
ensemble_val_r2: 0.4261181505761964
[CaTabRa] Creating shap explainer
[CaTabRa] Initialized out-of-distribution detector of type Autoencoder
[CaTabRa] Fitting out-of-distribution detector...
Iteration 1, loss = 0.25534312
Iteration 2, loss = 0.14401352
Iteration 3, loss = 0.09340346
Iteration 4, loss = 0.06996843
Iteration 5, loss = 0.05760558
Iteration 6, loss = 0.05068051
Iteration 7, loss = 0.04561450
Iteration 8, loss = 0.04126809
Iteration 9, loss = 0.03752003
Iteration 10, loss = 0.03445358
Iteration 11, loss = 0.03203054
Iteration 12, loss = 0.03026330
Iteration 13, loss = 0.02909024
Iteration 14, loss = 0.02833517
Iteration 15, loss = 0.02799340
Iteration 16, loss = 0.02790263
Iteration 17, loss = 0.02800053
Iteration 18, loss = 0.02816715
Iteration 19, loss = 0.02833172
Iteration 20, loss = 0.02843938
Iteration 21, loss = 0.02845107
Iteration 22, loss = 0.02837497
Iteration 23, loss = 0.02823050
Iteration 24, loss = 0.02807415
Iteration 25, loss = 0.02792351
Iteration 26, loss = 0.02779502
Iteration 27, loss = 0.02770411
Iteration 28, loss = 0.02762860
Iteration 29, loss = 0.02758793
Iteration 30, loss = 0.02757642
Iteration 31, loss = 0.02758407
Iteration 32, loss = 0.02758500
Iteration 33, loss = 0.02758951
Iteration 34, loss = 0.02760158
Iteration 35, loss = 0.02759723
Iteration 36, loss = 0.02758483
Iteration 37, loss = 0.02759659
Iteration 38, loss = 0.02757972
Iteration 39, loss = 0.02757239
Iteration 40, loss = 0.02757055
Iteration 41, loss = 0.02756013
Iteration 42, loss = 0.02755758
Iteration 43, loss = 0.02755460
Iteration 44, loss = 0.02755353
Iteration 45, loss = 0.02754807
Iteration 46, loss = 0.02755253
Iteration 47, loss = 0.02755638
Iteration 48, loss = 0.02755479
Iteration 49, loss = 0.02755098
Iteration 50, loss = 0.02755119
Iteration 51, loss = 0.02755607
Iteration 52, loss = 0.02755033
Iteration 53, loss = 0.02755640
Iteration 54, loss = 0.02755645
Iteration 55, loss = 0.02757725
Iteration 56, loss = 0.02755438
Iteration 57, loss = 0.02754601
Iteration 58, loss = 0.02755765
Iteration 59, loss = 0.02755344
Iteration 60, loss = 0.02754537
Iteration 61, loss = 0.02755214
Iteration 62, loss = 0.02755442
Iteration 63, loss = 0.02754713
Iteration 64, loss = 0.02756334
Iteration 65, loss = 0.02755607
Iteration 66, loss = 0.02756124
Iteration 67, loss = 0.02756447
Iteration 68, loss = 0.02754924
Iteration 69, loss = 0.02755174
Iteration 70, loss = 0.02755024
Iteration 71, loss = 0.02754852
Iteration 72, loss = 0.02754915
Iteration 73, loss = 0.02755081
Iteration 74, loss = 0.02755050
Iteration 75, loss = 0.02755403
Iteration 76, loss = 0.02755376
Iteration 77, loss = 0.02754660
Training loss did not improve more than tol=0.000100 for 50 consecutive epochs. Stopping.
[CaTabRa] Out-of-distribution detector fitted.
[CaTabRa] ### Analysis finished at 2023-02-07 11:27:39.991968
[CaTabRa] ### Elapsed time: 0 days 00:01:09.082138
[CaTabRa] ### Output saved in /mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/regression
[CaTabRa] ### Evaluation started at 2023-02-07 11:27:39.994605
[CaTabRa] Predicting out-of-distribution samples.
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Saving descriptive statistics completed
[CaTabRa] Evaluation results for train:
r2: 0.5506768595078299
mean_absolute_error: 41.25392512008969
mean_squared_error: 2602.6718459220265
[CaTabRa] Evaluation results for not_train:
r2: 0.5397624701720876
mean_absolute_error: 43.66707866842096
mean_squared_error: 2978.0636912322097
[CaTabRa] ### Evaluation finished at 2023-02-07 11:27:41.190400
[CaTabRa] ### Elapsed time: 0 days 00:00:01.195795
[CaTabRa] ### Output saved in /mnt/c/Users/amaletzk/Documents/CaTabRa/catabra/examples/regression/eval
The performance metrics reported during training differ from those in classification. This is also reflected in the detailed performance reports in subdirectory eval/.