Calculates Leave-One-Covariate-Out (LOCO) scores. Despite the name, this implementation can leave out one or more features at a time.
Details
LOCO measures feature importance by comparing model performance with and without each feature. For each feature, the model is retrained without that feature and the performance difference (reduced_model_loss - full_model_loss) indicates the feature's importance. Higher values indicate more important features.
References
Lei, Jing, Max, G'Sell, Alessandro, Rinaldo, J. R, Tibshirani, Wasserman, Larry (2018). “Distribution-Free Predictive Inference for Regression.” Journal of the American Statistical Association, 113(523), 1094–1111. ISSN 0162-1459, doi:10.1080/01621459.2017.1307116 .
Super classes
xplainfi::FeatureImportanceMethod
-> xplainfi::LeaveOutIn
-> LOCO
Methods
Method new()
Creates a new instance of this R6 class.
Usage
LOCO$new(
task,
learner,
measure,
resampling = NULL,
features = NULL,
iters_refit = 1L,
obs_loss = FALSE,
aggregation_fun = median
)
Arguments
task
(mlr3::Task) Task to compute importance for.
learner
(mlr3::Learner) Learner to use for prediction.
measure
(mlr3::Measure) Measure to use for scoring.
resampling
(mlr3::Resampling) Resampling strategy. Defaults to holdout.
features
(
character()
) Features to compute importance for. Defaults to all features.iters_refit
(
integer(1)
:1L
) Number of refit iterations per resampling iteration.obs_loss
(
logical(1)
:FALSE
) Whether to use observation-wise loss calculation (original LOCO formulation). IfFALSE
, uses aggregated scores.aggregation_fun
(
function
) Function to aggregate observation-wise losses whenobs_loss = TRUE
. Defaults tomedian
for original LOCO formulation.
Examples
library(mlr3learners)
task = tgen("friedman1")$generate(n = 200)
loco = LOCO$new(
task = task,
learner = lrn("regr.ranger", num.trees = 50),
measure = msr("regr.mse"), obs_loss = TRUE
)
#> ℹ No <Resampling> provided, using holdout resampling with default ratio.
loco$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: important1 2.94732045
#> 2: important2 0.86723577
#> 3: important3 0.01139204
#> 4: important4 3.64957000
#> 5: important5 0.59223158
#> 6: unimportant1 -0.48550724
#> 7: unimportant2 -0.33015697
#> 8: unimportant3 -0.14008650
#> 9: unimportant4 -0.15561077
#> 10: unimportant5 -0.67586552
# Using observation-wise losses to compute the median instead
loco_obsloss = LOCO$new(
task = task,
learner = lrn("regr.ranger", num.trees = 50),
measure = msr("regr.mae"), # to use absolute differences observation-wise
obs_loss = TRUE,
aggregation_fun = median
)
#> ℹ No <Resampling> provided, using holdout resampling with default ratio.
loco_obsloss$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: important1 0.73135175
#> 2: important2 0.22281111
#> 3: important3 0.03009412
#> 4: important4 0.93552712
#> 5: important5 0.07754968
#> 6: unimportant1 -0.18921770
#> 7: unimportant2 -0.17120105
#> 8: unimportant3 -0.06660810
#> 9: unimportant4 -0.11810734
#> 10: unimportant5 -0.08932307
loco_obsloss$obs_losses
#> row_ids feature iteration iter_refit truth response_ref
#> <int> <char> <int> <int> <num> <num>
#> 1: 1 important1 1 1 5.177457 9.124809
#> 2: 1 important2 1 1 5.177457 9.124809
#> 3: 1 important3 1 1 5.177457 9.124809
#> 4: 1 important4 1 1 5.177457 9.124809
#> 5: 1 important5 1 1 5.177457 9.124809
#> ---
#> 666: 197 unimportant1 1 1 20.228495 16.597960
#> 667: 197 unimportant2 1 1 20.228495 16.597960
#> 668: 197 unimportant3 1 1 20.228495 16.597960
#> 669: 197 unimportant4 1 1 20.228495 16.597960
#> 670: 197 unimportant5 1 1 20.228495 16.597960
#> response_feature loss_ref loss_feature obs_diff
#> <num> <num> <num> <num>
#> 1: 10.776419 3.947352 5.598962 1.65160984
#> 2: 10.222532 3.947352 5.045075 1.09772321
#> 3: 9.142916 3.947352 3.965459 0.01810729
#> 4: 11.481945 3.947352 6.304488 2.35713592
#> 5: 10.464568 3.947352 5.287111 1.33975936
#> ---
#> 666: 17.042869 3.630535 3.185626 -0.44490956
#> 667: 17.047377 3.630535 3.181118 -0.44941741
#> 668: 17.722991 3.630535 2.505504 -1.12503138
#> 669: 17.227836 3.630535 3.000659 -0.62987634
#> 670: 17.080910 3.630535 3.147585 -0.48295066