## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 7, 
  fig.height = 4
)

## ----setup, message=FALSE-----------------------------------------------------
library(CLRtools)
library(dplyr)
library(dagitty)
library(ggdag)
library(ggplot2)
library(dplyr)
library(rstan)

## ----echo=FALSE---------------------------------------------------------------
dag <- dagitty("dag {
  age -> weight
  age -> height
  age -> smoke
  age -> priorfrac
  age -> premeno
  age -> raterisk

  height -> weight
  weight -> bmi
  height -> bmi

  bmi -> armassist
  bmi -> fracture
  priorfrac -> fracture
  premeno -> fracture
  smoke -> fracture
  momfrac -> fracture
  armassist -> fracture
  raterisk -> fracture
}")

ggdag(dag, layout = "sugiyama", text = FALSE, node_size = 8) + 
  geom_dag_point(aes(color = (name == "fracture"))) +
  geom_dag_text(aes(label = name), size = 5, fontface = "bold", color = "darkblue") +
  scale_color_manual(values = c("TRUE" = "#E31A1C", "FALSE" = "lightgray")) +
  theme_dag() +
  theme(legend.position = "none", plot.title = element_text(face = "bold", size = 16)) +
  labs(title = "Causal Diagram for Fracture Risk")


## -----------------------------------------------------------------------------
# Load the GLOW500 dataset
data("glow500")

# Convert "Yes"/"No" variables to numeric (1 for "Yes", 0 for "No")
glow500[] <- lapply(glow500, function(x) {
  if (all(x %in% c("Yes", "No"))) {
    return(as.numeric(x == "Yes"))
  } else {
    return(x)
  }
})

# Standardize continuous variables to mean zero and standard deviation one
glow500[, c('age','weight','height', 'bmi')] <- scale(glow500[, c('age','weight','height', 'bmi')])

# Convert data frame subset to a named list for Stan
glow500_subset <- glow500[ , c('age','weight','height','bmi','priorfrac','premeno','momfrac','armassist','smoke','raterisk', 'fracture')]
glow500_list <- as.list(glow500_subset)

# Add total number of observations as 'N'
glow500_list$N <- nrow(glow500_subset)

## -----------------------------------------------------------------------------
# Find adjustment sets for estimating the direct effect of weight on fracture
adjustmentSets(dag, exposure = "weight", outcome = "fracture", effect = "direct")

# Find adjustment sets for estimating the total effect of weight on fracture
adjustmentSets(dag, exposure = "weight", outcome = "fracture")

## ----eval=FALSE---------------------------------------------------------------
# direct_glow_prior <- "
# data {
#   int<lower=0> N;          // Number of observations for simulation
#   vector[N] age;           // Vector of age values
#   vector[N] weight;        // Vector of weight values
# }
# generated quantities {
#   // Mediator model parameters
#   real aBM = normal_rng(0, 0.7);   // Intercept for BMI model
#   real bBM = normal_rng(0, 0.7);   // Slope for weight in BMI model
#   real<lower=0> sigma = exponential_rng(1);  // Standard deviation for BMI residuals
# 
#   // Fracture model parameters
#   real a = normal_rng(0, 0.7);      // Intercept for fracture model
#   real bAGE = normal_rng(0, 0.7);   // Coefficient for age
#   real bBMI = normal_rng(0, 0.7);   // Coefficient for BMI
#   real bWEIGHT = normal_rng(0, 0.7);// Coefficient for weight
# 
#   vector[N] mu_BMI;    // Predicted BMI values from mediator model
#   vector[N] p;         // Predicted fracture probabilities from outcome model
# 
#   // Loop over all observations to generate predicted BMI and fracture probabilities
#   for (i in 1:N) {
#     mu_BMI[i] = aBM + bBM * weight[i];   // BMI predicted by weight
#     p[i] = inv_logit(a + bAGE * age[i] + bBMI * mu_BMI[i] + bWEIGHT * weight[i]);  // Fracture probability
#   }
# }
# "
# 
# N <- 1000
# weight_seq <- seq(min(glow500_list$weight), max(glow500_list$weight), length.out = N)  # Create sequence of weight values
# age_seq <- seq(min(glow500_list$age), max(glow500_list$age), length.out = N)          # Create sequence of age values
# 
# # Run Stan prior predictive simulation with fixed parameters (no data fitting)
# direct_prior <- stan(
#   model_code = direct_glow_prior,
#   data = list(N = N, weight = weight_seq, age = age_seq),
#   chains = 4,
#   iter = 2000,
#   seed = 1000,
#   algorithm = 'Fixed_param'
# )
# 
# ## Extract parameters
# M_mu_BMI <- extract(direct_prior)$mu_BMI[1:200, ]
# p_sim_direct <- extract(direct_prior)$p[1:200, ]
# 
# # Save the fitted object for CRAN vignette
# saveRDS(list(mu_BMI = M_mu_BMI, p = p_sim_direct), file = "direct_prior_p.rds")

## -----------------------------------------------------------------------------
N <- 1000
weight_seq <- seq(min(glow500_list$weight), max(glow500_list$weight), length.out = N)  # Create sequence of weight values

