Skip to contents

Implements conditional sampling using Knockoffs.

Details

The KnockoffSampler samples Knockoffs based on the task data.

References

Watson D, Wright M (2021). “Testing conditional independence in supervised learning algorithms.” Machine Learning, 110(8), 2107-2129. doi:10.1007/s10994-021-06030-6 .

Blesch K, Watson D, Wright M (2023). “Conditional feature importance for mixed data.” AStA Advances in Statistical Analysis, 108(2), 259-278. doi:10.1007/s10182-023-00477-9 .

Super classes

xplainfi::FeatureSampler -> xplainfi::ConditionalSampler -> KnockoffSampler

Public fields

x_tilde

Knockoff matrix

Methods


Method new()

Creates a new instance of the KnockoffSampler class.

Usage

KnockoffSampler$new(
  task,
  knockoff_fun = function(x) knockoff::create.second_order(as.matrix(x))
)

Arguments

task

(mlr3::Task) Task to sample from

knockoff_fun

(function) Step size for variance adjustment. Default are second-order Gaussian knockoffs.


Method sample()

Sample values for feature(s) conditionally on other features using Knockoffs

Usage

KnockoffSampler$sample(feature, data = self$task$data())

Arguments

feature

(character) Feature(s) of interest to sample (can be single or multiple)

data

(data.table) Data containing conditioning features. Defaults to $task$data(), but typically a dedicated test set is provided.

Returns

Modified copy of the input data with the feature(s) sampled conditionally


Method clone()

The objects of this class are cloneable with this method.

Usage

KnockoffSampler$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library(mlr3)
task = tgen("2dnormals")$generate(n = 100)
# Create sampler with default parameters
sampler = KnockoffSampler$new(task)
# Will use the stored parameters
sampled_data = sampler$sample("x1")
if (FALSE) { # \dontrun{
# Example with sequential knockoffs (https://github.com/kormama1/seqknockoff)
task = tgen("simplex")$generate(n = 100)
sampler_seq = KnockoffSampler$new(task, knockoff_fun = seqknockoff::knockoffs_seq)
sampled_seq = sampler_seq$sample("x1")
} # }