Category Prediction Model

This page trains and evaluates a random forest model with 500 trees to predict the malware category using features from the AndMal2020 dataset.

Load Libraries and Data

# Load preprocessed data (from preprocessing.qmd)
andmal_after <- readRDS("data/processed/andmal_after.rds")

print(table(andmal_after$Category))

        Adware       Backdoor   FileInfector    No_Category            PUA 
          5142            546            119            884            625 
    Ransomware       Riskware      Scareware         Trojan  Trojan_Banker 
          1550           6792            424           4025            123 
Trojan_Dropper     Trojan_SMS     Trojan_Spy       Zero_Day 
           733            911           1039           2146 

We actually have a dataset that is very imbalanced, riskware and adware contain ~50x the amount of sample as trojan_banker

Feature Selection

# Metadata columns to exclude from features
metadata_cols <- c("Category", "Family", "Category_file", "reboot_state", "path", "file", "Hash")

# Get all column names
all_cols <- names(andmal_after)

# Identify metadata columns that actually exist
metadata_cols_present <- intersect(metadata_cols, all_cols)

# Feature columns for Category prediction (exclude Family and Category, plus other metadata)
feature_cols_model1 <- setdiff(all_cols, metadata_cols_present)

# Additional exclusion: any rank/color columns that might have been added
feature_cols_model1 <- feature_cols_model1[!str_detect(feature_cols_model1, "rank|color|fam_color|rank_in_cat")]

On my first attempt training the random forest model I left in the hash column. The lime and shap values were both very strong for this column on almost any sample. This is very odd so as someone who is not an expert in cybersecurity I decided to research what hash was. It turns out it is just SHA-256- a cryptographic hash function for any data- in this case of the APK, the android package kit file format; this essentially acts as a unique identifier for our sample so we must get rid of it.

Train/Test Split

# Create stratified split for Category
set.seed(15)
caret_available <- require("caret", quietly = TRUE)

if (caret_available) {
  train_index <- caret::createDataPartition(
    andmal_after$Category,
    p = train_test_split,
    list = FALSE,
    times = 1
  )
  train_data <- andmal_after[train_index, ]
  test_data <- andmal_after[-train_index, ]
} else {
  # Use base R stratified sampling
  categories <- unique(andmal_after$Category)
  train_indices <- c()
  for (cat in categories) {
    cat_indices <- which(andmal_after$Category == cat)
    n_train <- round(length(cat_indices) * train_test_split)
    train_cat_indices <- sample(cat_indices, n_train)
    train_indices <- c(train_indices, train_cat_indices)
  }
  train_data <- andmal_after[train_indices, ]
  test_data <- andmal_after[-train_indices, ]
}

I use a 80/20 train/test data split

Train Model

# Ensure required libraries are loaded (defensive check for individual rendering)
if (!require(ranger, quietly = TRUE)) library(ranger)
if (!require(dplyr, quietly = TRUE)) library(dplyr)

# Prepare data
train_model1_x <- train_data[, feature_cols_model1, drop = FALSE]
train_model1_y <- train_data$Category
test_model1_x <- test_data[, feature_cols_model1, drop = FALSE]
test_model1_y <- test_data$Category

# Calculate mtry (default: sqrt of number of features)
mtry_model1 <- floor(sqrt(length(feature_cols_model1)))


# Train model
start_time <- Sys.time()

rf_category <- ranger(
  x = train_model1_x,
  y = train_model1_y,
  num.trees = num_trees,
  mtry = mtry_model1,
  min.node.size = 1,
  num.threads = num_threads,
  classification = TRUE,
  probability = TRUE,
  importance = "impurity",
  verbose = TRUE
)

end_time <- Sys.time()
training_time <- difftime(end_time, start_time, units = "mins")

Evaluate Model

pred_model1 <- predict(rf_category, test_model1_x)
pred_categories <- pred_model1$predictions

# Get predicted class (highest probability)
pred_categories_class <- colnames(pred_categories)[apply(pred_categories, 1, which.max)]
pred_categories_class <- factor(pred_categories_class, levels = levels(test_model1_y))

# Confusion matrix (always use caret)
cm_model1 <- caret::confusionMatrix(pred_categories_class, test_model1_y)

# Calculate per-class metrics
metrics_model1 <- cm_model1$overall
per_class_model1 <- cm_model1$byClass

Model Performance Metrics