# Load model data
direct_prior <- readRDS(system.file("extdata", "direct_prior_p.rds", package = "CLRtools"))
M_mu_BMI <- direct_prior$mu_BMI
p_sim_direct <- direct_prior$p

### First plot
plot(NULL, xlim=range(weight_seq), ylim=c(-6, 6), xlab = "Weight", ylab = "mu_BMI")

# Plot prior predictive curves for the first 200 draws
for (i in 1:200) {
  lines(weight_seq, M_mu_BMI[i, ], col = alpha("black", 0.1))
}

### Second plot
plot(NULL, xlim=range(weight_seq), ylim=c(0, 1), xlab = "Weight", ylab = "Probability")

# Plot prior predictive curves for the first N draws
for (i in 1:200) {
  lines(weight_seq, p_sim_direct[i, ], col = alpha("black", 0.1))
}

## -----------------------------------------------------------------------------
direct_glow <- "
data {
  int<lower=0> N;                // Number of observations
  vector[N] age;                 // Age variable
  vector[N] bmi;                 // Body Mass Index (mediator)
  vector[N] weight;              // Weight (exposure)
  int<lower=0,upper=1> fracture[N];  // Fracture outcome (binary)
}

parameters {
// Parameters for outcome model (fracture as outcome)
  real a;            // Intercept for outcome model
  real bWEIGHT;      // Effect of weight on fracture
  real bAGE;         // Effect of age on fracture
  real bBMI;         // Effect of BMI on fracture
  
  // Parameters for mediator model (BMI as outcome)
  real aBM;          // Intercept for mediator model
  real bBM;          // Effect of weight on BMI
  real<lower=0> sigma;  // Standard deviation of BMI residuals
}

model {
  vector[N] p;        // Probability of fracture for each observation
  vector[N] mu_bmi;   // Predicted BMI mean for each observation
  
  // Priors for mediator model parameters
  aBM ~ normal(0, 0.7);
  bBM ~ normal(0, 0.7);
  sigma ~ exponential(1);
  
  // Calculate predicted BMI means for all observations based on weight
  for (i in 1:N) {
    mu_bmi[i] = aBM + bBM * weight[i];
  }
  
  // Likelihood for mediator: observed BMI assumed normal around predicted mean
  bmi ~ normal(mu_bmi, sigma);
  
  // Priors for outcome model parameters
  a ~ normal(0, 0.7);
  bAGE ~ normal(0, 0.7);
  bBMI ~ normal(0, 0.7);
  bWEIGHT ~ normal(0, 0.7);
  
  // Calculate fracture probabilities for all observations
  for (i in 1:N) {
    p[i] = inv_logit(a + bAGE * age[i] + bBMI * bmi[i] + bWEIGHT * weight[i]);
  }
  
  // Likelihood for outcome: fracture modeled as Bernoulli with probability p
  fracture ~ binomial(1, p);
}

generated quantities {
  vector[N] mu_BMI;      // Predicted BMI means for posterior predictive checks
  vector[N] p;           // Predicted fracture probabilities
  vector[N] log_lik;     // Log likelihood for each observation (for model comparison)
  int fracture_pred[N];  // Simulated fracture outcomes from posterior predictive distribution
  
  for (i in 1:N) {
    mu_BMI[i] = aBM + bBM * weight[i];
    p[i] = inv_logit(a + bAGE * age[i] + bBMI * bmi[i] + bWEIGHT * weight[i]);
    log_lik[i] = binomial_lpmf(fracture[i] | 1, p[i]);
    
    // Simulate fracture outcome based on predicted probability
    fracture_pred[i] = bernoulli_rng(p[i]);
  }
}
"
# Fit the Bayesian model for the direct effect using Stan
fit_direct_glow <- stan(
  model_code = direct_glow,
  data = glow500_list,
  chains = 4,
  iter = 2000,
  warmup = 1000,
  seed = 1000
)

## -----------------------------------------------------------------------------
# Extracting posterior predictive samples
O_rep <- rstan::extract(fit_direct_glow)$fracture_pred

summarize_results(
  model = fit_direct_glow, 
  ypredict = O_rep, 
  data = glow500_subset, 
  outcome = 'fracture', 
  var.param = c('a' = 'a', 'age' = 'bAGE', 'bmi' = 'bBMI', 'weight' = 'bWEIGHT', 'aBM' = 'aBM', 'bBM' = 'bBM'), 
  rounding = 6, 
  prob = 0.89, 
  point.est = 'median'
)


## -----------------------------------------------------------------------------
diagnostic_bayes(
  model = fit_direct_glow, 
  var.param = c('a', 'bAGE', 'bBMI', 'bWEIGHT', 'aBM', 'bBM')
)

