Skip to contents
library(xplainfi)
library(mlr3learners)
#> Loading required package: mlr3

# Data manip and visualization
library(data.table)
library(ggplot2)

There are multiple (work in progress) inference possible with the underlying implementation, but the API around them is still being worked out.

Setup

We use a simple linear DGP for demonstration purposes where

  • \(X_1\) and \(X_2\) are strongly correlated (r = 0.7)
  • \(X_1\) and \(X_3\) has an effect on Y
  • \(X_2\) and \(X_4\) don’t have an effect
task = sim_dgp_correlated(n = 2000, r = 0.7)
learner = lrn("regr.ranger", num.trees = 500)
measure = msr("regr.mse")

DAG for correlated features DGP

Variance-correction

When we calculate PFI using an appropriate resampling, such as subsampling with 15 repeats, we can use the approach recommended by Molnar et al. (2023) based on the proposed correction by Nadeay & Bengio (2003).

By default, any importance measures’ $importance() method will not output any variances or confidence intervals, it will merely compute averages over resampling iterations and repeats within resamplings (iter_perm here).

pfi = PFI$new(
    task = task,
    learner = learner,
    resampling = rsmp("subsampling", repeats = 15),
    measure = measure,
    iters_perm = 10 # for stability within resampling iters
)

pfi$compute()
pfi$importance()
#> Key: <feature>
#>    feature   importance
#>     <char>        <num>
#> 1:      x1  6.444709393
#> 2:      x2  0.097777103
#> 3:      x3  1.801488945
#> 4:      x4 -0.001099328

If we want unadjusted confidence intervals we can ask for them, but note these are too narrow / optimistic and hence invalid for inference:

pfi_ci_raw = pfi$importance(ci_method = "raw")
pfi_ci_raw
#> Key: <feature>
#>    feature   importance           se   conf_lower   conf_upper
#>     <char>        <num>        <num>        <num>        <num>
#> 1:      x1  6.444709393 0.0710846455  6.292247992  6.597170794
#> 2:      x2  0.097777103 0.0040374343  0.089117668  0.106436538
#> 3:      x3  1.801488945 0.0163236399  1.766478220  1.836499671
#> 4:      x4 -0.001099328 0.0003169572 -0.001779133 -0.000419522

Analogously we can retrieve the Nadeau & Bengio-adjusted standard errors and derived confidence intervals which were demonstrated to have better (but still imperfect) coverage:

pfi_ci_corrected = pfi$importance(ci_method = "nadeau_bengio")
pfi_ci_corrected
#> Key: <feature>
#>    feature   importance           se  conf_lower   conf_upper
#>     <char>        <num>        <num>       <num>        <num>
#> 1:      x1  6.444709393 0.2073141539  6.00006476 6.8893540305
#> 2:      x2  0.097777103 0.0117749378  0.07252237 0.1230318328
#> 3:      x3  1.801488945 0.0476069279  1.69938224 1.9035956504
#> 4:      x4 -0.001099328 0.0009243869 -0.00308194 0.0008832851

Non-parametric alternative: Empirical quantiles

Both "raw" and "nadeau_bengio" methods assume normally distributed importance scores and use parametric confidence intervals based on the t-distribution. As a non-parametric alternative, we can use empirical quantiles that directly use the resampling distribution to construct confidence-like intervals.

Like the Nadeau & Bengio approach, this method requires independent resampling splits. This means it works with subsampling or bootstrap, but not with cross-validation.

pfi_ci_quantile = pfi$importance(ci_method = "quantile")
pfi_ci_quantile
#> Key: <feature>
#>    feature   importance   conf_lower   conf_upper
#>     <char>        <num>        <num>        <num>
#> 1:      x1  6.444709393  5.981573776 6.8770935941
#> 2:      x2  0.097777103  0.075868045 0.1226624975
#> 3:      x3  1.801488945  1.692275614 1.8953539522
#> 4:      x4 -0.001099328 -0.003684845 0.0005401245

To highlight the differences between parametric and non-parametric approaches, we visualize all methods:

