Base class for SAGE (Shapley Additive Global Importance) feature importance based on Shapley values with marginalization. This is an abstract class - use MarginalSAGE or ConditionalSAGE.
Details
SAGE uses Shapley values to fairly distribute the total prediction performance among all features. Unlike perturbation-based methods, SAGE marginalizes features by integrating over their distribution. This is approximated by averaging predictions over a reference dataset.
Standard Error Calculation: The standard errors (SE) reported in
$convergence_history
reflect the uncertainty in Shapley value estimation
across different random permutations within a single resampling iteration.
These SEs quantify the Monte Carlo sampling error for a fixed trained model
and are only valid for inference about the importance of features for that
specific model. They do not capture broader uncertainty from model variability
across different train/test splits or resampling iterations.
References
Covert, Ian, Lundberg, M S, Lee, Su-In (2020). “Understanding Global Feature Contributions With Additive Importance Measures.” In Advances in Neural Information Processing Systems, volume 33, 17212–17223. https://proceedings.neurips.cc/paper/2020/hash/c7bf0b7c1a86d5eb3be2c722cf2cf746-Abstract.html.
Public fields
n_permutations
(
integer(1)
) Number of permutations to sample.reference_data
(
data.table
) Reference dataset for marginalization.sampler
(FeatureSampler) Sampler object for marginalization.
convergence_history
(
data.table
) History of SAGE values during computation.converged
(
logical(1)
) Whether convergence was detected.n_permutations_used
(
integer(1)
) Actual number of permutations used.
Methods
Method new()
Creates a new instance of the SAGE class.
Usage
SAGE$new(
task,
learner,
measure,
resampling = NULL,
features = NULL,
n_permutations = 10L,
reference_data = NULL,
batch_size = 5000L,
sampler = NULL,
max_reference_size = 100L,
early_stopping = FALSE,
convergence_threshold = 0.01,
se_threshold = Inf,
min_permutations = 10L,
check_interval = 2L
)
Arguments
task, learner, measure, resampling, features
Passed to FeatureImportanceMethod.
n_permutations
(
integer(1): 10L
) Number of permutations per coalition to sample for Shapley value estimation. The total number of evaluated coalitions is1 (empty) + n_permutations * n_features
.reference_data
(
data.table
|NULL
) Optional reference dataset. IfNULL
, uses training data. For each coalition to evaluate, an expanded datasets of sizen_test * n_reference
is created and evaluted in batches ofbatch_size
.batch_size
(
integer(1): 5000L
) Maximum number of observations to process in a single prediction call.sampler
(FeatureSampler) Sampler for marginalization. Only relevant for
ConditionalSAGE
.max_reference_size
(
integer(1): 100L
) Maximum size of reference dataset. If reference is larger, it will be subsampled.early_stopping
(
logical(1): FALSE
) Whether to enable early stopping based on convergence detection.convergence_threshold
(
numeric(1): 0.01
) Relative change threshold for convergence detection.se_threshold
(
numeric(1): Inf
) Standard error threshold for convergence detection.min_permutations
(
integer(1): 10L
) Minimum permutations before checking convergence.check_interval
(
integer(1): 2L
) Check convergence every N permutations.
Method compute()
Compute SAGE values.
Usage
SAGE$compute(
store_backends = TRUE,
batch_size = NULL,
early_stopping = NULL,
convergence_threshold = NULL,
se_threshold = NULL,
min_permutations = NULL,
check_interval = NULL
)
Arguments
store_backends
(
logical(1)
) Whether to store backends.batch_size
(
integer(1)
:5000L
) Maximum number of observations to process in a single prediction call.early_stopping
(
logical(1)
) Whether to check for convergence and stop early.convergence_threshold
(
numeric(1)
) Relative change threshold for convergence detection.se_threshold
(
numeric(1)
) Standard error threshold for convergence detection.min_permutations
(
integer(1)
) Minimum permutations before checking convergence.check_interval
(
integer(1)
) Check convergence every N permutations.