Skip to contents

SAGE with conditional sampling (features are marginalized conditionally). Uses ARF by default for conditional marginalization.

Super classes

xplainfi::FeatureImportanceMethod -> xplainfi::SAGE -> ConditionalSAGE

Methods

Inherited methods


Method new()

Creates a new instance of the ConditionalSAGE class.

Usage

ConditionalSAGE$new(
  task,
  learner,
  measure,
  resampling = NULL,
  features = NULL,
  n_permutations = 10L,
  reference_data = NULL,
  sampler = NULL,
  max_reference_size = NULL
)

Arguments

task, learner, measure, resampling, features

Passed to SAGE.

n_permutations

(integer(1)) Number of permutations to sample.

reference_data

(data.table) Optional reference dataset.

sampler

(ConditionalSampler) Optional custom sampler. Defaults to ARFSampler.

max_reference_size

(integer(1)) Maximum size of reference dataset.


Method clone()

The objects of this class are cloneable with this method.

Usage

ConditionalSAGE$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library(mlr3)
task = tgen("friedman1")$generate(n = 100)
sage = ConditionalSAGE$new(
  task = task,
  learner = lrn("regr.ranger", num.trees = 50),
  measure = msr("regr.mse"),
  n_permutations = 3L
)
#>  No <ConditionalSampler> provided, using <ARFSampler> with default settings.
#>  No <Resampling> provided, using holdout resampling with default ratio.
sage$compute()
#> Key: <feature>
#>          feature  importance
#>           <char>       <num>
#>  1:   important1  4.58861248
#>  2:   important2 -0.10848016
#>  3:   important3  0.08622897
#>  4:   important4  4.78590789
#>  5:   important5  2.26613012
#>  6: unimportant1 -0.51814805
#>  7: unimportant2  0.28900725
#>  8: unimportant3 -0.69336275
#>  9: unimportant4 -0.05010833
#> 10: unimportant5 -0.14724626

# Use batching for memory efficiency with large datasets
sage$compute(batch_size = 1000)
#> Key: <feature>
#>          feature  importance
#>           <char>       <num>
#>  1:   important1  4.58861248
#>  2:   important2 -0.10848016
#>  3:   important3  0.08622897
#>  4:   important4  4.78590789
#>  5:   important5  2.26613012
#>  6: unimportant1 -0.51814805
#>  7: unimportant2  0.28900725
#>  8: unimportant3 -0.69336275
#>  9: unimportant4 -0.05010833
#> 10: unimportant5 -0.14724626