grf::quantile_forest()
fits random forests in a way that makes it easy
to calculate quantile forests. Currently, this is the only engine
provided here, since quantile regression is the typical use-case.
Tuning Parameters
This model has 3 tuning parameters:
mtry
: # Randomly Selected Predictors (type: integer, default: see below)trees
: # Trees (type: integer, default: 2000L)min_n
: Minimal Node Size (type: integer, default: 5)
mtry
depends on the number of columns in the design matrix.
The default in grf::quantile_forest()
is min(ceiling(sqrt(ncol(X)) + 20), ncol(X))
.
For categorical predictors, a one-hot encoding is always used. This makes
splitting efficient, but has implications for the mtry
choice. A factor
with many levels will become a large number of columns in the design matrix
which means that some of these may be selected frequently for potential splits.
This is different than in other implementations of random forest. For more
details, see the grf
discussion.
Translation from parsnip to the original package
rand_forest(
mode = "regression", # you must specify the `mode = regression`
mtry = integer(1),
trees = integer(1),
min_n = integer(1)
) %>%
set_engine("grf_quantiles") %>%
translate()
#> Random Forest Model Specification (regression)
#>
#> Main Arguments:
#> mtry = integer(1)
#> trees = integer(1)
#> min_n = integer(1)
#>
#> Computational engine: grf_quantiles
#>
#> Model fit template:
#> grf::quantile_forest(X = missing_arg(), Y = missing_arg(), mtry = min_cols(~integer(1),
#> x), num.trees = integer(1), min.node.size = min_rows(~integer(1),
#> x), quantiles = c(0.1, 0.5, 0.9), num.threads = 1L, seed = stats::runif(1,
#> 0, .Machine$integer.max))
Examples
library(grf)
tib <- data.frame(
y = rnorm(100), x = rnorm(100), z = rnorm(100),
f = factor(sample(letters[1:3], 100, replace = TRUE))
)
spec <- rand_forest(engine = "grf_quantiles", mode = "regression")
out <- fit(spec, formula = y ~ x + z, data = tib)
predict(out, new_data = tib[1:5, ]) %>%
pivot_quantiles_wider(.pred)
#> # A tibble: 5 × 3
#> `0.1` `0.5` `0.9`
#> <dbl> <dbl> <dbl>
#> 1 -1.34 -0.323 0.894
#> 2 -1.78 -0.284 1.06
#> 3 -1.78 -0.284 0.976
#> 4 -1.22 -0.0254 1.08
#> 5 -1.28 -0.0308 1.32
# -- adjusting the desired quantiles
spec <- rand_forest(mode = "regression") %>%
set_engine(engine = "grf_quantiles", quantiles = c(1:9 / 10))
out <- fit(spec, formula = y ~ x + z, data = tib)
predict(out, new_data = tib[1:5, ]) %>%
pivot_quantiles_wider(.pred)
#> # A tibble: 5 × 9
#> `0.1` `0.2` `0.3` `0.4` `0.5` `0.6` `0.7` `0.8` `0.9`
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 -1.34 -1.03 -0.580 -0.372 -0.323 0.205 0.291 0.579 0.894
#> 2 -1.53 -1.03 -0.620 -0.503 -0.284 -0.0144 0.249 0.537 0.976
#> 3 -1.53 -1.03 -0.620 -0.503 -0.284 0.0701 0.249 0.538 0.976
#> 4 -1.22 -0.668 -0.620 -0.372 -0.0826 0.269 0.597 0.775 1.08
#> 5 -1.28 -0.790 -0.448 -0.279 -0.0308 0.230 0.538 0.750 1.32
# -- a more complicated task
library(dplyr)
dat <- case_death_rate_subset %>%
filter(time_value > as.Date("2021-10-01"))
rec <- epi_recipe(dat) %>%
step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) %>%
step_epi_ahead(death_rate, ahead = 7) %>%
step_epi_naomit()
frost <- frosting() %>%
layer_predict() %>%
layer_threshold(.pred)
spec <- rand_forest(mode = "regression") %>%
set_engine(engine = "grf_quantiles", quantiles = c(.25, .5, .75))
ewf <- epi_workflow(rec, spec, frost) %>%
fit(dat) %>%
forecast()
ewf %>%
rename(forecast_date = time_value) %>%
mutate(target_date = forecast_date + 7L) %>%
pivot_quantiles_wider(.pred)
#> # A tibble: 56 × 6
#> geo_value forecast_date target_date `0.25` `0.5` `0.75`
#> <chr> <date> <date> <dbl> <dbl> <dbl>
#> 1 ak 2021-12-31 2022-01-07 0.196 0.269 0.424
#> 2 al 2021-12-31 2022-01-07 0.153 0.201 0.301
#> 3 ar 2021-12-31 2022-01-07 0.450 0.505 0.601
#> 4 as 2021-12-31 2022-01-07 0 0 0
#> 5 az 2021-12-31 2022-01-07 0.485 0.654 0.860
#> 6 ca 2021-12-31 2022-01-07 0.176 0.218 0.260
#> 7 co 2021-12-31 2022-01-07 0.470 0.579 0.709
#> 8 ct 2021-12-31 2022-01-07 0.347 0.420 0.459
#> 9 dc 2021-12-31 2022-01-07 0.0811 0.230 0.446
#> 10 de 2021-12-31 2022-01-07 0.235 0.385 0.530
#> # ℹ 46 more rows