Skip to contents

Base class for SAGE (Shapley Additive Global Importance) feature importance based on Shapley values with marginalization. This is an abstract class - use MarginalSAGE or ConditionalSAGE.

Details

SAGE uses Shapley values to fairly distribute the total prediction performance among all features. Unlike perturbation-based methods, SAGE marginalizes features by integrating over their distribution. This is approximated by averaging predictions over a reference dataset.

References

Covert, Ian, Lundberg, M S, Lee, Su-In (2020). “Understanding Global Feature Contributions With Additive Importance Measures.” In Advances in Neural Information Processing Systems, volume 33, 17212–17223. https://proceedings.neurips.cc/paper/2020/hash/c7bf0b7c1a86d5eb3be2c722cf2cf746-Abstract.html.

Super class

xplainfi::FeatureImportanceMethod -> SAGE

Public fields

n_permutations

(integer(1)) Number of permutations to sample.

reference_data

(data.table) Reference dataset for marginalization.

sampler

(FeatureSampler) Sampler object for marginalization.

Methods

Inherited methods


Method new()

Creates a new instance of the SAGE class.

Usage

SAGE$new(
  task,
  learner,
  measure,
  resampling = NULL,
  features = NULL,
  n_permutations = 10L,
  reference_data = NULL,
  batch_size = NULL,
  sampler = NULL,
  max_reference_size = NULL
)

Arguments

task, learner, measure, resampling, features

Passed to FeatureImportanceMethod.

n_permutations

(integer(1): 10L) Number of permutations per coalition to sample for Shapley value estimation. The total number of evaluated coalitions is 1 (empty) + n_permutations * n_features.

reference_data

(data.table | NULL) Optional reference dataset. If NULL, uses training data. For each coalition to evaluate, an expanded datasets of size n_test * n_reference is created and evaluted in batches of batch_size if specified.

batch_size

(integer(1) | NULL) Maximum number of observations to process in a single prediction call. If NULL, processes all at once.

sampler

(FeatureSampler) Sampler for marginalization. Only relevant for ConditionalSAGE.

max_reference_size

(integer(1) | NULL) Maximum size of reference dataset. If reference is larger, it will be subsampled. If NULL, no subsampling is performed.


Method compute()

Compute SAGE values.

Usage

SAGE$compute(store_backends = TRUE, batch_size = NULL)

Arguments

store_backends

(logical(1)) Whether to store backends.

batch_size

(integer(1)) Maximum number of observations to process in a single prediction call. If NULL, processes all at once.


Method clone()

The objects of this class are cloneable with this method.

Usage

SAGE$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.