## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.align = "center",
  message = FALSE,
  warning = FALSE,
  out.width = "100%"
)

# Set seed for reproducibility
set.seed(42)

## ----eval=FALSE---------------------------------------------------------------
# # Install from GitHub
# devtools::install_github("Zaoqu-Liu/SVG")
# 
# # Install all optional dependencies
# install.packages(c("geometry", "RANN", "BRISC", "CompQuadForm",
#                    "spatstat.geom", "spatstat.explore"))

## ----load-package-------------------------------------------------------------
# Load the SVG package
library(SVG)

# Load the built-in example dataset
data(example_svg_data)

# Display dataset structure
cat("Dataset components:\n")
names(example_svg_data)

## ----data-summary-------------------------------------------------------------
# Extract components
expr_counts <- example_svg_data$counts        # Raw counts
expr_log <- example_svg_data$logcounts        # Log-normalized
coords <- example_svg_data$spatial_coords     # Spatial coordinates
gene_info <- example_svg_data$gene_info       # Gene metadata

cat("Data Dimensions:\n")
cat("  Number of spots:", ncol(expr_counts), "\n")
cat("  Number of genes:", nrow(expr_counts), "\n")
cat("  True SVGs:", sum(gene_info$is_svg), "\n")
cat("  Non-SVGs:", sum(!gene_info$is_svg), "\n")
cat("\nSpatial Pattern Types:\n")
print(table(gene_info$pattern_type))

## ----spatial-layout, fig.width=7, fig.height=6--------------------------------
oldpar <- par(mar = c(4, 4, 3, 2))
plot(coords[, 1], coords[, 2],
     pch = 19, cex = 0.9,
     col = adjustcolor("steelblue", alpha.f = 0.7),
     xlab = "X coordinate (μm)",
     ylab = "Y coordinate (μm)",
     main = "Spatial Spot Layout (Hexagonal Grid)",
     asp = 1)
grid(lty = 2, col = "gray80")

# Add spot count annotation
mtext(paste("n =", nrow(coords), "spots"), side = 3, line = 0.3, cex = 0.9)
par(oldpar)

## ----expr-distribution, fig.width=10, fig.height=4----------------------------
oldpar <- par(mfrow = c(1, 2), mar = c(4, 4, 3, 1))

# Raw counts distribution
hist(log10(rowMeans(expr_counts) + 1), breaks = 30,
     col = "lightblue", border = "white",
     xlab = expression(log[10](Mean~Count + 1)),
     ylab = "Number of Genes",
     main = "Gene Expression Distribution")

# Spot library size
hist(log10(colSums(expr_counts)), breaks = 30,
     col = "lightgreen", border = "white",
     xlab = expression(log[10](Library~Size)),
     ylab = "Number of Spots",
     main = "Spot Library Size Distribution")
par(oldpar)

## ----pattern-visualization, fig.width=12, fig.height=10-----------------------
# Define color palette for expression visualization
expr_colors <- function(x, pal = "RdYlBu") {
  x_scaled <- (x - min(x, na.rm = TRUE)) / (max(x, na.rm = TRUE) - min(x, na.rm = TRUE) + 1e-10)
  if (pal == "RdYlBu") {
    colors <- colorRampPalette(c("#313695", "#4575B4", "#74ADD1", "#ABD9E9", 
                                  "#E0F3F8", "#FFFFBF", "#FEE090", "#FDAE61", 
                                  "#F46D43", "#D73027", "#A50026"))(100)
  } else {
    colors <- colorRampPalette(c("navy", "white", "firebrick3"))(100)
  }
  colors[pmax(1, ceiling(x_scaled * 99) + 1)]
}

# Get example genes for each pattern type
pattern_types <- c("gradient", "hotspot", "periodic", "cluster")
pattern_genes <- sapply(pattern_types, function(pt) {
  which(gene_info$pattern_type == pt)[1]
})

# Get non-SVG example
non_svg_gene <- which(!gene_info$is_svg)[1]

oldpar <- par(mfrow = c(2, 3), mar = c(4, 4, 3, 1))

# Plot each pattern type
for (i in seq_along(pattern_types)) {
  gene_idx <- pattern_genes[i]
  gene_name <- rownames(expr_log)[gene_idx]
  gene_expr <- expr_log[gene_idx, ]
  
  plot(coords[, 1], coords[, 2],
       pch = 19, cex = 1.3,
       col = expr_colors(gene_expr),
       xlab = "X", ylab = "Y",
       main = paste0(gene_name, "\n(", pattern_types[i], " pattern)"),
       asp = 1)
}

