SHAP values are used to cluster data samples using the k-means method to identify subgroups of individuals with specific patterns of feature contributions.

SHAPclust(
  task,
  trained_model,
  splits,
  shap_Mean_wide,
  shap_Mean_long,
  num_of_clusters = 4,
  seed = 246,
  subset = 1,
  algorithm = "Hartigan-Wong",
  iter.max = 1000
)

Arguments

task

an mlr3 task for binary classification

trained_model

an mlr3 trained learner object

splits

an mlr3 object defining data splits for train and test sets

shap_Mean_wide

the data frame of SHAP values in wide format from eSHAP_plot.R

shap_Mean_long

the data frame of SHAP values in long format from eSHAP_plot.R

num_of_clusters

number of clusters to make based on SHAP values, default: 4

seed

an integer for reproducibility, Default to 246

subset

what percentage of the instances to use from 0 to 1 where 1 means all

algorithm

k-means algorithm character: "Hartigan-Wong", "Lloyd", "Forgy", "MacQueen".

iter.max

maximum number of iterations allowed

Value

A list containing four elements:

shap_plot_onerow

An interactive plot displaying the SHAP values for each feature, clustered by the specified number of clusters. Each cluster is shown in a facet.

combined_plot

A ggplot2 figure combining confusion matrices for each cluster, providing insights into the model's performance within each identified subgroup.

kmeans_fvals_desc

A summary table containing statistical descriptions of the clusters based on feature values.

shap_Mean_wide_kmeans

A data frame containing clustered SHAP values along with predictions and ground truth information.

kmeans_info

Information about the k-means clustering process, including cluster centers and assignment details.

References

Zargari Marandi, R., 2024. ExplaineR: an R package to explain machine learning models. Bioinformatics advances, 4(1), p.vbae049, https://doi.org/10.1093/bioadv/vbae049.

See also

Other functions to visualize and interpret machine learning models: eSHAP_plot.

Examples

