Implements conditional sampling using Adversarial Random Forests (ARF). ARF can handle mixed data types (continuous and categorical) and provides flexible conditional sampling by modeling the joint distribution.
Details
The ARFSampler fits an Adversarial Random Forest model on the task data, then uses it to generate samples from \(P(X_j | X_{-j})\) where \(X_j\) is the feature of interest and \(X_{-j}\) are the conditioning features.
References
Watson, S. D, Blesch, Kristin, Kapar, Jan, Wright, N. M (2023). “Adversarial Random Forests for Density Estimation and Generative Modeling.” In Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, 5357–5375. https://proceedings.mlr.press/v206/watson23a.html.
Blesch, Kristin, Koenen, Niklas, Kapar, Jan, Golchian, Pegah, Burk, Lukas, Loecher, Markus, Wright, N. M (2025). “Conditional Feature Importance with Generative Modeling Using Adversarial Random Forests.” Proceedings of the AAAI Conference on Artificial Intelligence, 39(15), 15596–15604. doi:10.1609/aaai.v39i15.33712 .
Super classes
xplainfi::FeatureSampler
-> xplainfi::ConditionalSampler
-> ARFSampler
Public fields
arf_model
Adversarial Random Forest model
psi
Distribution parameters estimated from ARF
Methods
Method new()
Creates a new instance of the ARFSampler class.
To fit the ARF in parallel, set arf_args = list(parallel = TRUE)
and register a parallel backend (see arf::arf).
Usage
ARFSampler$new(
task,
conditioning_set = NULL,
finite_bounds = "no",
round = TRUE,
stepsize = 0,
verbose = FALSE,
parallel = FALSE,
arf_args = NULL
)
Arguments
task
(mlr3::Task) Task to sample from
conditioning_set
(
character
|NULL
) Default conditioning set to use in$sample()
. This parameter only affects the sampling behavior, not the ARF model fitting.finite_bounds
(
character(1)
:"no"
) Passed toarf::forde()
. Default is"no"
for compatibility."local"
may improve extrapolation but can cause issues with some data.round
(
logical(1)
:TRUE
) Whether to round continuous variables back to their original precision.stepsize
(
numeric(1)
:0
) Number of rows of evidence to process at a time wehnparallel
isTRUE
. Default (0
) spreads evidence evenly over registered workers.verbose
(
logical(1)
:FALSE
) Whether to print progress messages. Default isFALSE
but default inarf
isTRUE
.parallel
(
logical(1)
:FALSE
) Whether to use parallel processing viaforeach
. See examples inarf::forge()
.arf_args
(list) Additional passed to
arf::adversarial_rf
.
Method sample()
Sample values for feature(s) conditionally on other features using ARF
Usage
ARFSampler$sample(
feature,
data = self$task$data(),
conditioning_set = NULL,
round = NULL,
stepsize = NULL,
verbose = NULL,
parallel = NULL,
...
)
Arguments
feature
(
character
) Feature(s) of interest to sample (can be single or multiple)data
(
data.table
) Data containing conditioning features. Defaults to$task$data()
, but typically a dedicated test set is provided.conditioning_set
(
character(n) | NULL
) Features to condition on. IfNULL
, uses the stored parameter if available, otherwise defaults to all other features.round
(
logical(1) | NULL
) Whether to round continuous variables. IfNULL
, uses the stored parameter value.stepsize
(
numeric(1) | NULL
) Step size for variance adjustment. IfNULL
, uses the stored parameter value.verbose
(
logical(1) | NULL
) Whether to print progress messages. IfNULL
, uses the stored parameter value.parallel
(
logical(1) | NULL
) Whether to use parallel processing. IfNULL
, uses the stored parameter value....
Further arguments passed to
arf::forge()
.
Examples
library(mlr3)
task = tgen("2dnormals")$generate(n = 100)
# Create sampler with default parameters
sampler = ARFSampler$new(task, conditioning_set = "x2", verbose = FALSE)
data = task$data()
# Will use the stored parameters
sampled_data = sampler$sample("x1", data)
# Example with custom parameters
sampler_custom = ARFSampler$new(task, round = FALSE)
sampled_custom = sampler_custom$sample("x1", data)