# Plot non-SVG
gene_expr <- expr_log[non_svg_gene, ]
plot(coords[, 1], coords[, 2],
     pch = 19, cex = 1.3,
     col = expr_colors(gene_expr),
     xlab = "X", ylab = "Y",
     main = paste0(rownames(expr_log)[non_svg_gene], "\n(non-SVG, random)"),
     asp = 1)

# Add color legend
plot.new()
par(mar = c(1, 1, 1, 1))
legend_colors <- colorRampPalette(c("#313695", "#FFFFBF", "#A50026"))(100)
legend_image <- as.raster(matrix(rev(legend_colors), ncol = 1))
plot(c(0, 1), c(0, 1), type = "n", axes = FALSE, xlab = "", ylab = "")
text(0.5, 0.95, "Expression Level", cex = 1.2, font = 2)
rasterImage(legend_image, 0.3, 0.1, 0.7, 0.85)
text(0.75, 0.1, "Low", cex = 0.9)
text(0.75, 0.85, "High", cex = 0.9)
par(oldpar)

## ----meringue-run-------------------------------------------------------------
# Run MERINGUE with KNN network
results_meringue <- CalSVG_MERINGUE(
  expr_matrix = expr_log,
  spatial_coords = coords,
  network_method = "knn",       # Network construction method
  k = 10,                       # Number of neighbors
  alternative = "greater",      # Test for positive autocorrelation
  adjust_method = "BH",         # Benjamini-Hochberg correction
  verbose = FALSE
)

# Display top SVGs
cat("Top 10 SVGs by MERINGUE:\n")
head(results_meringue[, c("gene", "observed", "expected", "z_score", "p.value", "p.adj")], 10)

## ----meringue-viz, fig.width=10, fig.height=4---------------------------------
oldpar <- par(mfrow = c(1, 2), mar = c(4, 4, 3, 1))

# Moran's I distribution
hist(results_meringue$observed, breaks = 40,
     col = "steelblue", border = "white",
     xlab = "Moran's I",
     ylab = "Number of Genes",
     main = "Distribution of Moran's I Statistics")
abline(v = results_meringue$expected[1], col = "red", lwd = 2, lty = 2)
legend("topright", legend = "E[I] under null", col = "red", lty = 2, lwd = 2)

# P-value distribution
hist(results_meringue$p.value, breaks = 40,
     col = "coral", border = "white",
     xlab = "P-value",
     ylab = "Number of Genes",
     main = "P-value Distribution")
abline(v = 0.05, col = "darkred", lwd = 2, lty = 2)
par(oldpar)

## ----binspect-run-------------------------------------------------------------
# Run binSpect with k-means binarization
results_binspect <- CalSVG_binSpect(
  expr_matrix = expr_log,
  spatial_coords = coords,
  bin_method = "kmeans",        # Binarization method
  network_method = "delaunay",  # Network construction
  do_fisher_test = TRUE,        # Perform Fisher's test
  adjust_method = "fdr",        # FDR correction
  verbose = FALSE
)

# Display top SVGs
cat("Top 10 SVGs by binSpect:\n")
head(results_binspect[, c("gene", "estimate", "p.value", "p.adj", "score")], 10)

## ----binspect-viz, fig.width=10, fig.height=4---------------------------------
oldpar <- par(mfrow = c(1, 2), mar = c(4, 4, 3, 1))

# Odds ratio distribution
or_log <- log2(results_binspect$estimate + 0.01)
hist(or_log, breaks = 40,
     col = "mediumpurple", border = "white",
     xlab = expression(log[2](Odds~Ratio)),
     ylab = "Number of Genes",
     main = "Distribution of Odds Ratios")
abline(v = 0, col = "red", lwd = 2, lty = 2)

# Score distribution (combines p-value and OR)
hist(results_binspect$score, breaks = 40,
     col = "darkorange", border = "white",
     xlab = "binSpect Score",
     ylab = "Number of Genes",
     main = "Distribution of binSpect Scores")
par(oldpar)

