Skip to contents

Calculates Leave-One-Covariate-Out (LOCO) scores. Despite the name, this implementation can leave out one or more features at a time.

Details

LOCO measures feature importance by comparing model performance with and without each feature. For each feature, the model is retrained without that feature and the performance difference (reduced_model_loss - full_model_loss) indicates the feature's importance. Higher values indicate more important features.

References

Lei, Jing, Max, G'Sell, Alessandro, Rinaldo, J. R, Tibshirani, Wasserman, Larry (2018). “Distribution-Free Predictive Inference for Regression.” Journal of the American Statistical Association, 113(523), 1094–1111. ISSN 0162-1459, doi:10.1080/01621459.2017.1307116 .

Methods

Public methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

LOCO$new(
  task,
  learner,
  measure,
  resampling = NULL,
  features = NULL,
  iters_refit = 1L,
  obs_loss = FALSE,
  aggregation_fun = median
)

Arguments

task

(mlr3::Task) Task to compute importance for.

learner

(mlr3::Learner) Learner to use for prediction.

measure

(mlr3::Measure) Measure to use for scoring.

resampling

(mlr3::Resampling) Resampling strategy. Defaults to holdout.

features

(character()) Features to compute importance for. Defaults to all features.

iters_refit

(integer(1): 1L) Number of refit iterations per resampling iteration.

obs_loss

(logical(1): FALSE) Whether to use observation-wise loss calculation (original LOCO formulation). If FALSE, uses aggregated scores.

aggregation_fun

(function) Function to aggregate observation-wise losses when obs_loss = TRUE. Defaults to median for original LOCO formulation.


Method clone()

The objects of this class are cloneable with this method.

Usage

LOCO$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library(mlr3learners)
task = tgen("friedman1")$generate(n = 200)
loco = LOCO$new(
  task = task,
  learner = lrn("regr.ranger", num.trees = 50),
  measure = msr("regr.mse"), obs_loss = TRUE
)
#>  No <Resampling> provided, using holdout resampling with default ratio.
loco$compute()
#> Key: <feature>
#>          feature  importance
#>           <char>       <num>
#>  1:   important1  2.94732045
#>  2:   important2  0.86723577
#>  3:   important3  0.01139204
#>  4:   important4  3.64957000
#>  5:   important5  0.59223158
#>  6: unimportant1 -0.48550724
#>  7: unimportant2 -0.33015697
#>  8: unimportant3 -0.14008650
#>  9: unimportant4 -0.15561077
#> 10: unimportant5 -0.67586552

# Using observation-wise losses to compute the median instead
loco_obsloss = LOCO$new(
  task = task,
  learner = lrn("regr.ranger", num.trees = 50),
  measure = msr("regr.mae"), # to use absolute differences observation-wise
  obs_loss = TRUE,
  aggregation_fun = median
)
#>  No <Resampling> provided, using holdout resampling with default ratio.
loco_obsloss$compute()
#> Key: <feature>
#>          feature  importance
#>           <char>       <num>
#>  1:   important1  0.73135175
#>  2:   important2  0.22281111
#>  3:   important3  0.03009412
#>  4:   important4  0.93552712
#>  5:   important5  0.07754968
#>  6: unimportant1 -0.18921770
#>  7: unimportant2 -0.17120105
#>  8: unimportant3 -0.06660810
#>  9: unimportant4 -0.11810734
#> 10: unimportant5 -0.08932307
loco_obsloss$obs_losses
#>      row_ids      feature iteration iter_refit     truth response_ref
#>        <int>       <char>     <int>      <int>     <num>        <num>
#>   1:       1   important1         1          1  5.177457     9.124809
#>   2:       1   important2         1          1  5.177457     9.124809
#>   3:       1   important3         1          1  5.177457     9.124809
#>   4:       1   important4         1          1  5.177457     9.124809
#>   5:       1   important5         1          1  5.177457     9.124809
#>  ---                                                                 
#> 666:     197 unimportant1         1          1 20.228495    16.597960
#> 667:     197 unimportant2         1          1 20.228495    16.597960
#> 668:     197 unimportant3         1          1 20.228495    16.597960
#> 669:     197 unimportant4         1          1 20.228495    16.597960
#> 670:     197 unimportant5         1          1 20.228495    16.597960
#>      response_feature loss_ref loss_feature    obs_diff
#>                 <num>    <num>        <num>       <num>
#>   1:        10.776419 3.947352     5.598962  1.65160984
#>   2:        10.222532 3.947352     5.045075  1.09772321
#>   3:         9.142916 3.947352     3.965459  0.01810729
#>   4:        11.481945 3.947352     6.304488  2.35713592
#>   5:        10.464568 3.947352     5.287111  1.33975936
#>  ---                                                   
#> 666:        17.042869 3.630535     3.185626 -0.44490956
#> 667:        17.047377 3.630535     3.181118 -0.44941741
#> 668:        17.722991 3.630535     2.505504 -1.12503138
#> 669:        17.227836 3.630535     3.000659 -0.62987634
#> 670:        17.080910 3.630535     3.147585 -0.48295066