Teaching a Computer How to Write

 · 87 mins read

This is a relatively long post! I would encourage you if you're trying to learn from 0 -> 1 to read the whole thing, but feel free to jump around as you so wish. I would say there's three main portions: concept, theory, and code.

My purpose here was to build up from the basics and really understand the flow. I provide quite a couple of models so we can see the progression from a simple neural net to a basic LSTM to Peephole LSTM to a stacked cascade of Peephole LSTMs to Mixture Density Networks to Attention Mechanism to Attention RNN to the Handwriting Prediction Network to finally throwing it all together to the full Handwriting Synthesis Network that Graves originally wrote about.

There's other things that maybe I'll discuss in the future like the need to pickle JAX models because if they're XLA compatible then you can't run inference on your CPU and issues like that. Another thing I didn't discuss really was temperature and bias for sampling. I also (sadly) didn't cover priming. However, I spent far more time on this than I should have. If you have any questions - as always - feel free to reach out if curious.

Enjoy!

One thing that I would highly recommend - if you're interested in the theory of LSTMs and why sigmoid vs tanh activations were chosen, I would really encourage reading Chris Olah's Understanding LSTMs blog post. It does a fantastic job.


✍️ Motivating Visualizations

Today, we’re going to learn how to teach a computer to write. I don’t mean generating text (which would have been probably a better thing to study in college), I mean learning to write like a human learns how to write with a pen and paper. My results (eventually) were pretty good, here are some motivating visualizations.

Let’s look at one. My family used to have this hung over our kitchen sink when I was a kid. I ate breakfast every day looking at it.

heart-writing-cleansed

heart-writing-gif

heart-mdn-aggregate

heart-attention-gif

heart-mdn

heart-sampling-gif

Again, I’d recommend jumping down to here: Synthesis Model Sampling. Arguably, the best part of this post. I’ll discuss what all these visualizations mean in detail.

Table of Contents

🥅 Motivation

This motivation is clear - this is something that I have wanted to find the time to do right since college. My engineering thesis was on this Graves paper. My senior year, I worked with my good friend (also he’s a brilliant engineer) Tom Wilmots to understand and dive into this paper.

I’m going to pull pieces of that, but the time has changed, and I wanted to revisit some of the work we did, hopefully clean it up, and finally put a nail in this (so my girlfriend / friends don’t have to keep hearing about it).

👨‍🏫 History

Tom and I were very interested in the concept of teaching a computer how to write in college. There is a very famous paper that was published around 2013 from Canadian computer scientist Alex Graves, titled Generating Sequences With Recurrent Neural Networks. At Swarthmore, you have to do Engineering thesis, called E90s. It’s basically a year (although I’d argue it’s more of a semester when it all shakes out) long project focused on doing a piece of work you’re proud of.

Tom and My Engineering Thesis

For the actual paper that we wrote, check it out here:

You can also check it out here: Application of Neural Networks with Handwriting Samples.

🙏 Acknowledgements

Before I dive in, I do want to make some acknowledgements just given this is a partial resumption of work.

  • Tom Wilmots - One of the brightest and best engineers I’ve worked with. He was an Engineering and Economics double major from Swarthmore. Pretty sure I would have failed my E90 thesis without him.
  • Matt Zucker - One of my role models and constant inspirations, Matt was kind enough to be Tom and my academic advisor for this final engineering project. He is the best professor I’ve come across.
  • Alex Graves - A professor that both Tom and I had the pleasure of working with. He responded to our emails, which I’m still very appreciative of. You can see more about his work at the University of Toronto here). He is the author of this paper, which Matt found for us and pretty much was the basis of our project. He’s also the creator of the Neural Turing Machine, which peaked my interest after having taken Theory of Computation, with my other fantastic professor Lila Fontes and learning about Turing machines.
  • David Ha - Another brilliant scientist who we had the privilege of corresponding with. Check out his blog here. It’s beautiful. He also is very prolific on ArXiv which is always cool to see.

📝 Concept

This section is going to be for non-technical people to understand what we were trying to do. It’s relatively simple. At a very high level, we are trying to teach a computer how to generate human looking handwriting. To do that, we are going to train a neural network. We are going to use a public dataset, called IAM Online Handwriting Database. This dataset had a ton of people write on a tablet where the data was being recorded. It collected basically sets of Stroke data, which were tuples of $(x, y, t)$, where $(x, y)$ are the coordinates on the tablet, and $t$ is the timestamp. We’ll use this data to train a model so that across all of the participants we have this blended approach of how to write like a human.

👾 Software

In college, we decided between Tensorflow and Pytorch. In college, we used Tensorflow. However, given the times, I wanted to still resume our tensorflow approach with updated designs, but I also wanted to try and use JAX. JAX is… newer. But it’s gotten some hype online and I think there’s a solid amount of adoption across the bigger AI labs now. In my opinion, Tensorflow is dying, Pytorch is the new status quo, and JAX is the new kid on the block. However, I’m not an ML researcher clearing millions of dollars. So grain of salt. This clickbaity article which declares “Pytorch is dead. Long live JAX” got a ton of flak online, but regardless… it piqued my interest enough to try it here.

I’ll cover all three here and yeah probably dive deepest into tensorflow… but feel free to skip this section.

Tensorflow

Programming Paradigm

Tensorflow has this interesting programming paradigm, where you are more or less creating a graph. You define Tensors and then when you run your dependency graph, those things are actually translated.

I have this quote from the Tensorflow API:

There’s only two things that go into Tensorflow.

  1. Building your computational dependency graph.
  2. Running your dependency graph.

This was the old way, but now that’s not totally true. Apparently, Tensorflow 2.0 helped out a lot with the computational model and the notion of eagerly executing, rather than building the graph and then having everything run at once.

Versions - How the times have changed

So - another fun fact - when we were doing this in college, we were on tensorflow version v0.11!!! They hadn’t even released a major version. Now, I’m doing this on Tensorflow 2.16.1. So the times have definitely changed.

being-old

Definitely haven’t been able to keep up with all those changes.

Tensorboard

Another cool thing about Tensorflow that should be mentioned is the ability to utilize the Tensorboard. This is a visualization suite that creates a local website where you can interactively and with a live stream visualize your dependency graph. You can do cool things like confirm that the error is actually decreasing over the epochs.

We used this a bit more in college. I didn’t get a real chance to dive into the updates made from this.

Pytorch

PyTorch is now basically the defacto standard for most serious research labs and AI shops. To me, it seems like things are still somewhat ported to Tensorflow for production, but I’m not totally sure about convention.

Pytorch seems to thread the line between Tensorflow and JAX. Functions don’t necessarily need to be pure to be utilized. You can loop and mutate state in a nn.Module just fine.

I won’t be covering pytorch but I certainly will come back around to it in later projects.

JAX

The new up and comer! I think it’s largely a crowd favorite for it’s speed. Documentation is obviously worse. One Redditor summarized it nicely:

Comment
byu/Few-Pomegranate4369 from discussion
inMachineLearning


I hit numerous roadblocks where functions weren’t actually pure and then the JIT compile portion basically failed on startup.

Programming Paradigm

JAX and Pytorch are definitely the most like traditional Python imperative flow. The restriction on JAX is largely around pure functions. Tensorflow is also gradually moving away from the compile your graph and then run it paradigm.

📊 Data

We’re using the IAM Online Handwriting Database. Specifically, I’m looking at data/lineStrokes-all.tar.gz, which is XML data that looks like this:

data

Example Handwriting IAM Data


There’s also this note:

The database is divided into 4 parts, a training set, a first validation set, a second validation set and a final test set. The training set may be used for training the recognition system, while the two validation sets may be used for optimizing some meta-parameters. The final test set must be left unseen until the final test is performed. Note that you are allowed to use also other data for training etc, but report all the changes when you publish your experimental results and let the test set unchanged (It contains 3859 sequences, i.e. XML-files - one for each text line).

So that determines our training set, validation set, second validation set, and a final test set.

🧠 Base Neural Network Theory

I am not going to dive into details as much as we did for our senior E90 thesis, but I do want to cover a couple of the building blocks.

Lions, Bears, and Many Neural Networks, oh my

I would highly encourage you to check out this website: https://www.asimovinstitute.org/neural-network-zoo/. I remember seeing it in college when working on this thesis and was stunned. If you’re too lazy to click, check out the fun picture:

neural-network-zoo

Courtesy of Asimov Institute


We’re going to explore some of the zoo in a bit more detail, specifically, focusing on LSTMs.

Basic Neural Network

basic-nn

Courtesy of AI ML


The core structure of a neural network is the connections between all of the neurons. Each connection carries an activation signal of varying strength. If the incoming signal to a neuron is strong enough, then the signal is permeated through the next stages of the network.

There is a input layer that feeds the data into the hidden layer. The outputs from the hidden layer are then passed to the output layer. Every connection between nodes carries a weight determining the amount of information that gets passed through.

Hyper Parameters

