Skip to contents

Base class for SAGE (Shapley Additive Global Importance) feature importance based on Shapley values with marginalization. This is an abstract class - use MarginalSAGE or ConditionalSAGE.

Details

SAGE uses Shapley values to fairly distribute the total prediction performance among all features. Unlike perturbation-based methods, SAGE marginalizes features by integrating over their distribution. This is approximated by averaging predictions over a reference dataset.

Standard Error Calculation: The standard errors (SE) reported in $convergence_history reflect the uncertainty in Shapley value estimation across different random permutations within a single resampling iteration. These SEs quantify the Monte Carlo sampling error for a fixed trained model and are only valid for inference about the importance of features for that specific model. They do not capture broader uncertainty from model variability across different train/test splits or resampling iterations.

References

Covert, Ian, Lundberg, M S, Lee, Su-In (2020). “Understanding Global Feature Contributions With Additive Importance Measures.” In Advances in Neural Information Processing Systems, volume 33, 17212–17223. https://proceedings.neurips.cc/paper/2020/hash/c7bf0b7c1a86d5eb3be2c722cf2cf746-Abstract.html.

Super class

xplainfi::FeatureImportanceMethod -> SAGE

Public fields

n_permutations

(integer(1)) Number of permutations to sample.

reference_data

(data.table) Reference dataset for marginalization.

sampler

(FeatureSampler) Sampler object for marginalization.

convergence_history

(data.table) History of SAGE values during computation.

converged

(logical(1)) Whether convergence was detected.

n_permutations_used

(integer(1)) Actual number of permutations used.

Methods

Inherited methods


Method new()

Creates a new instance of the SAGE class.

Usage

SAGE$new(
  task,
  learner,
  measure,
  resampling = NULL,
  features = NULL,
  n_permutations = 10L,
  reference_data = NULL,
  batch_size = 5000L,
  sampler = NULL,
  max_reference_size = 100L,
  early_stopping = FALSE,
  convergence_threshold = 0.01,
  se_threshold = Inf,
  min_permutations = 10L,
  check_interval = 2L
)

Arguments

task, learner, measure, resampling, features

Passed to FeatureImportanceMethod.

n_permutations

(integer(1): 10L) Number of permutations per coalition to sample for Shapley value estimation. The total number of evaluated coalitions is 1 (empty) + n_permutations * n_features.

reference_data

(data.table | NULL) Optional reference dataset. If NULL, uses training data. For each coalition to evaluate, an expanded datasets of size n_test * n_reference is created and evaluted in batches of batch_size.

batch_size

(integer(1): 5000L) Maximum number of observations to process in a single prediction call.

sampler

(FeatureSampler) Sampler for marginalization. Only relevant for ConditionalSAGE.

max_reference_size

(integer(1): 100L) Maximum size of reference dataset. If reference is larger, it will be subsampled.

early_stopping

(logical(1): FALSE) Whether to enable early stopping based on convergence detection.

convergence_threshold

(numeric(1): 0.01) Relative change threshold for convergence detection.

se_threshold

(numeric(1): Inf) Standard error threshold for convergence detection.

min_permutations

(integer(1): 10L) Minimum permutations before checking convergence.

check_interval

(integer(1): 2L) Check convergence every N permutations.


Method compute()

Compute SAGE values.

Usage

SAGE$compute(
  store_backends = TRUE,
  batch_size = NULL,
  early_stopping = NULL,
  convergence_threshold = NULL,
  se_threshold = NULL,
  min_permutations = NULL,
  check_interval = NULL
)

Arguments

store_backends

(logical(1)) Whether to store backends.

batch_size

(integer(1): 5000L) Maximum number of observations to process in a single prediction call.

early_stopping

(logical(1)) Whether to check for convergence and stop early.

convergence_threshold

(numeric(1)) Relative change threshold for convergence detection.

se_threshold

(numeric(1)) Standard error threshold for convergence detection.

min_permutations

(integer(1)) Minimum permutations before checking convergence.

check_interval

(integer(1)) Check convergence every N permutations.


Method plot_convergence()

Plot convergence history of SAGE values.

Usage

SAGE$plot_convergence(features = NULL)

Arguments

features

(character | NULL) Features to plot. If NULL, plots all features.

Returns

A ggplot2 object


Method clone()

The objects of this class are cloneable with this method.

Usage

SAGE$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.