Skip to contents

Feature Importance Method Class

Feature Importance Method Class

References

Nadeau, Claude, Bengio, Yoshua (2003). “Inference for the Generalization Error.” Machine Learning, 52(3), 239–281. ISSN 1573-0565, doi:10.1023/A:1024068626366 . Molnar, Christoph, Freiesleben, Timo, König, Gunnar, Herbinger, Julia, Reisinger, Tim, Casalicchio, Giuseppe, Wright, N. M, Bischl, Bernd (2023). “Relating the Partial Dependence Plot and Permutation Feature Importance to the Data Generating Process.” In Longo, Luca (eds.), Explainable Artificial Intelligence, 456–479. ISBN 978-3-031-44064-9, doi:10.1007/978-3-031-44064-9_24 .

Public fields

label

(character(1)) Method label.

task

(mlr3::Task)

learner

(mlr3::Learner)

measure

(mlr3::Measure)

resampling

(mlr3::Resampling), instantiated upon construction.

resample_result

(mlr3::ResampleResult) of the original learner and task, used for baseline scores.

features

(character: NULL) Features of interest. By default, importances will be computed for each feature in task, but optionally this can be restricted to at least one feature. Ignored if groups is specified.

groups

(list: NULL) A (named) list of features (names or indices as in task). If groups is specified, features is ignored. Importances will be calculated for group of features at a time, e.g., in PFI not one but the group of features will be permuted at each step. Analogusly in WVIM, each group of features will be left out (or in) for each model refit. Not all methods support groups (e.g., SAGE). See FIXME: vignette or examples.

param_set

(paradox::ps())

predictions

(data.table) Feature-specific prediction objects provided for some methods (PFI, WVIM). Contains columns for feature of interest, resampling iteration, refit or perturbation iteration, and mlr3::Prediction objects.

Methods


Method new()

Creates a new instance of this R6 class. This is typically intended for use by derived classes.

Usage

FeatureImportanceMethod$new(
  task,
  learner,
  measure,
  resampling = NULL,
  features = NULL,
  groups = NULL,
  param_set = paradox::ps(),
  label
)

Arguments

task, learner, measure, resampling, features, groups, param_set, label

Used to set fields


Method compute()

Compute feature importance scores

Usage

FeatureImportanceMethod$compute(store_backends = TRUE)

Arguments

store_backends

(logical(1): TRUE) Whether to store backends.


Method importance()

Get aggregated importance scores. The stored measure object's aggregator (default: mean) will be used to aggregated importance scores across resampling iterations and, depending on the method use, permutations (PerturbationImportance or refits LOCO).

Usage

FeatureImportanceMethod$importance(
  relation = NULL,
  standardize = FALSE,
  ci_method = c("none", "raw", "nadeau_bengio", "quantile"),
  conf_level = 0.95,
  ...
)

Arguments

relation

(character(1)) How to relate perturbed scores to originals ("difference" or "ratio"). If NULL, uses stored parameter value. This is only applicable for methods where importance is based on some relation between baseline and post-modifcation loss, i.e. PerturbationImportance methods such as PFI or WVIM / LOCO. Not available for SAGE methods.

standardize

(logical(1): FALSE) If TRUE, importances are standardized by the highest score so all scores fall in [-1, 1].

ci_method

(character(1): "none") Variance estimation method to use, defaulting to omitting variance estimation ("none"). If "raw", uncorrected variance estimates are provided purely for informative purposes with invalid (too narrow) confidence intervals. If "nadeau_bengio", variance correction is performed according to Nadeau & Bengio (2003) as suggested by Molnar et al. (2023). If "quantile", empirical quantiles are used to construct confidence-like intervals. These methods are model-agnostic and rely on suitable resamplings, e.g. subsampling with 15 repeats for "nadeau_bengio". See details.

conf_level

(numeric(1): 0.95): Conficence level to use for confidence interval construction when ci_method != "none".

...

Additional arguments passen to specialized methods, if any.

Details

Variance estimates for importance scores are biased due to the resampling procedure. Molnar et al. (2023) suggest to use the variance correction factor proposed by Nadeau & Bengio (2003) of n2/n1, where n2 and n1 are the sizes of the test- and train set, respectively. This should then be combined with approx. 15 iterations of either bootstrapping or subsampling.

The use of bootstrapping in this context can lead to problematic information leakage when combined with learners that perform bootstrapping themselves, e.g., Random Forest learners. In such cases, observations may be used as train- and test instances simultaneously, leading to erroneous performance estimates.

An approach leading to still imperfect, but improved variance estimates could be:

PFI$new(
  task = sim_dgp_interactions(n = 1000),
  learner = lrn("regr.ranger", num.trees = 100),
  measure = msr("regr.mse"),
  # Subsampling instead of bootstrapping due to RF
  resampling = rsmp("subsampling", repeats = 15),
  iters_perm = 5
)

iters_perm = 5 in this context only improves the stability of the PFI estimate within the resampling iteration, whereas rsmp("subsampling", repeats = 15) is used to accounter for learner variance and neccessitates variance correction factor.

This appraoch can in principle also be applied to CFI and RFI, but beware that a conditional sample such as ARFSampler also needs to be trained on data, which would need to be taken account by the variance estimation method. Analogously, the "nadeau_bengio" correction was recommended for the use with PFI by Molnar et al., so it's use with LOCO or MarginalSAGE is experimental.

Note that even if measure uses an aggregator function that is not the mean, variance estimation currently will always use mean() and var().

Returns

(data.table) Aggregated importance scores. with variables "feature", "importance" and depending in ci_method also "se", "conf_lower", "conf_upper".


Method obs_loss()

Calculate observation-wise importance scores.

Requires that $compute() was run and that measure is decomposable and has an observation-wise loss (Measure$obs_loss()) associated with it. This is not the case for measure like classif.auc, which is not decomposable.

Usage

FeatureImportanceMethod$obs_loss(relation = NULL)

Arguments

relation

(character(1)) How to relate perturbed scores to originals ("difference" or "ratio"). If NULL, uses stored parameter value. This is only applicable for methods where importance is based on some relation between baseline and post-modifcation loss, i.e. PerturbationImportance methods such as PFI or WVIM / LOCO. Not available for SAGE methods.


Method reset()

Resets all stored fields populated by $compute: $resample_result, $scores, $obs_losses, and $predictions.

Usage

FeatureImportanceMethod$reset()


Method print()

Print importance scores

Usage

FeatureImportanceMethod$print(...)

Arguments

...

Passed to print()


Method scores()

Calculate importance scores for each resampling iteration and sub-iterations (iter_rsmp in PFI for example).

Iteration-wise importance are computed on the fly depending on the chosen relation (difference or ratio) to avoid re-computation if only a different relation is needed.

Usage

FeatureImportanceMethod$scores(relation = NULL)

Arguments

relation

(character(1)) How to relate perturbed scores to originals ("difference" or "ratio"). If NULL, uses stored parameter value. This is only applicable for methods where importance is based on some relation between baseline and post-modifcation loss, i.e. PerturbationImportance methods such as PFI or WVIM / LOCO. Not available for SAGE methods.


Method clone()

The objects of this class are cloneable with this method.

Usage

FeatureImportanceMethod$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.