MINI-PROJECT 3: Customizing models

Author

Delphi Tooling Team

Published

April 22, 2026

Introduction

This notebook covers the third mini-project, focusing on building, customizing, and evaluating forecasting models using epipredict.

Load packages

{InsightNetApr26} package ensures all required Delphi tools and their correct versions/branches are installed.

if (!requireNamespace("pak", quietly = TRUE)) install.packages("pak")
if (!requireNamespace("InsightNetApr26", quietly = TRUE)) {
    pak::pkg_install("cmu-delphi/InsightNet-apr-2026/InsightNetApr26")
}

InsightNetApr26::verify_setup()

library(epidatr)
library(epiprocess)
library(epipredict)
library(dplyr)
library(ggplot2)
library(parsnip)

Initial Setup

We’ll use built-in COVID-19 case and death rates for training.

# covid_case_death_rates: built-in daily COVID-19 case and death rates per 100k
head(covid_case_death_rates)
An `epi_df` object, 6 x 4 with metadata:
* geo_type  = state
* time_type = day
* as_of     = 2023-03-10
Latency (lag from as_of to latest observation by time series):
* lag across all time series = 799 days

# A tibble: 6 × 4
  geo_value time_value case_rate death_rate
  <chr>     <date>         <dbl>      <dbl>
1 ak        2020-12-31      35.9      0.158
2 al        2020-12-31      65.1      0.438
3 ar        2020-12-31      66.0      1.27 
4 as        2020-12-31       0        0    
5 az        2020-12-31      76.8      1.10 
6 ca        2020-12-31      95.9      0.755
forecast_date <- as.Date("2021-08-01")
used_locations <- c("mi", "ny", "tx", "pa")

# Training data: everything up to and including the forecast date
jhu <- covid_case_death_rates |>
    filter(time_value <= forecast_date, geo_value %in% used_locations)

head(jhu)
An `epi_df` object, 6 x 4 with metadata:
* geo_type  = state
* time_type = day
* as_of     = 2023-03-10
Latency (lag from as_of to latest observation by time series):
* lag across all time series = 798–799 days

# A tibble: 6 × 4
  geo_value time_value case_rate death_rate
  <chr>     <date>         <dbl>      <dbl>
1 mi        2020-12-31      27.9      0.850
2 ny        2020-12-31      58.3      0.704
3 pa        2020-12-31      55.9      1.41 
4 tx        2020-12-31      83.0      0.576
5 mi        2021-01-01      27.9      0.850
6 ny        2021-01-01      61.3      0.735
autoplot(jhu, death_rate, case_rate)

Visualization Helper

We define a helper function to create fan charts of the forecasts.

