An implementation of posterior::summarise_draws() for grouped data frames (dplyr::grouped_df objects) such as returned by dplyr::group_by() and the various grouped-data-aware functions in tidybayes, such as spread_draws(), gather_draws(), add_epred_draws(), and add_predicted_draws(). This function provides a quick way to get a variety of summary statistics and diagnostics on draws.

# S3 method for class 'grouped_df'
summarise_draws(.x, ...)

Arguments

.x

A grouped data frame (dplyr::grouped_df object) such as returned by dplyr::group_by() where the data frame in each group (ignoring grouping columns) has the structure of a posterior::draws_df() object: ".chain", ".iteration", and ".draw" columns, with the remaining (non-grouping) columns being draws from variables.

...

Name-value pairs of summary or diagnostic functions. The provided names will be used as the names of the columns in the result unless the function returns a named vector, in which case the latter names are used. The functions can be specified in any format supported by as_function(). See Examples.

Value

A data frame (actually, a tibble) with all grouping columns from .x, a "variable" column containing variable names from .x, and the remaining columns containing summary statistics and diagnostics.

Details

While posterior::summarise_draws() can operate on tidy data frames of draws in the posterior::draws_df() format, that format does not support grouping columns. This provides an implementation of summarise_draws() that does support grouped data tables, essentially applying posterior::summarise_draws() to every sub-table of .x implied by the groups defined on the data frame.

See posterior::summarise_draws() for more details on the summary statistics and diagnostics you can use with this function. If you just want point summaries and intervals (not diagnostics), particularly for plotting, see point_interval(), which returns long-format data tables more suitable for that purpose (especially if you want to plot multiple uncertainty levels).

Author

Matthew Kay

Examples


library(posterior)
library(dplyr)

d = posterior::example_draws()

# The default posterior::summarise_draws() summarises all variables without
# splitting out indices:
summarise_draws(d)
#> # A tibble: 10 × 10
#>    variable  mean median    sd   mad      q5   q95  rhat ess_bulk ess_tail
#>    <chr>    <dbl>  <dbl> <dbl> <dbl>   <dbl> <dbl> <dbl>    <dbl>    <dbl>
#>  1 mu        4.18   4.16  3.40  3.57  -0.854  9.39  1.02     558.     322.
#>  2 tau       4.16   3.07  3.58  2.89   0.309 11.0   1.01     246.     202.
#>  3 theta[1]  6.75   5.97  6.30  4.87  -1.23  18.9   1.01     400.     254.
#>  4 theta[2]  5.25   5.13  4.63  4.25  -1.97  12.5   1.02     564.     372.
#>  5 theta[3]  3.04   3.99  6.80  4.94 -10.3   11.9   1.01     312.     205.
#>  6 theta[4]  4.86   4.99  4.92  4.51  -3.57  12.2   1.02     695.     252.
#>  7 theta[5]  3.22   3.72  5.08  4.38  -5.93  10.8   1.01     523.     306.
#>  8 theta[6]  3.99   4.14  5.16  4.81  -4.32  11.5   1.02     548.     205.
#>  9 theta[7]  6.50   5.90  5.26  4.54  -1.19  15.4   1.00     434.     308.
#> 10 theta[8]  4.57   4.64  5.25  4.89  -3.79  12.2   1.02     355.     146.

# The grouped_df implementation of summarise_draws() in tidybayes can handle
# output from spread_draws(), which is a grouped data table with the indices
# (here, `i`) left as columns:
d %>%
  spread_draws(theta[i]) %>%
  summarise_draws()
#> # A tibble: 8 × 11
#> # Groups:   i [8]
#>       i variable  mean median    sd   mad     q5   q95  rhat ess_bulk ess_tail
#>   <int> <chr>    <dbl>  <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl>    <dbl>    <dbl>
#> 1     1 theta     6.75   5.97  6.30  4.87  -1.23  18.9  1.01     400.     254.
#> 2     2 theta     5.25   5.13  4.63  4.25  -1.97  12.5  1.02     564.     372.
#> 3     3 theta     3.04   3.99  6.80  4.94 -10.3   11.9  1.01     312.     205.
#> 4     4 theta     4.86   4.99  4.92  4.51  -3.57  12.2  1.02     695.     252.
#> 5     5 theta     3.22   3.72  5.08  4.38  -5.93  10.8  1.01     523.     306.
#> 6     6 theta     3.99   4.14  5.16  4.81  -4.32  11.5  1.02     548.     205.
#> 7     7 theta     6.50   5.90  5.26  4.54  -1.19  15.4  1.00     434.     308.
#> 8     8 theta     4.57   4.64  5.25  4.89  -3.79  12.2  1.02     355.     146.

# Summary functions can also be provided, as in posterior::summarise_draws():
d %>%
  spread_draws(theta[i]) %>%
  summarise_draws(median, mad, rhat, ess_tail)
#> # A tibble: 8 × 6
#> # Groups:   i [8]
#>       i variable median   mad  rhat ess_tail
#>   <int> <chr>     <dbl> <dbl> <dbl>    <dbl>
#> 1     1 theta      5.97  4.87  1.01     254.
#> 2     2 theta      5.13  4.25  1.02     372.
#> 3     3 theta      3.99  4.94  1.01     205.
#> 4     4 theta      4.99  4.51  1.02     252.
#> 5     5 theta      3.72  4.38  1.01     306.
#> 6     6 theta      4.14  4.81  1.02     205.
#> 7     7 theta      5.90  4.54  1.00     308.
#> 8     8 theta      4.64  4.89  1.02     146.