Skip to contents

Calculates Leave-One-Covariate-Out (LOCO) scores. Despite the name, this implementation can leave out one or more features at a time.

Details

LOCO measures feature importance by comparing model performance with and without each feature. For each feature, the model is retrained without that feature and the performance difference (reduced_model_loss - full_model_loss) indicates the feature's importance. Higher values indicate more important features.

References

Lei, Jing, Max, G'Sell, Alessandro, Rinaldo, J. R, Tibshirani, Wasserman, Larry (2018). “Distribution-Free Predictive Inference for Regression.” Journal of the American Statistical Association, 113(523), 1094–1111. ISSN 0162-1459, doi:10.1080/01621459.2017.1307116 .

Methods

Public methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

LOCO$new(
  task,
  learner,
  measure,
  resampling = NULL,
  features = NULL,
  iters_refit = 1L,
  obs_loss = FALSE,
  aggregation_fun = median
)

Arguments

task

(mlr3::Task) Task to compute importance for.

learner

(mlr3::Learner) Learner to use for prediction.

measure

(mlr3::Measure) Measure to use for scoring.

resampling

(mlr3::Resampling) Resampling strategy. Defaults to holdout.

features

(character()) Features to compute importance for. Defaults to all features.

iters_refit

(integer(1): 1L) Number of refit iterations per resampling iteration.

obs_loss

(logical(1): FALSE) Whether to use observation-wise loss calculation (original LOCO formulation). If FALSE, uses aggregated scores.

aggregation_fun

(function) Function to aggregate observation-wise losses when obs_loss = TRUE. Defaults to median for original LOCO formulation.


Method clone()

The objects of this class are cloneable with this method.

Usage

LOCO$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library(mlr3learners)
task = tgen("friedman1")$generate(n = 200)
loco = LOCO$new(
  task = task,
  learner = lrn("regr.ranger", num.trees = 50),
  measure = msr("regr.mse"), obs_loss = TRUE
)
#>  No <Resampling> provided
#> Using `resampling = rsmp("holdout")` with default `ratio = 0.67`.
loco$compute()
#> Key: <feature>
#>          feature  importance
#>           <char>       <num>
#>  1:   important1  1.77642119
#>  2:   important2  0.78483613
#>  3:   important3  0.07920417
#>  4:   important4  3.62372184
#>  5:   important5  0.61506361
#>  6: unimportant1  0.12979459
#>  7: unimportant2 -0.01680615
#>  8: unimportant3  0.15576714
#>  9: unimportant4 -0.12916100
#> 10: unimportant5  0.36595416

# Using observation-wise losses to compute the median instead
loco_obsloss = LOCO$new(
  task = task,
  learner = lrn("regr.ranger", num.trees = 50),
  measure = msr("regr.mae"), # to use absolute differences observation-wise
  obs_loss = TRUE,
  aggregation_fun = median
)
#>  No <Resampling> provided
#> Using `resampling = rsmp("holdout")` with default `ratio = 0.67`.
loco_obsloss$compute()
#> Key: <feature>
#>          feature   importance
#>           <char>        <num>
#>  1:   important1  0.443246431
#>  2:   important2  0.304168859
#>  3:   important3  0.111510773
#>  4:   important4  0.805161323
#>  5:   important5  0.097861187
#>  6: unimportant1 -0.020882300
#>  7: unimportant2  0.145543352
#>  8: unimportant3  0.003488681
#>  9: unimportant4  0.031361178
#> 10: unimportant5  0.045435103
loco_obsloss$obs_losses
#>      row_ids      feature iteration iter_refit     truth response_ref
#>        <int>       <char>     <int>      <int>     <num>        <num>
#>   1:       1   important1         1          1 13.690317    13.820347
#>   2:       1   important2         1          1 13.690317    13.820347
#>   3:       1   important3         1          1 13.690317    13.820347
#>   4:       1   important4         1          1 13.690317    13.820347
#>   5:       1   important5         1          1 13.690317    13.820347
#>  ---                                                                 
#> 666:     195 unimportant1         1          1  6.371902     9.950791
#> 667:     195 unimportant2         1          1  6.371902     9.950791
#> 668:     195 unimportant3         1          1  6.371902     9.950791
#> 669:     195 unimportant4         1          1  6.371902     9.950791
#> 670:     195 unimportant5         1          1  6.371902     9.950791
#>      response_feature  loss_ref loss_feature    obs_diff
#>                 <num>     <num>        <num>       <num>
#>   1:        13.306309 0.1300299    0.3840078  0.25397785
#>   2:        13.529533 0.1300299    0.1607837  0.03075379
#>   3:        14.141133 0.1300299    0.4508165  0.32078658
#>   4:        15.158180 0.1300299    1.4678634  1.33783347
#>   5:        12.552168 0.1300299    1.1381483  1.00811837
#>  ---                                                    
#> 666:         9.193108 3.5788897    2.8212058 -0.75768388
#> 667:         9.066730 3.5788897    2.6948286 -0.88406104
#> 668:         9.204020 3.5788897    2.8321183 -0.74677132
#> 669:         9.335935 3.5788897    2.9640331 -0.61485659
#> 670:         8.895666 3.5788897    2.5237638 -1.05512585