fan_chart <- function(results,
                      window_df = covid_case_death_rates |>
                          filter(
                              time_value > "2021-07-01",
                              geo_value %in% used_locations
                          ),
                      value_name = "death_rate",
                      fcst_date = forecast_date,
                      title = "",
                      subtitle = "") {
    results |>
        pivot_quantiles_wider(.pred_distn) |>
        arrange(geo_value, target_date) |>
        ggplot(aes(x = target_date)) +
        geom_vline(
            xintercept = fcst_date, linetype = "dashed",
            color = "gray50", linewidth = 0.5
        ) +
        geom_line(
            data = window_df,
            aes(x = time_value, y = .data[[value_name]]),
            color = "gray20", linewidth = 0.6, inherit.aes = FALSE
        ) +
        geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`), fill = "dodgerblue4", alpha = 0.10) +
        geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = "dodgerblue4", alpha = 0.15) +
        geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = "dodgerblue4", alpha = 0.20) +
        geom_line(aes(y = `0.5`), color = "orange", linewidth = 0.8) +
        facet_wrap(~geo_value, scales = "free_y") +
        labs(
            title = title, subtitle = subtitle,
            x = "Date", y = "Deaths per 100k (7-day avg)"
        ) +
        theme_bw()
}

Basic Modeling

We start with three baseline forecasters.

ARX Forecaster

Autoregressive model using lagged values as features.

fcst_arx <- arx_forecaster(jhu, outcome = "death_rate")
fcst_arx
fcst_arx$predictions
# A tibble: 4 × 5
  geo_value  .pred .pred_distn forecast_date target_date
  <chr>      <dbl>   <qtls(7)> <date>        <date>     
1 mi        0.0691    [0.0691] 2021-08-01    2021-08-08 
2 ny        0.0529    [0.0529] 2021-08-01    2021-08-08 
3 pa        0.0488    [0.0488] 2021-08-01    2021-08-08 
4 tx        0.141      [0.141] 2021-08-01    2021-08-08 
# Inspect fitted coefficients
# It shows how each lag contributes to the forecast
hardhat::extract_fit_engine(fcst_arx$epi_workflow)

Call:
stats::lm(formula = ..y ~ ., data = data)

Coefficients:
      (Intercept)   lag_0_death_rate   lag_7_death_rate  lag_14_death_rate  
          0.02366            0.85465            0.21921           -0.21633  
# Use pivot_quantiles_wider to expand a quantile distribution column
# into individual columns
pivot_quantiles_wider(fcst_arx$predictions, .pred_distn)
# A tibble: 4 × 11
  geo_value  .pred forecast_date target_date `0.05`  `0.1` `0.25`  `0.5` `0.75`
  <chr>      <dbl> <date>        <date>       <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
1 mi        0.0691 2021-08-01    2021-08-08  0      0      0.0404 0.0691 0.0978
2 ny        0.0529 2021-08-01    2021-08-08  0      0      0.0242 0.0529 0.0817
3 pa        0.0488 2021-08-01    2021-08-08  0      0      0.0201 0.0488 0.0775
4 tx        0.141  2021-08-01    2021-08-08  0.0164 0.0670 0.112  0.141  0.169 
# ℹ 2 more variables: `0.9` <dbl>, `0.95` <dbl>

Flatline Forecaster

Predicts the most recently observed value with increasingly wide quantiles.

fcst_flat <- flatline_forecaster(jhu, outcome = "death_rate")
fcst_flat$predictions
# A tibble: 4 × 5
  geo_value  .pred .pred_distn forecast_date target_date
  <chr>      <dbl>   <qtls(7)> <date>        <date>     
1 mi        0.0573    [0.0573] 2021-08-01    2021-08-08 
2 ny        0.0347    [0.0347] 2021-08-01    2021-08-08 
3 pa        0.0324    [0.0324] 2021-08-01    2021-08-08 
4 tx        0.136      [0.136] 2021-08-01    2021-08-08 

CDC Baseline Forecaster

CDC FluSight flat-line baseline.

fcst_cdc <- cdc_baseline_forecaster(
    jhu,
    outcome = "death_rate",
    args_list = cdc_baseline_args_list(
        aheads = seq(1, 28, 7),
        data_frequency = "1 day"
    )
)
fcst_cdc$predictions |> filter(target_date == "2021-08-09")
# A tibble: 4 × 6
  geo_value  .pred ahead .pred_distn forecast_date target_date
  <chr>      <dbl> <dbl>  <qtls(23)> <date>        <date>     
1 mi        0.0573     8    [0.0573] 2021-08-01    2021-08-09 
2 ny        0.0347     8    [0.0347] 2021-08-01    2021-08-09 
3 pa        0.0324     8    [0.0324] 2021-08-01    2021-08-09 
4 tx        0.136      8     [0.136] 2021-08-01    2021-08-09 
fan_chart(
    fcst_cdc$predictions,
    title = "CDC baseline forecast: 1–28 day horizon",
    subtitle = "Flat median; quantiles widen via repeated sampling | dashed = forecast date"
)

Climatological Forecaster

Predicts the median and quantiles based on historical values around the same date in previous years.

fluview_hhs <- pub_fluview(
    regions = paste0("hhs", 1:10),
    epiweeks = epirange(100001, 222201)
)
fluview <- fluview_hhs |>
    select(
        geo_value = region,
        time_value = epiweek,
        issue,
        ili
    ) |>
    as_epi_archive() |>
    epix_as_of_current()

fcst_climate <- climatological_forecaster(
    fluview |> filter(time_value < "2023-10-08"),
    outcome = "ili",
    args_list = climate_args_list(
        forecast_horizon = seq(0, 28, 7),
        time_type = "week",
        quantile_by_key = "geo_value",
        forecast_date = as.Date("2023-10-08")
    )
)
fcst_climate$predictions
# A tibble: 50 × 5
   geo_value forecast_date target_date .pred .pred_distn
   <chr>     <date>        <date>      <dbl>   <qtls(7)>
 1 hhs1      2023-10-08    2023-10-08  0.658     [0.648]
 2 hhs10     2023-10-08    2023-10-08  1.03       [1.03]
 3 hhs2      2023-10-08    2023-10-08  1.46       [1.45]
 4 hhs3      2023-10-08    2023-10-08  1.37       [1.37]
 5 hhs4      2023-10-08    2023-10-08  1.30       [1.27]
 6 hhs5      2023-10-08    2023-10-08  1.16       [1.14]
 7 hhs6      2023-10-08    2023-10-08  1.90        [1.9]
 8 hhs7      2023-10-08    2023-10-08  0.927     [0.938]
 9 hhs8      2023-10-08    2023-10-08  0.845     [0.834]
10 hhs9      2023-10-08    2023-10-08  1.42        [1.4]
# ℹ 40 more rows
fan_chart(
    fcst_climate$predictions,
    fluview |>
        filter(time_value >= "2023-10-08", time_value < "2024-05-01") |>
        mutate(geo_value = factor(geo_value, levels = paste0("hhs", 1:10))),
    "ili",
    fcst_date = as.Date("2023-10-08")
)

Model Customization

We can customize the training engine and hyperparameters.

# Swap the training engine via the trainer:
# linear_reg() — default
# quantile_reg() — optimizes quantile loss directly

lags <- arx_args_list(lags = c(0, 7, 14), ahead = 28)

fcst_linear <- arx_forecaster(jhu, "death_rate", "death_rate",
    trainer = linear_reg(),
    args_list = lags
)
fcst_qr <- arx_forecaster(jhu, "death_rate", "death_rate",
    trainer = quantile_reg(),
    args_list = lags
)

# Compare engines: 80% interval + median, color-coded by model
bind_rows(
    fcst_linear$predictions |> mutate(model = "Linear"),
    fcst_qr$predictions |> mutate(model = "Quantile reg")
) |>
    pivot_quantiles_wider(.pred_distn) |>
    filter(geo_value %in% used_locations) |>
    ggplot(aes(x = target_date, color = model, fill = model)) +
    geom_vline(
        xintercept = forecast_date, linetype = "dashed",
        color = "gray50", linewidth = 0.5
    ) +
    geom_line(
        data = filter(
            covid_case_death_rates,
            geo_value %in% used_locations,
            time_value > forecast_date - 30,
            time_value <= forecast_date + 30
        ),
        aes(x = time_value, y = death_rate),
        color = "black", linewidth = 0.5, inherit.aes = FALSE
    ) +
    geom_linerange(
        aes(ymin = `0.1`, ymax = `0.9`),
        linewidth = 4, alpha = 0.3,
        position = position_dodge(width = 1.5)
    ) +
    geom_point(aes(y = .pred),
        size = 3,
        position = position_dodge(width = 1.5)
    ) +
    scale_color_brewer(palette = "Set1") +
    scale_fill_brewer(palette = "Set1") +
    facet_wrap(~geo_value, scales = "free_y") +
    labs(
        title = "14-day ahead forecast: model comparison",
        subtitle = "Points = median; bars = 80% prediction interval | dashed = forecast date",
        x = "Date", y = "Deaths per 100k (7-day avg)", color = "Model", fill = "Model"
    ) +
    theme_bw() +
    theme(legend.position = "bottom")

Multi-Signal and Trajectory Forecasting

We can use multiple predictors and forecast over a range of horizons.

# Fit one arx model per horizon (0–28 days ahead), and plot as a fan chart.
all_arx <- lapply(
    seq(0, 28, 7),
    \(h) arx_forecaster(
        jhu,
        outcome = "death_rate",
        predictors = c("case_rate", "death_rate"),
        trainer = quantile_reg(),
        args_list = arx_args_list(
            lags  = list(c(0, 1, 2, 3, 7, 14), c(0, 7, 14)),
            ahead = h
        )
    )
)
results_arx <- lapply(all_arx, \(x) x$predictions) |> bind_rows()
fan_chart(
    results_arx,
    title    = "ARX forecast: 0–28 day horizon",
    subtitle = "Predictors: case rate + death rate | bands = 90% / 80% / 50% | dashed = forecast date"
)

Growth Rate Classification

Finally, we can using arx_classifier to predict categories of growth rates.

classifier <- arx_classifier(
    covid_case_death_rates |>
        filter(geo_value %in% used_locations, time_value < forecast_date),
    outcome = "death_rate",
    predictors = c("death_rate", "case_rate"),
    args_list = arx_class_args_list(
        lags   = list(c(0, 1, 2, 3, 7, 14), c(0, 7, 14)),
        ahead  = 2 * 7,
        breaks = 0.25 / 7
    )
)

classifier$predictions
# A tibble: 4 × 4
  geo_value .pred_class   forecast_date target_date
  <chr>     <fct>         <date>        <date>     
1 mi        (-Inf,0.0357] 2021-07-31    2021-08-14 
2 ny        (0.0357, Inf] 2021-07-31    2021-08-14 
3 pa        (0.0357, Inf] 2021-07-31    2021-08-14 
4 tx        (0.0357, Inf] 2021-07-31    2021-08-14