Skip to contents

SHAP values are used to cluster data samples using the k-means method to identify subgroups of individuals with specific patterns of feature contributions.

Usage

SHAPclust(
  task,
  trained_model,
  splits,
  shap_Mean_wide,
  shap_Mean_long,
  num_of_clusters = 4,
  seed = 246,
  subset = 1,
  algorithm = "Hartigan-Wong",
  iter.max = 1000
)

Arguments

task

an mlr3 task for binary classification

trained_model

an mlr3 trained learner object

splits

an mlr3 object defining data splits for train and test sets

shap_Mean_wide

the data frame of SHAP values in wide format from eSHAP_plot.R

shap_Mean_long

the data frame of SHAP values in long format from eSHAP_plot.R

num_of_clusters

number of clusters to make based on SHAP values, default: 4

seed

an integer for reproducibility, Default to 246

subset

what percentage of the instances to use from 0 to 1 where 1 means all

algorithm

k-means algorithm character: "Hartigan-Wong", "Lloyd", "Forgy", "MacQueen".

iter.max

maximum number of iterations allowed

Value

A list containing four elements:

shap_plot_onerow

An interactive plot displaying the SHAP values for each feature, clustered by the specified number of clusters. Each cluster is shown in a facet.

combined_plot

A ggplot2 figure combining confusion matrices for each cluster, providing insights into the model's performance within each identified subgroup.

kmeans_fvals_desc

A summary table containing statistical descriptions of the clusters based on feature values.

shap_Mean_wide_kmeans

A data frame containing clustered SHAP values along with predictions and ground truth information.

kmeans_info

Information about the k-means clustering process, including cluster centers and assignment details.

References

Zargari Marandi, R., 2024. ExplaineR: an R package to explain machine learning models. Bioinformatics advances, 4(1), p.vbae049, https://doi.org/10.1093/bioadv/vbae049.

See also

Other functions to visualize and interpret machine learning models: eSHAP_plot.

Examples

