We actually have a dataset that is very imbalanced, riskware and adware contain ~50x the amount of sample as trojan_banker
Feature Selection
# Metadata columns to exclude from featuresmetadata_cols <-c("Category", "Family", "Category_file", "reboot_state", "path", "file", "Hash")# Get all column namesall_cols <-names(andmal_after)# Identify metadata columns that actually existmetadata_cols_present <-intersect(metadata_cols, all_cols)# Feature columns for Category prediction (exclude Family and Category, plus other metadata)feature_cols_model1 <-setdiff(all_cols, metadata_cols_present)# Additional exclusion: any rank/color columns that might have been addedfeature_cols_model1 <- feature_cols_model1[!str_detect(feature_cols_model1, "rank|color|fam_color|rank_in_cat")]
On my first attempt training the random forest model I left in the hash column. The lime and shap values were both very strong for this column on almost any sample. This is very odd so as someone who is not an expert in cybersecurity I decided to research what hash was. It turns out it is just SHA-256- a cryptographic hash function for any data- in this case of the APK, the android package kit file format; this essentially acts as a unique identifier for our sample so we must get rid of it.
The random forest model for category prediction demonstrates suprisingly strong performance in distinguishing between malware categories. The precision score of zero-day is one of the worst. This is to be expected since these are novel pieces of code. Generally going by any of the scores we can see that No_category and Zero_Day are the hardest to classify. But in a way both of these classes represent new classes or unorthodox code. I suspect the low amount of samples we had for file-injector are one of the main reasons it performed so poorly.
Model Interpretability
Instance Selection
# Set seed for reproducibilityset.seed(15)# Select one instance from each category from the test settarget_categories <-c("Adware", "Riskware", "Trojan")selected_instances <-list()for (category in target_categories) { category_indices <-which(test_data$Category == category)if (length(category_indices) >0) {# Select first instance from this category in test set selected_idx <- category_indices[1] selected_instances[[category]] <- test_data[selected_idx, ] } else {invisible() # No instances found for this category }}# Combine into a single data frame for easier handlingselected_instances_df <-bind_rows(selected_instances)# Extract feature matrices for explanationsselected_instances_x <- selected_instances_df[, feature_cols_model1, drop =FALSE]# Save selected instances for use in family.qmdif (!dir.exists("data/processed")) {dir.create("data/processed", recursive =TRUE)}saveRDS(selected_instances_df, "data/processed/selected_instances.rds")
LIME Explanations
I select an instance from each of 3 of the most populated categories(Adware,Riskware,Trojan) I then obtain the LIME and Shapley values for these 3 instances.
# Ensure required libraries are loadedif (!require(lime, quietly =TRUE)) library(lime)if (!require(ggplot2, quietly =TRUE)) library(ggplot2)if (!require(dplyr, quietly =TRUE)) library(dplyr)# Set seed for reproducibilityset.seed(15)# Create a prediction function wrapper for ranger modelpredict_function <-function(model, newdata) { pred <-predict(model, newdata)return(pred$predictions)}# Create LIME explainerexplainer <-lime( train_model1_x,model = rf_category,bin_continuous =TRUE,n_bins =5)# Generate explanations for all selected instanceslime_explanations <- lime::explain( selected_instances_x,explainer = explainer,n_features =10,n_permutations =5000,n_labels =1)# Plot LIME explanationsfor (i in1:nrow(selected_instances_x)) { category_name <- selected_instances_df$Category[i]print(plot_features(lime_explanations[lime_explanations$case == i, ]))}
fastshap Explanations
Per-Instance SHAP Values
# Ensure required libraries are loadedif (!require(fastshap, quietly =TRUE)) library(fastshap)if (!require(ggplot2, quietly =TRUE)) library(ggplot2)if (!require(dplyr, quietly =TRUE)) library(dplyr)# Set seed for reproducibilityset.seed(15)# Create prediction function for fastshappred_wrapper <-function(object, newdata) { pred <-predict(object, newdata)return(pred$predictions)}# Calculate SHAP values for each selected instanceshap_values_list <-list()for (i in1:nrow(selected_instances_x)) { category_name <- selected_instances_df$Category[i]# Ensure newdata is a data.frame to match training feature class newdata_df <-as.data.frame(selected_instances_x[i, , drop =FALSE])# Calculate SHAP values for this instance shap_vals <-explain( rf_category,X = train_model1_x,newdata = newdata_df,pred_wrapper = pred_wrapper,nsim =100 ) shap_values_list[[category_name]] <- shap_vals# Plot SHAP values for this instance# Get the predicted class probabilities pred_probs <-predict(rf_category, newdata_df)$predictions pred_class <-colnames(pred_probs)[which.max(pred_probs)]# Handle SHAP values - fastshap returns a matrix/data.frame# For multi-class, extract SHAP values for the predicted classif (is.data.frame(shap_vals) ||is.matrix(shap_vals)) {# If it's a matrix/data.frame, check if it has class columnsif (pred_class %in%colnames(shap_vals)) { shap_df <-data.frame(feature =rownames(shap_vals),shap_value = shap_vals[[pred_class]] ) } elseif (ncol(shap_vals) ==length(feature_cols_model1)) {# If columns are features, use the first (or only) row shap_df <-data.frame(feature =colnames(shap_vals),shap_value =as.numeric(shap_vals[1, ]) ) } else {# Try to extract first column or first rowif (nrow(shap_vals) ==1) { shap_df <-data.frame(feature =colnames(shap_vals),shap_value =as.numeric(shap_vals[1, ]) ) } else { shap_df <-data.frame(feature =rownames(shap_vals),shap_value =as.numeric(shap_vals[, 1]) ) } } shap_df <- shap_df %>%arrange(desc(abs(shap_value))) %>%head(20)# Determine the order of magnitude (exponent) for scientific notation max_abs_value <-max(abs(shap_df$shap_value))if (max_abs_value >0) { exponent <-floor(log10(max_abs_value))# Round to nearest multiple of 3 for cleaner display exponent <-round(exponent /3) *3 } else { exponent <-0 }# Create custom label function that shows only significant digits# and scales by the exponent scale_factor <-10^(-exponent) label_func <-function(x) { scaled <- x * scale_factor# Format with appropriate decimal placesif (abs(exponent) >=3) {sprintf("%.3f", scaled) } else {sprintf("%.4f", scaled) } }# Create visualization y_axis_label <-if (abs(exponent) >=3) {sprintf("SHAP Value (×10^%d)", exponent) } else {"SHAP Value" } p_shap <-ggplot(shap_df, aes(x =reorder(feature, shap_value), y = shap_value)) +geom_col(aes(fill = shap_value >0)) +scale_fill_manual(values =c("TRUE"="#2E8B57", "FALSE"="#DC143C"),labels =c("TRUE"="Positive", "FALSE"="Negative"),name ="SHAP Value" ) +scale_y_continuous(labels = label_func) +coord_flip() +labs(title =sprintf("SHAP Values - %s Instance", category_name),subtitle =sprintf("Predicted class: %s", pred_class),x ="Feature",y = y_axis_label ) +theme_minimal() +theme(plot.title =element_text(size =14, face ="bold"),plot.subtitle =element_text(size =12),axis.text.y =element_text(size =8) )print(p_shap) } else {cat(sprintf("SHAP values format not recognized for %s instance\n", category_name)) }}