pfi_cis = rbindlist(
    list(
        pfi_ci_raw[, type := "raw"],
        pfi_ci_corrected[, type := "nadeau_bengio"],
        pfi_ci_quantile[, type := "quantile"]
    ),
    fill = TRUE
)

ggplot(pfi_cis, aes(y = feature, color = type)) +
    geom_errorbar(
        aes(xmin = conf_lower, xmax = conf_upper),
        position = position_dodge(width = 0.6),
        width = .5
    ) +
    geom_point(aes(x = importance), position = position_dodge(width = 0.6)) +
    scale_color_brewer(palette = "Set2") +
    labs(
        title = "Parametric & non-parametric CI methods",
        subtitle = "RF with 15 subsampling iterations",
        color = NULL
    ) +
    theme_minimal(base_size = 14) +
    theme(legend.position = "bottom")

The results highlight just how optimistic the unadjusted, raw confidence intervals are.

Conditional predictive impact (CPI)

CPI is implemented by the cpi package, and provides conditional variable importance using knockoffs. It works with mlr3 and its output on our data looks like this:

library(cpi)

resampling = rsmp("cv", folds = 5)
resampling$instantiate(task)
cpi_res = cpi(
    task = task,
    learner = learner,
    resampling = resampling,
    measure = measure,
    test = "t"
)
setDT(cpi_res)
setnames(cpi_res, "Variable", "feature")
cpi_res[, method := "CPI"]

cpi_res
#>    feature           CPI           SE   test statistic      estimate
#>     <char>         <num>        <num> <char>     <num>         <num>
#> 1:      x1  4.4905247594 0.1393416714      t 32.226718  4.4905247594
#> 2:      x2 -0.0026748965 0.0023132871      t -1.156318 -0.0026748965
#> 3:      x3  1.7350915247 0.0551007024      t 31.489463  1.7350915247
#> 4:      x4 -0.0009994758 0.0009142715      t -1.093194 -0.0009994758
#>          p.value        ci.lo method
#>            <num>        <num> <char>
#> 1: 3.644769e-184  4.261221841    CPI
#> 2:  8.761554e-01 -0.006481679    CPI
#> 3: 2.156035e-177  1.644416914    CPI
#> 4:  8.627797e-01 -0.002504016    CPI

CPI with knockoffs

Since xplainfi also includes knockoffs via the KnockoffSampler and the KnockoffGaussianSampler, the latter implementing the second order Gaussian knockoffs also used by default in cpi, we can recreate its results using CFI with the corresponding sampler.

CFI with a knockoff sampler supports CPI inference directly via ci_method = "cpi":

knockoff_gaussian = KnockoffGaussianSampler$new(task)

cfi = CFI$new(
    task = task,
    learner = learner,
    resampling = resampling,
    measure = measure,
    sampler = knockoff_gaussian
)

cfi$compute()

# CPI uses observation-wise losses with one-sided t-test
cfi_cpi_res = cfi$importance(ci_method = "cpi")
cfi_cpi_res
#> Key: <feature>
#>    feature    importance           se   statistic       p.value   conf_lower
#>     <char>         <num>        <num>       <num>         <num>        <num>
#> 1:      x1  4.6485211373 0.1412096720 32.91928289 1.430822e-190  4.416144207
#> 2:      x2 -0.0002372201 0.0029839960 -0.07949746  5.316775e-01 -0.005147732
#> 3:      x3  1.8639780917 0.0561726066 33.18304428 5.072022e-193  1.771539538
#> 4:      x4 -0.0002621368 0.0007881572 -0.33259452  6.302624e-01 -0.001559141
#>    conf_upper
#>         <num>
#> 1:        Inf
#> 2:        Inf
#> 3:        Inf
#> 4:        Inf

# Rename columns to match cpi package output for comparison
setnames(cfi_cpi_res, c("importance", "conf_lower"), c("CPI", "ci.lo"))
cfi_cpi_res[, method := "CFI+Knockoffs"]

The results should be very similar to those computed by cpi(), so let’s compare them:

rbindlist(list(cpi_res, cfi_cpi_res), fill = TRUE) |>
    ggplot(aes(y = feature, x = CPI, color = method)) +
    geom_point(position = position_dodge(width = 0.3)) +
    geom_errorbar(
        aes(xmin = CPI, xmax = ci.lo),
        position = position_dodge(width = 0.3),
        width = 0.5
    ) +
    scale_color_brewer(palette = "Dark2") +
    labs(
        title = "CPI and CFI with Knockoff sampler",
        subtitle = "RF with 5-fold CV",
        color = NULL
    ) +
    theme_minimal(base_size = 14) +
    theme(legend.position = "top")

A noteable caveat of the knockoff approach is that they are not readily available for mixed data (with categorical features).

CPI with ARF

An alternative is available using ARF as conditional sampler rather than knockoffs (CITE cARFi), which we can perform analogously:

arf_sampler = ARFSampler$new(
    task = task,
    finite_bounds = "local",
    min_node_size = 20,
    epsilon = 1e-15
)

cfi_arf = CFI$new(
    task = task,
    learner = learner,
    resampling = resampling,
    measure = measure,
    sampler = arf_sampler
)

cfi_arf$compute()

# CPI uses observation-wise losses with one-sided t-test
cfi_arf_res = cfi_arf$importance(ci_method = "cpi")
cfi_arf_res
#> Key: <feature>
#>    feature   importance           se statistic       p.value    conf_lower
#>     <char>        <num>        <num>     <num>         <num>         <num>
#> 1:      x1  4.149913006 0.1377986710 30.115769 6.416347e-165  3.9231492740
#> 2:      x2  0.005623742 0.0036623984  1.535535  6.240533e-02 -0.0004031606
#> 3:      x3  1.703418858 0.0547593855 31.107341 6.653459e-174  1.6133059238
#> 4:      x4 -0.002416235 0.0007663808 -3.152786  9.991794e-01 -0.0036774039
#>    conf_upper
#>         <num>
#> 1:        Inf
#> 2:        Inf
#> 3:        Inf
#> 4:        Inf

# Rename columns to match cpi package output for comparison
setnames(cfi_arf_res, c("importance", "conf_lower"), c("CPI", "ci.lo"))
cfi_arf_res[, method := "CFI+ARF"]

We can now compare all three methods:

rbindlist(list(cpi_res, cfi_cpi_res, cfi_arf_res), fill = TRUE) |>
    ggplot(aes(y = feature, x = CPI, color = method)) +
    geom_point(position = position_dodge(width = 0.3)) +
    geom_errorbar(
        aes(xmin = CPI, xmax = ci.lo),
        position = position_dodge(width = 0.3),
        width = 0.5
    ) +
    scale_color_brewer(palette = "Dark2") +
    labs(
        title = "CPI and CFI with Knockoffs and ARF",
        subtitle = "RF with 5-fold CV",
        color = NULL
    ) +
    theme_minimal(base_size = 14) +
    theme(legend.position = "top")

As expected, the ARF-based approach differs more from both knockoff-based approaches, but they are all roughly in agreement.

Note: xplainfi will gain a dedicated interface to perform CPI, but the API is yet to be worked out.

LOCO (WIP)

(CITATION) proposed inference for LOCO using the median absolute differences of the baseline- and post-refit loss differences

\[ \theta_j = \mathrm{med}\left( |Y - \hat{f}_{n_1}^{-j}(X)| - |Y - \hat{f}_{n_1}(X)| \big| D_1 \right) \]

If we apply LOCO as implemented in xplainfi using the median absolute error (MAE) as our measure including the median as the aggregation function, we unfortunately get something else, though:

measure_mae = msr("regr.mae")
measure_mae$aggregator = median

loco = LOCO$new(
    task = task,
    learner = learner,
    resampling = rsmp("cv", folds = 3),
    measure = measure_mae
)

loco$compute()
loco$importance()
#> Key: <feature>
#>    feature importance
#>     <char>      <num>
#> 1:      x1 0.99962897
#> 2:      x2 0.04442239
#> 3:      x3 0.63009668
#> 4:      x4 0.04054127

