SHAP values are used to cluster data samples using the k-means method to identify subgroups of individuals with specific patterns of feature contributions.
SHAPclust(
task,
trained_model,
splits,
shap_Mean_wide,
shap_Mean_long,
num_of_clusters = 4,
seed = 246,
subset = 1,
algorithm = "Hartigan-Wong",
iter.max = 1000
)
an mlr3 task for binary classification
an mlr3 trained learner object
an mlr3 object defining data splits for train and test sets
the data frame of SHAP values in wide format from eSHAP_plot.R
the data frame of SHAP values in long format from eSHAP_plot.R
number of clusters to make based on SHAP values, default: 4
an integer for reproducibility, Default to 246
what percentage of the instances to use from 0 to 1 where 1 means all
k-means algorithm character: "Hartigan-Wong", "Lloyd", "Forgy", "MacQueen".
maximum number of iterations allowed
A list containing four elements:
An interactive plot displaying the SHAP values for each feature, clustered by the specified number of clusters. Each cluster is shown in a facet.
A ggplot2 figure combining confusion matrices for each cluster, providing insights into the model's performance within each identified subgroup.
A summary table containing statistical descriptions of the clusters based on feature values.
A data frame containing clustered SHAP values along with predictions and ground truth information.
Information about the k-means clustering process, including cluster centers and assignment details.
Zargari Marandi, R., 2024. ExplaineR: an R package to explain machine learning models. Bioinformatics advances, 4(1), p.vbae049, https://doi.org/10.1093/bioadv/vbae049.
Other functions to visualize and interpret machine learning models: eSHAP_plot
.
# \donttest{
library("explainer")
seed <- 246
set.seed(seed)
# Load necessary packages
if (!requireNamespace("mlbench", quietly = TRUE)) stop("mlbench not installed.")
if (!requireNamespace("mlr3learners", quietly = TRUE)) stop("mlr3learners not installed.")
if (!requireNamespace("ranger", quietly = TRUE)) stop("ranger not installed.")
# Load BreastCancer dataset
utils::data("BreastCancer", package = "mlbench")
target_col <- "Class"
positive_class <- "malignant"
mydata <- BreastCancer[, -1]
mydata <- na.omit(mydata)
sex <- sample(
c("Male", "Female"),
size = nrow(mydata),
replace = TRUE
)
mydata$age <- as.numeric(sample(
seq(18,60),
size = nrow(mydata),
replace = TRUE
))
mydata$sex <- factor(
sex,
levels = c("Male", "Female"),
labels = c(1, 0)
)
maintask <- mlr3::TaskClassif$new(
id = "my_classification_task",
backend = mydata,
target = target_col,
positive = positive_class
)
splits <- mlr3::partition(maintask)
mylrn <- mlr3::lrn(
"classif.ranger",
predict_type = "prob"
)
mylrn$train(maintask, splits$train)
SHAP_output <- eSHAP_plot(
task = maintask,
trained_model = mylrn,
splits = splits,
sample.size = 2, # also 30 or more
seed = seed,
subset = 0.02 # up to 1
)
#> Warning: Ignoring unknown aesthetics: text
shap_Mean_wide <- SHAP_output[[2]]
shap_Mean_long <- SHAP_output[[3]]
SHAP_plot_clusters <- SHAPclust(
task = maintask,
trained_model = mylrn,
splits = splits,
shap_Mean_wide = shap_Mean_wide,
shap_Mean_long = shap_Mean_long,
num_of_clusters = 3, # your choice
seed = seed,
subset = 0.02, # match with eSHAP_plot
algorithm="Hartigan-Wong",
iter.max = 10
)
#> sample_num feature Phi cluster mean_phi f_val
#> 1: 1 Bare.nuclei -1.269421e-01 2 0.0509097619 0.0000000
#> 2: 1 Bl.cromatin -6.999405e-02 2 0.0414809524 0.2500000
#> 3: 1 Cell.shape -1.032143e-02 2 0.0448956349 0.2222222
#> 4: 1 Cell.size -1.649984e-01 2 0.0500060317 0.0000000
#> 5: 1 Cl.thickness -1.129325e-02 2 0.0100390476 0.0000000
#> 6: 1 Epith.c.size -2.531587e-02 2 0.0093425397 0.2500000
#> 7: 1 Marg.adhesion -6.559524e-03 2 0.0143293651 0.2500000
#> 8: 1 Mitoses -1.144444e-03 2 0.0047004762 0.0000000
#> 9: 1 Normal.nucleoli -7.803810e-02 2 0.0322451587 0.0000000
#> 10: 1 age -3.290873e-03 2 0.0096596825 0.4782609
#> 11: 1 sex 0.000000e+00 2 0.0005798413 1.0000000
#> 12: 2 Bare.nuclei 0.000000e+00 1 0.0509097619 0.0000000
#> 13: 2 Bl.cromatin -7.762302e-03 1 0.0414809524 0.0000000
#> 14: 2 Cell.shape 1.086937e-01 1 0.0448956349 0.3333333
#> 15: 2 Cell.size 7.969444e-03 1 0.0500060317 0.1250000
#> 16: 2 Cl.thickness -7.388889e-03 1 0.0100390476 0.5000000
#> 17: 2 Epith.c.size 0.000000e+00 1 0.0093425397 0.0000000
#> 18: 2 Marg.adhesion 0.000000e+00 1 0.0143293651 0.0000000
#> 19: 2 Mitoses 0.000000e+00 1 0.0047004762 0.0000000
#> 20: 2 Normal.nucleoli 0.000000e+00 1 0.0322451587 0.0000000
#> 21: 2 age 8.115079e-03 1 0.0096596825 1.0000000
#> 22: 2 sex 0.000000e+00 1 0.0005798413 0.0000000
#> 23: 3 Bare.nuclei 2.711071e-02 1 0.0509097619 0.8888889
#> 24: 3 Bl.cromatin -4.012897e-02 1 0.0414809524 0.5000000
#> 25: 3 Cell.shape 4.722222e-05 1 0.0448956349 1.0000000
#> 26: 3 Cell.size -7.780556e-03 1 0.0500060317 1.0000000
#> 27: 3 Cl.thickness 2.715198e-02 1 0.0100390476 1.0000000
#> 28: 3 Epith.c.size -6.109524e-03 1 0.0093425397 1.0000000
#> 29: 3 Marg.adhesion 2.766667e-03 1 0.0143293651 1.0000000
#> 30: 3 Mitoses 8.369444e-03 1 0.0047004762 1.0000000
#> 31: 3 Normal.nucleoli 2.432222e-02 1 0.0322451587 1.0000000
#> 32: 3 age -1.735516e-02 1 0.0096596825 0.0000000
#> 33: 3 sex 1.400000e-03 1 0.0005798413 1.0000000
#> 34: 4 Bare.nuclei 0.000000e+00 1 0.0509097619 0.0000000
#> 35: 4 Bl.cromatin 0.000000e+00 1 0.0414809524 0.0000000
#> 36: 4 Cell.shape 0.000000e+00 1 0.0448956349 0.0000000
#> 37: 4 Cell.size 6.726984e-03 1 0.0500060317 0.2500000
#> 38: 4 Cl.thickness -8.333333e-05 1 0.0100390476 0.1666667
#> 39: 4 Epith.c.size 0.000000e+00 1 0.0093425397 0.2500000
#> 40: 4 Marg.adhesion 0.000000e+00 1 0.0143293651 0.0000000
#> 41: 4 Mitoses 0.000000e+00 1 0.0047004762 0.0000000
#> 42: 4 Normal.nucleoli 0.000000e+00 1 0.0322451587 0.0000000
#> 43: 4 age -3.400000e-03 1 0.0096596825 0.3478261
#> 44: 4 sex 0.000000e+00 1 0.0005798413 1.0000000
#> 45: 5 Bare.nuclei 1.004960e-01 3 0.0509097619 1.0000000
#> 46: 5 Bl.cromatin 8.951944e-02 3 0.0414809524 1.0000000
#> 47: 5 Cell.shape 1.054159e-01 3 0.0448956349 1.0000000
#> 48: 5 Cell.size 6.255476e-02 3 0.0500060317 1.0000000
#> 49: 5 Cl.thickness 4.277778e-03 3 0.0100390476 0.6666667
#> 50: 5 Epith.c.size 1.528730e-02 3 0.0093425397 1.0000000
#> 51: 5 Marg.adhesion 6.232063e-02 3 0.0143293651 1.0000000
#> 52: 5 Mitoses 1.398849e-02 3 0.0047004762 1.0000000
#> 53: 5 Normal.nucleoli 5.886548e-02 3 0.0322451587 1.0000000
#> 54: 5 age -1.613730e-02 3 0.0096596825 0.1304348
#> 55: 5 sex 1.499206e-03 3 0.0005798413 1.0000000
#> sample_num feature Phi cluster mean_phi f_val
#> unscaled_f_val correct_prediction pred_prob pred_class
#> 1: 1 Correct 0.002000000 benign
#> 2: 2 Correct 0.002000000 benign
#> 3: 3 Correct 0.002000000 benign
#> 4: 1 Correct 0.002000000 benign
#> 5: 2 Correct 0.002000000 benign
#> 6: 2 Correct 0.002000000 benign
#> 7: 2 Correct 0.002000000 benign
#> 8: 1 Correct 0.002000000 benign
#> 9: 1 Correct 0.002000000 benign
#> 10: 38 Correct 0.002000000 benign
#> 11: 2 Correct 0.002000000 benign
#> 12: 1 Correct 0.109626984 benign
#> 13: 1 Correct 0.109626984 benign
#> 14: 4 Correct 0.109626984 benign
#> 15: 2 Correct 0.109626984 benign
#> 16: 5 Correct 0.109626984 benign
#> 17: 1 Correct 0.109626984 benign
#> 18: 1 Correct 0.109626984 benign
#> 19: 1 Correct 0.109626984 benign
#> 20: 1 Correct 0.109626984 benign
#> 21: 59 Correct 0.109626984 benign
#> 22: 1 Correct 0.109626984 benign
#> 23: 9 Correct 0.935619048 malignant
#> 24: 3 Correct 0.935619048 malignant
#> 25: 10 Correct 0.935619048 malignant
#> 26: 10 Correct 0.935619048 malignant
#> 27: 8 Correct 0.935619048 malignant
#> 28: 6 Correct 0.935619048 malignant
#> 29: 8 Correct 0.935619048 malignant
#> 30: 9 Correct 0.935619048 malignant
#> 31: 10 Correct 0.935619048 malignant
#> 32: 27 Correct 0.935619048 malignant
#> 33: 2 Correct 0.935619048 malignant
#> 34: 1 Correct 0.003243651 benign
#> 35: 1 Correct 0.003243651 benign
#> 36: 1 Correct 0.003243651 benign
#> 37: 3 Correct 0.003243651 benign
#> 38: 3 Correct 0.003243651 benign
#> 39: 2 Correct 0.003243651 benign
#> 40: 1 Correct 0.003243651 benign
#> 41: 1 Correct 0.003243651 benign
#> 42: 1 Correct 0.003243651 benign
#> 43: 35 Correct 0.003243651 benign
#> 44: 2 Correct 0.003243651 benign
#> 45: 10 Correct 0.995395238 malignant
#> 46: 8 Correct 0.995395238 malignant
#> 47: 10 Correct 0.995395238 malignant
#> 48: 10 Correct 0.995395238 malignant
#> 49: 6 Correct 0.995395238 malignant
#> 50: 10 Correct 0.995395238 malignant
#> 51: 10 Correct 0.995395238 malignant
#> 52: 9 Correct 0.995395238 malignant
#> 53: 10 Correct 0.995395238 malignant
#> 54: 30 Correct 0.995395238 malignant
#> 55: 2 Correct 0.995395238 malignant
#> unscaled_f_val correct_prediction pred_prob pred_class
#> Warning: Ignoring unknown aesthetics: text
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: no non-missing arguments to max; returning -Inf
#> Warning: Computation failed in `stat_ydensity()`:
#> replacement has 1 row, data has 0
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning: no non-missing arguments to max; returning -Inf
#> Warning: Computation failed in `stat_ydensity()`:
#> replacement has 1 row, data has 0
# }