For a basic neural network, there are generally three hyperparameters:

  • pattern of connections between all neurons
  • weights of connections between neurons
  • activation functions of the neurons

In our project however, we focus on a specific class of neural networks called Recurrent Neural Networks (RNNs), and the more specific variation of RNNs called Long Short Term Memory networks (LSTMs).

However, let’s give a bit more context. There’s really two broad types of neural networks:

Feedforward Neural Network

These neural networks channel information in one direction.

The figure above is showing a feedforward neural network because the connections do not allow for the same input data to be seen multiple times by the same node.

These networks are generally very well used for mapping raw data to categories. For example, classifying a face from an image.

Every node outputs a numerical value that it then passes to all its successor nodes. In other words:

\[\begin{align} y_j = f(x_j) \end{align} \tag{1}\]

where

\[\begin{align} x_j = \sum_{i \in P_j} w_{ij} y_i \end{align} \tag{2}\]

where

  • $y_j$ is the output of node $j$
  • $x_j$ is the total weighted input for node $j$
  • $w_{ij}$ is the weight from node $i$ to node $j$
  • $y_i$ is the output from node $i$
  • $P_j$ represents the set of predecessor nodes to node $j$

Also note, $f(x)$ should be a smooth non-linear activation function that maps outputs to a reasonable domain. Some common activation functions include $\tanh(x)$ or the sigmoid function. These complex functions are necessary because the neural network is literally trying to learn a non-linear pattern.

Backpropagation

Backpropagation is the mechanism in which we pass the error back through the network starting at the output node. Generally, we minimize using [stochastic gradient descent][stoch-grad-desc]. Again, lots of different ways we can define our error, but we can use sum of squared residuals between our $k$ targets and the output of $k$ nodes of the network.

\[\begin{align} E = \frac{1}{2} \sum_{k}(t_k - y_k)^2 \end{align} \tag{3}\]

The gradient descent part comes in next. We generate the set of all gradients with respect to error and minimize these gradients. We’re minimizing this:

\[\begin{align} g_{ij} = - \frac{\delta E}{\delta w_{ij}} \end{align} \tag{4}\]

So overall, we’re continually altering the weights and minimizing their individual effect oin the overall error of the outputs.

The major downfall of this simple network is that we don’t have full context. With sequences, there’s not enough information about the previous words, so the context is missing. And that leads us to our next structure.

Recurrent Neural Network

Recurrent Neural Networks (RNNs) have a capacity to remember. This memory stems from the fact that their input is not only the current input vector but also a variation of what they output at previous time steps.

This visualization from Christopher Olah (who holy hell i just realized is a co-founder of Anthropic, but who Tom and I used to follow closely in college) is a great visualization:

rnn-unrolled

Courtesy of Chris Olah's Understanding LSTMs


This RNN module is being unrolled over multiple timestamps. Information is passed within a module at time step $t$ to the module at $t+1$.

Per Tom and my paper,

An ideal RNN would theoretically be able to remember as far back as was necessary in order to to make an accurate prediction. However, as with many things, the theory does not carry over to reality. RNNs have trouble learning long term dependencies due to the vanishing gradient problem. An example of such a long term dependency might be if we are trying to predict the last word in the following sentence ”My family originally comes from Belgium so my native language is PREDICTION”. A normal RNN would possibly be able to recognize that the prediction should be a language but it would need the earlier context of Belgium to be able to accurately predict DUTCH.

Topically, this is why the craze around LLMs is so impressive. There’s a lot more going on with LLMs… which… I will not cover here.

The notion of backpropagation is basically the same just we also have the added dimension of time.

The crux of the issue is that RNNs have many layers and as we begin to push the derivatives to zero. The gradients become too small and cause underflow. In actual meaning, the networks then cease to be able to learn.

However, Sepp Hochreiter and Juergen Schmidhuber developed the Long Short Term Memory (LSTM) unit that solved this vanishing gradient problem.

Long Short Term Memory Networks

Long Short Term Memory (LSTM) networks are specifically designed to learn long term dependencies.

Every form of RNN has repeating modules that pass information across timesteps, and LSTMs are no different. Where they different is the inner structure of each module. While a standard RNN might have a single neural layer, LSTMs have four.

lstm-viz

Courtesy of Chris Olah's Understanding LSTMs


Understanding the LLM Structure

So let’s better understand the structure above. There’s a way more comprehensive walkthrough here. I’d encourage you to check out that walkthrough.

lstm-viz

Courtesy of Chris Olah's Understanding LSTMs


The top line is key to the LSTM’s ability to remember. It is called the cell state. We’ll reference it as $C_t$.

The first neural network layer is a sigmoid function. It takes as input the concatenation between the current input $x_t$ and the output of the previous module $h_{t-1}$. This is coined as the forget gate. It is in control of what to forget for the cell state. The sigmoid function is a good architecture decision here because it basically outputs numbers between [0,1] indicating how much the layer should let through.

We piecewise multiply the output of the sigmoid layer $f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$, with the cell state from the previous module $C_{t-1}$, forgetting the things that it doesn’t see as important.