## ----sparkx-run---------------------------------------------------------------
# Run SPARK-X with mixture of kernels
# Note: Uses raw counts (not log-transformed)
results_sparkx <- CalSVG_SPARKX(
  expr_matrix = expr_counts,    # Raw counts recommended
  spatial_coords = coords,
  kernel_option = "mixture",    # All 11 kernels
  adjust_method = "BY",         # Benjamini-Yekutieli (conservative)
  verbose = FALSE
)

# Display top SVGs
cat("Top 10 SVGs by SPARK-X:\n")
head(results_sparkx[, c("gene", "p.value", "p.adj")], 10)

## ----sparkx-viz, fig.width=10, fig.height=4-----------------------------------
oldpar <- par(mfrow = c(1, 2), mar = c(4, 4, 3, 1))

# Combined p-value distribution
hist(-log10(results_sparkx$p.value + 1e-300), breaks = 40,
     col = "seagreen", border = "white",
     xlab = expression(-log[10](p-value)),
     ylab = "Number of Genes",
     main = "SPARK-X P-value Distribution")

# Volcano-style plot
pval_log <- -log10(results_sparkx$p.adj + 1e-300)
plot(seq_along(pval_log), pval_log,
     pch = 19, cex = 0.6,
     col = ifelse(pval_log > -log10(0.05), "red", "gray50"),
     xlab = "Gene Index",
     ylab = expression(-log[10](adjusted~p-value)),
     main = "SPARK-X Significance Plot")
abline(h = -log10(0.05), col = "red", lty = 2)
par(oldpar)

## ----seurat-run---------------------------------------------------------------
# Run Seurat Moran's I
results_seurat <- CalSVG_Seurat(
  expr_matrix = expr_log,
  spatial_coords = coords,
  weight_scheme = "inverse_squared",  # Seurat default
  adjust_method = "BH",
  verbose = FALSE
)

# Display top SVGs
cat("Top 10 SVGs by Seurat:\n")
head(results_seurat[, c("gene", "observed", "expected", "sd", "p.value", "p.adj")], 10)

## ----calsvg-unified-----------------------------------------------------------
# Example: Run MERINGUE through unified interface
results_unified <- CalSVG(
  expr_matrix = expr_log,
  spatial_coords = coords,
  method = "meringue",       # Options: meringue, binspect, sparkx, seurat, nnsvg, markvario
  network_method = "knn",
  k = 10,
  verbose = FALSE
)

cat("Unified interface results:\n")
head(results_unified[, c("gene", "p.value", "p.adj")], 5)

## ----metrics-calculation------------------------------------------------------
# Ground truth
truth <- gene_info$is_svg

# Function to calculate comprehensive metrics
calc_performance <- function(result, truth, pval_col = "p.adj", threshold = 0.05) {
  detected <- result[[pval_col]] < threshold
  detected[is.na(detected)] <- FALSE
  
  TP <- sum(truth & detected)
  FP <- sum(!truth & detected)
  FN <- sum(truth & !detected)
  TN <- sum(!truth & !detected)
  
  list(
    TP = TP, FP = FP, FN = FN, TN = TN,
    Sensitivity = TP / (TP + FN),
    Specificity = TN / (TN + FP),
    Precision = TP / max(TP + FP, 1),
    NPV = TN / max(TN + FN, 1),
    F1 = 2 * TP / max(2 * TP + FP + FN, 1),
    Accuracy = (TP + TN) / (TP + TN + FP + FN),
    MCC = (TP * TN - FP * FN) / sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN) + 1e-10)
  )
}

# Calculate metrics for each method
metrics_list <- list(
  MERINGUE = calc_performance(results_meringue, truth),
  binSpect = calc_performance(results_binspect, truth),
  `SPARK-X` = calc_performance(results_sparkx, truth),
  Seurat = calc_performance(results_seurat, truth)
)

# Create metrics data frame
metrics_df <- do.call(rbind, lapply(names(metrics_list), function(m) {
  data.frame(
    Method = m,
    Sensitivity = metrics_list[[m]]$Sensitivity,
    Specificity = metrics_list[[m]]$Specificity,
    Precision = metrics_list[[m]]$Precision,
    F1 = metrics_list[[m]]$F1,
    MCC = metrics_list[[m]]$MCC
  )
}))

knitr::kable(metrics_df, digits = 3, 
             caption = "Performance Comparison on Simulated Data (FDR < 0.05)")

## ----performance-heatmap, fig.width=9, fig.height=5---------------------------
# Prepare data for heatmap
metrics_matrix <- as.matrix(metrics_df[, -1])
rownames(metrics_matrix) <- metrics_df$Method