This is not exactly what the authors propose, because $score() calculates the aggregation function (median) for each resampling iteration first, and takes the difference afterwards, i.e.

\[ \theta_j = \mathrm{med}\left(|Y - \hat{f}_{n_1}^{-j}(X)|\right) - \mathrm{med}\left(|Y - \hat{f}_{n_1}(X)| \big| D_1 \right) \]

In the default case where the arithemtic mean is used, it does not matter whether we calculate the difference of the means or the mean of the differences, but using the median it does.

We can, however, reconstruct it by using the observation-wise losses (in this case, the absolute error):

loco_obsloss = loco$obs_loss()
head(loco_obsloss)
#>    feature iter_rsmp iter_refit row_ids loss_baseline loss_post obs_importance
#>     <char>     <int>      <num>   <int>         <num>     <num>          <num>
#> 1:      x1         1          1       5    0.26081483 1.0828744      0.8220595
#> 2:      x1         1          1       8    0.07863945 1.4662680      1.3876286
#> 3:      x1         1          1      18    0.12976167 0.9243277      0.7945660
#> 4:      x1         1          1      19    0.14154518 1.3045495      1.1630043
#> 5:      x1         1          1      20    0.25832021 0.4959166      0.2375964
#> 6:      x1         1          1      21    0.18253098 0.7409787      0.5584477

obs_importance here refers to the difference loss_post - loss_baseline, so

  • loss_baseline $ = |Y - _{n_1}(X)|$
  • loss_post $ = |Y - _{n_1}^{-j}(X)|$
  • obs_importance = loss_post - loss_baseline

Which means by taking the median for each feature \(j\) within each resampling iteration, we can construct \(\theta_j(D_1)\) as proposed, for each set \(D_k\) where \(k\) is the resampling iteration:

loco_thetas = loco_obsloss[, list(theta = median(obs_importance)), by = c("feature", "iter_rsmp")]
loco_thetas
#>     feature iter_rsmp      theta
#>      <char>     <int>      <num>
#>  1:      x1         1 0.89502517
#>  2:      x1         2 0.86756440
#>  3:      x1         3 0.77528481
#>  4:      x2         1 0.02726682
#>  5:      x2         2 0.02843028
#>  6:      x2         3 0.01773224
#>  7:      x3         1 0.51379135
#>  8:      x3         2 0.54167072
#>  9:      x3         3 0.56786544
#> 10:      x4         1 0.01789627
#> 11:      x4         2 0.03044097
#> 12:      x4         3 0.02302404

The authors then propose to construct distribution-free confidence intervals, e.g. using a sign- or Wilcoxon test We can for example use [wilcoxon.test] to compute confidence intervals around the estimated pseudo-median:

loco_wilcox_ci = loco_obsloss[,
    {
        tt <- wilcox.test(
            obs_importance,
            conf.int = TRUE,
            conf.level = 0.95
        )
        .(
            statistic = tt$statistic,
            estimate = tt$estimate, # the pseudomedian importance
            p.value = tt$p.value,
            conf_lower = tt$conf.int[1],
            conf_upper = tt$conf.int[2]
        )
    },
    by = feature
]

loco_wilcox_ci
#>    feature statistic   estimate       p.value conf_lower conf_upper
#>     <char>     <num>      <num>         <num>      <num>      <num>
#> 1:      x1   1943253 0.92960910 1.151861e-291 0.88933069 0.97052252
#> 2:      x2   1262272 0.03119993  3.880922e-24 0.02514172 0.03730949
#> 3:      x3   1891407 0.59426242 1.067356e-260 0.56662125 0.62207703
#> 4:      x4   1351954 0.03074982  3.657070e-42 0.02623189 0.03538269

Note: The above approach needs checking with the literature to ensure it’s actually corresponding to what was proposed and the results are valid.

The main point of this section is to illustrate that the availability of the intermediate parts (i.e. obs losses) and flexibility regarding the used measure allows for flexibility in terms of inference.