Then right in the center of the image above there are two neural network layers which make up the update gate. First, $x_t \cdot h_{t-1}$ is pushed through both a sigmoid ($\sigma$) layer and a $\tanh$ layer. The output of this sigmoid layer $i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_C)$ determines which values to use to update, and the output of the $\tanh$ layer $\hat{C} = \sigma (W_C \cdot [h_{t-1}, x_{t} + b_C$, proposes an entirely new cell state. These two results are then piecewise multiplied and added to the current cell state (which we just edited using the forget layer) outputting the new cell state $C_t$.

The final neural network layer is called the output gate. It determines the relevant portion of the cell state to output as $h_t$. Once again, we feed $x_t \cdot h_{t-1}$ through a sigmoid layer whose output, $o_t = \sigma (W_o \cdot [h_{t-1}, x_t] + b_o)$, we piecewise multiple with $\tanh(C_t)$. The result of the multiplication determines the output of the LSTM module. Note that the purple $\tanh$ is not a neural network layer, but a piecewise multiplication intended to push the current cell state into a reasonable domain.

I'm serious... you guys should check out Olah's Understanding LSTMs. Here he is back in 2015 strongly foreshadowing transformers given the focus on attention (which is truly the hardest part of all this) blog post.

olah-attention

Courtesy of Chris Olah's Understanding LSTMs



🧬 Concepts to Code

When very first starting this project, I kind of figured that I would be able to use some of my college code, but looking back. It’s quite a mess and I don’t think that’s the way to go about it.

I thought for awhile about how best to structure this part. Meaning the code, but also how to show this in my blog post. With all the buzz about JAX, I wanted to try that too, so I thought it’d be helpful to show a side be side translation of the tensorflow vs jax code. My hope is that we’ll walk through the concepts and have a good understanding of the theory, and then the code will make a bit more sense. One note, is that I was a bit burnt of this project by the end so the JAX code I was trying to use optax (link) and flax (link) as much as possible to cut down on bulkiness of code.

So we’ll walk through the building blocks (in terms of code) and then show the code translations.

LSTM Cell with Peephole Connections

Theory

The basic LSTM cell (tf.keras.layers.LSTMCell) does not actually have the notion of peephole connections.

According to the very functional code that sjvasquez wrote, I don’t think we actually need it, but I figured it would be fun to implement regardless. Back in the old days, when Tensorflow would support add-ons, there was some work around this here, but that project was deprecated.

That being said…. the JAX / Flax code also doesn’t have LSTMs out of the gate with peepholes and so…. I just used the normal ones. The JAX model actually trained a bit better, but I think part of that was also just patience.

Code

def call(self, inputs: tf.Tensor, state: Tuple[tf.Tensor, tf.Tensor]):
    """
    This is basically implementing Graves's equations on page 5
    https://www.cs.toronto.edu/~graves/preprint.pdf
    equations 5-11.

    From the paper,
    * sigma is the logistic sigmoid function
    * i -> input gate
    * f -> forget gate
    * o -> output gate
    * c -> cell state
    * W_{hi} - hidden-input gate matrix
    * W_{xo} - input-output gate matrix
    * W_{ci} - are diagonal
        + so element m in each gate vector only receives input from
        + element m of the cell vector
    """

    # going to be shape (?, num_lstm_units)
    h_tm1, c_tm1 = state

    # basically the meat of eq, 7, 8, 9, 10
    z = tf.matmul(inputs, self.kernel) + tf.matmul(h_tm1, self.recurrent_kernel) + self.bias
    i_lin, f_lin, g_lin, o_lin = tf.split(z, num_or_size_splits=4, axis=1)

    if self.should_apply_peephole:
        pw_i = tf.expand_dims(self.peephole_weights[:, 0], axis=0)
        pw_f = tf.expand_dims(self.peephole_weights[:, 1], axis=0)
        i_lin = i_lin + c_tm1 * pw_i
        f_lin = f_lin + c_tm1 * pw_f

    # apply activation functions! see Olah's blog
    i = tf.sigmoid(i_lin)
    f = tf.sigmoid(f_lin)
    g = tf.tanh(g_lin)
    c = f * c_tm1 + i * g

    if self.should_apply_peephole:
        pw_o = tf.expand_dims(self.peephole_weights[:, 2], axis=0)
        o_lin = o_lin + c * pw_o

    o = tf.sigmoid(o_lin)

    # final hidden state -> eq. 11
    h = o * tf.tanh(c)
    return h, [h, c]
class HandwritingModel(nnx.Module):
    def __init__(
        self,
        config: ModelConfig,
        rngs: nnx.Rngs,
        synthesis_mode: bool = False,
    ) -> None:
        self.config = config
        self.synthesis_mode = synthesis_mode

        # rngs is basically a set of random keys / number generators
        self.lstm_cells = self._build_lstm_stack(rngs)
        if synthesis_mode:
            # i mean we really only care about synthesis mode, but in
            # this case we can make it explicit that if we have it then we should add our
            # attention layer
            self.attention_layer = nnx.Linear(
                config.hidden_size + config.alphabet_size + 3, 3 * config.num_attention_gaussians, rngs=rngs
            )

        # mdn portion
        self.mdn_layer = self._build_mdn_head(rngs)

    def _build_lstm_stack(self, rngs: nnx.Rngs):
        cells = []
        for i in range(self.config.num_layers):
            if i == 0:
                if self.synthesis_mode:
                    # so if we're in synthesis mode, then we need to add the alphabet size
                    # and the 3 dimensions of the input stroke
                    # that's because our alphabet size is the number of characters in our alphabet
                    # and the 3 dimensions of the input stroke are the x, y, and eos values
                    in_size = self.config.alphabet_size + 3
                else:
                    in_size = 3
            else:
                # similar in both (just in synthesis we only care if we need to expand by the alphabet size)
                in_size = self.config.hidden_size + 3
                if self.synthesis_mode:
                    in_size += self.config.alphabet_size

            # ok... being lazy but this is just standard LSTM
            cells.append(
                {"linear": nnx.Linear(in_size + self.config.hidden_size, 4 * self.config.hidden_size, rngs=rngs)}
            )
        return cells

    def lstm_cell(
        self, x: jnp.ndarray, h: jnp.ndarray, c: jnp.ndarray, layer_idx: int
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # just think about this as grabbing the W and b for our matrix mults
        linear = self.lstm_cells[layer_idx]["linear"]

        combined = jnp.concatenate([x, h], axis=-1)
        gates = linear(combined)

        i, f, g, o = jnp.split(gates, 4, axis=-1)

        # activations
        i = nnx.sigmoid(i)
        f = nnx.sigmoid(f)
        g = nnx.tanh(g)
        o = nnx.sigmoid(o)

        # get new LSTM cell state
        c_new = f * c + i * g
        h_new = o * nnx.tanh(c_new)
        return h_new, c_new

Gaussian Mixture Models

Theory

gmm-viz


Gaussian Mixture Models are an unsupervised technique to learn an underlying probabilistic model.

Brilliant has an incredible explanation walking through the theory here. I’d encourage you to check it out, but at a very high level:

  1. A number of Gaussians is specified by the user
  2. The algo learns various parameters that represent the data while maximizing the likelihood of seeing such data

So if we have $k$ components, for a multivariate Gaussian mixture model, we’ll learn $k$ means, $k$ variances, $k$ mixture weights, $k$ correlations through expectation maximization.

From Brilliant, there are really two steps for the EM step:

The first step, known as the expectation step or E step, consists of calculating the expectation of the component assignments $C_k$ for each data point $x_i \in X$ given the model parameters $\phi_k, \mu_k$ , and $\sigma_k$ .

The second step is known as the maximization step or M step, which consists of maximizing the expectations calculated in the E step with respect to the model parameters. This step consists of updating the values $\phi_k, \mu_k$ , and $\sigma_k$ .

Code

There’s actually not a whole lot of code to provide here. GMMs are more of the technique that we’ll combine with the output of a neural network. That leads us smoothly to our next section.

Mixture Density Networks

Theory

Mixture Density Networks are an extension of GMMs that predict the parameters of a mixture probability distribution.

mdn-viz


Per our paper:

The idea is relatively simple - we take the output from a neural network and parametrize the learned parameters of the GMM. The result is that we can infer probabilistic prediction from our learned parameters. If our neural network is reason- ably predicting where the next point might be, the GMM will then learn probabilistic parameters that model the distribution of the next point. This is different in a few key aspects. Namely, we now have target values because our data is sequential. Therefore, when we feed in our targets, we minimize the log likelihood based on those expectations, thus altering the GMM portion of the model to learn the predicted values.

More or less though, the problem we’re trying to solve is predicting the next input given our output vector. Essentially, we’re asking for $\text{Pr}(x_{t+1} | y_t)$. I’m not going to show the proof (we didn’t in our paper right), but the equation for the conditional probability is shown below:

\[\begin{align} \text{Pr}(x_{t+1} | y_t) = \sum_{j=1}^{M} \pi_{j}^t \mathcal{N} (x_{t+1} \mid \mu_j^t, \sigma_j^t, \rho_j^t) \end{align} \tag{5}\]

where

\[\begin{align} \mathcal{N}(x \mid \mu, \sigma, \rho) = \frac{1}{2\pi \sigma_1 \sigma_2 \sqrt[]{1-\rho^2}} \exp \left[\frac{-Z}{2(1-\rho^2)}\right] \end{align} \tag{6}\]

and

\[\begin{align} Z = \frac{(x_1 - \mu_1)^2 }{\sigma_1^2} + \frac{(x_2 - \mu_2)^2}{\sigma_2^2} - \frac{2\rho (x_1 - \mu_1) (x_2 - \mu_2) }{\sigma_1 \sigma_2} \end{align} \tag{7}\]

Now, there’s a slight variation here because we have a handwriting specific end-of-stroke parameter. So we modify our conditional probability formula to result in our final calculation of:

\[\begin{align} \textrm{Pr}(x_{t+1} \mid y_t ) = \sum\limits_{j=1}\limits^{M} \pi_j^t \; \mathcal{N} (x_{t+1} \mid \mu_j^t, \sigma_j^t, \rho_j^t) \begin{cases} e_t & \textrm{if } (x_{t+1})_3 = 1 \\ 1-e_t & \textrm{otherwise} \end{cases} \end{align} \tag{8}\]

And that’s it! That’s our final probability output from the MDN. Once we have this, performing our expectation maximization is simple as our loss function that we choose to minimize is just:

\[\begin{align} \mathcal{L}(\mathbf{x}) = - \sum\limits_{t=1}^{T} \log \textrm{Pr}(x_{t+1} \mid y_t) \end{align} \tag{9}\]

Code

Here’s the corresponding code section for my mixture density network.

class MixtureDensityLayer(tf.keras.layers.Layer):
    def __init__(
        self,
        num_components,
        name="mdn",
        temperature=1.0,
        enable_regularization=False,
        sigma_reg_weight=0.01,
        rho_reg_weight=0.01,
        entropy_reg_weight=0.1,
        **kwargs,
    ):
        super(MixtureDensityLayer, self).__init__(name=name, **kwargs)
        self.num_components = num_components
        # The number of parameters per mixture component: 2 means, 2 standard deviations, 1 correlation, 1 weight , 1 for eos
        # so that's our constant num_mixture_components_per_component
        self.output_dim = num_components * NUM_MIXTURE_COMPONENTS_PER_COMPONENT + 1
        self.mod_name = name
        self.temperature = temperature
        self.enable_regularization = enable_regularization
        self.sigma_reg_weight = sigma_reg_weight
        self.rho_reg_weight = rho_reg_weight
        self.entropy_reg_weight = entropy_reg_weight

    def build(self, input_shape):
        graves_initializer = tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.075)

        self.input_units = input_shape[-1]
        # weights
        # lots of weight initialization here... could simplify here too

        # biases
        # lots of bias initialization here... could simplify this part by just doing a massive 
        # and splitting... see the code if you're curious
        super().build(input_shape)

    def call(self, inputs, training=None):
        temperature = 1.0 if not training else self.temperature

        pi_logits = tf.matmul(inputs, self.W_pi) + self.b_pi
        pi = tf.nn.softmax(pi_logits / temperature, axis=-1)  # [B, T, K]
        # clipping here... I was getting cooked by NaN creep
        pi = tf.clip_by_value(pi, 1e-6, 1.0)

        mu = tf.matmul(inputs, self.W_mu) + self.b_mu  # [B, T, 2K]
        mu1, mu2 = tf.split(mu, 2, axis=2)

        log_sigma = tf.matmul(inputs, self.W_sigma) + self.b_sigma  # [B, T, 2K]
        # again, this might be overkill but seems realistic for clipping
        log_sigma = tf.clip_by_value(log_sigma, -5.0, 2.0)
        sigma = tf.exp(log_sigma)
        sigma1, sigma2 = tf.split(sigma, 2, axis=2)

        rho_raw = tf.matmul(inputs, self.W_rho) + self.b_rho
        rho = tf.tanh(rho_raw) * 0.9

        eos_logit = tf.matmul(inputs, self.W_eos) + self.b_eos

        return tf.concat([pi, mu1, mu2, sigma1, sigma2, rho, eos_logit], axis=2)
class HandwritingModel(nnx.Module):
    def __init__(
        self,
        config: ModelConfig,
        rngs: nnx.Rngs,
        synthesis_mode: bool = False,
    ) -> None:
        self.config = config
        self.synthesis_mode = synthesis_mode

        # rngs is basically a set of random keys / number generators
        self.lstm_cells = self._build_lstm_stack(rngs)
        if synthesis_mode:
            # i mean we really only care about synthesis mode, but in
            # this case we can make it explicit that if we have it then we should add our
            # attention layer
            self.attention_layer = nnx.Linear(
                config.hidden_size + config.alphabet_size + 3, 3 * config.num_attention_gaussians, rngs=rngs
            )

        # mdn portion
        self.mdn_layer = self._build_mdn_head(rngs)

    #....
    
    def __call__(
        self,
        inputs: jnp.ndarray,
        char_seq: Optional[jnp.ndarray] = None,
        char_lens: Optional[jnp.ndarray] = None,
        initial_state: Optional[RNNState] = None,
        return_state: bool = False,
    ) -> jnp.ndarray:
        batch_size, seq_len, _ = inputs.shape

        if initial_state is None:
            h = jnp.zeros((self.config.num_layers, batch_size, self.config.hidden_size), inputs.dtype)
            c = jnp.zeros_like(h)
            kappa = jnp.zeros((batch_size, self.config.num_attention_gaussians), inputs.dtype)
            window = jnp.zeros((batch_size, self.config.alphabet_size), inputs.dtype)
        else:
            h, c = initial_state.hidden, initial_state.cell
            kappa, window = initial_state.kappa, initial_state.window

        def step(carry, x_t):
            h, c, kappa, window = carry
            h_layers = []
            c_layers = []

            # layer1
            if self.synthesis_mode:
                layer1_input = jnp.concatenate([window, x_t], axis=-1)
            else:
                layer1_input = x_t

            h1, c1 = self.lstm_cell(layer1_input, h[0], c[0], 0)
            h_layers.append(h1)
            c_layers.append(c1)

            # layer1 -> attention
            if self.synthesis_mode and char_seq is not None and char_lens is not None:
                window, kappa = self.compute_attention(h1, kappa, window, x_t, char_seq, char_lens)

            # attention -> layer2 and layer3
            for layer_idx in range(1, self.config.num_layers):
                if self.synthesis_mode:
                    layer_input = jnp.concatenate([x_t, h_layers[-1], window], axis=-1)
                else:
                    layer_input = jnp.concatenate([x_t, h_layers[-1]], axis=-1)

                h_new, c_new = self.lstm_cell(layer_input, h[layer_idx], c[layer_idx], layer_idx)
                h_layers.append(h_new)
                c_layers.append(c_new)

            h_new = jnp.stack(h_layers)
            c_new = jnp.stack(c_layers)

            # mdn output from final hidden state
            mdn_out = self.mdn_layer(h_layers[-1])  # [B, 6M+1]

            return (h_new, c_new, kappa, window), mdn_out

        # this was the major unlock for JAX performance
        # it allows us to vectorize the computation over the time dimension
        # transpose inputs from [B, T, 3] to [T, B, 3] for scan
        inputs_transposed = inputs.swapaxes(0, 1)
        (h, c, kappa, window), outputs = jax.lax.scan(step, (h, c, kappa, window), inputs_transposed)

        # transpose back
        outputs = outputs.swapaxes(0, 1)

        if return_state:
            final_state = RNNState(hidden=h, cell=c, kappa=kappa, window=window)
            return outputs, final_state

        return outputs

Mixture Density Loss

Theory

I already covered the theory above, so I won’t go into that here, but just figured it was easier to split out the code between network and calculating our loss. Note, there’s some pretty aggressive clipping going on just given I had some pretty high instability with JAX. I think partially because of the implementation and clipping but loss would just go to 0 rather than the program crashing. To be clear, loss going to zero was not desired.

Code

@tf.keras.utils.register_keras_serializable()
def mdn_loss(y_true, y_pred, stroke_lengths, num_components, eps=1e-8):
    """
    Mixture density negative log-likelihood computed fully in log-space.

    y_true: [B, T, 3]  -> (x, y, eos ∈ {0,1})
    y_pred: [B, T, 6*K + 1] -> (pi, mu1, mu2, sigma1, sigma2, rho, eos_logit)

    The log space change was because I was getting absolutely torched by the
    gradients when using the normal space.
    """
    out_pi, mu1, mu2, sigma1, sigma2, rho, eos_logits = tf.split(
        y_pred,
        [num_components] * 6 + [1],
        axis=2,
    )

    x, y, eos_targets = tf.split(y_true, [1, 1, 1], axis=-1)

    sigma1 = tf.clip_by_value(sigma1, 1e-2, 10.0)
    sigma2 = tf.clip_by_value(sigma2, 1e-2, 10.0)
    rho = tf.clip_by_value(rho, -0.9, 0.9)
    out_pi = tf.clip_by_value(out_pi, eps, 1.0)

    log_2pi = tf.constant(np.log(2.0 * np.pi), dtype=y_pred.dtype)
    one_minus_rho2 = tf.clip_by_value(1.0 - tf.square(rho), eps, 2.0)
    log_one_minus_rho2 = tf.math.log(one_minus_rho2)
    z1 = (x - mu1) / sigma1
    z2 = (y - mu2) / sigma2

    quad = tf.square(z1) + tf.square(z2) - 2.0 * rho * z1 * z2
    quad = tf.clip_by_value(quad, 0.0, 100.0)
    log_norm = -(log_2pi + tf.math.log(sigma1) + tf.math.log(sigma2) + 0.5 * log_one_minus_rho2)
    log_gauss = log_norm - 0.5 * quad / one_minus_rho2  # [B, T, K]

    # log mixture via log-sum-exp
    log_pi = tf.math.log(out_pi)  # [B, T, K]
    log_gmm = tf.reduce_logsumexp(log_pi + log_gauss, axis=-1)  # [B, T]

    # bce (bernoulli cross entropy) to help out with stability
    eos_nll = tf.nn.sigmoid_cross_entropy_with_logits(labels=eos_targets, logits=eos_logits)  # [B, T, 1]
    eos_nll = tf.squeeze(eos_nll, axis=-1)  # [B, T]

    nll = -log_gmm + eos_nll  # [B, T]
    if stroke_lengths is not None:
        mask = tf.sequence_mask(stroke_lengths, maxlen=tf.shape(y_true)[1], dtype=nll.dtype)
        nll = nll * mask
        denom = tf.maximum(tf.reduce_sum(mask), 1.0)
        return tf.reduce_sum(nll) / denom

    return tf.reduce_mean(nll)
def compute_loss(
    predictions: jnp.ndarray,
    targets: jnp.ndarray,
    lengths: Optional[jnp.ndarray] = None,
    num_mixtures: int = NUM_BIVARIATE_GAUSSIAN_MIXTURE_COMPONENTS,
) -> jnp.ndarray:
    nc = num_mixtures
    pi, mu1, mu2, s1, s2, rho, eos_pred = jnp.split(predictions, [nc, 2 * nc, 3 * nc, 4 * nc, 5 * nc, 6 * nc], axis=-1)

    pi = nnx.softmax(pi, axis=-1)
    s1 = jnp.exp(jnp.clip(s1, -10, 3))
    s2 = jnp.exp(jnp.clip(s2, -10, 3))
    rho = jnp.clip(nnx.tanh(rho) * 0.95, -0.95, 0.95)
    eos_pred = jnp.clip(nnx.sigmoid(eos_pred), 1e-8, 1 - 1e-8)

    x, y, eos = jnp.split(targets, [1, 2], axis=-1)

    # major change is we compute log probabilities with better numerical stability
    rho_sq = jnp.clip(rho**2, 0, 0.9025)
    one_minus_rho_sq = jnp.maximum(1 - rho_sq, 1e-6)
    norm = -jnp.log(2 * jnp.pi) - jnp.log(s1) - jnp.log(s2) - 0.5 * jnp.log(one_minus_rho_sq)

    z1 = (x - mu1) / jnp.maximum(s1, 1e-6)
    z2 = (y - mu2) / jnp.maximum(s2, 1e-6)

    exp_term = -0.5 / one_minus_rho_sq * (z1**2 + z2**2 - 2 * rho * z1 * z2)
    exp_term = jnp.clip(exp_term, -50, 0)
    log_probs = norm + exp_term
    log_pi = jnp.log(jnp.maximum(pi, 1e-8))
    log_mixture = jax.nn.logsumexp(log_pi + log_probs, axis=-1)

    eos_loss = -jnp.sum(eos * jnp.log(eos_pred) + (1 - eos) * jnp.log(1 - eos_pred), axis=-1)

    loss = -log_mixture + eos_loss
    loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)

    if lengths is not None:
        mask = jnp.arange(predictions.shape[1]) < lengths[:, None]
        loss = jnp.where(mask, loss, 0.0)
        total_loss = jnp.sum(loss) / jnp.maximum(jnp.sum(mask), 1)
        return jnp.where(jnp.isnan(total_loss) | jnp.isinf(total_loss), 0.0, total_loss)

    mean_loss = jnp.mean(loss)
    return jnp.where(jnp.isnan(mean_loss) | jnp.isinf(mean_loss), 0.0, mean_loss)

Attention Mechanism

Theory

The attention mechanism really only comes into play with the Synthesis Network which sadly Tom and I never got to in college. The idea (similar to most attention notions) is that we need to tell our model more specifically where to focus. This isn’t like the transformer notion of attention from the famous “Attention is All You Need” paper, but it’s the idea that we have various Gaussians to indicate probabilistically where we should be focusing. We utilize one-hot encoding vectors over our input characters so that we can more clearly identify the numerical representation. So the question we’re basically answering is like “oh, i see a ‘w’ character, generally how far along do we need to write for that?” to help also answer the question of when do we need to terminate.

The mathematical representation is here:

Given a length $U$ character sequence $\mathbf{c}$ and a length $T$ data sequence $\mathbf{x}$, the soft window $w_t$ into $\mathbf{c}$ at timestep $t$ ($1 \leq t \leq T$) is defined by the following discrete convolution with a mixture of $K$ Gaussian functions

\[\begin{align} \phi(t, u) &= \sum_{k=1}^K \alpha^k_t\exp\left(-\beta_t^k\left(\kappa_t^k-u\right)^2\right)\\ w_t &= \sum_{u=1}^U \phi(t, u)c_u \end{align}\]

where $\phi(t, u)$ is the \emph{window weight} of $c_u$ at timestep $t$.

Intuitively, the $\kappa_t$ parameters control the location of the window, the $\beta_t$ parameters control the width of the window and the $\alpha_t$ parameters control the importance of the window within the mixture.

The size of the soft window vectors is the same as the size of the character vectors $c_u$ (assuming a one-hot encoding, this will be the number of characters in the alphabet).

Note that the window mixture is not normalised and hence does not determine a probability distribution; however the window weight $\phi(t, u)$ can be loosely interpreted as the network’s belief that it is writing character $c_u$ at time $t$.

Code

@tf.keras.utils.register_keras_serializable()
class AttentionMechanism(tf.keras.layers.Layer):
    """
    Attention mechanism for the handwriting synthesis model.
    This is a version of the attention mechanism used in
    the original paper by Alex Graves. It uses a Gaussian
    window to focus on different parts of the character sequence
    at each time step.

    See section: 5.0 / 5.1
    """

    def __init__(self, num_gaussians, num_chars, name="attention", debug=False, **kwargs) -> None:
        super(AttentionMechanism, self).__init__(**kwargs)
        self.num_gaussians = num_gaussians
        self.num_chars = num_chars
        self.name_mod = name
        self.debug = debug

    def call(
        self,
        inputs,  # shape: [batch_size, num_gaussians, 3]
        prev_kappa,  # shape: [batch_size, num_gaussians]
        char_seq_one_hot,  # shape: [batch_size, char_len, num_chars]
        sequence_lengths,  # shape: [batch_size]
    ) -> tuple[tf.Tensor, tf.Tensor]:
        raw = tf.matmul(inputs, self.attention_kernel) + self.attention_bias
        alpha_hat, beta_hat, kappa_hat = tf.split(raw, 3, axis=1)  # shape: [batch_size, num_gaussians, 1]

        eps = tf.constant(1e-6, dtype=inputs.dtype)
        scaling = 0.1  # Gentler activation
        alpha = tf.nn.softplus(alpha_hat * scaling) + eps  # [B, G]
        beta = tf.nn.softplus(beta_hat * scaling) + eps  # [B, G]
        dkap = tf.nn.softplus(kappa_hat * scaling) + eps

        alpha = tf.clip_by_value(alpha, 0.01, 10.0)
        beta = tf.clip_by_value(beta, 0.01, 10.0)
        dkap = tf.clip_by_value(dkap, 1e-5, 0.5)

        kappa = prev_kappa + dkap
        kappa = tf.clip_by_value(kappa, 0.0, 30.0)

        char_len = tf.shape(char_seq_one_hot)[1]
        batch_size = tf.shape(inputs)[0]
        u = tf.cast(tf.range(1, char_len + 1), tf.float32)
        u = tf.reshape(u, [1, 1, -1])  # shape: [1, 1, char_len]
        u = tf.tile(u, [batch_size, self.num_gaussians, 1])  # shape: [batch_size, num_gaussians, char_len]

        alpha = tf.expand_dims(alpha, axis=-1)  # shape: [batch_size, num_gaussians, 1]
        beta = tf.expand_dims(beta, axis=-1)  # shape: [batch_size, num_gaussians, 1]
        kappa = tf.expand_dims(kappa, axis=-1)  # shape: [batch_size, num_gaussians, 1]

        exponent = -beta * tf.square(kappa - u)
        exponent = tf.clip_by_value(exponent, -50.0, 0.0)
        phi = alpha * tf.exp(exponent)  # shape: [batch_size, num_gaussians, char_len]
        phi = tf.reduce_sum(phi, axis=1)  # Sum over gaussians: [B, L]

        sequence_mask = tf.sequence_mask(sequence_lengths, maxlen=char_len, dtype=tf.float32)
        phi = phi * sequence_mask  # mask paddings

        phi = tf.where(tf.math.is_finite(phi), phi, tf.zeros_like(phi))
        # we don't normalize here - Graves calls that out specifically!
        # > Note that the window mixture is not normalised
        # > and hence does not determine a probability distribution; however the window
        # > weight φ(t,u) can be loosely interpreted as the network's belief that it is writ-
        # > ing character cu at time t.
        # still section 5.1

        # window vec
        phi = tf.expand_dims(phi, axis=-1)  # shape: [batch_size, char_len, 1]
        w = tf.reduce_sum(phi * char_seq_one_hot, axis=1)  # shape: [batch_size, num_chars]

        w = tf.where(tf.math.is_finite(w), w, tf.zeros_like(w))
        return w, kappa[:, :, 0]
    def compute_attention(
        self,
        h: jnp.ndarray,  # [B, H]
        prev_kappa: jnp.ndarray,  # [B, G]
        window: jnp.ndarray,  # [B, A]
        x: jnp.ndarray,  # [B, 3]
        char_seq: jnp.ndarray,  # [B, U, A] one-hot
        char_lens: jnp.ndarray,  # [B] lengths
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Compute Gaussian window attention over character sequence."""

        attention_input = jnp.concatenate([window, x, h], axis=-1)
        params = self.attention_layer(attention_input)  # [B, 3G]
        params = nnx.softplus(params)
        alpha, beta, kappa_inc = jnp.split(params, 3, axis=-1)

        # again... probably sliiiiightly overkill
        alpha = jnp.maximum(alpha, 1e-4)
        beta = jnp.maximum(beta, 1e-4)
        kappa_inc = jnp.maximum(kappa_inc, 1e-4)

        # ok this was a trick from svasquez - the dividing by 25.0
        # is to help kappa learn given that 25 is roughly the average
        # number of strokes per sequence
        kappa = prev_kappa + kappa_inc / 25.0

        U = char_seq.shape[1]
        positions = jnp.arange(U, dtype=jnp.float32)[None, None, :]  # [1, 1, U]
        kappa_exp = kappa[:, :, None]  # [B, G, 1]
        alpha_exp = alpha[:, :, None]  # [B, G, 1]
        beta_exp = beta[:, :, None]  # [B, G, 1]

        # gaussian window
        phi = alpha_exp * jnp.exp(-beta_exp * (kappa_exp - positions) ** 2)  # [B, G, U]
        phi = jnp.sum(phi, axis=1)

        # mask out positions beyond char_lens
        mask = jnp.arange(U)[None, :] < char_lens[:, None]  # [B, U]
        phi = jnp.where(mask, phi, 0.0)

        # so Graves said that
        phi = phi / (jnp.sum(phi, axis=-1, keepdims=True) + 1e-8)

        # Apply to character sequence
        # window: [B, A] = sum_u phi[b,u]*char_seq[b,u,:]
        window_new = jnp.einsum("bu,bua->ba", phi, char_seq)

        return window_new, kappa

Stacked LSTM

Theory

The one distinction between Graves’s setup and a standard LSTM is that Graves uses a cascade of LSTMs. So we use the MDN to generate a probabilistic prediction however our neural network is the cascade of LSTMs.

Per our paper:

The LSTM cascade buys us a few different things. As Graves aptly points out, it mitigates the vanishing gradient problem even more greatly than a single LSTM could. This is because of the skip-connections. All hidden layers have access to the input and all hidden layers are also directly connected to the output node. As a result, there are less processing steps from the bottom of the network to the top.

So it looks something like this:

graves-stacked-lstm

Courtesy of Alex Graves's paper


The one thing to note is that there is a dimensionality increase given we now have these hidden layers. Tom and I broke this down in our paper here:

Let’s observe the $x_{t-1}$ input. $h_{t-1}^1$ only has $x_{t-1}$ as its input which is in $\mathbb{R}^3$ because $(x, y, eos)$. However, we also pass our input $x_{t-1}$ into $h_{t-1}^2$. We assume that we simply concatenate the original input and the output of the first hidden layer. Because LSTMs do not scale dimensionality, we know the output is going to be in $\mathbb{R}^3$ as well. Therefore, after this concatenation, the input into the second hidden layer will be in $\mathbb{R}^6$. We can follow this process through and see that, the input to the third hidden layer will be in $\mathbb{R}^9$. Finally, we concatenate all of the LSTM cells (i.e. the hidden layers) together, thus getting a final dimension of $\mathbb{R}^{18}$ fed into our MDN. Note, this is for $m=3$ hidden layers, but more generally, we can observe the relation as

\[\begin{align} \textrm{final dimension} = k \frac{m(m+1)}{2} \end{align}\]

Here’s my take is that I actually like how I constructed the Tensorflow version more from a composability perspective. I think the code is cleaner. However, c’est la vie.

Code

This is where the various cell vs layer concept in Tensorflow was very nice.

You can see here how the parts all come together smoothly. The custom RNN cell takes the lstm_cells (which are stacked), and then can basically abstract out and operate on the individual time steps without having to worry about actually introducing another for loop. This is beneficial because of the batching and GPU win we can get when it eventually becomes time.

@tf.keras.utils.register_keras_serializable()
class DeepHandwritingSynthesisModel(tf.keras.Model):
    """
    A similar implementation to the previous model,
    but now we're throwing the good old attention mechanism back into the mix.
    """

    def __init__(
        self,
        units: int = NUM_LSTM_CELLS_PER_HIDDEN_LAYER,
        num_layers: int = NUM_LSTM_HIDDEN_LAYERS,
        num_mixture_components: int = NUM_BIVARIATE_GAUSSIAN_MIXTURE_COMPONENTS,
        num_chars: int = ALPHABET_SIZE,
        num_attention_gaussians: int = NUM_ATTENTION_GAUSSIAN_COMPONENTS,
        gradient_clip_value: float = GRADIENT_CLIP_VALUE,
        enable_mdn_regularization: bool = False,
        debug=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.units = units
        self.num_layers = num_layers
        self.num_mixture_components = num_mixture_components
        self.num_chars = num_chars
        self.num_attention_gaussians = num_attention_gaussians
        self.gradient_clip_value = gradient_clip_value
        self.enable_mdn_regularization = enable_mdn_regularization
        # Store LSTM cells as tracked attributes instead of list
        self.lstm_cells = []
        for idx in range(num_layers):
            cell = LSTMPeepholeCell(units, idx)
            setattr(self, f'lstm_cell_{idx}', cell)  # Register as tracked attribute
            self.lstm_cells.append(cell)

        self.attention_mechanism = AttentionMechanism(num_gaussians=num_attention_gaussians, num_chars=num_chars)
        self.attention_rnn_cell = AttentionRNNCell(self.lstm_cells, self.attention_mechanism, self.num_chars)
        self.rnn_layer = tf.keras.layers.RNN(self.attention_rnn_cell, return_sequences=True)
        self.mdn_layer = MixtureDensityLayer(num_mixture_components, enable_regularization=enable_mdn_regularization)
        self.debug = debug

        # metrics
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.nll_tracker = tf.keras.metrics.Mean(name="nll")
        self.eos_accuracy_tracker = tf.keras.metrics.Mean(name="eos_accuracy")
        self.eos_prob_tracker = tf.keras.metrics.Mean(name="eos_prob")

    def call(
        self, inputs: Dict[str, tf.Tensor], training: Optional[bool] = None, mask: Optional[tf.Tensor] = None
    ) -> tf.Tensor:
        input_strokes = inputs["input_strokes"]
        input_chars = inputs["input_chars"]
        input_char_lens = inputs["input_char_lens"]

        # one-hot encode the character sequence and set RNN cell attributes
        char_seq_one_hot = tf.one_hot(input_chars, depth=self.num_chars)
        self.attention_rnn_cell.char_seq_one_hot = char_seq_one_hot
        self.attention_rnn_cell.char_seq_len = input_char_lens

        # initial states
        batch_size = tf.shape(input_strokes)[0]
        initial_states = self.attention_rnn_cell.get_initial_state(batch_size=batch_size, dtype=input_strokes.dtype)
        initial_states_list = [
            initial_states["lstm_0_h"],
            initial_states["lstm_0_c"],
            initial_states["lstm_1_h"],
            initial_states["lstm_1_c"],
            initial_states["lstm_2_h"],
            initial_states["lstm_2_c"],
            initial_states["kappa"],
            initial_states["w"],
        ]

        # then through our RNN (which wraps stacked LSTM cells + attention mechanism)
        # and then through our MDN layer
        outputs = self.rnn_layer(input_strokes, initial_state=initial_states_list, training=training)
        final_output = self.mdn_layer(outputs)
        return final_output
    def __call__(
        self,
        inputs: jnp.ndarray,
        char_seq: Optional[jnp.ndarray] = None,
        char_lens: Optional[jnp.ndarray] = None,
        initial_state: Optional[RNNState] = None,
        return_state: bool = False,
    ) -> jnp.ndarray:
        batch_size, seq_len, _ = inputs.shape

        if initial_state is None:
            h = jnp.zeros((self.config.num_layers, batch_size, self.config.hidden_size), inputs.dtype)
            c = jnp.zeros_like(h)
            kappa = jnp.zeros((batch_size, self.config.num_attention_gaussians), inputs.dtype)
            window = jnp.zeros((batch_size, self.config.alphabet_size), inputs.dtype)
        else:
            h, c = initial_state.hidden, initial_state.cell
            kappa, window = initial_state.kappa, initial_state.window

        def step(carry, x_t):
            h, c, kappa, window = carry
            h_layers = []
            c_layers = []

            # layer1
            if self.synthesis_mode:
                layer1_input = jnp.concatenate([window, x_t], axis=-1)
            else:
                layer1_input = x_t

            h1, c1 = self.lstm_cell(layer1_input, h[0], c[0], 0)
            h_layers.append(h1)
            c_layers.append(c1)

            # layer1 -> attention
            if self.synthesis_mode and char_seq is not None and char_lens is not None:
                window, kappa = self.compute_attention(h1, kappa, window, x_t, char_seq, char_lens)

            # attention -> layer2 and layer3
            for layer_idx in range(1, self.config.num_layers):
                if self.synthesis_mode:
                    layer_input = jnp.concatenate([x_t, h_layers[-1], window], axis=-1)
                else:
                    layer_input = jnp.concatenate([x_t, h_layers[-1]], axis=-1)

                h_new, c_new = self.lstm_cell(layer_input, h[layer_idx], c[layer_idx], layer_idx)
                h_layers.append(h_new)
                c_layers.append(c_new)

            h_new = jnp.stack(h_layers)
            c_new = jnp.stack(c_layers)

            # mdn output from final hidden state
            mdn_out = self.mdn_layer(h_layers[-1])  # [B, 6M+1]

            return (h_new, c_new, kappa, window), mdn_out

        # this was the major unlock for JAX performance
        # it allows us to vectorize the computation over the time dimension
        # transpose inputs from [B, T, 3] to [T, B, 3] for scan
        inputs_transposed = inputs.swapaxes(0, 1)
        (h, c, kappa, window), outputs = jax.lax.scan(step, (h, c, kappa, window), inputs_transposed)

        # transpose back
        outputs = outputs.swapaxes(0, 1)

        if return_state:
            final_state = RNNState(hidden=h, cell=c, kappa=kappa, window=window)
            return outputs, final_state

        return outputs

Final Result

Alright finally! So what do we have, and what can we do now?

We now are going to feed the output from our LSTM cascade into the GMM in order to build a probabilistic prediction model for the next stroke. The GMM will then be fed the actual next point, in order to create some idea of the deviation os that the loss can be properly minimized.

🏋️ Training Results

Vast AI GPU Enabled Execution

Vast AI GPU Enabled Running


2024-04-21 19:01:02.183969: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
train data found. Loading...
test data found. Loading...
valid2 data found. Loading...
valid1 data found. Loading...
2024-04-21 19:01:04.798925: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22455 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:82:00.0, compute capability: 8.6
2024-04-21 19:01:05.887036: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-04-21 19:01:05.887070: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.
2024-04-21 19:01:05.887164: I external/local_xla/xla/backends/profiler/gpu/cupti_tracer.cc:1239] Profiler found 1 GPUs
2024-04-21 19:01:05.917572: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.
2024-04-21 19:01:05.917763: I external/local_xla/xla/backends/profiler/gpu/cupti_tracer.cc:1364] CUPTI activity buffer flushed
Epoch 1/10000
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1713726072.109654 2329 service.cc:145] XLA service 0x7ad5bc004600 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1713726072.109731 2329 service.cc:153] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2024-04-21 19:01:12.346749: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
W0000 00:00:1713726072.691839 2329 assert_op.cc:38] Ignoring Assert operator assert_greater/Assert/AssertGuard/Assert
W0000 00:00:1713726072.694098 2329 assert_op.cc:38] Ignoring Assert operator assert_greater_1/Assert/AssertGuard/Assert
W0000 00:00:1713726072.696267 2329 assert_op.cc:38] Ignoring Assert operator assert_near/Assert/AssertGuard/Assert
2024-04-21 19:01:13.095183: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8906
2024-04-21 19:01:14.883021: W external/local_xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below 17.97GiB (19297974672 bytes) by rematerialization; only reduced to 20.51GiB (22027581828 bytes), down from 20.67GiB (22193496744 bytes) originally
I0000 00:00:1713726076.329853 2329 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
167/168 ━━━━━━━━━━━━━━━━━━━━ 0s 685ms/step - loss: 2.8113W0000 00:00:1713726191.427557 2333 assert_op.cc:38] Ignoring Assert operator assert_greater/Assert/AssertGuard/Assert
W0000 00:00:1713726191.429182 2333 assert_op.cc:38] Ignoring Assert operator assert_greater_1/Assert/AssertGuard/Assert
W0000 00:00:1713726191.430622 2333 assert_op.cc:38] Ignoring Assert operator assert_near/Assert/AssertGuard/Assert
2024-04-21 19:03:13.488256: W external/local_xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below 17.97GiB (19298282069 bytes) by rematerialization; only reduced to 19.75GiB (21203023676 bytes), down from 19.87GiB (21340423652 bytes) originally
168/168 ━━━━━━━━━━━━━━━━━━━━ 0s 709ms/step - loss: 2.8097
Epoch 1: Saving model.

Epoch 1: Loss improved from None to 0.0, saving model.
Model parameters after the 1st epoch:
Model: "deep_handwriting_synthesis_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ lstm_peephole_cell │ ? │ 764,400 │
│ (LSTMPeepholeCell) │ │ │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ lstm_peephole_cell_1 │ ? │ 1,404,400 │
│ (LSTMPeepholeCell) │ │ │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ lstm_peephole_cell_2 │ ? │ 1,404,400 │
│ (LSTMPeepholeCell) │ │ │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ attention (AttentionMechanism) │ ? │ 14,310 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ attention_rnn_cell │ ? │ 3,587,510 │
│ (AttentionRNNCell) │ │ │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ rnn (RNN) │ ? │ 3,587,510 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ mdn (MixtureDensityLayer) │ ? │ 48,521 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 7,272,064 (27.74 MB)
Trainable params: 3,636,031 (13.87 MB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 3,636,033 (13.87 MB)

All parameters:

[[lstm_peephole_kernel1]] shape: (76, 1600)
[[lstm_peephole_recurrent_kernel1]] shape: (400, 1600)
[[lstm_peephole_weights1]] shape: (400, 3)
[[lstm_peephole_bias1]] shape: (1600,)
[[lstm_peephole_kernel2]] shape: (476, 1600)
[[lstm_peephole_recurrent_kernel2]] shape: (400, 1600)
[[lstm_peephole_weights2]] shape: (400, 3)
[[lstm_peephole_bias2]] shape: (1600,)
[[lstm_peephole_kernel3]] shape: (476, 1600)
[[lstm_peephole_recurrent_kernel3]] shape: (400, 1600)
[[lstm_peephole_weights3]] shape: (400, 3)
[[lstm_peephole_bias3]] shape: (1600,)
[[kernel]] shape: (476, 30)
[[bias]] shape: (30,)
[[mdn_W_pi]] shape: (400, 20)
[[mdn_W_mu]] shape: (400, 40)
[[mdn_W_sigma]] shape: (400, 40)
[[mdn_W_rho]] shape: (400, 20)
[[mdn_W_eos]] shape: (400, 1)
[[mdn_b_pi]] shape: (20,)
[[mdn_b_mu]] shape: (40,)
[[mdn_b_sigma]] shape: (40,)
[[mdn_b_rho]] shape: (20,)
[[mdn_b_eos]] shape: (1,)

Trainable parameters:

(same here)

Trainable parameter count:

3636031
168/168 ━━━━━━━━━━━━━━━━━━━━ 133s 728ms/step - loss: 2.7931
Epoch 2/10000
60/168 ━━━━━━━━━━━━━━━━━━━━ 1:14 686ms/step - loss: 2.4870


Ok so that’s all well and good and some fun math and neural network construction, but the meat of this project is about what we’re actually building with this theory. So let’s lay out our to-do list.

Problem #1 - Gradient Explosion Problem

Somehow on my first run through of this, I was still getting explodient gradients in the later stages of training my model.

As a result, I chose the laborious and time consuming process to run the training model on CPU so that I could print out debugging information and then run tensorboard’s Debugger model so I could inspect which gradients were exploding to nan or dreaded inf.

Here’s an example of what that looked like:

tensorboard

Which was even more annoying because of this: https://github.com/tensorflow/tensorflow/issues/59215 issue.

Problem #2 - OOM Galore

Uh oh, looks like the vast.ai instance I utilized didn’t have enough memory. Here is an example of one of the errors I ran into:

Out of memory error here
Out of memory while trying to allocate 22271409880 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   29.54MiB
              constant allocation:         4B
        maybe_live_out allocation:   27.74MiB
     preallocated temp allocation:   20.74GiB
                 total allocation:   20.77GiB
Peak buffers:
        Buffer 1:
                Size: 3.40GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,476,1600]
                ==========================

        Buffer 2:
                Size: 3.40GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,476,1600]
                ==========================

        Buffer 3:
                Size: 2.86GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul_1/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,400,1600]
                ==========================

        Buffer 4:
                Size: 2.86GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul_1/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,400,1600]
                ==========================

        Buffer 5:
                Size: 2.86GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul_1/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,400,1600]
                ==========================

        Buffer 6:
                Size: 556.64MiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_1/MatMul/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,76,1600]
                ==========================

        Buffer 7:
                Size: 219.73MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,10,75]
                ==========================

        Buffer 8:
                Size: 219.73MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,10,75]
                ==========================

        Buffer 9:
                Size: 219.73MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,10,75]
                ==========================

        Buffer 10:
                Size: 139.45MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,476]
                ==========================

        Buffer 11:
                Size: 139.45MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,476]
                ==========================

        Buffer 12:
                Size: 139.45MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,476]
                ==========================

        Buffer 13:
                Size: 117.19MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,400]
                ==========================

        Buffer 14:
                Size: 117.19MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,400]
                ==========================

        Buffer 15:
                Size: 117.19MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,400]
                ==========================


         [[]]


Sanity Check - Validating Model Dimensions (with AI… so somewhat)

So where does AI using AI come in? I wanted to validate that the shapes of my training parameters all looked good. Again, I print out the shapes on the very first epoch to get some more details and do a quick dimensionality alignment check. I then fed that into ChatGPT along with the paper and asked it to double check all my params. Here was it’s output:

Model Summary

Model Summary

Model Name: deep_handwriting_synthesis_model
Total Parameters: 7,272,063 (27.74 MB)
Trainable Parameters: 3,636,031 (13.87 MB)
Non-Trainable Parameters: 0 (0.00 B)
Optimizer Parameters: 3,636,032 (13.87 MB)


Layer-wise Breakdown

1. LSTM Peephole Cells

You have three LSTM peephole cells:

  • lstm_peephole_cell
  • lstm_peephole_cell_1
  • lstm_peephole_cell_2
Parameters and Shapes

For each LSTM peephole cell:

  • Kernel Shape: (input_dim, 4 * units)
  • Recurrent Kernel Shape: (units, 4 * units)
  • Peephole Weights Shape: (units, 3)
  • Bias Shape: (4 * units,)
Shapes:

First LSTM Peephole Cell (lstm_peephole_cell):

  • Input Dimension: 76
  • Units: 400
  • Kernel Shape: (76, 1600)    (76, 4 * 400)
  • Recurrent Kernel Shape: (400, 1600)    (400, 4 * 400)
  • Peephole Weights Shape: (400, 3)    (400 units, 3 gates)
  • Bias Shape: (1600,)    (4 * 400,)

Second and Third LSTM Peephole Cells (lstm_peephole_cell_1 and lstm_peephole_cell_2):

  • Input Dimension: 476
    This includes concatenated inputs from previous layers and attention outputs.
  • Units: 400
  • Kernel Shape: (476, 1600)    (476, 4 * 400)
  • Recurrent Kernel Shape: (400, 1600)    (400, 4 * 400)
  • Peephole Weights Shape: (400, 3)    (400 units, 3 gates)
  • Bias Shape: (1600,)    (4 * 400,)
Conclusion

The shapes of the LSTM peephole cells’ parameters are correct and align with standard LSTM architectures with peephole connections.


2. Attention Mechanism

  • Layer: attention (AttentionMechanism)
Parameters and Shapes
  • Kernel Shape: (476, 30)
  • Bias Shape: (30,)

Explanation:

  • Input Dimension: 476
    This likely comes from concatenating the window vector w, the current input x_t, and the output from the first LSTM layer s1_out.
  • Output Dimension: 30
    Corresponds to parameters needed for the attention mechanism, such as the attention weights.
Conclusion

The attention mechanism’s parameter shapes are appropriate.


3. Mixture Density Network (MDN) Layer

  • Layer: mdn (MixtureDensityLayer)
Parameters and Shapes
  • mdn_W_pi: (400, 20)
  • mdn_b_pi: (20,)
  • mdn_W_mu: (400, 40)
  • mdn_b_mu: (40,)
  • mdn_W_sigma: (400, 40)
  • mdn_b_sigma: (40,)
  • mdn_W_rho: (400, 20)
  • mdn_b_rho: (20,)
  • mdn_W_eos: (400, 1)
  • mdn_b_eos: (1,)

Explanation:

  • Hidden Units from Last LSTM Layer: 400
  • Number of Mixture Components: 20
MDN Outputs:
  • pi (mixture weights): 20 components
  • mu (means): 2 coordinates per component * 20 components = 40 outputs
  • sigma (standard deviations): 2 per component * 20 components = 40 outputs
  • rho (correlation coefficients): 20 components
  • eos (end-of-stroke probability): 1 output
Parameter Shapes:
  • Weights:
  • mdn_W_pi: (400, 20)
  • mdn_W_mu: (400, 40)
  • mdn_W_sigma: (400, 40)
  • mdn_W_rho: (400, 20)
  • mdn_W_eos: (400, 1)
  • Biases:
  • mdn_b_pi: (20,)
  • mdn_b_mu: (40,)
  • mdn_b_sigma: (40,)
  • mdn_b_rho: (20,)
  • mdn_b_eos: (1,)
Conclusion

The MDN layer’s parameter shapes correctly correspond to the required outputs for a mixture of 20 bivariate Gaussian components.


4. Total Parameters and Trainable Parameters

  • Total Parameters: 7,272,063
  • Trainable Parameters: 3,636,031

Observation:

  • The total parameters are approximately double the trainable parameters.
  • This may be due to counting both the weights and optimizer states (like moments in the Adam optimizer).
  • Optimizer Parameters: 3,636,032 (slightly off by one, which could be due to rounding or reporting).
Conclusion

The parameter counts make sense, and the discrepancy is acceptable.


Your Model’s Parameters

All Parameters

All parameters listed are correctly shaped and correspond to their respective layers.

Trainable Parameters

All trainable parameters match the parameters in the model and have the correct shapes.


Answering Your Question

Yes, based on the model summary and the parameter shapes you’ve provided, everything appears to be in order. The shapes align with the expected dimensions > given your model’s architecture:

  • LSTM Peephole Cells:
  • Input dimensions and units are consistent.
  • Kernels and recurrent kernels have correct shapes.
  • Peephole weights and biases are appropriately sized.

  • Attention Mechanism:
  • The input and output dimensions match the concatenated inputs and expected output size.

  • MDN Layer:
  • The weights and biases correspond correctly to the number of mixture components and the outputs required for the mixture density network.

  • Parameter Counts:
  • The total number of parameters is reasonable for a model of this complexity.
  • The trainable parameters are correctly identified.

Final Thoughts

  • Proceed with Training: You can confidently continue running your training code.

  • Monitoring: Keep an eye on training metrics such as loss, especially since you have confirmed that negative loss values are acceptable in your MDN context.
  • Validation: Consider using a validation set to monitor overfitting and generalization performance.
  • Performance Checks: If you encounter any unexpected behavior during training, revisit the data preprocessing steps and ensure that the inputs to your model are correctly shaped and normalized.

Additional Recommendations

  • Logging: Utilize TensorBoard or similar tools to visualize training progress.
  • Batch Size and Sequence Length: Ensure that your batch sizes and sequence lengths are manageable to prevent memory issues, as discussed earlier.
  • Model Saving: Regularly save model checkpoints to avoid losing progress in case of interruptions.


✍️ Visualizations

Learning with Dummy Data

Again, we used dummy data to start with to ensure our various components were learning and converging correctly.

I’m not going to burn too many pixels with these visualizations given I think they’re less interesting.

Here is our entire network and just sampling from the means (not showing the mixture densities) across the entire example datasets. One thing to note here if you can see how the LSTMs can still handle this type of larger contexts. Again, it pales in comparison to modern day transformer context, but still impressive.

handwriting_loop_lstm_simple

handwriting_zig_lstm_simple

Synthesis Model Sampling

So again, given the above information, $\phi(t, u)$ represents the networks belief that it’s writing character $c_u$ at time $t$. It’s monotonically increasing (which makes sense and is enforced mathematically) and we can see its pretty stepwise increasing.

One of my favorite portions of these visualizations is the mixture components weights. You can see the various Gaussians activating for different parts of the synthesis network. For example, for end of stroke signals, we have separate Gaussians owning that portion of the model.

Most of these were generated like so:

╭─johnlarkin@Mac ~/Documents/coding/generative-handwriting-jax main*
╰─➤  uv run python generative_handwriting/generate/generate_handwriting_cpu.py \
    --checkpoint "checkpoints_saved/synthesis/loss_-2.59/checkpoint_216_cpu.pkl" \
    --text "It has to be symphonic" \
    --bias "0.75" \
    --temperature "0.75" \
    --fps "60" \
    --formats "all" \
    --seed "42"

Another note is… my termination condition logic probably could be improved. Remember, we’re doing one-hot encoding which includes the null terminator. So the null term should be at len(line_text). Attention spans the full sequence. Specifically, $\phi$ has shape [batch, char_seq_length] so we can get our single sample (i.e. batch of 0), and then look at the char sequence length to basically see where our attention is at. In code speak, here’s what I’m doing:

        # char_seq includes null terminator at index len(line_text)
        if phi is not None and t >= len(line_text) * 2:
            char_idx = int(jnp.argmax(phi[0]))
            sampled_eos = stroke[2] > 0.5

            # we can stop when:
            # 1. attention has reached the null terminator (char_idx == len(line_text)) AND we sampled EOS
            # 2. attention weight on null terminator is dominant (> 0.5)
            # 3. we're well past the text and sampled EOS multiple times
            null_attention = float(phi[0][len(line_text)]) if len(phi[0]) > len(line_text) else 0.0

            if char_idx == len(line_text) and sampled_eos:
                # this hits most
                break
            elif null_attention > 0.5:
                # attention strongly focused on null terminator
                break
            elif char_idx >= len(line_text) and t > len(line_text) * 10:
                # failsafe: past text and generated way too much
                break

Finally, on the visualization front, I’m generating everything with bias 0.75 and temperature 0.75. I’m not going to discuss those, but the original paper goes into more detail.


Heart has its reasons

One thing to note is that we are still constrained by line length. For example, if we try to specify this as a single line, we start to lose our attention and the context starts to fail. Part of this is that if we exceed the line length that we trained on (in terms of stroke sequences or input text length), then we start to flail.

So note the discrepancy between these two when we introduce a line break:

heart-mdn-aggregate

vs

heart-oneliner-mdn-aggregate

You can see how the model is less trained given the higher deviations towards the end of the line. Note, these MDN heatmap graphs on the bottom are created by showing the three highest weighted $\pi$ components per timestamp and then aggregating them across all timestamps.

Furthermore, the eos signals generally have the highest uncertainty and most spread out sigmas which makes sense given it’s the highest variable point.


Loved and lost

better-to-have-loved-writing

better-to-have-loved-mdn


It has to be symphonic

symphonic-writing

symphonic-sampling


Is a model a lie?

model-lie-writing

model-lie-writing-colored

model-lie-mdn


Fish folly

Hadn’t ever heard of this one but it’s my best friend’s favorite quote.

fish-folly-attention

fish-folly-attention

Conclusion

This - again - was a bit of a bear of a project. It was maybe not my best use of time, but it was a labor of love.

I don’t think I’ll embark on a project of this nature in awhile (sadly). However, I hoped the reader has enjoyed. And feel free to pull the code and dive in yourself.

When re-reading my old draft blog post, I liked the way I ended things. So here it is:

Finally, I want to leave with a quote from our academic advisor Matt Zucker. When I asked him when we know that our model is good enough, he responded with the following.

“Learning never stops.”