Calculates Leave-One-Covariate-In (LOCI) scores. Despite the name, this implementation can leave in one or more features at a time.
Details
LOCI measures feature importance by training models with only each individual feature (or feature subset) and comparing their performance to a featureless baseline model (optimal constant prediction). The importance is calculated as (featureless_model_loss - single_feature_loss). Positive values indicate the feature performs better than the baseline, negative values indicate worse performance.
Super classes
xplainfi::FeatureImportanceMethod
-> xplainfi::LeaveOutIn
-> LOCI
Methods
Method new()
Creates a new instance of this R6 class.
Usage
LOCI$new(
task,
learner,
measure,
resampling = NULL,
features = NULL,
iters_refit = 1L,
obs_loss = FALSE,
aggregation_fun = median
)
Arguments
task
(mlr3::Task) Task to compute importance for.
learner
(mlr3::Learner) Learner to use for prediction.
measure
(mlr3::Measure) Measure to use for scoring.
resampling
(mlr3::Resampling) Resampling strategy. Defaults to holdout.
features
(
character()
) Features to compute importance for. Defaults to all features.iters_refit
(
integer(1)
) Number of refit iterations per resampling iteration.obs_loss
(
logical(1)
) Whether to use observation-wise loss calculation (analogous to LOCO) when supported by the measure. IfFALSE
(default), uses aggregated scores.aggregation_fun
(
function
) Function to aggregate observation-wise losses whenobs_loss = TRUE
. Defaults tomedian
, analogous to LOCO.
Examples
library(mlr3)
task = tgen("friedman1")$generate(n = 200)
loci = LOCI$new(
task = task,
learner = lrn("regr.ranger", num.trees = 50),
measure = msr("regr.mse")
)
#> ℹ No <Resampling> provided, using holdout resampling with default ratio.
loci$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: important1 3.551881
#> 2: important2 -6.699716
#> 3: important3 -4.446604
#> 4: important4 5.771280
#> 5: important5 -7.030814
#> 6: unimportant1 -12.061976
#> 7: unimportant2 -4.970430
#> 8: unimportant3 -10.786875
#> 9: unimportant4 -11.829766
#> 10: unimportant5 -6.400764