Calculates Leave-One-Covariate-Out (LOCO) scores. Despite the name, this implementation can leave out one or more features at a time.
Details
LOCO measures feature importance by comparing model performance with and without each feature. For each feature, the model is retrained without that feature and the performance difference (reduced_model_loss - full_model_loss) indicates the feature's importance. Higher values indicate more important features.
References
Lei, Jing, Max, G'Sell, Alessandro, Rinaldo, J. R, Tibshirani, Wasserman, Larry (2018). “Distribution-Free Predictive Inference for Regression.” Journal of the American Statistical Association, 113(523), 1094–1111. ISSN 0162-1459, doi:10.1080/01621459.2017.1307116 .
Super classes
xplainfi::FeatureImportanceMethod
-> xplainfi::LeaveOutIn
-> LOCO
Methods
Method new()
Creates a new instance of this R6 class.
Usage
LOCO$new(
task,
learner,
measure,
resampling = NULL,
features = NULL,
iters_refit = 1L,
obs_loss = FALSE,
aggregation_fun = median
)
Arguments
task
(mlr3::Task) Task to compute importance for.
learner
(mlr3::Learner) Learner to use for prediction.
measure
(mlr3::Measure) Measure to use for scoring.
resampling
(mlr3::Resampling) Resampling strategy. Defaults to holdout.
features
(
character()
) Features to compute importance for. Defaults to all features.iters_refit
(
integer(1)
:1L
) Number of refit iterations per resampling iteration.obs_loss
(
logical(1)
:FALSE
) Whether to use observation-wise loss calculation (original LOCO formulation). IfFALSE
, uses aggregated scores.aggregation_fun
(
function
) Function to aggregate observation-wise losses whenobs_loss = TRUE
. Defaults tomedian
for original LOCO formulation.
Examples
library(mlr3learners)
task = tgen("friedman1")$generate(n = 200)
loco = LOCO$new(
task = task,
learner = lrn("regr.ranger", num.trees = 50),
measure = msr("regr.mse"), obs_loss = TRUE
)
#> ℹ No <Resampling> provided
#> Using `resampling = rsmp("holdout")` with default `ratio = 0.67`.
loco$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: important1 1.77642119
#> 2: important2 0.78483613
#> 3: important3 0.07920417
#> 4: important4 3.62372184
#> 5: important5 0.61506361
#> 6: unimportant1 0.12979459
#> 7: unimportant2 -0.01680615
#> 8: unimportant3 0.15576714
#> 9: unimportant4 -0.12916100
#> 10: unimportant5 0.36595416
# Using observation-wise losses to compute the median instead
loco_obsloss = LOCO$new(
task = task,
learner = lrn("regr.ranger", num.trees = 50),
measure = msr("regr.mae"), # to use absolute differences observation-wise
obs_loss = TRUE,
aggregation_fun = median
)
#> ℹ No <Resampling> provided
#> Using `resampling = rsmp("holdout")` with default `ratio = 0.67`.
loco_obsloss$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: important1 0.443246431
#> 2: important2 0.304168859
#> 3: important3 0.111510773
#> 4: important4 0.805161323
#> 5: important5 0.097861187
#> 6: unimportant1 -0.020882300
#> 7: unimportant2 0.145543352
#> 8: unimportant3 0.003488681
#> 9: unimportant4 0.031361178
#> 10: unimportant5 0.045435103
loco_obsloss$obs_losses
#> row_ids feature iteration iter_refit truth response_ref
#> <int> <char> <int> <int> <num> <num>
#> 1: 1 important1 1 1 13.690317 13.820347
#> 2: 1 important2 1 1 13.690317 13.820347
#> 3: 1 important3 1 1 13.690317 13.820347
#> 4: 1 important4 1 1 13.690317 13.820347
#> 5: 1 important5 1 1 13.690317 13.820347
#> ---
#> 666: 195 unimportant1 1 1 6.371902 9.950791
#> 667: 195 unimportant2 1 1 6.371902 9.950791
#> 668: 195 unimportant3 1 1 6.371902 9.950791
#> 669: 195 unimportant4 1 1 6.371902 9.950791
#> 670: 195 unimportant5 1 1 6.371902 9.950791
#> response_feature loss_ref loss_feature obs_diff
#> <num> <num> <num> <num>
#> 1: 13.306309 0.1300299 0.3840078 0.25397785
#> 2: 13.529533 0.1300299 0.1607837 0.03075379
#> 3: 14.141133 0.1300299 0.4508165 0.32078658
#> 4: 15.158180 0.1300299 1.4678634 1.33783347
#> 5: 12.552168 0.1300299 1.1381483 1.00811837
#> ---
#> 666: 9.193108 3.5788897 2.8212058 -0.75768388
#> 667: 9.066730 3.5788897 2.6948286 -0.88406104
#> 668: 9.204020 3.5788897 2.8321183 -0.74677132
#> 669: 9.335935 3.5788897 2.9640331 -0.61485659
#> 670: 8.895666 3.5788897 2.5237638 -1.05512585