SHAP values are used to cluster data samples using the k-means method to identify subgroups of individuals with specific patterns of feature contributions.
Usage
SHAPclust(
task,
trained_model,
splits,
shap_Mean_wide,
shap_Mean_long,
num_of_clusters = 4,
seed = 246,
subset = 1,
algorithm = "Hartigan-Wong",
iter.max = 1000
)
Arguments
- task
an mlr3 task for binary classification
- trained_model
an mlr3 trained learner object
- splits
an mlr3 object defining data splits for train and test sets
- shap_Mean_wide
the data frame of SHAP values in wide format from eSHAP_plot.R
- shap_Mean_long
the data frame of SHAP values in long format from eSHAP_plot.R
- num_of_clusters
number of clusters to make based on SHAP values, default: 4
- seed
an integer for reproducibility, Default to 246
- subset
what percentage of the instances to use from 0 to 1 where 1 means all
- algorithm
k-means algorithm character: "Hartigan-Wong", "Lloyd", "Forgy", "MacQueen".
- iter.max
maximum number of iterations allowed
Value
A list containing four elements:
- shap_plot_onerow
An interactive plot displaying the SHAP values for each feature, clustered by the specified number of clusters. Each cluster is shown in a facet.
- combined_plot
A ggplot2 figure combining confusion matrices for each cluster, providing insights into the model's performance within each identified subgroup.
- kmeans_fvals_desc
A summary table containing statistical descriptions of the clusters based on feature values.
- shap_Mean_wide_kmeans
A data frame containing clustered SHAP values along with predictions and ground truth information.
- kmeans_info
Information about the k-means clustering process, including cluster centers and assignment details.
References
Zargari Marandi, R., 2024. ExplaineR: an R package to explain machine learning models. Bioinformatics advances, 4(1), p.vbae049, https://doi.org/10.1093/bioadv/vbae049.
See also
Other functions to visualize and interpret machine learning models: eSHAP_plot
.
Examples
# \donttest{
library("explainer")
seed <- 246
set.seed(seed)
# Load necessary packages
if (!requireNamespace("mlbench", quietly = TRUE)) stop("mlbench not installed.")
if (!requireNamespace("mlr3learners", quietly = TRUE)) stop("mlr3learners not installed.")
if (!requireNamespace("ranger", quietly = TRUE)) stop("ranger not installed.")
# Load BreastCancer dataset
utils::data("BreastCancer", package = "mlbench")
target_col <- "Class"
positive_class <- "malignant"
mydata <- BreastCancer[, -1]
mydata <- na.omit(mydata)
sex <- sample(
c("Male", "Female"),
size = nrow(mydata),
replace = TRUE
)
mydata$age <- as.numeric(sample(
seq(18, 60),
size = nrow(mydata),
replace = TRUE
))
mydata$sex <- factor(
sex,
levels = c("Male", "Female"),
labels = c(1, 0)
)
maintask <- mlr3::TaskClassif$new(
id = "my_classification_task",
backend = mydata,
target = target_col,
positive = positive_class
)
splits <- mlr3::partition(maintask)
mylrn <- mlr3::lrn(
"classif.ranger",
predict_type = "prob"
)
mylrn$train(maintask, splits$train)
SHAP_output <- eSHAP_plot(
task = maintask,
trained_model = mylrn,
splits = splits,
sample.size = 2, # also 30 or more
seed = seed,
subset = 0.02 # up to 1
)
#> Error in ungroup(.): could not find function "ungroup"
shap_Mean_wide <- SHAP_output[[2]]
#> Error: object 'SHAP_output' not found
shap_Mean_long <- SHAP_output[[3]]
#> Error: object 'SHAP_output' not found
SHAP_plot_clusters <- SHAPclust(
task = maintask,
trained_model = mylrn,
splits = splits,
shap_Mean_wide = shap_Mean_wide,
shap_Mean_long = shap_Mean_long,
num_of_clusters = 3, # your choice
seed = seed,
subset = 0.02, # match with eSHAP_plot
algorithm = "Hartigan-Wong",
iter.max = 10
)
#> Error: object 'shap_Mean_wide' not found
# }