In a previous post , we showed how to use tfprobability – the R interface to TensorFlow Probability – to build a multilevel, or partial pooling model of tadpole survival in differently sized (and thus, differing in inhabitant number) tanks.
A completely pooled model would have resulted in a global estimate of survival count, irrespective of tank, while an unpooled model would have learned to predict survival count for each tank separately. The former approach does not take into account different circumstances; the latter does not make use of common information. (Also, it clearly has no predictive use unless we want to make predictions for the very same entities we used to train the model.)
In contrast, a partially pooled model lets you make predictions for the familiar, as well as new entities: Just use the appropriate prior.
Assuming we are in fact interested in the same entities – why would we want to apply partial pooling? For the same reasons so much effort in machine learning goes into devising regularization mechanisms. We don’t want to overfit too much to actual measurements, be they related to the same entity or a class of entities. If I want to predict my heart rate as I wake up next morning, based on a single measurement I’m taking now (let’s say it’s evening and I’m frantically typing a blog post), I better take into account some facts about heart rate behavior in general (instead of just projecting into the future the exact value measured right now).
In the tadpole example, this means we expect generalization to work better for tanks with many inhabitants, compared to more solitary environments. For the latter ones, we better take a peek at survival rates from other tanks, to supplement the sparse, idiosyncratic information available. Or using the technical term, in the latter case we hope for the model to shrink its estimates toward the overall mean more noticeably than in the former.
This type of information sharing is already very useful, but it gets better. The tadpole model is a varying intercepts model, as McElreath calls it (or random intercepts, as it is sometimes – confusingly – called 1) – intercepts referring to the way we make predictions for entities (here: tanks), with no predictor variables present. So if we can pool information about intercepts, why not pool information about slopes as well? This will allow us to, in addition, make use of relationships between variables learnt on different entities in the training set.
So as you might have guessed by now, varying slopes (or random slopes, if you will) is the topic of today’s post. Again, we take up an example from McElreath’s book, and show how to accomplish the same thing with tfprobability.
Coffee, please#
Unlike the tadpole case, this time we work with simulated data. This is the data McElreath uses to introduce the varying slopes modeling technique; he then goes on and applies it to one of the book’s most featured datasets, the pro-social (or indifferent, rather!) chimpanzees. For today, we stay with the simulated data for two reasons: First, the subject matter per se is non-trivial enough; and second, we want to keep careful track of what our model does, and whether its output is sufficiently close to the results McElreath obtained from Stan 2.
So, the scenario is this. 3 Cafés vary in how popular they are. In a popular café, when you order coffee, you’re likely to wait. In a less popular café, you’ll likely be served much faster. That’s one thing. Second, all cafés tend to be more crowded in the mornings than in the afternoons. Thus in the morning, you’ll wait longer than in the afternoon – this goes for the popular as well as the less popular cafés.
In terms of intercepts and slopes, we can picture the morning waits as intercepts, and the resultant afternoon waits as arising due to the slopes of the lines joining each morning and afternoon wait, respectively.
So when we partially-pool intercepts, we have one “intercept prior” (itself constrained by a prior, of course), and a set of café-specific intercepts that will vary around it. When we partially-pool slopes, we have a “slope prior” reflecting the overall relationship between morning and afternoon waits, and a set of café-specific slopes reflecting the individual relationships. Cognitively, that means that if you have never been to the Café Gerbeaud in Budapest but have been to cafés before, you might have a less-than-uninformed idea about how long you are going to wait; it also means that if you normally get your coffee in your favorite corner café in the mornings, and now you pass by there in the afternoon, you have an approximate idea how long it’s going to take (namely, fewer minutes than in the mornings).
So is that all? Actually, no. In our scenario, intercepts and slopes are related. If, at a less popular café, I always get my coffee before two minutes have passed, there is little room for improvement. At a highly popular café though, if it could easily take ten minutes in the mornings, then there is quite some potential for decrease in waiting time in the afternoon. So in my prediction for this afternoon’s waiting time, I should factor in this interaction effect.
So, now that we have an idea of what this is all about, let’s see how we can model these effects with tfprobability. But first, we actually have to generate the data.
Simulate the data#
We directly follow McElreath in the way the data are generated.
|
|
Take a glimpse at the data:
|
|
Observations: 200
Variables: 3
$ cafe <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3,...
$ afternoon <int> 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,...
$ wait <dbl> 3.9678929, 3.8571978, 4.7278755, 2.7610133, 4.1194827, 3.54365,...
On to building the model.
The model#
As in the previous post on multi-level modeling , we use tfd_joint_distribution_sequential to define the model and Hamiltonian Monte Carlo for sampling. Consider taking a look at the first section of that post for a quick reminder of the overall procedure.
Before we code the model, let’s quickly get library loading out of the way. Importantly, again just like in the previous post, we need to install a master build of TensorFlow Probability, as we’re making use of very new features not yet available in the current release version. The same goes for the R packages tensorflow and tfprobability: Please install the respective development versions from github.
|
|
Now here is the model definition. We’ll go through it step by step in an instant.
|
|
The first five distributions are priors. First, we have the prior for the correlation matrix.
Basically, this would be an LKJ distribution
of shape 2x2 and with concentration parameter equal to 2.
For performance reasons, we work with a version that inputs and outputs Cholesky factors instead:
|
|
What kind of prior is this? As McElreath keeps reminding us, nothing is more instructive than sampling from the prior. For us to see what’s going on, we use the base LKJ distribution, not the Cholesky one:
|
|
So this prior is moderately skeptical about strong correlations, but pretty open to learning from data.
The next distribution in line
|
|
is the prior for the variance of the waiting time, the very last distribution in the list.
Next is the prior distribution of variances for the intercepts and slopes. This prior is the same for both cases, but we specify a sample_shape of 2 to get two individual samples.
|
|
Now that we have the respective prior variances, we move on to the prior means. Both are normal distributions.
|
|
|
|
On to the heart of the model, where the partial pooling happens. We are going to construct partially-pooled intercepts and slopes for all of the cafés. Like we said above, intercepts and slopes are not independent; they interact. Thus, we need to use a multivariate normal distribution. The means are given by the prior means defined right above, while the covariance matrix is built from the above prior variances and the prior correlation matrix. The output shape here is determined by the number of cafés: We want an intercept and a slope for every café.
|
|
Finally, we sample the actual waiting times. This code pulls out the correct intercepts and slopes from the multivariate normal and outputs the mean waiting time, dependent on what café we’re in and whether it’s morning or afternoon.
|
|
Before running the sampling, it’s always a good idea to do a quick check on the model.
|
|
We sample from the model and then, check the log probability.
|
|
We want a scalar log probability per member in the batch, which is what we get.
tf.Tensor([-466.1392 -149.92587 -196.51688], shape=(3,), dtype=float32)
Running the chains#
The actual Monte Carlo sampling works just like in the previous post, with one exception. Sampling happens in unconstrained parameter space, but at the end we need to get valid correlation matrix parameters rho and valid variances sigma and sigma_cafe. Conversion between spaces is done via TFP bijectors. Luckily, this is not something we have to do as users; all we need to specify are appropriate bijectors. For the normal distributions in the model, there is nothing to do.
|
|
Now we can set up the Hamiltonian Monte Carlo sampler.
|
|
Again, we can obtain additional diagnostics (here: step sizes and acceptance rates) by registering a trace function:
|
|
Here, then, is the sampling function. Note how we use tf_function to put it on the graph. At least as of today, this makes a huge difference in sampling performance when using eager execution.
|
|
So how do our samples look, and what do we get in terms of posteriors? Let’s see.
Results#
At this moment, mcmc_trace is a list of tensors of different shapes, dependent on how we defined the parameters. We need to do a bit of post-processing to be able to summarise and display the results.
|
|
Trace plots#
How well do the chains mix?
|
|
Awesome! (The first two parameters of rho, the Cholesky factor of the correlation matrix, need to stay fixed at 1 and 0, respectively.)
Now, on to some summary statistics on the posteriors of the parameters.
Parameters#
Like last time, we display posterior means and standard deviations, as well as the highest posterior density interval (HPDI). We add effective sample sizes and rhat values.
|
|
# A tibble: 49 x 7
key mean sd lower upper ess rhat
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 rho_1 1 0 1 1 NaN NaN
2 rho_2 0 0 0 0 NaN NaN
3 rho_3 -0.517 0.176 -0.831 -0.195 42.4 1.01
4 rho_4 0.832 0.103 0.644 1.000 46.5 1.02
5 sigma 0.473 0.0264 0.420 0.523 424. 1.00
6 sigma_cafe_1 0.967 0.163 0.694 1.29 97.9 1.00
7 sigma_cafe_2 0.607 0.129 0.386 0.861 42.3 1.03
8 b -1.14 0.141 -1.43 -0.864 95.1 1.00
9 a 3.66 0.218 3.22 4.07 75.3 1.01
10 a_cafe_1 4.20 0.192 3.83 4.57 83.9 1.01
11 b_cafe_1 -1.13 0.251 -1.63 -0.664 63.6 1.02
12 a_cafe_2 2.17 0.195 1.79 2.54 59.3 1.01
13 b_cafe_2 -0.923 0.260 -1.42 -0.388 46.0 1.01
14 a_cafe_3 4.40 0.195 4.02 4.79 56.7 1.01
15 b_cafe_3 -1.97 0.258 -2.52 -1.51 43.9 1.01
16 a_cafe_4 3.22 0.199 2.80 3.57 58.7 1.02
17 b_cafe_4 -1.20 0.254 -1.70 -0.713 36.3 1.01
18 a_cafe_5 1.86 0.197 1.45 2.20 52.8 1.03
19 b_cafe_5 -0.113 0.263 -0.615 0.390 34.6 1.04
20 a_cafe_6 4.26 0.210 3.87 4.67 43.4 1.02
21 b_cafe_6 -1.30 0.277 -1.80 -0.713 41.4 1.05
22 a_cafe_7 3.61 0.198 3.23 3.98 44.9 1.01
23 b_cafe_7 -1.02 0.263 -1.51 -0.489 37.7 1.03
24 a_cafe_8 3.95 0.189 3.59 4.31 73.1 1.01
25 b_cafe_8 -1.64 0.248 -2.10 -1.13 60.7 1.02
26 a_cafe_9 3.98 0.212 3.57 4.37 76.3 1.03
27 b_cafe_9 -1.29 0.273 -1.83 -0.776 57.8 1.05
28 a_cafe_10 3.60 0.187 3.24 3.96 104. 1.01
29 b_cafe_10 -1.00 0.245 -1.47 -0.512 70.4 1.00
30 a_cafe_11 1.95 0.200 1.56 2.35 55.9 1.03
31 b_cafe_11 -0.449 0.266 -1.00 0.0619 42.5 1.04
32 a_cafe_12 3.84 0.195 3.46 4.22 76.0 1.02
33 b_cafe_12 -1.17 0.259 -1.65 -0.670 62.5 1.03
34 a_cafe_13 3.88 0.201 3.50 4.29 62.2 1.02
35 b_cafe_13 -1.81 0.270 -2.30 -1.29 48.3 1.03
36 a_cafe_14 3.19 0.212 2.82 3.61 65.9 1.07
37 b_cafe_14 -0.961 0.278 -1.49 -0.401 49.9 1.06
38 a_cafe_15 4.46 0.212 4.08 4.91 62.0 1.09
39 b_cafe_15 -2.20 0.290 -2.72 -1.59 47.8 1.11
40 a_cafe_16 3.41 0.193 3.02 3.78 62.7 1.02
41 b_cafe_16 -1.07 0.253 -1.54 -0.567 48.5 1.05
42 a_cafe_17 4.22 0.201 3.82 4.60 58.7 1.01
43 b_cafe_17 -1.24 0.273 -1.74 -0.703 43.8 1.01
44 a_cafe_18 5.77 0.210 5.34 6.18 66.0 1.02
45 b_cafe_18 -1.05 0.284 -1.61 -0.511 49.8 1.02
46 a_cafe_19 3.23 0.203 2.88 3.65 52.7 1.02
47 b_cafe_19 -0.232 0.276 -0.808 0.243 45.2 1.01
48 a_cafe_20 3.74 0.212 3.35 4.21 48.2 1.04
49 b_cafe_20 -1.09 0.281 -1.58 -0.506 36.5 1.05
So what do we have? If you run this “live”, for the rows a_cafe_n resp. b_cafe_n, you see a nice alternation of white and red coloring: For all cafés, the inferred slopes are negative.
The inferred slope prior (b) is around -1.14, which is not too far off from the value we used for sampling: 1.
The rho posterior estimates, admittedly, are less useful unless you are accustomed to compose Cholesky factors in your head. We compute the resulting posterior correlations and their mean:
|
|
-0.5166775
The value we used for sampling was -0.7, so we see the regularization effect. In case you’re wondering, for the same data Stan yields an estimate of -0.5.
Finally, let’s display equivalents to McElreath’s figures illustrating shrinkage on the parameter (café-specific intercepts and slopes) as well as the outcome (morning resp. afternoon waiting times) scales.
Shrinkage#
As expected, we see that the individual intercepts and slopes are pulled towards the mean – the more, the further away they are from the center.
|
|
The same behavior is visible on the outcome scale.
|
|
Wrapping up#
By now, we hope we have convinced you of the power inherent in Bayesian modeling, as well as conveyed some ideas on how this is achievable with TensorFlow Probability. As with every DSL though, it takes time to proceed from understanding worked examples to design your own models. And not just time – it helps to have seen a lot of different models, focusing on different tasks and applications. On this blog, we plan to loosely follow up on Bayesian modeling with TFP, picking up some of the tasks and challenges elaborated on in the later chapters of McElreath’s book. Thanks for reading!
-
cf. the Wikipedia article on multilevel models for a collection of terms encountered when dealing with this subject, and e.g. Gelman’s dissection of various ways random effects are defined ↩︎
-
We won’t overload this post by explicitly comparing results here, but we did that when writing the code. ↩︎
-
Disclaimer: We have not verified whether this is an adequate model of the world, but it really doesn’t matter either. ↩︎