Before we jump into the technicalities: This post is, of course, dedicated to McElreath who wrote one of most intriguing books on Bayesian (or should we just say - scientific?) modeling we’re aware of. If you haven’t read Statistical Rethinking , and are interested in modeling, you might definitely want to check it out. In this post, we’re not going to try to re-tell the story: Our clear focus will, instead, be a demonstration of how to do MCMC with tfprobability .1
Concretely, this post has two parts. The first is a quick overview of how to use tfd_joint_sequential_distribution to construct a model, and then sample from it using Hamiltonian Monte Carlo. This part can be consulted for quick code look-up, or as a frugal template of the whole process. The second part then walks through a multi-level model in more detail, showing how to extract, post-process and visualize sampling as well as diagnostic outputs.
Reedfrogs#
The data comes with the rethinking package.
|
|
'data.frame': 48 obs. of 5 variables:
$ density : int 10 10 10 10 10 10 10 10 10 10 ...
$ pred : Factor w/ 2 levels "no","pred": 1 1 1 1 1 1 1 1 2 2 ...
$ size : Factor w/ 2 levels "big","small": 1 1 1 1 2 2 2 2 1 1 ...
$ surv : int 9 10 7 10 9 9 10 9 4 9 ...
$ propsurv: num 0.9 1 0.7 1 0.9 0.9 1 0.9 0.4 0.9 ...
The task is modeling survivor counts among tadpoles, where tadpoles are held in tanks of different sizes (equivalently, different numbers of inhabitants). Each row in the dataset describes one tank, with its initial count of inhabitants (density) and number of survivors (surv).
In the technical overview part, we build a simple unpooled model that describes every tank in isolation. Then, in the detailed walk-through, we’ll see how to construct a varying intercepts model that allows for information sharing between tanks.
Constructing models with tfd_joint_distribution_sequential#
tfd_joint_distribution_sequential represents a model as a list of conditional distributions.
This is easiest to see on a real example, so we’ll jump right in, creating an unpooled model of the tadpole data.
This is the how the model specification would look in Stan:
model{
vector[48] p;
a ~ normal( 0 , 1.5 );
for ( i in 1:48 ) {
p[i] = a[tank[i]];
p[i] = inv_logit(p[i]);
}
S ~ binomial( N , p );
}
And here is tfd_joint_distribution_sequential:
|
|
The model consists of two distributions: Prior means and variances for the 48 tadpole tanks are specified by tfd_multivariate_normal_diag; then tfd_binomial generates survival counts for each tank.
Note how the first distribution is unconditional, while the second depends on the first. Note too how the second has to be wrapped in tfd_independent to avoid wrong broadcasting. (This is an aspect of tfd_joint_distribution_sequential usage that deserves to be documented more systematically, which is surely going to happen.2 Just think that this functionality was added to TFP master only three weeks ago!)
As an aside, the model specification here ends up shorter than in Stan as tfd_binomial optionally takes logits as parameters.
As with every TFP distribution, you can do a quick functionality check by sampling from the model:3
|
|
[[1]]
Tensor("MultivariateNormalDiag/sample/affine_linear_operator/forward/add:0",
shape=(2, 48), dtype=float32)
[[2]]
Tensor("IndependentJointDistributionSequential/sample/Beta/sample/Reshape:0",
shape=(2, 48), dtype=float32)
and computing log probabilities:
|
|
t[[1]]
Tensor("MultivariateNormalDiag/sample/affine_linear_operator/forward/add:0",
shape=(2, 48), dtype=float32)
[[2]]
Tensor("IndependentJointDistributionSequential/sample/Beta/sample/Reshape:0",
shape=(2, 48), dtype=float32)
Now, let’s see how we can sample from this model using Hamiltonian Monte Carlo.
Running Hamiltonian Monte Carlo in TFP#
We define a Hamiltonian Monte Carlo kernel with dynamic step size adaptation based on a desired acceptance probability.
|
|
We then run the sampler, passing in an initial state. If we want to run $n$ chains, that state has to be of length $n$, for every parameter in the model (here we have just one).
The sampling function, mcmc_sample_chain
, may optionally be passed a trace_fn that tells TFP which kinds of meta information to save. Here we save acceptance ratios and step sizes.
|
|
When sampling is finished, we can access the samples as res$all_states:
|
|
Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack/TensorArrayGatherV3:0",
shape=(500, 4, 48), dtype=float32)
This is the shape of the samples for l, the 48 per-tank logits: 500 samples times 4 chains times 48 parameters.
From these samples, we can compute effective sample size and $rhat$ (alias mcmc_potential_scale_reduction):
|
|
Whereas diagnostic information is available in res$trace:
|
|
After this quick outline, let’s move on to the topic promised in the title: multi-level modeling, or partial pooling. This time, we’ll also take a closer look at sampling results and diagnostic outputs.
Multi-level tadpoles 4#
The multi-level model – or varying intercepts model, in this case: we’ll get to varying slopes in a later post – adds a hyperprior to the model. Instead of deciding on a mean and variance of the normal prior the logits are drawn from, we let the model learn means and variances for individual tanks. These per-tank means, while being priors for the binomial logits, are assumed to be normally distributed, and are themselves regularized by a normal prior for the mean and an exponential prior for the variance.
For the Stan-savvy, here is the Stan formulation of this model.
|
|
And here it is with TFP:
|
|
Technically, dependencies in tfd_joint_distribution_sequential are defined via spatial proximity in the list: In the learned prior for the logits
|
|
sigma refers to the distribution immediately above, and a_bar to the one above that.
Analogously, in the distribution of survival counts
|
|
l refers to the distribution immediately preceding its own definition.
Again, let’s sample from this model to see if shapes are correct.
|
|
They are.
[[1]]
Tensor("Normal/sample_1/Reshape:0", shape=(2,), dtype=float32)
[[2]]
Tensor("Exponential/sample_1/Reshape:0", shape=(2,), dtype=float32)
[[3]]
Tensor("SampleJointDistributionSequential/sample_1/Normal/sample/Reshape:0",
shape=(2, 48), dtype=float32)
[[4]]
Tensor("IndependentJointDistributionSequential/sample_1/Beta/sample/Reshape:0",
shape=(2, 48), dtype=float32)
And to make sure we get one overall log_prob per batch:
|
|
Tensor("JointDistributionSequential/log_prob/add_3:0", shape=(2,), dtype=float32)
Training this model works like before, except that now the initial state comprises three parameters, a_bar, sigma and l:
|
|
Here is the sampling routine:
|
|
This time, mcmc_trace is a list of three: We have
[[1]]
Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack/TensorArrayGatherV3:0",
shape=(500, 4), dtype=float32)
[[2]]
Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack_1/TensorArrayGatherV3:0",
shape=(500, 4), dtype=float32)
[[3]]
Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack_2/TensorArrayGatherV3:0",
shape=(500, 4, 48), dtype=float32)
Now let’s create graph nodes for the results and information we’re interested in.
|
|
And we’re ready to actually run the chains.
|
|
This time, let’s actually inspect those results.
Multi-level tadpoles: Results#
First, how do the chains behave?
Trace plots#
Extract the samples for a_bar and sigma, as well as one of the learned priors for the logits:
|
|
Here’s a trace plot for a_bar:
|
|
And here for sigma and a_1:
How about the posterior distributions of the parameters, first and foremost, the varying intercepts a_1 … a_48?
Posterior distributions#
|
|
Now let’s see the corresponding posterior means and highest posterior density intervals.
(The below code includes the hyperpriors in summary as we’ll want to display a complete precis-like output soon.)
Posterior means and HPDIs#
|
|
Now for an equivalent to precis. We already computed means, standard deviations and the HPDI interval. Let’s add n_eff, the effective number of samples, and rhat, the Gelman-Rubin statistic.
Comprehensive summary (a.k.a. “precis”)#
|
|
# A tibble: 50 x 7
key mean sd lower upper ess rhat
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 a_bar 1.35 0.266 0.792 1.87 405. 1.00
2 sigma 1.64 0.218 1.23 2.05 83.6 1.00
3 a_1 2.14 0.887 0.451 3.92 33.5 1.04
4 a_2 3.16 1.13 1.09 5.48 23.7 1.03
5 a_3 1.01 0.698 -0.333 2.31 65.2 1.02
6 a_4 3.02 1.04 1.06 5.05 31.1 1.03
7 a_5 2.11 0.843 0.625 3.88 49.0 1.05
8 a_6 2.06 0.904 0.496 3.87 39.8 1.03
9 a_7 3.20 1.27 1.11 6.12 14.2 1.02
10 a_8 2.21 0.894 0.623 4.18 44.7 1.04
# ... with 40 more rows
For the varying intercepts, effective sample sizes are pretty low, indicating we might want to investigate possible reasons.
Let’s also display posterior survival probabilities, analogously to figure 13.2 in the book.
Posterior survival probabilities#
|
|
|
|
Finally, we want to make sure we see the shrinkage behavior displayed in figure 13.1 in the book.
Shrinkage#
|
|
We see results similar in spirit to McElreath’s: estimates are shrunken to the mean (the cyan-colored line). Also, shrinkage seems to be more active in smaller tanks, which are the lower-numbered ones on the left of the plot.
Outlook#
In this post, we saw how to construct a varying intercepts model with tfprobability, as well as how to extract sampling results and relevant diagnostics. In an upcoming post, we’ll move on to varying slopes.
With non-negligible probability, our example will build on one of Mc Elreath’s again…
Thanks for reading!
-
For a supplementary introduction to Bayesian modeling focusing on complete coverage, yet starting from the very beginning, you might want to consult Ben Lambert’s Student’s Guide to Bayesian Statistics . ↩︎
-
As of today, lots of useful information is available in Modeling with JointDistribution and Multilevel Modeling Primer , but some experimentation may needed to adapt the – numerous! – examples to your needs. ↩︎
-
Updated footnote, as of May 13th: When this post was written, we were still experimenting with the use of
tf.functionfrom R, so it seemed safest to code the complete example in graph mode. The next post on MCMC will use eager execution, and show how to achieve good performance by placing the actual sampling procedure on the graph. ↩︎ -
yep, it’s a quote ↩︎