Implements conditional sampling using Adversarial Random Forests (ARF). ARF can handle mixed data types (continuous and categorical) and provides flexible conditional sampling by modeling the joint distribution.
Details
The ARFSampler fits an Adversarial Random Forest model on the task data, then uses it to generate samples from \(P(X_j | X_{-j})\) where \(X_j\) is the feature of interest and \(X_{-j}\) are the conditioning features.
References
Watson, S. D, Blesch, Kristin, Kapar, Jan, Wright, N. M (2023). “Adversarial Random Forests for Density Estimation and Generative Modeling.” In Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, 5357–5375. https://proceedings.mlr.press/v206/watson23a.html.
Blesch, Kristin, Koenen, Niklas, Kapar, Jan, Golchian, Pegah, Burk, Lukas, Loecher, Markus, Wright, N. M (2025). “Conditional Feature Importance with Generative Modeling Using Adversarial Random Forests.” Proceedings of the AAAI Conference on Artificial Intelligence, 39(15), 15596–15604. doi:10.1609/aaai.v39i15.33712 .
Super classes
xplainfi::FeatureSampler
-> xplainfi::ConditionalSampler
-> ARFSampler
Public fields
feature_types
(
character()
) Feature types supported by the sampler. Will be checked against the provied mlr3::Task to ensure compatibility.arf_model
Adversarial Random Forest model created by arf::adversarial_rf.
psi
Distribution parameters estimated from by arf::forde.
Methods
Inherited methods
Method new()
Creates a new instance of the ARFSampler class.
To fit the ARF in parallel, register a parallel backend first (see arf::arf) and set parallel = TRUE
.
Usage
ARFSampler$new(
task,
conditioning_set = NULL,
num_trees = 10L,
min_node_size = 2L,
finite_bounds = "no",
epsilon = 1e-15,
round = TRUE,
stepsize = 0,
verbose = FALSE,
parallel = FALSE,
...
)
Arguments
task
(mlr3::Task) Task to sample from.
conditioning_set
(
character
|NULL
) Default conditioning set to use in$sample()
. This parameter only affects the sampling behavior, not the ARF model fitting.num_trees
(
integer(1)
:10L
) Number of trees for ARF. Passed to arf::adversarial_rf.min_node_size
(
integer(1)
:2L
) Minimum node size for ARF. Passed to arf::adversarial_rf.finite_bounds
(
character(1)
:"no"
) How to handle variable bounds. Passed to arf::forde. Default is"no"
for compatibility."local"
may improve extrapolation but can cause issues with some data.epsilon
(
numeric(1)
:0
) Slack parameter for whenfinite_bounds != "no"
. Passed to arf::forde.round
(
logical(1)
:TRUE
) Whether to round continuous variables back to their original precision in sampling. Can be overridden in$sample()
calls.stepsize
(
numeric(1)
:0
) Number of rows of evidence to process at a time whenparallel
isTRUE
. Default (0
) spreads evidence evenly over registered workers. Can be overridden in$sample()
calls.verbose
(
logical(1)
:FALSE
) Whether to print progress messages. Default isFALSE
(arf's default isTRUE
). Can be overridden in$sample()
calls.parallel
(
logical(1)
:FALSE
) Whether to use parallel processing viaforeach
. See examples inarf::forge()
. Can be overridden in$sample()
calls....
Additional arguments passed to arf::adversarial_rf.
Method sample()
Sample from stored task. Parameters conditioning_set
, round
, stepsize
, verbose
, and parallel
use hierarchical resolution: function argument > stored param_set
value > hard-coded default.
Usage
ARFSampler$sample(
feature,
row_ids = NULL,
conditioning_set = NULL,
round = NULL,
stepsize = NULL,
verbose = NULL,
parallel = NULL
)
Arguments
feature
(
character
) Feature(s) to sample.row_ids
(
integer()
|NULL
) Row IDs to use. IfNULL
, uses all rows.conditioning_set
(
character
|NULL
) Features to condition on.round
(
logical(1)
|NULL
) Round continuous variables.stepsize
(
numeric(1)
|NULL
) Batch size for parallel processing.verbose
(
logical(1)
|NULL
) Print progress messages.parallel
(
logical(1)
|NULL
) Use parallel processing.
Method sample_newdata()
Sample from external data (e.g., test set). See $sample()
for parameter details.
Usage
ARFSampler$sample_newdata(
feature,
newdata,
conditioning_set = NULL,
round = NULL,
stepsize = NULL,
verbose = NULL,
parallel = NULL
)
Arguments
feature
(
character
) Feature(s) to sample.newdata
(
data.table
) External data to use.conditioning_set
(
character
|NULL
) Features to condition on.round
(
logical(1)
|NULL
) Round continuous variables.stepsize
(
numeric(1)
|NULL
) Batch size for parallel processing.verbose
(
logical(1)
|NULL
) Print progress messages.parallel
(
logical(1)
|NULL
) Use parallel processing.
Examples
library(mlr3)
task = tgen("2dnormals")$generate(n = 100)
# Create sampler with default parameters
sampler = ARFSampler$new(task, conditioning_set = "x2", verbose = FALSE)
# Sample using row_ids from stored task
sampled_data = sampler$sample("x1")
# Or use external data
data = task$data()
sampled_data_ext = sampler$sample_newdata("x1", newdata = data)
# Example with custom parameters
sampler_custom = ARFSampler$new(task, round = FALSE)
sampled_custom = sampler_custom$sample("x1")