# Create heatmap
oldpar <- par(mar = c(5, 8, 4, 6))
image(t(metrics_matrix), axes = FALSE,
      col = colorRampPalette(c("#440154", "#31688E", "#35B779", "#FDE725"))(100),
      main = "Performance Metrics Heatmap")

axis(1, at = seq(0, 1, length.out = ncol(metrics_matrix)),
     labels = colnames(metrics_matrix), las = 2, cex.axis = 0.9)
axis(2, at = seq(0, 1, length.out = nrow(metrics_matrix)),
     labels = rownames(metrics_matrix), las = 1, cex.axis = 0.9)

# Add values
for (i in 1:nrow(metrics_matrix)) {
  for (j in 1:ncol(metrics_matrix)) {
    text((j - 1) / (ncol(metrics_matrix) - 1),
         (i - 1) / (nrow(metrics_matrix) - 1),
         sprintf("%.2f", metrics_matrix[i, j]),
         cex = 0.9, col = ifelse(metrics_matrix[i, j] > 0.6, "white", "black"))
  }
}
par(oldpar)

## ----roc-analysis, fig.width=8, fig.height=7----------------------------------
# Function to compute ROC curve
compute_roc <- function(scores, truth, higher_is_better = FALSE) {
  if (higher_is_better) {
    scores <- -scores
  }
  scores[is.na(scores)] <- max(scores, na.rm = TRUE) + 1
  
  thresholds <- sort(unique(c(-Inf, scores, Inf)))
  
  tpr <- fpr <- numeric(length(thresholds))
  for (i in seq_along(thresholds)) {
    predicted <- scores <= thresholds[i]
    tpr[i] <- sum(truth & predicted) / sum(truth)
    fpr[i] <- sum(!truth & predicted) / sum(!truth)
  }
  
  # Calculate AUC using trapezoidal rule
  ord <- order(fpr, tpr)
  fpr <- fpr[ord]
  tpr <- tpr[ord]
  auc <- sum(diff(fpr) * (head(tpr, -1) + tail(tpr, -1)) / 2)
  
  list(fpr = fpr, tpr = tpr, auc = auc)
}

# Compute ROC for each method
roc_meringue <- compute_roc(results_meringue$p.value, truth)
roc_binspect <- compute_roc(results_binspect$p.value, truth)
roc_sparkx <- compute_roc(results_sparkx$p.value, truth)
roc_seurat <- compute_roc(results_seurat$p.value, truth)

# Plot ROC curves
oldpar <- par(mar = c(5, 5, 4, 2))
plot(roc_meringue$fpr, roc_meringue$tpr, type = "l", lwd = 3, col = "#E41A1C",
     xlab = "False Positive Rate (1 - Specificity)",
     ylab = "True Positive Rate (Sensitivity)",
     main = "Receiver Operating Characteristic (ROC) Curves",
     xlim = c(0, 1), ylim = c(0, 1),
     cex.lab = 1.2, cex.main = 1.3)
lines(roc_binspect$fpr, roc_binspect$tpr, lwd = 3, col = "#377EB8")
lines(roc_sparkx$fpr, roc_sparkx$tpr, lwd = 3, col = "#4DAF4A")
lines(roc_seurat$fpr, roc_seurat$tpr, lwd = 3, col = "#984EA3")
abline(0, 1, lty = 2, col = "gray50", lwd = 2)

# Add AUC values to legend
legend("bottomright", 
       legend = c(
         sprintf("MERINGUE (AUC = %.3f)", roc_meringue$auc),
         sprintf("binSpect (AUC = %.3f)", roc_binspect$auc),
         sprintf("SPARK-X (AUC = %.3f)", roc_sparkx$auc),
         sprintf("Seurat (AUC = %.3f)", roc_seurat$auc)
       ),
       col = c("#E41A1C", "#377EB8", "#4DAF4A", "#984EA3"),
       lwd = 3, cex = 1.0, bty = "n")
par(oldpar)

## ----overlap-analysis, fig.width=9, fig.height=5------------------------------
# Get significant genes for each method
sig_genes <- list(
  MERINGUE = results_meringue$gene[results_meringue$p.adj < 0.05],
  binSpect = results_binspect$gene[results_binspect$p.adj < 0.05],
  SPARKX = results_sparkx$gene[results_sparkx$p.adj < 0.05],
  Seurat = results_seurat$gene[results_seurat$p.adj < 0.05]
)

