Skip to contents

grf::quantile_forest() fits random forests in a way that makes it easy to calculate quantile forests. Currently, this is the only engine provided here, since quantile regression is the typical use-case.

Tuning Parameters

This model has 3 tuning parameters:

  • mtry: # Randomly Selected Predictors (type: integer, default: see below)

  • trees: # Trees (type: integer, default: 2000L)

  • min_n: Minimal Node Size (type: integer, default: 5)

mtry depends on the number of columns in the design matrix. The default in grf::quantile_forest() is min(ceiling(sqrt(ncol(X)) + 20), ncol(X)).

For categorical predictors, a one-hot encoding is always used. This makes splitting efficient, but has implications for the mtry choice. A factor with many levels will become a large number of columns in the design matrix which means that some of these may be selected frequently for potential splits. This is different than in other implementations of random forest. For more details, see the grf discussion.

Translation from parsnip to the original package

rand_forest(
  mode = "regression", # you must specify the `mode = regression`
  mtry = integer(1),
  trees = integer(1),
  min_n = integer(1)
) %>%
  set_engine("grf_quantiles") %>%
  translate()
#> Random Forest Model Specification (regression)
#>
#> Main Arguments:
#>   mtry = integer(1)
#>   trees = integer(1)
#>   min_n = integer(1)
#>
#> Computational engine: grf_quantiles
#>
#> Model fit template:
#> grf::quantile_forest(X = missing_arg(), Y = missing_arg(), mtry = min_cols(~integer(1),
#>     x), num.trees = integer(1), min.node.size = min_rows(~integer(1),
#>     x), quantiles = c(0.1, 0.5, 0.9), num.threads = 1L, seed = stats::runif(1,
#>     0, .Machine$integer.max))

Case weights

Case weights are not supported.

Examples

library(grf)
tib <- data.frame(
  y = rnorm(100), x = rnorm(100), z = rnorm(100),
  f = factor(sample(letters[1:3], 100, replace = TRUE))
)
spec <- rand_forest(engine = "grf_quantiles", mode = "regression")
out <- fit(spec, formula = y ~ x + z, data = tib)
predict(out, new_data = tib[1:5, ]) %>%
  pivot_quantiles_wider(.pred)
#> # A tibble: 5 × 3
#>   `0.1`   `0.5` `0.9`
#>   <dbl>   <dbl> <dbl>
#> 1 -1.34 -0.323  0.894
#> 2 -1.78 -0.284  1.06 
#> 3 -1.78 -0.284  0.976
#> 4 -1.22 -0.0254 1.08 
#> 5 -1.28 -0.0308 1.32 

# -- adjusting the desired quantiles

spec <- rand_forest(mode = "regression") %>%
  set_engine(engine = "grf_quantiles", quantiles = c(1:9 / 10))
out <- fit(spec, formula = y ~ x + z, data = tib)
predict(out, new_data = tib[1:5, ]) %>%
  pivot_quantiles_wider(.pred)
#> # A tibble: 5 × 9
#>   `0.1`  `0.2`  `0.3`  `0.4`   `0.5`   `0.6` `0.7` `0.8` `0.9`
#>   <dbl>  <dbl>  <dbl>  <dbl>   <dbl>   <dbl> <dbl> <dbl> <dbl>
#> 1 -1.34 -1.03  -0.580 -0.372 -0.323   0.205  0.291 0.579 0.894
#> 2 -1.53 -1.03  -0.620 -0.503 -0.284  -0.0144 0.249 0.537 0.976
#> 3 -1.53 -1.03  -0.620 -0.503 -0.284   0.0701 0.249 0.538 0.976
#> 4 -1.22 -0.668 -0.620 -0.372 -0.0826  0.269  0.597 0.775 1.08 
#> 5 -1.28 -0.790 -0.448 -0.279 -0.0308  0.230  0.538 0.750 1.32 

# -- a more complicated task

library(dplyr)
dat <- case_death_rate_subset %>%
  filter(time_value > as.Date("2021-10-01"))
rec <- epi_recipe(dat) %>%
  step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) %>%
  step_epi_ahead(death_rate, ahead = 7) %>%
  step_epi_naomit()
frost <- frosting() %>%
  layer_predict() %>%
  layer_threshold(.pred)
spec <- rand_forest(mode = "regression") %>%
  set_engine(engine = "grf_quantiles", quantiles = c(.25, .5, .75))

ewf <- epi_workflow(rec, spec, frost) %>%
  fit(dat) %>%
  forecast()
ewf %>%
  rename(forecast_date = time_value) %>%
  mutate(target_date = forecast_date + 7L) %>%
  pivot_quantiles_wider(.pred)
#> # A tibble: 56 × 6
#>    geo_value forecast_date target_date `0.25` `0.5` `0.75`
#>    <chr>     <date>        <date>       <dbl> <dbl>  <dbl>
#>  1 ak        2021-12-31    2022-01-07  0.196  0.269  0.424
#>  2 al        2021-12-31    2022-01-07  0.153  0.201  0.301
#>  3 ar        2021-12-31    2022-01-07  0.450  0.505  0.601
#>  4 as        2021-12-31    2022-01-07  0      0      0    
#>  5 az        2021-12-31    2022-01-07  0.485  0.654  0.860
#>  6 ca        2021-12-31    2022-01-07  0.176  0.218  0.260
#>  7 co        2021-12-31    2022-01-07  0.470  0.579  0.709
#>  8 ct        2021-12-31    2022-01-07  0.347  0.420  0.459
#>  9 dc        2021-12-31    2022-01-07  0.0811 0.230  0.446
#> 10 de        2021-12-31    2022-01-07  0.235  0.385  0.530
#> # ℹ 46 more rows