Skip to contents

Implements conditional sampling using Adversarial Random Forests (ARF). ARF can handle mixed data types (continuous and categorical) and provides flexible conditional sampling by modeling the joint distribution.

Details

The ARFSampler fits an Adversarial Random Forest model on the task data, then uses it to generate samples from \(P(X_j | X_{-j})\) where \(X_j\) is the feature of interest and \(X_{-j}\) are the conditioning features.

References

Watson, S. D, Blesch, Kristin, Kapar, Jan, Wright, N. M (2023). “Adversarial Random Forests for Density Estimation and Generative Modeling.” In Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, 5357–5375. https://proceedings.mlr.press/v206/watson23a.html.

Blesch, Kristin, Koenen, Niklas, Kapar, Jan, Golchian, Pegah, Burk, Lukas, Loecher, Markus, Wright, N. M (2025). “Conditional Feature Importance with Generative Modeling Using Adversarial Random Forests.” Proceedings of the AAAI Conference on Artificial Intelligence, 39(15), 15596–15604. doi:10.1609/aaai.v39i15.33712 .

Public fields

arf_model

Adversarial Random Forest model

psi

Distribution parameters estimated from ARF

Methods


Method new()

Creates a new instance of the ARFSampler class. To fit the ARF in parallel, set arf_args = list(parallel = TRUE) and register a parallel backend (see arf::arf).

Usage

ARFSampler$new(
  task,
  conditioning_set = NULL,
  finite_bounds = "no",
  round = TRUE,
  stepsize = 0,
  verbose = FALSE,
  parallel = FALSE,
  arf_args = NULL
)

Arguments

task

(mlr3::Task) Task to sample from

conditioning_set

(character | NULL) Default conditioning set to use in $sample(). This parameter only affects the sampling behavior, not the ARF model fitting.

finite_bounds

(character(1): "no") Passed to arf::forde(). Default is "no" for compatibility. "local" may improve extrapolation but can cause issues with some data.

round

(logical(1): TRUE) Whether to round continuous variables back to their original precision.

stepsize

(numeric(1): 0) Number of rows of evidence to process at a time wehn parallel is TRUE. Default (0) spreads evidence evenly over registered workers.

verbose

(logical(1): FALSE) Whether to print progress messages. Default is FALSE but default in arf is TRUE.

parallel

(logical(1): FALSE) Whether to use parallel processing via foreach. See examples in arf::forge().

arf_args

(list) Additional passed to arf::adversarial_rf.


Method sample()

Sample values for feature(s) conditionally on other features using ARF

Usage

ARFSampler$sample(
  feature,
  data = self$task$data(),
  conditioning_set = NULL,
  round = NULL,
  stepsize = NULL,
  verbose = NULL,
  parallel = NULL,
  ...
)

Arguments

feature

(character) Feature(s) of interest to sample (can be single or multiple)

data

(data.table) Data containing conditioning features. Defaults to $task$data(), but typically a dedicated test set is provided.

conditioning_set

(character(n) | NULL) Features to condition on. If NULL, uses the stored parameter if available, otherwise defaults to all other features.

round

(logical(1) | NULL) Whether to round continuous variables. If NULL, uses the stored parameter value.

stepsize

(numeric(1) | NULL) Step size for variance adjustment. If NULL, uses the stored parameter value.

verbose

(logical(1) | NULL) Whether to print progress messages. If NULL, uses the stored parameter value.

parallel

(logical(1) | NULL) Whether to use parallel processing. If NULL, uses the stored parameter value.

...

Further arguments passed to arf::forge().

Returns

Modified copy of the input data with the feature(s) sampled conditionally


Method clone()

The objects of this class are cloneable with this method.

Usage

ARFSampler$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library(mlr3)
task = tgen("2dnormals")$generate(n = 100)
# Create sampler with default parameters
sampler = ARFSampler$new(task, conditioning_set = "x2", verbose = FALSE)
data = task$data()
# Will use the stored parameters
sampled_data = sampler$sample("x1", data)

# Example with custom parameters
sampler_custom = ARFSampler$new(task, round = FALSE)
sampled_custom = sampler_custom$sample("x1", data)