Decorate a Bayesian model fit or a sample from it with types for variable and dimension data types. Meant to be used before calling spread_draws() or gather_draws() so that the values returned by those functions are translated back into useful data types.

recover_types(model, ...)

Arguments

model

A supported Bayesian model fit. Tidybayes supports a variety of model objects; for a full list of supported models, see tidybayes-models.

...

Lists (or data frames) providing data prototypes used to convert columns returned by spread_draws() and gather_draws() back into useful data types. See Details.

Value

A decorated version of model.

Details

Each argument in ... specifies a list or data.frame. The model is decorated with a list of constructors that can convert a numeric column into the data types in the lists in ....

Then, when spread_draws() or gather_draws() is called on the decorated model, each list entry with the same name as the variable or a dimension in variable_spec is a used as a prototype for that variable or dimension --- i.e., its type is taken to be the expected type of that variable or dimension. Those types are used to translate numeric values of variables back into useful values (for example, levels of a factor).

The most common use of recover_types is to automatically translate dimensions of a variable that correspond to levels of a factor in the original data back into levels of that factor. The simplest way to do this is to pass in the data frame from which the original data came.

Supported types of prototypes are factor, ordered, and logical. For example:

  • if prototypes$v is a factor, the v column in the returned draws is translated into a factor using factor(v, labels=levels(prototypes$v), ordered=is.ordered(prototypes$v)).

  • if prototypes$v is a logical, the v column is translated into a logical using as.logical(v).

Additional data types can be supported by providing a custom implementation of the generic function as_constructor.

Author

Matthew Kay

Examples

# \donttest{

library(dplyr)
library(magrittr)
library(rstan)
#> Loading required package: StanHeaders
#> rstan (Version 2.21.8, GitRev: 2e1f913d3ca3)
#> For execution on a local, multicore CPU with excess RAM we recommend calling
#> options(mc.cores = parallel::detectCores()).
#> To avoid recompilation of unchanged Stan programs, we recommend calling
#> rstan_options(auto_write = TRUE)
#> 
#> Attaching package: ‘rstan’
#> The following object is masked from ‘package:coda’:
#> 
#>     traceplot
#> The following object is masked from ‘package:magrittr’:
#> 
#>     extract
#> The following objects are masked from ‘package:posterior’:
#> 
#>     ess_bulk, ess_tail

# Here's an example dataset with a categorical predictor (`condition`) with several levels:
set.seed(5)
n = 10
n_condition = 5
ABC = tibble(
  condition = factor(rep(c("A","B","C","D","E"), n)),
  response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
)

# We'll fit the following model to it:
stan_code = "
  data {
    int<lower=1> n;
    int<lower=1> n_condition;
    int<lower=1, upper=n_condition> condition[n];
    real response[n];
  }
  parameters {
    real overall_mean;
    vector[n_condition] condition_zoffset;
    real<lower=0> response_sd;
    real<lower=0> condition_mean_sd;
  }
  transformed parameters {
    vector[n_condition] condition_mean;
    condition_mean = overall_mean + condition_zoffset * condition_mean_sd;
  }
  model {
    response_sd ~ cauchy(0, 1);       // => half-cauchy(0, 1)
    condition_mean_sd ~ cauchy(0, 1); // => half-cauchy(0, 1)
    overall_mean ~ normal(0, 5);

    //=> condition_mean ~ normal(overall_mean, condition_mean_sd)
    condition_zoffset ~ normal(0, 1);

    for (i in 1:n) {
      response[i] ~ normal(condition_mean[condition[i]], response_sd);
    }
  }
"

m = stan(model_code = stan_code, data = compose_data(ABC), control = list(adapt_delta=0.99),
  # 1 chain / few iterations just so example runs quickly
  # do not use in practice
  chains = 1, iter = 500)
#> 
#> SAMPLING FOR MODEL 'c2d0fdaa698b1cc34d5e522dcb8c1731' NOW (CHAIN 1).
#> Chain 1: 
#> Chain 1: Gradient evaluation took 9e-06 seconds
#> Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.09 seconds.
#> Chain 1: Adjust your expectations accordingly!
#> Chain 1: 
#> Chain 1: 
#> Chain 1: Iteration:   1 / 500 [  0%]  (Warmup)
#> Chain 1: Iteration:  50 / 500 [ 10%]  (Warmup)
#> Chain 1: Iteration: 100 / 500 [ 20%]  (Warmup)
#> Chain 1: Iteration: 150 / 500 [ 30%]  (Warmup)
#> Chain 1: Iteration: 200 / 500 [ 40%]  (Warmup)
#> Chain 1: Iteration: 250 / 500 [ 50%]  (Warmup)
#> Chain 1: Iteration: 251 / 500 [ 50%]  (Sampling)
#> Chain 1: Iteration: 300 / 500 [ 60%]  (Sampling)
#> Chain 1: Iteration: 350 / 500 [ 70%]  (Sampling)
#> Chain 1: Iteration: 400 / 500 [ 80%]  (Sampling)
#> Chain 1: Iteration: 450 / 500 [ 90%]  (Sampling)
#> Chain 1: Iteration: 500 / 500 [100%]  (Sampling)
#> Chain 1: 
#> Chain 1:  Elapsed Time: 0.076293 seconds (Warm-up)
#> Chain 1:                0.076287 seconds (Sampling)
#> Chain 1:                0.15258 seconds (Total)
#> Chain 1: 
#> Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#bulk-ess
#> Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#tail-ess

# without using recover_types(), the `condition` column returned by spread_draws()
# will be an integer:
m %>%
  spread_draws(condition_mean[condition]) %>%
  median_qi()
#> # A tibble: 5 × 7
#>   condition condition_mean .lower .upper .width .point .interval
#>       <int>          <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
#> 1         1          0.188 -0.196  0.559   0.95 median qi       
#> 2         2          0.999  0.676  1.36    0.95 median qi       
#> 3         3          1.85   1.49   2.18    0.95 median qi       
#> 4         4          1.04   0.657  1.33    0.95 median qi       
#> 5         5         -0.883 -1.22  -0.523   0.95 median qi       

# If we apply recover_types() first, subsequent calls to other tidybayes functions will
# automatically back-convert factors so that they are labeled with their original levels
# (assuming the same name is used)
m %<>% recover_types(ABC)

# now the `condition` column with be a factor with levels "A", "B", "C", ...
m %>%
  spread_draws(condition_mean[condition]) %>%
  median_qi()
#> # A tibble: 5 × 7
#>   condition condition_mean .lower .upper .width .point .interval
#>   <fct>              <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
#> 1 A                  0.188 -0.196  0.559   0.95 median qi       
#> 2 B                  0.999  0.676  1.36    0.95 median qi       
#> 3 C                  1.85   1.49   2.18    0.95 median qi       
#> 4 D                  1.04   0.657  1.33    0.95 median qi       
#> 5 E                 -0.883 -1.22  -0.523   0.95 median qi       

# }