About two weeks ago, we introduced TensorFlow Probability (TFP) , showing how to create and sample from distributions and put them to use in a Variational Autoencoder (VAE) that learns its prior. Today, we move on to a different specimen in the VAE model zoo: the Vector Quantised Variational Autoencoder (VQ-VAE) described in Neural Discrete Representation Learning (Oord et al. 2017). This model differs from most VAEs in that its approximate posterior is not continuous, but discrete - hence the “quantised” in the article’s title. We’ll quickly look at what this means, and then dive directly into the code, combining Keras layers, eager execution, and TFP.
Discrete codes#
Many phenomena are best thought of, and modeled, as discrete. This holds for phonemes and lexemes in language, higher-level structures in images (think objects instead of pixels),and tasks that necessitate reasoning and planning. The latent code used in most VAEs, however, is continuous - usually it’s a multivariate Gaussian. Continuous-space VAEs have been found very successful in reconstructing their input, but often they suffer from something called posterior collapse: The decoder is so powerful that it may create realistic output given just any input. This means there is no incentive to learn an expressive latent space.
In VQ-VAE, however, each input sample gets mapped deterministically to one of a set of embedding vectors 1. Together, these embedding vectors constitute the prior for the latent space. As such, an embedding vector contains a lot more information than a mean and a variance, and thus, is much harder to ignore by the decoder.
The question then is: Where is that magical hat, for us to pull out meaningful embeddings?
Learning a discrete embedding space#
From the above conceptual description, we now have two questions to answer. First, by what mechanism do we assign input samples (that went through the encoder) to appropriate embedding vectors? And second: How can we learn embedding vectors that actually are useful representations - that when fed to a decoder, will result in entities perceived as belonging to the same species?
As regards assignment, a tensor emitted from the encoder is simply mapped to its nearest neighbor in embedding space, using Euclidean distance. The embedding vectors are then updated using exponential moving averages 2. As we’ll see soon, this means that they are actually not being learned using gradient descent - a feature worth pointing out as we don’t come across it every day in deep learning.
Concretely, how then should the loss function and training process look? This will probably easiest be seen in code.
Coding the VQ-VAE#
The complete code for this example, including utilities for model saving and image visualization, is available on github as part of the Keras examples. Order of presentation here may differ from actual execution order for expository purposes, so please to actually run the code consider making use of the example on github.
Setup and data loading#
As in all our prior posts on VAEs, we use eager execution, which presupposes the TensorFlow implementation of Keras.
|
|
As in our previous post on doing VAE with TFP, we’ll use Kuzushiji-MNIST (Clanuwat et al. 2018) as input. Now is the time to look at what we ended up generating that time and place your bet: How will that compare against the discrete latent space of VQ-VAE?
|
|
Hyperparameters#
In addition to the “usual” hyperparameters we have in deep learning, the VQ-VAE infrastructure introduces a few model-specific ones. First of all, the embedding space is of dimensionality number of embedding vectors times embedding vector size:
|
|
The latent space in our example will be of size one, that is, we have a single embedding vector representing the latent code for each input sample. This will be fine for our dataset, but it should be noted that van den Oord et al. used far higher-dimensional latent spaces on e.g. ImageNet and Cifar-10 3.
|
|
Encoder model#
The encoder uses convolutional layers to extract image features. Its output is a 3-d tensor of shape batchsize * 1 * code_size.
|
|
|
|
As always, let’s make use of the fact that we’re using eager execution, and see a few example outputs.
|
|
tf.Tensor(
[[[ 0.00516277 -0.00746826 0.0268365 ... -0.012577 -0.07752544
-0.02947626]]
...
[[-0.04757921 -0.07282603 -0.06814402 ... -0.10861694 -0.01237121
0.11455103]]], shape=(64, 1, 16), dtype=float32)
Now, each of these 16d vectors needs to be mapped to the embedding vector it is closest to. This mapping is taken care of by another model: vector_quantizer.
Vector quantizer model#
This is how we will instantiate the vector quantizer:
|
|
This model serves two purposes: First, it acts as a store for the embedding vectors. Second, it matches encoder output to available embeddings.
Here, the current state of embeddings is stored in codebook. ema_means and ema_count are for bookkeeping purposes only (note how they are set to be non-trainable). We’ll see them in use shortly.
|
|
In addition to the actual embeddings, in its call method vector_quantizer holds the assignment logic.
First, we compute the Euclidean distance of each encoding to the vectors in the codebook (tf$norm).
We assign each encoding to the closest as by that distance embedding (tf$argmin) and one-hot-encode the assignments (tf$one_hot). Finally, we isolate the corresponding vector by masking out all others and summing up what’s left over (multiplication followed by tf$reduce_sum).
Regarding the axis argument used with many TensorFlow functions, please take into consideration that in contrast to their k_* siblings, raw TensorFlow (tf$*) functions expect axis numbering to be 0-based. We also have to add the L’s after the numbers to conform to TensorFlow’s datatype requirements.
|
|
Now that we’ve seen how the codes are stored, let’s add functionality for updating them. As we said above, they are not learned via gradient descent. Instead, they are exponential moving averages, continually updated by whatever new “class member” they get assigned.
So here is a function update_ema that will take care of this.
update_ema uses TensorFlow moving_averages
to
- first, keep track of the number of currently assigned samples per code (
updated_ema_count), and - second, compute and assign the current exponential moving average (
updated_ema_means).
|
|
Before we look at the training loop, let’s quickly complete the scene adding in the last actor, the decoder.
Decoder model#
The decoder is pretty standard, performing a series of deconvolutions and finally, returning a probability for each image pixel.
|
|
Now we’re ready to train. One thing we haven’t really talked about yet is the cost function: Given the differences in architecture (compared to standard VAEs), will the losses still look as expected (the usual add-up of reconstruction loss and KL divergence)? We’ll see that in a second.
Training loop#
Here’s the optimizer we’ll use. Losses will be calculated inline.
|
|
The training loop, as usual, is a loop over epochs, where each iteration is a loop over batches obtained from the dataset.
For each batch, we have a forward pass, recorded by a gradientTape, based on which we calculate the loss.
The tape will then determine the gradients of all trainable weights throughout the model, and the optimizer will use those gradients to update the weights.
So far, all of this conforms to a scheme we’ve oftentimes seen before. One point to note though: In this same loop, we also call update_ema to recalculate the moving averages, as those are not operated on during backprop.
Here is the essential functionality: 4
|
|
Now, for the actual action. Inside the context of the gradient tape, we first determine which encoded input sample gets assigned to which embedding vector.
|
|
Now, for this assignment operation there is no gradient. Instead what we can do is pass the gradients from decoder input straight through to encoder output.
Here tf$stop_gradient exempts nearest_codebook_entries from the chain of gradients, so encoder and decoder are linked by codes:
|
|
In sum, backprop will take care of the decoder’s as well as the encoder’s weights, whereas the latent embeddings are updated using moving averages, as we’ve seen already.
Now we’re ready to tackle the losses. There are three components:
- First, the reconstruction loss, which is just the log probability of the actual input under the distribution learned by the decoder.
|
|
- Second, we have the commitment loss, defined as the mean squared deviation of the encoded input samples from the nearest neighbors they’ve been assigned to: We want the network to “commit” to a concise set of latent codes!
|
|
- Finally, we have the usual KL diverge to a prior. As, a priori, all assignments are equally probable, this component of the loss is constant and can oftentimes be dispensed of. We’re adding it here mainly for illustrative purposes.
|
|
Summing up all three components, we arrive at the overall loss 5:
|
|
Before we look at the results, let’s see what happens inside gradientTape at a single glance:
|
|
Results#
And here we go. This time, we can’t have the 2d “morphing view” one generally likes to display with VAEs (there just is no 2d latent space). Instead, the two images below are (1) letters generated from random input and (2) reconstructed actual letters, each saved after training for nine epochs.
Two things jump to the eye: First, the generated letters are significantly sharper than their continuous-prior counterparts (from the previous post). And second, would you have been able to tell the random image from the reconstruction image?
Conclusion#
At this point, we’ve hopefully convinced you of the power and effectiveness of this discrete-latents approach. However, you might secretly have hoped we’d apply this to more complex data, such as the elements of speech we mentioned in the introduction, or higher-resolution images as found in ImageNet. 6
The truth is that there’s a continuous tradeoff between the number of new and exciting techniques we can show, and the time we can spend on iterations to successfully apply these techniques to complex datasets. In the end it’s you, our readers, who will put these techniques to meaningful use on relevant, real world data.
Clanuwat, Tarin, Mikel Bober-Irizar, Asanobu Kitamoto, Alex Lamb, Kazuaki Yamamoto, and David Ha. 2018. “Deep Learning for Classical Japanese Literature.” December 3. https://arxiv.org/abs/cs.CV/1812.01718 .
Oord, Aaron van den, Oriol Vinyals, and Koray Kavukcuoglu. 2017. “Neural Discrete Representation Learning.” CoRR abs/1711.00937. http://arxiv.org/abs/1711.00937 .
-
Assuming a 1d latent space, that is. The authors actually used 1d, 2d and 3d spaces in their experiments. ↩︎
-
In the paper, the authors actually mention this as one of two ways to learn the prior, the other one being vector quantisation. ↩︎
-
To be specific, the authors indicate that they used a field of 32 x 32 latents for ImageNet, and 8 x 8 x 10 for CIFAR10. ↩︎
-
The code on github additionally contains functionality to display generated images, output the losses, and save checkpoints. ↩︎
-
Here beta is a scaling parameter found surprisingly unimportant by the paper authors. ↩︎
-
Although we have to say we find that Kuzushiji-MNIST beats MNIST by far, in complexity and aesthetics! ↩︎