## ----setup, include = FALSE---------------------------------------------------
#file.edit(normalizePath("~/.Renviron"))
LOCAL <- identical(Sys.getenv("LOCAL"), "TRUE")
#LOCAL=TRUE
knitr::opts_chunk$set(purl = LOCAL)
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----cache=TRUE, eval=LOCAL---------------------------------------------------
library(gamlss)
library(SelectBoost.gamlss)

set.seed(2025)
n <- 5000  # larger n makes the differences clearer

families <- list(
  NO = gamlss.dist::NO(),
  PO = gamlss.dist::PO(),
  LOGNO = gamlss.dist::LOGNO(),
  GA = gamlss.dist::GA(),
  IG = gamlss.dist::IG(),
  LO = gamlss.dist::LO(),
  LOGITNO = gamlss.dist::LOGITNO(),
  GEOM = gamlss.dist::GEOM(),
  BE = gamlss.dist::BE()
)

gen_fun <- list(
  NO = function(n) gamlss.dist::rNO(n, mu = 0, sigma = 1),
  PO = function(n) gamlss.dist::rPO(n, mu = 2.5),
  LOGNO = function(n) gamlss.dist::rLOGNO(n, mu = 0, sigma = 0.6),
  GA = function(n) gamlss.dist::rGA(n, mu = 2, sigma = 0.5),
  IG = function(n) gamlss.dist::rIG(n, mu = 2, sigma = 0.5),
  LO = function(n) gamlss.dist::rLO(n, mu = 0, sigma = 1),
  LOGITNO = function(n) gamlss.dist::rLOGITNO(n, mu = 0, sigma = 1),
  GEOM = function(n) gamlss.dist::rGEOM(n, mu = 3),
  BE = function(n) gamlss.dist::rBE(n, mu = 0.4, sigma = 0.2)
)

bench_one <- function(fname) {
  fam <- families[[fname]]
  gen <- gen_fun[[fname]]
  if (is.null(fam) || is.null(gen)) return(NULL)

  y <- gen(n)
  dat <- data.frame(y = y)

  fit <- try(gamlss::gamlss(y ~ 1, data = dat, family = fam), silent = TRUE)
  if (inherits(fit, "try-error")) return(NULL)

  # ensure predictions are available on newdata (same data is fine)
  fast_vs_generic_ll(fit, newdata = dat, reps = 50)
}

res_list <- lapply(names(families), bench_one)
names(res_list) <- names(families)
res_list <- Filter(Negate(is.null), res_list)

# Present results
res <- do.call(rbind, lapply(names(res_list), function(nm) {
  df <- res_list[[nm]]
  df$family <- nm
  df
}))

res
if (nrow(res)) {
  # Plot median times by family (if microbenchmark unit)
  if (!is.null(attr(res_list[[1]], "unit"))) {
    # simple barplot: generic vs fast for each family
    op <- par(mfrow=c(1,1), mar=c(8,4,2,1))
    fams <- unique(res$family)
    med_fast <- tapply(res$median[res$method=="fast"], res$family[res$method=="fast"], median)
    med_gen  <- tapply(res$median[res$method=="generic"], res$family[res$method=="generic"], median)
    mat <- rbind(med_gen[fams], med_fast[fams])
    barplot(mat, beside=TRUE, las=2, legend.text = c("generic","fast"),
            main="Median time (microbenchmark units)", ylab=attr(res_list[[1]], "unit"))
    par(op)
  }
}