# Calculate pairwise Jaccard indices
jaccard <- function(a, b) {
  length(intersect(a, b)) / length(union(a, b))
}

methods <- names(sig_genes)
jaccard_matrix <- matrix(0, length(methods), length(methods))
rownames(jaccard_matrix) <- colnames(jaccard_matrix) <- methods

for (i in seq_along(methods)) {
  for (j in seq_along(methods)) {
    jaccard_matrix[i, j] <- jaccard(sig_genes[[i]], sig_genes[[j]])
  }
}

# Visualize overlap
oldpar <- par(mfrow = c(1, 2), mar = c(5, 6, 4, 4))

# Jaccard similarity heatmap
image(jaccard_matrix, axes = FALSE,
      col = colorRampPalette(c("white", "steelblue", "darkblue"))(100),
      main = "Pairwise Jaccard Similarity")
axis(1, at = seq(0, 1, length.out = 4), labels = methods, las = 2, cex.axis = 0.9)
axis(2, at = seq(0, 1, length.out = 4), labels = methods, las = 1, cex.axis = 0.9)

for (i in 1:4) {
  for (j in 1:4) {
    text((j - 1) / 3, (i - 1) / 3, 
         sprintf("%.2f", jaccard_matrix[i, j]),
         cex = 1.0, col = ifelse(jaccard_matrix[i, j] > 0.5, "white", "black"))
  }
}

# Number of significant genes
barplot(sapply(sig_genes, length),
        col = c("#E41A1C", "#377EB8", "#4DAF4A", "#984EA3"),
        ylab = "Number of Significant Genes",
        main = "Number of SVGs Detected",
        las = 1, border = NA)
abline(h = sum(truth), lty = 2, col = "red", lwd = 2)
legend("topright", legend = "True SVGs (n=50)", lty = 2, col = "red", lwd = 2, bty = "n")
par(oldpar)

## ----custom-simulation--------------------------------------------------------
# Set seed for reproducibility
set.seed(2024)

# Generate custom dataset with specific parameters
sim_data <- simulate_spatial_data(
  n_spots = 400,              # Number of spatial locations
  n_genes = 150,              # Total genes
  n_svg = 30,                 # Number of true SVGs
  grid_type = "hexagonal",    # Spatial arrangement
  pattern_types = c("gradient", "hotspot", "periodic"),  # Pattern types to include
  mean_counts = 150,          # Mean expression level
  dispersion = 2.5            # Negative binomial dispersion
)

cat("Custom Simulation Summary:\n")
cat("  Spots:", ncol(sim_data$counts), "\n")
cat("  Genes:", nrow(sim_data$counts), "\n")
cat("  True SVGs:", sum(sim_data$gene_info$is_svg), "\n")
cat("\nPattern distribution:\n")
print(table(sim_data$gene_info$pattern_type))

## ----parallel-demo------------------------------------------------------------
# Demonstrate parallel processing
# Note: mclapply doesn't work on Windows; falls back to sequential

n_cores <- 2  # Adjust based on your system

t_start <- Sys.time()
results_parallel <- CalSVG_MERINGUE(
  expr_matrix = expr_log,
  spatial_coords = coords,
  n_threads = n_cores,
  verbose = FALSE
)
t_end <- Sys.time()

cat(sprintf("Parallel execution with %d cores: %.2f seconds\n", 
            n_cores, as.numeric(t_end - t_start, units = "secs")))

## ----gene-filtering-----------------------------------------------------------
# Pre-filter genes to improve signal-to-noise and reduce computation

# Strategy 1: Expression threshold
gene_means <- rowMeans(expr_log)
expr_high <- expr_log[gene_means > quantile(gene_means, 0.1), ]
cat("After expression filter:", nrow(expr_high), "genes\n")

# Strategy 2: Variance filter
gene_vars <- apply(expr_high, 1, var)
expr_filtered <- expr_high[gene_vars > quantile(gene_vars, 0.25), ]
cat("After variance filter:", nrow(expr_filtered), "genes\n")

# Strategy 3: Coefficient of variation
cv <- apply(expr_high, 1, sd) / (rowMeans(expr_high) + 0.1)
expr_cv <- expr_high[cv > quantile(cv, 0.25), ]
cat("After CV filter:", nrow(expr_cv), "genes\n")

## ----session-info-------------------------------------------------------------
sessionInfo()

