Skip to contents

This function adds a postprocessing layer to extract a point forecast from a distributional forecast. NOTE: With default arguments, this will remove information, so one should usually call this AFTER layer_quantile_distn() or set the name argument to something specific.

Usage

layer_point_from_distn(
  frosting,
  ...,
  type = c("median", "mean"),
  name = NULL,
  id = rand_id("point_from_distn")
)

Arguments

frosting

a frosting postprocessor

...

Unused, include for consistency with other layers.

type

character. Either mean or median.

name

character. The name for the output column. The default NULL will overwrite the .pred column, removing the distribution information.

id

a random id string

Value

an updated frosting postprocessor.

Examples

library(dplyr)
jhu <- case_death_rate_subset %>%
  filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))

r <- epi_recipe(jhu) %>%
  step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
  step_epi_ahead(death_rate, ahead = 7) %>%
  step_epi_naomit()

wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>%
  fit(jhu)

f1 <- frosting() %>%
  layer_predict() %>%
  layer_quantile_distn() %>% # puts the other quantiles in a different col
  layer_point_from_distn() %>% # mutate `.pred` to contain only a point prediction
  layer_naomit(.pred)
wf1 <- wf %>% add_frosting(f1)

p1 <- forecast(wf1)
p1
#> An `epi_df` object, 3 x 4 with metadata:
#> * geo_type  = state
#> * time_type = day
#> * as_of     = 2022-05-31 19:08:25.791826
#> 
#> # A tibble: 3 × 4
#>   geo_value time_value .pred        .pred_distn
#> * <chr>     <date>     <dbl>             <dist>
#> 1 ak        2021-12-31 0.167 quantiles(0.12)[2]
#> 2 ca        2021-12-31 0.177 quantiles(0.21)[2]
#> 3 ny        2021-12-31 0.272 quantiles(0.25)[2]

f2 <- frosting() %>%
  layer_predict() %>%
  layer_point_from_distn() %>% # mutate `.pred` to contain only a point prediction
  layer_naomit(.pred)
wf2 <- wf %>% add_frosting(f2)

p2 <- forecast(wf2)
p2
#> An `epi_df` object, 3 x 3 with metadata:
#> * geo_type  = state
#> * time_type = day
#> * as_of     = 2022-05-31 19:08:25.791826
#> 
#> # A tibble: 3 × 3
#>   geo_value time_value .pred
#> * <chr>     <date>     <dbl>
#> 1 ak        2021-12-31 0.167
#> 2 ca        2021-12-31 0.177
#> 3 ny        2021-12-31 0.272