The goal of xplainfi
is to collect common feature importance methods under a unified and extensible interface.
For now, it is built specifically around mlr3, as available abstractions for learners, tasks, measures, etc. greatly simplify the implementation of importance measures.
Installation
You can install the development version of xplainfi
like so:
# install.packages(pak)
pak::pak("jemus42/xplainfi")
Example: PFI
Here is a basic example on how to calculate PFI for a given learner and task, using repeated cross-validation as resampling strategy and computing PFI within each resampling 5 times:
library(xplainfi)
library(mlr3)
library(mlr3learners)
task = tsk("german_credit")
learner = lrn("classif.ranger", num.trees = 100)
measure = msr("classif.ce")
pfi = PFI$new(
task = task,
learner = learner,
measure = measure,
resampling = rsmp("repeated_cv", folds = 3, repeats = 2),
iters_perm = 5
)
Compute and print PFI scores:
pfi$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: age 9.929091e-04
#> 2: amount 1.288294e-02
#> 3: credit_history 1.218554e-02
#> 4: duration 1.598605e-02
#> 5: employment_duration 3.890717e-03
#> 6: foreign_worker -1.202700e-03
#> 7: housing -8.016999e-04
#> 8: installment_rate 3.599408e-03
#> 9: job -1.002799e-03
#> 10: number_credits -2.402103e-03
#> 11: other_debtors 5.898713e-03
#> 12: other_installment_plans -9.095922e-04
#> 13: people_liable 5.994018e-07
#> 14: personal_status_sex -1.807496e-03
#> 15: present_residence 6.944070e-04
#> 16: property 1.291111e-03
#> 17: purpose 2.486918e-03
#> 18: savings 1.819694e-02
#> 19: status 3.978829e-02
#> 20: telephone 1.293209e-03
#> feature importance
Retrieve scores later in pfi$importance
.
When PFI is computed based on resampling with multiple iterations, and / or multiple permutation iterations, the individual scores can be retrieved as a data.table
:
pfi$scores
#> Key: <feature, iter_rsmp, iter_perm>
#> feature iter_rsmp iter_perm classif.ce_orig classif.ce_perm importance
#> <char> <int> <int> <num> <num> <num>
#> 1: age 1 1 0.2095808 0.2305389 0.020958084
#> 2: age 1 2 0.2095808 0.2335329 0.023952096
#> 3: age 1 3 0.2095808 0.2275449 0.017964072
#> 4: age 1 4 0.2095808 0.2215569 0.011976048
#> 5: age 1 5 0.2095808 0.2155689 0.005988024
#> ---
#> 596: telephone 6 1 0.2612613 0.2432432 -0.018018018
#> 597: telephone 6 2 0.2612613 0.2552553 -0.006006006
#> 598: telephone 6 3 0.2612613 0.2612613 0.000000000
#> 599: telephone 6 4 0.2612613 0.2522523 -0.009009009
#> 600: telephone 6 5 0.2612613 0.2402402 -0.021021021
Where iter_rsmp
corresponds to the resampling iteration, i.e., 3 * 2 = 6 for 2 repeats of 3-fold cross-validation, and iter_perm
corresponds to the permutation iteration, 5 in this case. While pfi$importance
contains the means across all iterations, pfi$scores
allows you to manually aggregate them in any way you see fit.
In the simplest case, you run PFI with a single resampling iteration (holdout) and a single permutation iteration, and pfi$importance
will contain the same values as pfi$scores
.
pfi_single = PFI$new(
task = task,
learner = learner,
measure = measure
)
pfi_single$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: age 0.003003003
#> 2: amount 0.012012012
#> 3: credit_history 0.024024024
#> 4: duration 0.012012012
#> 5: employment_duration 0.006006006
#> 6: foreign_worker 0.000000000
#> 7: housing 0.006006006
#> 8: installment_rate 0.024024024
#> 9: job -0.003003003
#> 10: number_credits -0.003003003
#> 11: other_debtors 0.012012012
#> 12: other_installment_plans 0.006006006
#> 13: people_liable 0.009009009
#> 14: personal_status_sex 0.003003003
#> 15: present_residence 0.006006006
#> 16: property 0.003003003
#> 17: purpose 0.015015015
#> 18: savings 0.003003003
#> 19: status 0.054054054
#> 20: telephone -0.003003003
#> feature importance
pfi_single$scores
#> Key: <feature, iter_rsmp, iter_perm>
#> feature iter_rsmp iter_perm classif.ce_orig classif.ce_perm
#> <char> <int> <int> <num> <num>
#> 1: age 1 1 0.2732733 0.2762763
#> 2: amount 1 1 0.2732733 0.2852853
#> 3: credit_history 1 1 0.2732733 0.2972973
#> 4: duration 1 1 0.2732733 0.2852853
#> 5: employment_duration 1 1 0.2732733 0.2792793
#> 6: foreign_worker 1 1 0.2732733 0.2732733
#> 7: housing 1 1 0.2732733 0.2792793
#> 8: installment_rate 1 1 0.2732733 0.2972973
#> 9: job 1 1 0.2732733 0.2702703
#> 10: number_credits 1 1 0.2732733 0.2702703
#> 11: other_debtors 1 1 0.2732733 0.2852853
#> 12: other_installment_plans 1 1 0.2732733 0.2792793
#> 13: people_liable 1 1 0.2732733 0.2822823
#> 14: personal_status_sex 1 1 0.2732733 0.2762763
#> 15: present_residence 1 1 0.2732733 0.2792793
#> importance
#> <num>
#> 1: 0.003003003
#> 2: 0.012012012
#> 3: 0.024024024
#> 4: 0.012012012
#> 5: 0.006006006
#> 6: 0.000000000
#> 7: 0.006006006
#> 8: 0.024024024
#> 9: -0.003003003
#> 10: -0.003003003
#> 11: 0.012012012
#> 12: 0.006006006
#> 13: 0.009009009
#> 14: 0.003003003
#> 15: 0.006006006
#> [ reached getOption("max.print") -- omitted 6 rows ]