## ----eval=FALSE---------------------------------------------------------------
# total_glow_prior <- "
# data {
#   int<lower=0> N;          // Number of observations for simulation
#   vector[N] age;           // Vector of age values
#   vector[N] weight;        // Vector of weight values
#   vector[N] height;        // Vector of height values
# }
# generated quantities {
#   // Fracture model parameters
#   real a = normal_rng(0, 0.7);        // Intercept for fracture model
#   real bAGE = normal_rng(0, 0.7);     // Coefficient for age
#   real bHEIGHT = normal_rng(0, 0.7);  // Coefficient for height
#   real bWEIGHT = normal_rng(0, 0.7);  // Coefficient for weight
# 
#   vector[N] p;
# 
#   // Loop over all observations to generate fracture probabilities
#   for (i in 1:N) {
#     p[i] = inv_logit(a + bAGE * age[i] + bHEIGHT * height[i] + bWEIGHT * weight[i]);
# 
#   }
# }
# "
# N <- 1000
# height_seq <- seq(min(glow500_list$height), max(glow500_list$height), length.out = N)  # Create sequence of height values
# weight_seq <- seq(min(glow500_list$weight), max(glow500_list$weight), length.out = N)  # Create sequence of weight values
# age_seq <- seq(min(glow500_list$age), max(glow500_list$age), length.out = N)           # Create sequence of age values
# 
# # Run Stan prior predictive simulation with fixed parameters (no data fitting)
# total_prior <- stan(
#   model_code = total_glow_prior,
#   data = list(N = N, height = height_seq, weight = weight_seq, age = age_seq),
#   chains = 4,
#   iter = 2000,
#   seed = 1000,
#   algorithm = 'Fixed_param'
# )
# 
# p_sim_total <- extract(total_prior)$p[1:200, ]
# 
# # Save the fitted object for CRAN vignette
# saveRDS(list(p = p_sim_total ), file = "total_prior_p.rds")

## -----------------------------------------------------------------------------
N <- 1000
weight_seq <- seq(min(glow500_list$weight), max(glow500_list$weight), length.out = N)  # Create sequence of weight values

# Load model
total_prior <- readRDS(system.file("extdata", "total_prior_p.rds", package = "CLRtools"))
p_sim_total <- total_prior$p

# Plot prior predictive expected BMI curves against weight
plot(NULL, xlim=range(weight_seq), ylim=c(0,1), xlab = "weight", ylab = "Probability")

for (i in 1:200) {
  lines(weight_seq, p_sim_total[i, ], col = alpha("black", 0.1))
}


## -----------------------------------------------------------------------------
total_glow <- "
data {
  int<lower=0> N;                // Number of observations
  vector[N] age;                 // Age variable
  vector[N] height;              // Height Variable 
  vector[N] weight;              // Weight (exposure)
  int<lower=0,upper=1> fracture[N];  // Fracture outcome (binary)
}
parameters {
// Parameters for outcome model 
 real a;                // Intercept for outcome model
  real bWEIGHT;         // Effect of weight on fracture
  real bAGE;            // Effect of age on fracture
  real bHEIGHT;         // Effect of height on fracture
  
}
model {
  vector[N] p;          // Probability of fracture for each observation
  
  // Priors for outcome model parameters
  a ~ normal(0, 0.7);
  bAGE ~ normal(0, 0.7);
  bWEIGHT ~ normal(0, 0.7);
  bHEIGHT ~ normal(0, 0.7);

  // Calculate fracture probabilities for all observations
  for (i in 1:N) {
    p[i] = inv_logit(a + bAGE * age[i] + bHEIGHT * height[i] + bWEIGHT * weight[i]);
  }
  
  // Likelihood for outcome: fracture modeled as Bernoulli with probability p
  fracture ~ binomial( 1 , p );
}

generated quantities {
  vector[N] p;            // Predicted fracture probabilities
  vector[N] log_lik;      // Log likelihood for each observation (for model comparison)
  int fracture_pred[N];   // Simulated fracture outcomes from posterior predictive distribution

  for (i in 1:N) {
    p[i] = inv_logit(a + bAGE * age[i] + bHEIGHT * height[i] + bWEIGHT * weight[i]);
    log_lik[i] = binomial_lpmf(fracture[i] | 1, p[i]);
    
    // Simulate fracture outcome based on predicted probability
    fracture_pred[i] = bernoulli_rng(p[i]);  
  }
}
"
# Fit the Bayesian model for the total effect using Stan
fit_total_glow <- stan(
  model_code = total_glow,
  data = glow500_list,
  chains = 4,
  iter = 2000,
  warmup = 1000,
  seed = 1000
)

## -----------------------------------------------------------------------------
# Extracting posterior predictive samples
O_rep <- rstan::extract(fit_total_glow)$fracture_pred

summarize_results(
  model = fit_total_glow, 
  ypredict = O_rep, 
  data= glow500_subset, 
  outcome = 'fracture', 
  var.param = c('a'='a', 'age' = "bAGE", 'height'='bHEIGHT', 'weight'='bWEIGHT'), 
  rounding = 6, 
  prob = 0.89, 
  point.est = 'median')

## -----------------------------------------------------------------------------
diagnostic_bayes(
  model = fit_total_glow, 
  var.param = c( 'a', "bAGE", 'bWEIGHT', 'bHEIGHT'))