Category Prediction Model Performance Metrics
Category Overall Accuracy Precision Recall F1 Score
Class: Adware 0.8064 0.7280 0.9163 0.8114
Class: Backdoor 0.8064 0.9259 0.6881 0.7895
Class: FileInfector 0.8064 1.0000 0.4783 0.6471
Class: No_Category 0.8064 0.8750 0.2784 0.4224
Class: PUA 0.8064 0.8750 0.6720 0.7602
Class: Ransomware 0.8064 0.7727 0.8774 0.8218
Class: Riskware 0.8064 0.8603 0.8844 0.8722
Class: Scareware 0.8064 0.8642 0.8333 0.8485
Class: Trojan 0.8064 0.8381 0.8745 0.8559
Class: Trojan_Banker 0.8064 0.8000 0.5000 0.6154
Class: Trojan_Dropper 0.8064 0.9082 0.6096 0.7295
Class: Trojan_SMS 0.8064 0.8553 0.7143 0.7784
Class: Trojan_Spy 0.8064 0.9055 0.8792 0.8922
Class: Zero_Day 0.8064 0.6486 0.5035 0.5669
Overall Metrics:
      Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull 
     0.8064323      0.7666103      0.7952103      0.8172957      0.2712745 
AccuracyPValue  McnemarPValue 
     0.0000000            NaN 

Per-Class Metrics:
                      Sensitivity Specificity Pos Pred Value Neg Pred Value
Class: Adware           0.9163424   0.9115133      0.7279753      0.9768319
Class: Backdoor         0.6880734   0.9987748      0.9259259      0.9930964
Class: FileInfector     0.4782609   1.0000000      1.0000000      0.9975976
Class: No_Category      0.2784091   0.9985507      0.8750000      0.9743434
Class: PUA              0.6720000   0.9975415      0.8750000      0.9916497
Class: Ransomware       0.8774194   0.9829642      0.7727273      0.9918350
Class: Riskware         0.8843888   0.9465461      0.8603152      0.9565097
Class: Scareware        0.8333333   0.9977651      0.8641975      0.9971574
Class: Trojan           0.8745342   0.9676268      0.8380952      0.9757561
Class: Trojan_Banker    0.5000000   0.9993978      0.8000000      0.9975957
Class: Trojan_Dropper   0.6095890   0.9981481      0.9081633      0.9883863
Class: Trojan_SMS       0.7142857   0.9954395      0.8552632      0.9892872
Class: Trojan_Spy       0.8792271   0.9960408      0.9054726      0.9947971
Class: Zero_Day         0.5034965   0.9744374      0.6486486      0.9544190
                      Precision    Recall        F1  Prevalence Detection Rate
Class: Adware         0.7279753 0.9163424 0.8113695 0.205353576    0.188174191
Class: Backdoor       0.9259259 0.6880734 0.7894737 0.021773871    0.014982022
Class: FileInfector   1.0000000 0.4782609 0.6470588 0.004594487    0.002197363
Class: No_Category    0.8750000 0.2784091 0.4224138 0.035157811    0.009788254
Class: PUA            0.8750000 0.6720000 0.7601810 0.024970036    0.016779864
Class: Ransomware     0.7727273 0.8774194 0.8217523 0.061925689    0.054334798
Class: Riskware       0.8603152 0.8843888 0.8721859 0.271274471    0.239912105
Class: Scareware      0.8641975 0.8333333 0.8484848 0.016779864    0.013983220
Class: Trojan         0.8380952 0.8745342 0.8559271 0.160807032    0.140631243
Class: Trojan_Banker  0.8000000 0.5000000 0.6153846 0.004794247    0.002397123
Class: Trojan_Dropper 0.9081633 0.6095890 0.7295082 0.029165002    0.017778666
Class: Trojan_SMS     0.8552632 0.7142857 0.7784431 0.036356372    0.025968837
Class: Trojan_Spy     0.9054726 0.8792271 0.8921569 0.041350380    0.036356372
Class: Zero_Day       0.6486486 0.5034965 0.5669291 0.085697163    0.043148222
                      Detection Prevalence Balanced Accuracy
