Skip to contents

Implements conditional sampling using Adversarial Random Forests (ARF). ARF can handle mixed data types (continuous and categorical) and provides flexible conditional sampling by modeling the joint distribution.

Details

The ARFSampler fits an Adversarial Random Forest model on the task data, then uses it to generate samples from \(P(X_j | X_{-j})\) where \(X_j\) is the feature of interest and \(X_{-j}\) are the conditioning features.

References

Watson, S. D, Blesch, Kristin, Kapar, Jan, Wright, N. M (2023). “Adversarial Random Forests for Density Estimation and Generative Modeling.” In Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, 5357–5375. https://proceedings.mlr.press/v206/watson23a.html.

Blesch, Kristin, Koenen, Niklas, Kapar, Jan, Golchian, Pegah, Burk, Lukas, Loecher, Markus, Wright, N. M (2025). “Conditional Feature Importance with Generative Modeling Using Adversarial Random Forests.” Proceedings of the AAAI Conference on Artificial Intelligence, 39(15), 15596–15604. doi:10.1609/aaai.v39i15.33712 .

Public fields

feature_types

(character()) Feature types supported by the sampler. Will be checked against the provied mlr3::Task to ensure compatibility.

arf_model

Adversarial Random Forest model created by arf::adversarial_rf.

psi

Distribution parameters estimated from by arf::forde.

Methods

Inherited methods


Method new()

Creates a new instance of the ARFSampler class. To fit the ARF in parallel, register a parallel backend first (see arf::arf) and set parallel = TRUE.

Usage

ARFSampler$new(
  task,
  conditioning_set = NULL,
  num_trees = 10L,
  min_node_size = 2L,
  finite_bounds = "no",
  epsilon = 1e-15,
  round = TRUE,
  stepsize = 0,
  verbose = FALSE,
  parallel = FALSE,
  ...
)

Arguments

task

(mlr3::Task) Task to sample from.

conditioning_set

(character | NULL) Default conditioning set to use in $sample(). This parameter only affects the sampling behavior, not the ARF model fitting.

num_trees

(integer(1): 10L) Number of trees for ARF. Passed to arf::adversarial_rf.

min_node_size

(integer(1): 2L) Minimum node size for ARF. Passed to arf::adversarial_rf.

finite_bounds

(character(1): "no") How to handle variable bounds. Passed to arf::forde. Default is "no" for compatibility. "local" may improve extrapolation but can cause issues with some data.

epsilon

(numeric(1): 0) Slack parameter for when finite_bounds != "no". Passed to arf::forde.

round

(logical(1): TRUE) Whether to round continuous variables back to their original precision in sampling. Can be overridden in $sample() calls.

stepsize

(numeric(1): 0) Number of rows of evidence to process at a time when parallel is TRUE. Default (0) spreads evidence evenly over registered workers. Can be overridden in $sample() calls.

verbose

(logical(1): FALSE) Whether to print progress messages. Default is FALSE (arf's default is TRUE). Can be overridden in $sample() calls.

parallel

(logical(1): FALSE) Whether to use parallel processing via foreach. See examples in arf::forge(). Can be overridden in $sample() calls.

...

Additional arguments passed to arf::adversarial_rf.


Method sample()

Sample from stored task. Parameters conditioning_set, round, stepsize, verbose, and parallel use hierarchical resolution: function argument > stored param_set value > hard-coded default.

Usage

ARFSampler$sample(
  feature,
  row_ids = NULL,
  conditioning_set = NULL,
  round = NULL,
  stepsize = NULL,
  verbose = NULL,
  parallel = NULL
)

Arguments

feature

(character) Feature(s) to sample.

row_ids

(integer() | NULL) Row IDs to use. If NULL, uses all rows.

conditioning_set

(character | NULL) Features to condition on.

round

(logical(1) | NULL) Round continuous variables.

stepsize

(numeric(1) | NULL) Batch size for parallel processing.

verbose

(logical(1) | NULL) Print progress messages.

parallel

(logical(1) | NULL) Use parallel processing.

Returns

Modified copy with sampled feature(s).


Method sample_newdata()

Sample from external data (e.g., test set). See $sample() for parameter details.

Usage

ARFSampler$sample_newdata(
  feature,
  newdata,
  conditioning_set = NULL,
  round = NULL,
  stepsize = NULL,
  verbose = NULL,
  parallel = NULL
)

Arguments

feature

(character) Feature(s) to sample.

newdata

(data.table) External data to use.

conditioning_set

(character | NULL) Features to condition on.

round

(logical(1) | NULL) Round continuous variables.

stepsize

(numeric(1) | NULL) Batch size for parallel processing.

verbose

(logical(1) | NULL) Print progress messages.

parallel

(logical(1) | NULL) Use parallel processing.

Returns

Modified copy with sampled feature(s).


Method clone()

The objects of this class are cloneable with this method.

Usage

ARFSampler$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library(mlr3)
task = tgen("2dnormals")$generate(n = 100)
# Create sampler with default parameters
sampler = ARFSampler$new(task, conditioning_set = "x2", verbose = FALSE)
# Sample using row_ids from stored task
sampled_data = sampler$sample("x1")
# Or use external data
data = task$data()
sampled_data_ext = sampler$sample_newdata("x1", newdata = data)

# Example with custom parameters
sampler_custom = ARFSampler$new(task, round = FALSE)
sampled_custom = sampler_custom$sample("x1")