# \donttest{
library("explainer")
seed <- 246
set.seed(seed)
# Load necessary packages
if (!requireNamespace("mlbench", quietly = TRUE)) stop("mlbench not installed.")
if (!requireNamespace("mlr3learners", quietly = TRUE)) stop("mlr3learners not installed.")
if (!requireNamespace("ranger", quietly = TRUE)) stop("ranger not installed.")
# Load BreastCancer dataset
utils::data("BreastCancer", package = "mlbench")
target_col <- "Class"
positive_class <- "malignant"
mydata <- BreastCancer[, -1]
mydata <- na.omit(mydata)
sex <- sample(
  c("Male", "Female"),
  size = nrow(mydata),
  replace = TRUE
)
mydata$age <- as.numeric(sample(
  seq(18,60),
  size = nrow(mydata),
  replace = TRUE
))
mydata$sex <- factor(
  sex,
  levels = c("Male", "Female"),
  labels = c(1, 0)
)
maintask <- mlr3::TaskClassif$new(
  id = "my_classification_task",
  backend = mydata,
  target = target_col,
  positive = positive_class
)
splits <- mlr3::partition(maintask)
mylrn <- mlr3::lrn(
  "classif.ranger",
  predict_type = "prob"
)
mylrn$train(maintask, splits$train)
SHAP_output <- eSHAP_plot(
  task = maintask,
  trained_model = mylrn,
  splits = splits,
  sample.size = 2, # also 30 or more
  seed = seed,
  subset = 0.02 # up to 1
)
#> Warning: Ignoring unknown aesthetics: text
shap_Mean_wide <- SHAP_output[[2]]
shap_Mean_long <- SHAP_output[[3]]
SHAP_plot_clusters <- SHAPclust(
  task = maintask,
  trained_model = mylrn,
  splits = splits,
  shap_Mean_wide = shap_Mean_wide,
  shap_Mean_long = shap_Mean_long,
  num_of_clusters = 3, # your choice
  seed = seed,
  subset = 0.02, # match with eSHAP_plot
  algorithm="Hartigan-Wong",
  iter.max = 10
)
#>     sample_num         feature           Phi cluster     mean_phi     f_val
#>  1:          1     Bare.nuclei -1.269421e-01       2 0.0509097619 0.0000000
#>  2:          1     Bl.cromatin -6.999405e-02       2 0.0414809524 0.2500000
#>  3:          1      Cell.shape -1.032143e-02       2 0.0448956349 0.2222222
#>  4:          1       Cell.size -1.649984e-01       2 0.0500060317 0.0000000
#>  5:          1    Cl.thickness -1.129325e-02       2 0.0100390476 0.0000000
#>  6:          1    Epith.c.size -2.531587e-02       2 0.0093425397 0.2500000
#>  7:          1   Marg.adhesion -6.559524e-03       2 0.0143293651 0.2500000
#>  8:          1         Mitoses -1.144444e-03       2 0.0047004762 0.0000000
#>  9:          1 Normal.nucleoli -7.803810e-02       2 0.0322451587 0.0000000
#> 10:          1             age -3.290873e-03       2 0.0096596825 0.4782609
#> 11:          1             sex  0.000000e+00       2 0.0005798413 1.0000000
#> 12:          2     Bare.nuclei  0.000000e+00       1 0.0509097619 0.0000000
#> 13:          2     Bl.cromatin -7.762302e-03       1 0.0414809524 0.0000000
#> 14:          2      Cell.shape  1.086937e-01       1 0.0448956349 0.3333333
#> 15:          2       Cell.size  7.969444e-03       1 0.0500060317 0.1250000
#> 16:          2    Cl.thickness -7.388889e-03       1 0.0100390476 0.5000000
#> 17:          2    Epith.c.size  0.000000e+00       1 0.0093425397 0.0000000
#> 18:          2   Marg.adhesion  0.000000e+00       1 0.0143293651 0.0000000
#> 19:          2         Mitoses  0.000000e+00       1 0.0047004762 0.0000000
#> 20:          2 Normal.nucleoli  0.000000e+00       1 0.0322451587 0.0000000
#> 21:          2             age  8.115079e-03       1 0.0096596825 1.0000000
#> 22:          2             sex  0.000000e+00       1 0.0005798413 0.0000000
#> 23:          3     Bare.nuclei  2.711071e-02       1 0.0509097619 0.8888889
#> 24:          3     Bl.cromatin -4.012897e-02       1 0.0414809524 0.5000000
#> 25:          3      Cell.shape  4.722222e-05       1 0.0448956349 1.0000000
#> 26:          3       Cell.size -7.780556e-03       1 0.0500060317 1.0000000
#> 27:          3    Cl.thickness  2.715198e-02       1 0.0100390476 1.0000000
#> 28:          3    Epith.c.size -6.109524e-03       1 0.0093425397 1.0000000
#> 29:          3   Marg.adhesion  2.766667e-03       1 0.0143293651 1.0000000
#> 30:          3         Mitoses  8.369444e-03       1 0.0047004762 1.0000000
#> 31:          3 Normal.nucleoli  2.432222e-02       1 0.0322451587 1.0000000
#> 32:          3             age -1.735516e-02       1 0.0096596825 0.0000000
#> 33:          3             sex  1.400000e-03       1 0.0005798413 1.0000000
#> 34:          4     Bare.nuclei  0.000000e+00       1 0.0509097619 0.0000000
#> 35:          4     Bl.cromatin  0.000000e+00       1 0.0414809524 0.0000000
#> 36:          4      Cell.shape  0.000000e+00       1 0.0448956349 0.0000000
#> 37:          4       Cell.size  6.726984e-03       1 0.0500060317 0.2500000
#> 38:          4    Cl.thickness -8.333333e-05       1 0.0100390476 0.1666667
#> 39:          4    Epith.c.size  0.000000e+00       1 0.0093425397 0.2500000
#> 40:          4   Marg.adhesion  0.000000e+00       1 0.0143293651 0.0000000
#> 41:          4         Mitoses  0.000000e+00       1 0.0047004762 0.0000000
#> 42:          4 Normal.nucleoli  0.000000e+00       1 0.0322451587 0.0000000
#> 43:          4             age -3.400000e-03       1 0.0096596825 0.3478261
#> 44:          4             sex  0.000000e+00       1 0.0005798413 1.0000000
#> 45:          5     Bare.nuclei  1.004960e-01       3 0.0509097619 1.0000000
#> 46:          5     Bl.cromatin  8.951944e-02       3 0.0414809524 1.0000000
#> 47:          5      Cell.shape  1.054159e-01       3 0.0448956349 1.0000000
#> 48:          5       Cell.size  6.255476e-02       3 0.0500060317 1.0000000
#> 49:          5    Cl.thickness  4.277778e-03       3 0.0100390476 0.6666667
#> 50:          5    Epith.c.size  1.528730e-02       3 0.0093425397 1.0000000
#> 51:          5   Marg.adhesion  6.232063e-02       3 0.0143293651 1.0000000
#> 52:          5         Mitoses  1.398849e-02       3 0.0047004762 1.0000000
#> 53:          5 Normal.nucleoli  5.886548e-02       3 0.0322451587 1.0000000
#> 54:          5             age -1.613730e-02       3 0.0096596825 0.1304348
#> 55:          5             sex  1.499206e-03       3 0.0005798413 1.0000000
#>     sample_num         feature           Phi cluster     mean_phi     f_val
#>     unscaled_f_val correct_prediction   pred_prob pred_class
#>  1:              1            Correct 0.002000000     benign
#>  2:              2            Correct 0.002000000     benign
#>  3:              3            Correct 0.002000000     benign
#>  4:              1            Correct 0.002000000     benign
#>  5:              2            Correct 0.002000000     benign
#>  6:              2            Correct 0.002000000     benign
#>  7:              2            Correct 0.002000000     benign
#>  8:              1            Correct 0.002000000     benign
#>  9:              1            Correct 0.002000000     benign
#> 10:             38            Correct 0.002000000     benign
#> 11:              2            Correct 0.002000000     benign
#> 12:              1            Correct 0.109626984     benign
#> 13:              1            Correct 0.109626984     benign
#> 14:              4            Correct 0.109626984     benign
#> 15:              2            Correct 0.109626984     benign
#> 16:              5            Correct 0.109626984     benign
#> 17:              1            Correct 0.109626984     benign
#> 18:              1            Correct 0.109626984     benign
#> 19:              1            Correct 0.109626984     benign
#> 20:              1            Correct 0.109626984     benign
#> 21:             59            Correct 0.109626984     benign
#> 22:              1            Correct 0.109626984     benign
#> 23:              9            Correct 0.935619048  malignant
#> 24:              3            Correct 0.935619048  malignant
#> 25:             10            Correct 0.935619048  malignant
#> 26:             10            Correct 0.935619048  malignant
#> 27:              8            Correct 0.935619048  malignant
#> 28:              6            Correct 0.935619048  malignant
#> 29:              8            Correct 0.935619048  malignant
#> 30:              9            Correct 0.935619048  malignant
#> 31:             10            Correct 0.935619048  malignant
#> 32:             27            Correct 0.935619048  malignant
#> 33:              2            Correct 0.935619048  malignant
#> 34:              1            Correct 0.003243651     benign
#> 35:              1            Correct 0.003243651     benign
#> 36:              1            Correct 0.003243651     benign
#> 37:              3            Correct 0.003243651     benign
#> 38:              3            Correct 0.003243651     benign
#> 39:              2            Correct 0.003243651     benign
#> 40:              1            Correct 0.003243651     benign
#> 41:              1            Correct 0.003243651     benign
#> 42:              1            Correct 0.003243651     benign
#> 43:             35            Correct 0.003243651     benign
#> 44:              2            Correct 0.003243651     benign
#> 45:             10            Correct 0.995395238  malignant
#> 46:              8            Correct 0.995395238  malignant
#> 47:             10            Correct 0.995395238  malignant
#> 48:             10            Correct 0.995395238  malignant
#> 49:              6            Correct 0.995395238  malignant
#> 50:             10            Correct 0.995395238  malignant
#> 51:             10            Correct 0.995395238  malignant
#> 52:              9            Correct 0.995395238  malignant
#> 53:             10            Correct 0.995395238  malignant
#> 54:             30            Correct 0.995395238  malignant
#> 55:              2            Correct 0.995395238  malignant
#>     unscaled_f_val correct_prediction   pred_prob pred_class
#> Warning: Ignoring unknown aesthetics: text
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: no non-missing arguments to max; returning -Inf
#> Warning: Computation failed in `stat_ydensity()`:
#> replacement has 1 row, data has 0
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: no non-missing arguments to max; returning -Inf
#> Warning: Computation failed in `stat_ydensity()`:
#> replacement has 1 row, data has 0
# }