# \donttest{
library("explainer")
seed <- 246
set.seed(seed)
# Load necessary packages
if (!requireNamespace("mlbench", quietly = TRUE)) stop("mlbench not installed.")
if (!requireNamespace("mlr3learners", quietly = TRUE)) stop("mlr3learners not installed.")
if (!requireNamespace("ranger", quietly = TRUE)) stop("ranger not installed.")
# Load BreastCancer dataset
utils::data("BreastCancer", package = "mlbench")
target_col <- "Class"
positive_class <- "malignant"
mydata <- BreastCancer[, -1]
mydata <- na.omit(mydata)
sex <- sample(
  c("Male", "Female"),
  size = nrow(mydata),
  replace = TRUE
)
mydata$age <- as.numeric(sample(
  seq(18, 60),
  size = nrow(mydata),
  replace = TRUE
))
mydata$sex <- factor(
  sex,
  levels = c("Male", "Female"),
  labels = c(1, 0)
)
maintask <- mlr3::TaskClassif$new(
  id = "my_classification_task",
  backend = mydata,
  target = target_col,
  positive = positive_class
)
splits <- mlr3::partition(maintask)
mylrn <- mlr3::lrn(
  "classif.ranger",
  predict_type = "prob"
)
mylrn$train(maintask, splits$train)
SHAP_output <- eSHAP_plot(
  task = maintask,
  trained_model = mylrn,
  splits = splits,
  sample.size = 2, # also 30 or more
  seed = seed,
  subset = 0.02 # up to 1
)
#> Warning: Ignoring unknown aesthetics: text
shap_Mean_wide <- SHAP_output[[2]]
shap_Mean_long <- SHAP_output[[3]]
SHAP_plot_clusters <- SHAPclust(
  task = maintask,
  trained_model = mylrn,
  splits = splits,
  shap_Mean_wide = shap_Mean_wide,
  shap_Mean_long = shap_Mean_long,
  num_of_clusters = 3, # your choice
  seed = seed,
  subset = 0.02, # match with eSHAP_plot
  algorithm = "Hartigan-Wong",
  iter.max = 10
)
#> Key: <sample_num, feature, Phi>
#>     sample_num         feature           Phi cluster     mean_phi     f_val
#>          <int>          <char>         <num>   <int>        <num>     <num>
#>  1:          1     Bare.nuclei -0.1399936508       2 0.0493383117 0.0000000
#>  2:          1     Bl.cromatin -0.0828492063       2 0.0489450397 0.0000000
#>  3:          1      Cell.shape -0.0367019841       2 0.0745567280 0.0000000
#>  4:          1       Cell.size  0.0000000000       2 0.0319849206 0.0000000
#>  5:          1    Cl.thickness -0.0178948413       2 0.0262465278 0.5000000
#>  6:          1    Epith.c.size -0.0055142857       2 0.0059076389 0.5000000
#>  7:          1   Marg.adhesion  0.0000000000       2 0.0238763889 0.0000000
#>  8:          1         Mitoses  0.0000000000       2 0.0043198413 0.0000000
#>  9:          1 Normal.nucleoli -0.0798761905       2 0.0537973214 0.0000000
#> 10:          1             age -0.0151575397       2 0.0140473214 0.9444444
#> 11:          1             sex -0.0010932540       2 0.0005277778       NaN
#> 12:          2     Bare.nuclei -0.0050662698       3 0.0493383117 1.0000000
#> 13:          2     Bl.cromatin  0.0610115079       3 0.0489450397 1.0000000
#> 14:          2      Cell.shape  0.1521496032       3 0.0745567280 1.0000000
#> 15:          2       Cell.size  0.1072460317       3 0.0319849206 1.0000000
#> 16:          2    Cl.thickness  0.0501396825       3 0.0262465278 1.0000000
#> 17:          2    Epith.c.size  0.0181162698       3 0.0059076389 1.0000000
#> 18:          2   Marg.adhesion  0.0858515873       3 0.0238763889 1.0000000
#> 19:          2         Mitoses  0.0172793651       3 0.0043198413 1.0000000
#> 20:          2 Normal.nucleoli  0.0572436508       3 0.0537973214 1.0000000
#> 21:          2             age -0.0293571429       3 0.0140473214 1.0000000
#> 22:          2             sex  0.0010178571       3 0.0005277778       NaN
#> 23:          3     Bare.nuclei  0.0000000000       1 0.0493383117 0.0000000
#> 24:          3     Bl.cromatin  0.0000000000       1 0.0489450397 0.5000000
#> 25:          3      Cell.shape  0.0000000000       1 0.0745567280 0.0000000
#> 26:          3       Cell.size  0.0000000000       1 0.0319849206 0.0000000
#> 27:          3    Cl.thickness  0.0002857143       1 0.0262465278 0.6666667
#> 28:          3    Epith.c.size  0.0000000000       1 0.0059076389 0.0000000
#> 29:          3   Marg.adhesion  0.0000000000       1 0.0238763889 0.0000000
#> 30:          3         Mitoses  0.0000000000       1 0.0043198413 0.0000000
#> 31:          3 Normal.nucleoli  0.0000000000       1 0.0537973214 0.0000000
#> 32:          3             age  0.0002857143       1 0.0140473214 0.0000000
#> 33:          3             sex  0.0000000000       1 0.0005277778       NaN
#> 34:          4     Bare.nuclei -0.0522933261       2 0.0493383117 0.0000000
#> 35:          4     Bl.cromatin -0.0519194444       2 0.0489450397 0.0000000
#> 36:          4      Cell.shape -0.1093753247       2 0.0745567280 0.0000000
#> 37:          4       Cell.size -0.0206936508       2 0.0319849206 0.0000000
#> 38:          4    Cl.thickness -0.0366658730       2 0.0262465278 0.0000000
#> 39:          4    Epith.c.size  0.0000000000       2 0.0059076389 0.0000000
#> 40:          4   Marg.adhesion -0.0096539683       2 0.0238763889 0.0000000
#> 41:          4         Mitoses  0.0000000000       2 0.0043198413 0.0000000
#> 42:          4 Normal.nucleoli -0.0780694444       2 0.0537973214 0.0000000
#> 43:          4             age -0.0113888889       2 0.0140473214 0.4722222
#> 44:          4             sex  0.0000000000       2 0.0005277778       NaN
#>     sample_num         feature           Phi cluster     mean_phi     f_val
#>     unscaled_f_val correct_prediction    pred_prob pred_class
#>              <num>             <fctr>        <num>     <fctr>
#>  1:              1            Correct 0.0043809524     benign
#>  2:              1            Correct 0.0043809524     benign
#>  3:              1            Correct 0.0043809524     benign
#>  4:              1            Correct 0.0043809524     benign
#>  5:              4            Correct 0.0043809524     benign
#>  6:              3            Correct 0.0043809524     benign
#>  7:              1            Correct 0.0043809524     benign
#>  8:              1            Correct 0.0043809524     benign
#>  9:              1            Correct 0.0043809524     benign
#> 10:             54            Correct 0.0043809524     benign
#> 11:              2            Correct 0.0043809524     benign
#> 12:              3            Correct 0.9575134921  malignant
#> 13:              8            Correct 0.9575134921  malignant
#> 14:              7            Correct 0.9575134921  malignant
#> 15:              8            Correct 0.9575134921  malignant
#> 16:              7            Correct 0.9575134921  malignant
#> 17:              4            Correct 0.9575134921  malignant
#> 18:              6            Correct 0.9575134921  malignant
#> 19:              4            Correct 0.9575134921  malignant
#> 20:              8            Correct 0.9575134921  malignant
#> 21:             56            Correct 0.9575134921  malignant
#> 22:              2            Correct 0.9575134921  malignant
#> 23:              1            Correct 0.0005714286     benign
#> 24:              3            Correct 0.0005714286     benign
#> 25:              1            Correct 0.0005714286     benign
#> 26:              1            Correct 0.0005714286     benign
#> 27:              5            Correct 0.0005714286     benign
#> 28:              2            Correct 0.0005714286     benign
#> 29:              1            Correct 0.0005714286     benign
#> 30:              1            Correct 0.0005714286     benign
#> 31:              1            Correct 0.0005714286     benign
#> 32:             20            Correct 0.0005714286     benign
#> 33:              2            Correct 0.0005714286     benign
#> 34:              1            Correct 0.0000000000     benign
#> 35:              1            Correct 0.0000000000     benign
#> 36:              1            Correct 0.0000000000     benign
#> 37:              1            Correct 0.0000000000     benign
#> 38:              1            Correct 0.0000000000     benign
#> 39:              2            Correct 0.0000000000     benign
#> 40:              1            Correct 0.0000000000     benign
#> 41:              1            Correct 0.0000000000     benign
#> 42:              1            Correct 0.0000000000     benign
#> 43:             37            Correct 0.0000000000     benign
#> 44:              2            Correct 0.0000000000     benign
#>     unscaled_f_val correct_prediction    pred_prob pred_class
#> Warning: Ignoring unknown aesthetics: text
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: no non-missing arguments to max; returning -Inf
#> Warning: Computation failed in `stat_ydensity()`.
#> Caused by error in `$<-.data.frame`:
#> ! replacement has 1 row, data has 0
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: Groups with fewer than two datapoints have been dropped.
#>  Set `drop = FALSE` to consider such groups for position adjustment purposes.
#> Warning: no non-missing arguments to max; returning -Inf
#> Warning: Computation failed in `stat_ydensity()`.
#> Caused by error in `$<-.data.frame`:
#> ! replacement has 1 row, data has 0
#> Warning: 'ggimage' is missing. Will not plot arrows and zero-shading.
#> Warning: 'rsvg' is missing. Will not plot arrows and zero-shading.
#> Warning: 'ggimage' is missing. Will not plot arrows and zero-shading.
#> Warning: 'rsvg' is missing. Will not plot arrows and zero-shading.
#> Warning: 'ggimage' is missing. Will not plot arrows and zero-shading.
#> Warning: 'rsvg' is missing. Will not plot arrows and zero-shading.
# }