Class: Adware                  0.258489812         0.9139279
Class: Backdoor                0.016180583         0.8434241
Class: FileInfector            0.002197363         0.7391304
Class: No_Category             0.011186576         0.6384799
Class: PUA                     0.019176988         0.8347707
Class: Ransomware              0.070315621         0.9301918
Class: Riskware                0.278865362         0.9154674
Class: Scareware               0.016180583         0.9155492
Class: Trojan                  0.167798642         0.9210805
Class: Trojan_Banker           0.002996404         0.7496989
Class: Trojan_Dropper          0.019576508         0.8038686
Class: Trojan_SMS              0.030363564         0.8548626
Class: Trojan_Spy              0.040151818         0.9376339
Class: Zero_Day                0.066520176         0.7389670

Feature Importance

Top 20 Most Important Features:
                                                             feature importance
1                                                 Memory_SharedClean   279.2457
2                                                Memory_PrivateDirty   271.0982
3                                                   Memory_HeapAlloc   265.9934
4                                                 Memory_SharedDirty   259.8224
5                                                    Memory_HeapSize   254.7220
6  API_DeviceInfo_android.telephony.TelephonyManager_getSubscriberId   247.2338
7                                                          env_calls   244.2941
8                                                    env_probe_count   236.0168
9                                                      log_env_probe   229.8714
10     API_DeviceInfo_android.telephony.TelephonyManager_getDeviceId   227.9130
11                                                       dirty_ratio   226.4893
12                                                   Memory_PssTotal   225.2356
13                           API_Network_java.net.URL_openConnection   223.8871
14                                                       shared_frac   223.8149
15                API_Crypto-Hash_java.security.MessageDigest_digest   221.4126
16                                               Memory_PrivateClean   221.0624
17                                                          id_calls   219.4206
18                        API_Command_java.lang.ProcessBuilder_start   214.2442
19                                                      Memory_Views   213.5331
20                                                      log_PssTotal   213.4108
saveRDS(rf_category, "data/models/rf_category_model.rds")

Random Forest Model Summary

The random forest model for category prediction demonstrates suprisingly strong performance in distinguishing between malware categories. The precision score of zero-day is one of the worst. This is to be expected since these are novel pieces of code. Generally going by any of the scores we can see that No_category and Zero_Day are the hardest to classify. But in a way both of these classes represent new classes or unorthodox code. I suspect the low amount of samples we had for file-injector are one of the main reasons it performed so poorly.

Model Interpretability

Instance Selection

# Set seed for reproducibility
set.seed(15)

# Select one instance from each category from the test set
target_categories <- c("Adware", "Riskware", "Trojan")
selected_instances <- list()

for (category in target_categories) {
  category_indices <- which(test_data$Category == category)
  if (length(category_indices) > 0) {
    # Select first instance from this category in test set
    selected_idx <- category_indices[1]
    selected_instances[[category]] <- test_data[selected_idx, ]
  } else {
    invisible()  # No instances found for this category
  }
}

# Combine into a single data frame for easier handling
selected_instances_df <- bind_rows(selected_instances)

# Extract feature matrices for explanations
selected_instances_x <- selected_instances_df[, feature_cols_model1, drop = FALSE]

# Save selected instances for use in family.qmd
if (!dir.exists("data/processed")) {
  dir.create("data/processed", recursive = TRUE)
}
saveRDS(selected_instances_df, "data/processed/selected_instances.rds")

LIME Explanations

I select an instance from each of 3 of the most populated categories(Adware,Riskware,Trojan) I then obtain the LIME and Shapley values for these 3 instances.

# Ensure required libraries are loaded
if (!require(lime, quietly = TRUE)) library(lime)
if (!require(ggplot2, quietly = TRUE)) library(ggplot2)
if (!require(dplyr, quietly = TRUE)) library(dplyr)

# Set seed for reproducibility
set.seed(15)

# Create a prediction function wrapper for ranger model
predict_function <- function(model, newdata) {
  pred <- predict(model, newdata)
  return(pred$predictions)
}

# Create LIME explainer
explainer <- lime(
  train_model1_x,
  model = rf_category,
  bin_continuous = TRUE,
  n_bins = 5
)

# Generate explanations for all selected instances
lime_explanations <- lime::explain(
  selected_instances_x,
  explainer = explainer,
  n_features = 10,
  n_permutations = 5000,
  n_labels = 1
)

# Plot LIME explanations
for (i in 1:nrow(selected_instances_x)) {
  category_name <- selected_instances_df$Category[i]
  print(plot_features(lime_explanations[lime_explanations$case == i, ]))
}

fastshap Explanations

Per-Instance SHAP Values

# Ensure required libraries are loaded
if (!require(fastshap, quietly = TRUE)) library(fastshap)
if (!require(ggplot2, quietly = TRUE)) library(ggplot2)
if (!require(dplyr, quietly = TRUE)) library(dplyr)

# Set seed for reproducibility
set.seed(15)

# Create prediction function for fastshap
pred_wrapper <- function(object, newdata) {
  pred <- predict(object, newdata)
  return(pred$predictions)
}


# Calculate SHAP values for each selected instance
shap_values_list <- list()

for (i in 1:nrow(selected_instances_x)) {
  category_name <- selected_instances_df$Category[i]
  
  # Ensure newdata is a data.frame to match training feature class
  newdata_df <- as.data.frame(selected_instances_x[i, , drop = FALSE])
  
  # Calculate SHAP values for this instance
  shap_vals <- explain(
    rf_category,
    X = train_model1_x,
    newdata = newdata_df,
    pred_wrapper = pred_wrapper,
    nsim = 100
  )
  
  shap_values_list[[category_name]] <- shap_vals
  
  # Plot SHAP values for this instance
  # Get the predicted class probabilities
  pred_probs <- predict(rf_category, newdata_df)$predictions
  pred_class <- colnames(pred_probs)[which.max(pred_probs)]
  
  # Handle SHAP values - fastshap returns a matrix/data.frame
  # For multi-class, extract SHAP values for the predicted class
  if (is.data.frame(shap_vals) || is.matrix(shap_vals)) {
    # If it's a matrix/data.frame, check if it has class columns
    if (pred_class %in% colnames(shap_vals)) {
      shap_df <- data.frame(
        feature = rownames(shap_vals),
        shap_value = shap_vals[[pred_class]]
      )
    } else if (ncol(shap_vals) == length(feature_cols_model1)) {
      # If columns are features, use the first (or only) row
      shap_df <- data.frame(
        feature = colnames(shap_vals),
        shap_value = as.numeric(shap_vals[1, ])
      )
    } else {
      # Try to extract first column or first row
      if (nrow(shap_vals) == 1) {
        shap_df <- data.frame(
          feature = colnames(shap_vals),
          shap_value = as.numeric(shap_vals[1, ])
        )
      } else {
        shap_df <- data.frame(
          feature = rownames(shap_vals),
          shap_value = as.numeric(shap_vals[, 1])
        )
      }
    }
    
    shap_df <- shap_df %>%
      arrange(desc(abs(shap_value))) %>%
      head(20)
    
  
    # Determine the order of magnitude (exponent) for scientific notation
    max_abs_value <- max(abs(shap_df$shap_value))
    if (max_abs_value > 0) {
      exponent <- floor(log10(max_abs_value))
      # Round to nearest multiple of 3 for cleaner display
      exponent <- round(exponent / 3) * 3
    } else {
      exponent <- 0
    }
    
    # Create custom label function that shows only significant digits
    # and scales by the exponent
    scale_factor <- 10^(-exponent)
    label_func <- function(x) {
      scaled <- x * scale_factor
      # Format with appropriate decimal places
      if (abs(exponent) >= 3) {
        sprintf("%.3f", scaled)
      } else {
        sprintf("%.4f", scaled)
      }
    }
    
    # Create visualization
    y_axis_label <- if (abs(exponent) >= 3) {
      sprintf("SHAP Value (×10^%d)", exponent)
    } else {
      "SHAP Value"
    }
    
    p_shap <- ggplot(shap_df, aes(x = reorder(feature, shap_value), y = shap_value)) +
      geom_col(aes(fill = shap_value > 0)) +
      scale_fill_manual(
        values = c("TRUE" = "#2E8B57", "FALSE" = "#DC143C"),
        labels = c("TRUE" = "Positive", "FALSE" = "Negative"),
        name = "SHAP Value"
      ) +
      scale_y_continuous(labels = label_func) +
      coord_flip() +
      labs(
        title = sprintf("SHAP Values - %s Instance", category_name),
        subtitle = sprintf("Predicted class: %s", pred_class),
        x = "Feature",
        y = y_axis_label
      ) +
      theme_minimal() +
      theme(
        plot.title = element_text(size = 14, face = "bold"),
        plot.subtitle = element_text(size = 12),
        axis.text.y = element_text(size = 8)
      )
    
    print(p_shap)
  } else {
    cat(sprintf("SHAP values format not recognized for %s instance\n", category_name))
  }
}