<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://johnlarkin1.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://johnlarkin1.github.io/" rel="alternate" type="text/html" /><updated>2026-04-03T23:33:37+00:00</updated><id>https://johnlarkin1.github.io/feed.xml</id><title type="html">Where will you go next?</title><subtitle>John Larkin&apos;s personal coding blog and thought journal. Exploring various ideas in software engineering, math, data, and more in my (limited) free time.
</subtitle><author><name>johnlarkin1</name></author><entry><title type="html">Vanity Mirror</title><link href="https://johnlarkin1.github.io/2026/larkin-vanity-mirror/" rel="alternate" type="text/html" title="Vanity Mirror" /><published>2026-02-08T00:00:00+00:00</published><updated>2026-02-08T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2026/larkin-vanity-mirror</id><content type="html" xml:base="https://johnlarkin1.github.io/2026/larkin-vanity-mirror/"><![CDATA[<h1 id="context">Context</h1>

<p>In the age of AI, shipping has become easier than ever. And also, borderline more addictive than ever. I will certainly rant about it in a Substack post at some point in the future, but as my projects grow, I wanted an easy way to keep track of various metrics (sadly, I’m curious and vain and want to see what people like). I built this “Vanity Mirror” as a way to do that, and figured it was fine to share publically (even if another <a href="https://react2shell.com/">React2Shell</a> RCE vulnerability occurs… are hackers really gonna want limited read-only permissions to my Google Analytics properties?).</p>

<h1 id="demo">Demo</h1>

<p>Feel free to check out the website here:</p>

<div class="project-registry">
  <a href="https://larkin-vanity-mirror.vercel.app" target="_blank" rel="noopener" class="registry-card web">
    <span class="lang-icon">🪞</span>
    <span class="lang-badge">Web</span>
    <span class="registry-name">Vanity Mirror</span>
  </a>
</div>

<p>But also it’s embedded here:</p>

<div class="vanity-mirror-iframe-wrapper">
  <iframe class="vanity-mirror-iframe" src="https://larkin-vanity-mirror.vercel.app/blog" title="Larkin Vanity Mirror Dashboard" width="1440" height="900" loading="lazy" allowfullscreen="">
  </iframe>
</div>

<h1 id="domain-name">Domain Name?</h1>

<p>I was too lazy / too broke to buy an official domain (although I’m sure the market for <code class="language-plaintext highlighter-rouge">larkin-vanity-mirror.xyz</code> can’t be too high). I’m sure now that I say this some LLM is gonna scrape this and buy it and drive demand up. c’est la vie.</p>

<h1 id="favorite-part">Favorite Part</h1>

<p>Regardless, my favorite part about this is that I took the shortcut of making this a <a href="https://developer.mozilla.org/en-US/docs/Web/Progressive_web_apps">PWA</a> so now it’s very easily hooked up into my mobile experience.</p>

<div class="video-container-mobile">
  <div class="video-wrapper-dark">
    <video src="https://www.dropbox.com/scl/fi/f6jbjd325w1irgknjw4l5/vanity-mirror-screen-recording.mp4?rlkey=swt2k91wtift2epfeqtf2z3oh&amp;st=gl332k68&amp;raw=1" muted="" autoplay="" loop="" controls="">
    </video>
  </div>
</div>

<p>To download, just follow these steps:</p>

<ol>
  <li>Go to <a href="https://larkin-vanity-mirror.vercel.app/"><strong>larkin-vanity-mirror.vercep.app</strong></a> on your mobile device</li>
  <li>Click context menu three dots in bottom right (on newer iOS)</li>
  <li>Click <code class="language-plaintext highlighter-rouge">Share</code></li>
  <li>Scroll down and go to <code class="language-plaintext highlighter-rouge">Add to Home Screen</code></li>
  <li>Voila 🎉</li>
</ol>

<p>Feel free to email / let me know if enough interest and I can try to generalize it. Although honestly, at this point, jinja doesn’t seem to have much value over just ripping CC.</p>

<p>Thanks!</p>]]></content><author><name>johnlarkin1</name></author><category term="Development" /><category term="Reflection" /><summary type="html"><![CDATA[Context]]></summary></entry><entry><title type="html">Multi Armed Bandit</title><link href="https://johnlarkin1.github.io/2026/multi-armed-bandit/" rel="alternate" type="text/html" title="Multi Armed Bandit" /><published>2026-02-01T00:00:00+00:00</published><updated>2026-02-01T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2026/multi-armed-bandit</id><content type="html" xml:base="https://johnlarkin1.github.io/2026/multi-armed-bandit/"><![CDATA[<!-- 
<div class="markdown-alert markdown-alert-caution">
<p>This was meant to be for a take-home... I submitted some version of the first draft, but then couldn't stop and here we are. If parts trail off, it's because I shouldn't have even gone this deep into other more pressing matter
</p>
</div>

<br> -->

<div class="project-registry">
  <a href="https://github.com/johnlarkin1/multi-armed-bandit" target="_blank" rel="noopener" class="registry-card github">
    <span class="lang-icon">🐙</span>
    <span class="lang-badge">Source</span>
    <span class="registry-name">GitHub</span>
  </a>
</div>

<h1 id="motivation">Motivation</h1>

<p>Here is a motivating visual to build up some momentum to read on. This is our dashboard tool to compare various multi-armed bandit strategies. We’ll understand this more thoroughly at the end of this blog post.</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="/videos/multi-armed-bandit/multi-armed-bandit.mp4" type="video/mp4" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<h1 id="context">Context</h1>

<p>Recently, I responded to some recruiters and fielded a couple of interviews.</p>

<p>I generally abhor interviewing. There are parts I absolutely love - meeting new people, learning about new technical challenges, studying up on businesses or industries - but there are also parts I <em>abhor</em>. Getting grilled on usage of the Web Speech API (man oh man was I in the wrong interview) or how to <a href="https://leetcode.com/problems/decode-string/description/">decode a string</a> in 2026 does feel… a bit perplexing. I’ll rant about it on Substack at some point in time.</p>

<p>However! I do genuinely enjoy take homes (as exemplified by <a href="/2024/book-brain">Book Brain</a>). Despite often it being a bigger time constraint, and more of a commitment.</p>

<p>This blog post is going to go over a concept and problem that (embarrassingly enough), I hadn’t yet seen before the take home. For more context, I had accepted another offer in the same timeframe, and withdrew from this specific takehome process. It’s unfortunate too because I do genuinely believe the company will be a $10BN company in no time, and the engineering seems fascinating.</p>

<p>While I ultimately withdrew from this interviewing cycle, and sent them only my thoughts on the problem, this blog post is going to talk about a take home question I received from that company. I’m anonymizing the company to keep the sanctity of their interview process.</p>

<p>The company restricted Ai usage during the take, so I did a ton of research / youtube videos. However, for this blog post, some details of implementation will be left to Claude. The repo has documentation and detail included various transcripts between Claude and I. So let’s begin with the problem.</p>

<h1 id="setup">Setup</h1>

<p>This blog post is going to focus on the <a href="https://en.wikipedia.org/wiki/Multi-armed_bandit">multi-armed bandit</a> problem, which is commonly abbreviated as MAB. There is a lot here, so I won’t be able to cover everything, but I’ll cover the parts that the corresponding Github repo covers.</p>

<h1 id="multi-armed-bandit-problem-mab">Multi-Armed Bandit Problem (MAB)</h1>

<p>The <a href="https://en.wikipedia.org/wiki/Multi-armed_bandit">traditional multi-armed bandit</a> is pretty well encapsulated by a hypothetical situation. I’ll give you the long / fun version, and then I’ll give you an abbreviated Wikipedia version.</p>

<hr />

<p>Imagine, you wake up.</p>

<p><img src="/assets/svg/multi-armed-bandit/life-is-full.svg" alt="Life is full" class="center-super-medium-shrink lightbox-image" /></p>

<p>You live in a beautiful city (let’s say Cincinnati).</p>

<p><img src="https://assets.simpleviewinc.com/sv-cincy/image/upload/c_fill,h_840,q_75,w_1200/v1/cms_resources/clients/cincy/msvachphotography_Instagram_1244_ig_17864840302845987_89e72393-d2a6-4837-bb9c-865845b1366b.jpg" alt="Cincinnati Skyline" class="center-shrink lightbox-image" /></p>

<div class="image-caption">Kudos to @msvachphotography for the shot from Mt. Echo Park</div>
<p><br /></p>

<p>But then you realize you have too much money in your pockets. You decide to gamble (i discourage this, especially after seeing how the sausage is made).</p>

<p>So you hit the casino!</p>

<p><img src="/assets/svg/multi-armed-bandit/too-much-mula.svg" alt="Life is full" class="center-shrink lightbox-image" /></p>

<p>However, because it’s Cincinnati, this is a very nice casino. You actually have a chance to win. However, they only have single-armed bandits - commonly known as slot machines! These are unique slot machines, and their underlying probably distributions become more apparent over time.</p>

<p>Despite having too much money in your pockets, you love winning, so you do want to win. Your problem therefore is to figure out the optimal strategy for which machines to play, when to play those machines, how many times to play them, and when you need to switch.</p>

<hr />

<p>Wikipedia more blandly (but also more succinctly) puts this as:</p>

<blockquote>
  <p>More generally, it is a problem in which a decision maker iteratively selects one of multiple fixed choices (i.e., arms or actions) when the properties of each choice are only partially known at the time of allocation, and may become better understood as time passes. A fundamental aspect of bandit problems is that choosing an arm does not affect the properties of the arm or other arms.[4]</p>
</blockquote>

<h2 id="stochastic-mab-approaches">Stochastic MAB Approaches</h2>

<p>Before we go any further, let’s fully dissect this problem.</p>

<p>There are really two main focuses that I covered in code and fully studied up on. I will not be talking about $\epsilon$-greedy approaches, but here are <a href="https://www.geeksforgeeks.org/machine-learning/epsilon-greedy-algorithm-in-reinforcement-learning/">some</a> <a href="https://www.geeksforgeeks.org/machine-learning/epsilon-greedy-algorithm-in-reinforcement-learning/">other</a> <a href="https://www.geeksforgeeks.org/machine-learning/epsilon-greedy-algorithm-in-reinforcement-learning/">resources</a>. We’re actually going to focus on UCB vs Thompson Sampling, which are two methods that work very well. I’ll discuss further below in the implementation about my thoughts about how I modified them to handle the take-home explicitly.</p>

<h3 id="upper-confidence-bound"><a href="https://en.wikipedia.org/wiki/Upper_Confidence_Bound">Upper Confidence Bound</a></h3>

<p>The theory behind UCB is that we are trying to optimistically explore. UCB1 is meant to balance the level of exploration vs exploitation.</p>

<p>I am not going to go into the full derivation, but it references something called <a href="https://en.wikipedia.org/wiki/Hoeffding%27s_inequality">Hoeffding’s Inequality</a> to build up a framework.</p>

<p>It eventually lets us get to:</p>

\[UCB_i(t) = \bar{x}_i + \underbrace{c \cdot \sqrt{\frac{\ln(t)}{n_i}}}_{\text{exploration bonus}}\]

<p>Where:</p>
<ul>
  <li>$\bar{x}_i$ = empirical success rate of server $i$</li>
  <li>$t$ = total number of requests across all servers</li>
  <li>$n_i$ = number of times server $i$ has been tried</li>
  <li>$c$ = exploration constant (default: $\sqrt{2}$)</li>
</ul>

<p>Normally, you’ll see this kind of folded up with $c$ being part of the square root, but that exploration bonus was key in my modified UCB approach.</p>

<h3 id="thompson-sampling"><a href="https://en.wikipedia.org/wiki/Thompson_sampling">Thompson Sampling</a></h3>

<p>With this approach, the derivation can actually make a bit more sense (in my opinion). It’s also (probably relatedly) the approach I like the most.</p>

<p>We model the process for the specific outcome of the arm $a$ as a <a href="https://en.wikipedia.org/wiki/Bernoulli_distribution">Bernoulli distribution</a>. Basically, it means we have a $p$ probability of getting a 1 (in this case, a reward, in our specific case further down - a successful downstream server request). The value 0 has a probably $q = 1 - p$ of occurring.</p>

<p>We can then model this uncertainty about the Bernoulli parameter $p$ as a <a href="https://en.wikipedia.org/wiki/Beta_distribution">beta distribution</a>. We’re trying to figure out the probability $p$ for each arm $a$ (or further on as we’ll see, the downstream server).</p>

<p>Think of using our beta distribution as a heuristic for what we actually think about each arm. With Thompson sampling, we’re basically maintaining a best guess distribution for each of the arms and updating it as we go and learn more information. I believe the technical term for this is that we’re using a beta distribution as a prior <em>and</em> our posterior given we are assuming a beta distribution in both cases.</p>

<p>Formally, the beta distribution has a $\alpha$ and a $\beta$ that control the shape of the distribution. They are exponents of the variable and the variable’s complement respectively. So again, this can be written as:</p>

\[f(x; \alpha, \beta) = \text{constant} \cdot x^{\alpha - 1} \, (1 - x)^{\beta - 1}\]

<p>Then our logic is pretty straight forward given how we’re modeling this. For every success of the arm, we can update our $\alpha$ with a simple $\alpha’ = \alpha + 1$ and for every failure, we can update our $\beta$ (given it’s modelling the complement) as $\beta’ = \beta + 1$.</p>

<p>A picture is worth a thousand words, so an interactive visualization must be worth at least a million right? This is a Claude generated vanilla JS + Chart.js artifact. I’d recommend autoplaying or doing the <code class="language-plaintext highlighter-rouge">Run Thompson Round</code>, but you can also see results by adding success and failures to the various arms. The main point is that you’ll see how our beta distributions should steadily converge to the real $p$ with increasing accuracy.</p>

<div class="interactive-beta-viz" data-arms="3" data-true-probs="0.7,0.4,0.55"></div>

<h1 id="multi-armed-bandit-variants">Multi-Armed Bandit Variants</h1>

<p>The situation I described above is really the stochastic MAB. There’s a finite set of arms, and the reward distribution is unknown. As I learned throughout this process, there are many variants and generalizations of this problem. Specifically, these are <em>generalizations</em> where the MAB is extended by adding some information or structure to the problem. Namely:</p>

<ul>
  <li><a href="https://en.wikipedia.org/wiki/Multi-armed_bandit#Adversarial_bandit">adversarial bandits</a>
    <ul>
      <li>this is probably my favorite variant. the notion is that you have an adversary that is trying to <strong>maximize</strong> your regret, while you’re trying to minimize your regret. so they’re basically trying to trick or con your algorithm.</li>
      <li>if you’re asking yourself (like I did), ok well then why doesn’t the adversary just assign $r_{a,t} = 0$ as the reward function for all arms $a$ at time $t$, well… you shouldn’t really think about it in terms of reward. Reward is relative. We instead want to think about it in terms of <em>regret</em> which I’ll talk more about later. There are two subvariants (<a href="https://www.cs.cornell.edu/~rdk/papers/anytime.pdf">oblivious adversary</a> and <a href="https://ui.adsabs.harvard.edu/abs/2006cs........2053D/abstract">adaptive adversary</a>), but we’re not going to discuss those - although a very interesting extension is the <a href="https://en.wikipedia.org/wiki/Multi-armed_bandit#:~:text-%5Bedit%5D-,Exp3,-%5Bedit%5D">EXP3</a> algorithm.</li>
    </ul>
  </li>
  <li><a href="https://towardsdatascience.com/an-overview-of-contextual-bandits-53ac3aa45034/">contextual bandits</a>
    <ul>
      <li>the notion here is that instead of learning $E[r \mid a]$ where again $r$ is the reward and $a$ is the arm you pick, you’re learning $E[r \mid x, a]$ where $x$ is some additional bit of context at time $t$ that you’re exposed to.</li>
    </ul>
  </li>
  <li><a href="https://doogkong.github.io/2017/slides/Yue.pdf">dueling bandits</a>
    <ul>
      <li>an interesting variant where instead of being exposed to the reward, your information is limited to just picking two bandits and only knowing which one is better comparatively… but again it’s stochastic. So you can inquire about the same two arms and it’s very feasible that you’ll get different results for the comparison. The whole notion is that you’re building up this preference matrix. Seems like an incredibly difficult problem.</li>
    </ul>
  </li>
</ul>

<h1 id="bandit-with-knapsack-bwk-variant">Bandit with Knapsack (BwK) Variant</h1>

<p>I’m going to preempt the reader and discuss another variant, where I’ll spend a bit more time. That model is the Bandit with Knapsack problem.</p>

<p>The original paper is from <a href="https://sites.google.com/site/ashwinkumarbv/home">Ashwinkumar Badanidiyuru</a>, <a href="https://www.cs.cornell.edu/~rdk/">Robert Kleinberg</a>, and <a href="https://scholar.google.com/citations?user-f2x233wAAAAJ&amp;hl=en">Aleksandrs Slivkins</a>. People who I’d love to be an iota as smart as. You can see the paper <a href="https://www.alphaxiv.org/abs/1305.2545">here</a>. It’s a 55 page paper, and I’d be lying if I said I read past the <strong>Preliminaries</strong> section. Section 3+ have some heavy math that is over my head.</p>

<p>The problem statement is relatively simple though. Your arms now have resources associated with them that they consume. I honestly think it’s easier to draw it out mathematically and reference the actual paper (also shoutout to <a href="https://www.alphaxiv.org/">alphaxiv</a>, it’s got most of the normal [arvix] features, just with some ai native question answering and highlighting which has been nice).</p>

<h2 id="formal-declaration">Formal Declaration</h2>

<p>I’d like to state that the paper starts out with the generalized form of <em>many</em> resources being managed and consumed. It makes sense given it’s a professional paper and the general case is more interesting. However, you can imagine $d$ being 1 and that we have a single resource that we’re managing.</p>

<p>So again, we have $X$ finite arms from 1 to $m$. An individual arm can be declared as $x$. Formally, we can say</p>

\[X = \{ 1,\, 2,\, \ldots,\, x, \, \ldots, \,m-1,\, m \}\]

<p>There are $T$ rounds (which interestingly enough is known before time in this variant). So $t$ is the round at time $t$ (and one round per time increment).</p>

\[t = \{1,\,2,\, \ldots,\, T-1,\, T \}\]

<p>There are $d$ resources where $d \geq 1$ and the $d$ resources are indexed from $i$ from $1,\, \ldots,\, d$. (the $d$ in our specific example is going to be the number of servers still, because each server is its own rate limit).</p>

<p>So the problem now changes because at round $t$ when arm $x$ is pulled we now don’t just get a reward, but we instead get a reward and a consumption vector indicating how much of the resources were consumed. In other words,</p>

\[\left( r_t, c_{t,1}, \ldots , c_{t,d} \right)\]

<p>The paper declares this as $\pi_x$ where $\pi_x$ is an <strong>unknown latent distribution</strong> over $[0,1]^{d+1}$.</p>

<p>Now “latent spaces” have gotten a ton of usage since LLMs blew up, but basically this just means there is some distribution, and it is fixed, but it’s unknown to the learner.</p>

<p>Just to also break down the syntax because $[0,1]^{d+1}$ can be a bit misleading, but this just means</p>

\[[0,1]^{d+1} = \underbrace{[0,1] \times [0,1] \times \cdots \times [0,1]}_{d+1\ \text{times}}\]

<p>So it’s really just a vector of length $d+1$ (the +1 is because we have $d$ resources, but then one reward $r$, so it’s kind of a shorthand).</p>

<p>$\pi_x$ is a <strong>joint probability distribution</strong> over $(r, c_1, …, c_d)$, or 
\((r, c_1, ..., c_d) \sim \pi_x\)</p>

<p>meaning when you pull an arm, you draw one vector from this distribution.</p>

<p>This of course leads us to budgeting. Each resource $i$ has a budget where $B_i \geq 0$</p>

<p>The overall process stops as soon as we have exhausted <strong>ANY</strong> resource budget.</p>

<h2 id="algorithms-presented">Algorithms Presented</h2>

<p>The paper presents two algorithms - <code class="language-plaintext highlighter-rouge">BalancedExploration</code> and <code class="language-plaintext highlighter-rouge">PrimalDualWithBK</code>.</p>

<h3 id="balancedexploration"><code class="language-plaintext highlighter-rouge">BalancedExploration</code></h3>
<p>At a high level, <code class="language-plaintext highlighter-rouge">BalancedExploration</code> tries to explore as much as possible while avoiding suboptimal strategies. It tries to converge to a LP-perfect distribution. LP-perfect here is a LP-relaxation called LP-primal (also LP = linear programming). So basically if they can reduce some of the constraints in this LP-primal approach then they can have an optimal algorithm. This LP-primal not only reduces the constraints, but the LP assumes that we know the average reward for each arm and removes the uncertainty, and it lets us perform fractional tasks rather than full tasks (this gets into the integer programming formulation which is helpful for the second part).</p>

<p>The algorithm is “simple” as the authors put it, but somewhat abstracted. On each phase, it eliminates any mix of tasks that are obviously not LP-perfect. It creates a confidence interval of potentially perfect LP distributions.</p>

<p>Then for each task, it tries to explore that task as much as possible, and gathers the information. It then repeats until it runs out of time or resources.</p>

<p>Transparently, I get it at this level, but I don’t understand the underlying math pinning it. That confidence interval calculation is… unclear to me. And given I don’t even have an implementation for it in my repo (which is the point of this post).</p>

<p>Ah actually! Giving Claude enough context and framing for this, and it does make sense for my repo. It’s still using UCB / LCB for reward and cost respectively, and then forming that as the score. i.e.:</p>

<details>
  <summary style="padding: 10px; border-radius: 5px; cursor: pointer; color: #D77656; font-weight: bold; border: 1px solid rgba(215, 118, 86, 0.4);">
    <svg xmlns="http://www.w3.org/2000/svg" width="20" height="17" viewBox="12 22 96 70" style="vertical-align: middle; margin-right: 8px;">
      <path d="M0 0 C23.76 0 47.52 0 72 0 C72 9.24 72 18.48 72 28 C75.96 28 79.92 28 84 28 C84 32.62 84 37.24 84 42 C80.04 42 76.08 42 72 42 C72 46.62 72 51.24 72 56 C70.02 56 68.04 56 66 56 C66 60.62 66 65.24 66 70 C64.02 70 62.04 70 60 70 C60 65.38 60 60.76 60 56 C58.02 56 56.04 56 54 56 C54 60.62 54 65.24 54 70 C52.02 70 50.04 70 48 70 C48 65.38 48 60.76 48 56 C40.08 56 32.16 56 24 56 C24 60.62 24 65.24 24 70 C22.02 70 20.04 70 18 70 C18 65.38 18 60.76 18 56 C16.02 56 14.04 56 12 56 C12 60.62 12 65.24 12 70 C10.02 70 8.04 70 6 70 C6 65.38 6 60.76 6 56 C4.02 56 2.04 56 0 56 C0 51.38 0 46.76 0 42 C-3.96 42 -7.92 42 -12 42 C-12 37.38 -12 32.76 -12 28 C-8.04 28 -4.08 28 0 28 C0 18.76 0 9.52 0 0 Z" fill="#D77656" transform="translate(24,22)" />
      <path d="M0 0 C1.98 0 3.96 0 6 0 C6 4.62 6 9.24 6 14 C4.02 14 2.04 14 0 14 C0 9.38 0 4.76 0 0 Z" fill="#000000" transform="translate(78,36)" />
      <path d="M0 0 C1.98 0 3.96 0 6 0 C6 4.62 6 9.24 6 14 C4.02 14 2.04 14 0 14 C0 9.38 0 4.76 0 0 Z" fill="#000000" transform="translate(36,36)" />
    </svg>
    Claude Code: BalancedExploration Explanation
  </summary>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>

 ▐▛███▜▌   Claude Code v2.1.17
▝▜█████▛▘  Opus 4.5 · Claude Max
  ▘▘ ▝▝    ~/Documents/coding/multi-armed-bandit

❯ /clear
  ⎿  (no content)

❯ Can you explain how the `BalancedExploration` BwK algorithm could be applied here?

⏺ The Balanced Exploration algorithm from the Bandits with Knapsacks (BwK) literature is a natural fit for your load balancer because you
  have a resource constraint built into the problem: the retry penalty system.

  BwK Problem Mapping
  ┌──────────────────────┬─────────────────────────────────────────────────────────────┐
  │     BwK Concept      │                     Your Load Balancer                      │
  ├──────────────────────┼─────────────────────────────────────────────────────────────┤
  │ Arms                 │ Downstream servers (T1/T2/T3 tiers)                         │
  ├──────────────────────┼─────────────────────────────────────────────────────────────┤
  │ Reward               │ Successful request (+1)                                     │
  ├──────────────────────┼─────────────────────────────────────────────────────────────┤
  │ Resource consumption │ Retries needed (penalty after 3 attempts)                   │
  ├──────────────────────┼─────────────────────────────────────────────────────────────┤
  │ Budget               │ Implicit via score formula: success - 0.5 * penalty_retries │
  └──────────────────────┴─────────────────────────────────────────────────────────────┘
  How Balanced Exploration Works

  Standard MAB algorithms (like your Thompson/UCB strategies) only track reward (success rate). BwK algorithms track both:

  1. Reward distribution: P(success | server)
  2. Cost distribution: E[retries needed | server]

  The key insight: a server with 80% success rate that usually fails on first attempt but succeeds on retry might be worse than a 70%
  success server that either works immediately or fails completely.

  Algorithm Sketch

  class BalancedExplorationStrategy(BaseStrategy):
      def __init__(self):
          # Track both success AND retry cost per server
          self.reward_estimates = {}  # E[success]
          self.cost_estimates = {}    # E[attempts needed]
          self.confidence_bounds = {}

      def select_server(self, servers, attempt_number):
          # Compute "bang per buck" with confidence bounds
          for server in servers:
              reward_ucb = self.reward_estimates[server] + confidence_bonus
              cost_lcb = max(self.cost_estimates[server] - confidence_bonus, epsilon)

              # Key BwK idea: optimize reward/cost ratio
              server.score = reward_ucb / cost_lcb

          # Balanced exploration: mix greedy + uniform exploration
          if should_explore():  # Based on remaining budget
              return uniform_random(servers)
          return argmax(server.score for server in servers)

</code></pre></div></div>
</details>

<h3 id="primaldualwithbk">PrimalDualWithBK</h3>

<p>The difference here is that we basically have a shadow parameter per resource that is the <em>dual vector</em> portion of this. It represents unit costs for resources and then helps balance resource consumption.</p>

<p>This $\lambda$ is our internal or shadow price that indicates how much we’re willing to pay for a given resource.</p>

<p>Primal-dual is different because it’s adaptive to our remaining budget. If we’re burning through resources too quickly, $\lambda$ rises and we do our best to avoid costly servers.</p>

<details>
  <summary style="padding: 10px; border-radius: 5px; cursor: pointer; color: #D77656; font-weight: bold; border: 1px solid rgba(215, 118, 86, 0.4);">
    <svg xmlns="http://www.w3.org/2000/svg" width="20" height="17" viewBox="12 22 96 70" style="vertical-align: middle; margin-right: 8px;">
      <path d="M0 0 C23.76 0 47.52 0 72 0 C72 9.24 72 18.48 72 28 C75.96 28 79.92 28 84 28 C84 32.62 84 37.24 84 42 C80.04 42 76.08 42 72 42 C72 46.62 72 51.24 72 56 C70.02 56 68.04 56 66 56 C66 60.62 66 65.24 66 70 C64.02 70 62.04 70 60 70 C60 65.38 60 60.76 60 56 C58.02 56 56.04 56 54 56 C54 60.62 54 65.24 54 70 C52.02 70 50.04 70 48 70 C48 65.38 48 60.76 48 56 C40.08 56 32.16 56 24 56 C24 60.62 24 65.24 24 70 C22.02 70 20.04 70 18 70 C18 65.38 18 60.76 18 56 C16.02 56 14.04 56 12 56 C12 60.62 12 65.24 12 70 C10.02 70 8.04 70 6 70 C6 65.38 6 60.76 6 56 C4.02 56 2.04 56 0 56 C0 51.38 0 46.76 0 42 C-3.96 42 -7.92 42 -12 42 C-12 37.38 -12 32.76 -12 28 C-8.04 28 -4.08 28 0 28 C0 18.76 0 9.52 0 0 Z" fill="#D77656" transform="translate(24,22)" />
      <path d="M0 0 C1.98 0 3.96 0 6 0 C6 4.62 6 9.24 6 14 C4.02 14 2.04 14 0 14 C0 9.38 0 4.76 0 0 Z" fill="#000000" transform="translate(78,36)" />
      <path d="M0 0 C1.98 0 3.96 0 6 0 C6 4.62 6 9.24 6 14 C4.02 14 2.04 14 0 14 C0 9.38 0 4.76 0 0 Z" fill="#000000" transform="translate(36,36)" />
    </svg>
    Claude Code: PrimalDual Explanation
  </summary>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
⏺ Great question—these are two fundamentally different algorithmic paradigms for BwK.

  Core Philosophical Difference
  ┌────────────────────────────────────────────┬───────────────────────────────────────────────────┐
  │            Balanced Exploration            │                  Primal-Dual BwK                  │
  ├────────────────────────────────────────────┼───────────────────────────────────────────────────┤
  │ Ratio-based: Optimize reward/cost directly │ Price-based: Learn a "shadow price" for resources │
  ├────────────────────────────────────────────┼───────────────────────────────────────────────────┤
  │ Explore to reduce uncertainty              │ Adapt prices based on consumption rate            │
  ├────────────────────────────────────────────┼───────────────────────────────────────────────────┤
  │ UCB-style confidence bounds                │ Lagrangian relaxation of constraints              │
  └────────────────────────────────────────────┴───────────────────────────────────────────────────┘
  Primal-Dual Intuition

  Instead of optimizing a ratio, Primal-Dual converts the constrained problem into an unconstrained one:

  Original: maximize E[reward] subject to E[cost] ≤ budget

  Primal-Dual: maximize E[reward] - λ * E[cost]
               where λ (dual variable) is learned online

  The dual variable λ is the shadow price—how much you're willing to "pay" to consume one unit of resource.

  Algorithm Sketch

  class PrimalDualBwKStrategy(BaseStrategy):
      def __init__(self):
          self.lambda_price = 0.0  # Shadow price for retries
          self.total_penalty_retries = 0
          self.total_requests = 0

      def select_server(self, servers, attempt_number):
          # Pick server maximizing: E[reward] - λ * E[cost]
          for server in servers:
              reward = self.success_estimate[server]
              cost = self.retry_cost_estimate[server]
              server.score = reward - self.lambda_price * cost

          return argmax(server.score)

      def update(self, server_url, success, attempts_used):
          # Update reward/cost estimates (standard)
          self._update_estimates(server_url, success, attempts_used)

          # KEY DIFFERENCE: Update the shadow price
          self.total_penalty_retries += max(0, attempts_used - 3)
          self.total_requests += 1

          # If consuming resources too fast, increase price
          target_rate = 0.5  # Expected penalty retries per request
          actual_rate = self.total_penalty_retries / self.total_requests

          # Multiplicative weights update
          self.lambda_price *= (1 + eta * (actual_rate - target_rate))

</code></pre></div></div>
</details>

<h1 id="take-home-multi-arm-bandit-variant">Take Home Multi-Arm Bandit Variant</h1>
<p>The takehome I received had an interesting twist on this. The change is that: <strong>you are only penalized for a failing server request after $k$ tries.</strong> So you are still trying to maximize your “score” (i.e. reward) but you’re also given some leeway.</p>

<p>It was not until deep research with Claude / ChatGPT that I learned the problem could (I think) best be framed as a <a href="#bandit-with-knapsack-bwk-variant"><strong>BwK</strong></a> problem.</p>

<h1 id="flaky-server---bwk-framing">Flaky Server - BwK Framing</h1>

<p>For more context with the takehome, the MAB portion was framed as you’re building a load balancer where the downstream servers are flaky and you’re trying to minimize penalties (which are signed after your failing load balancer request). They simply sent a binary (which I actually dislike and think is very sketch to send a binary with no details, certs or signatures, notarization etc). The binary opened up the following:</p>
<ul>
  <li>10 ports with Config 1 (constant error rate)</li>
  <li>10 ports with Config 2 (constant error rate + constant rate limit)</li>
  <li>10 ports with Config 3 (constant error rate + complex rate limit)</li>
</ul>

<h2 id="approach">Approach</h2>

<h3 id="aggressively-inspecting-the-binary">Aggressively Inspecting the Binary</h3>
<p>No fucking way am I blankly running a binary on my personal computer.</p>

<p>I am familiar with some basic CLI tools for inspecting binaries (<code class="language-plaintext highlighter-rouge">otool</code>, <code class="language-plaintext highlighter-rouge">strings</code>, <code class="language-plaintext highlighter-rouge">xattr</code> from the Dropbox days). However, this was something that I freely threw at Claude with explicit instructions not to run the binary and not to tell me anything about the underlying implementations of the load balancer config implementations (I’ll get to the de-compilation step in a bit).</p>

<p>I also knew that for all commands actually starting the load balancer binary that we would be running them in a restricted mode using <a href="https://igorstechnoclub.com/sandbox-exec/">sandbox-exec</a> which I hadn’t stumbled upon until this project. The blog i just linked does a fantastic job, so you should feel comfortable giving it some site traffic and peeking into that one. TLDR is it’s a way to run a binary in a sandboxed environment so that it only has permissions to various resources that you permit.</p>

<p>All of this looked good, so I was onto the actual implementation.</p>

<h3 id="load-balancer">Load Balancer</h3>

<p>This was obviously the meat of the problem and the most fun to reason and think about. Probably because it was the most math / stats intensive. I wrote a couple of versions myself, tried and saw the failures (Claude found bugs with how I was calculating the beta distributions variance for example) and kept iterating. It’s the part of the code I know the best and I can walk through the various implementations.</p>

<p>The later versions where we get into the BwK approaches (<code class="language-plaintext highlighter-rouge">v6</code> - <code class="language-plaintext highlighter-rouge">v8</code>) are implementations by Claude, but still interesting to see how they perform relative to the original ideas.</p>

<p>At this point, I’m pretty burnt on this project and I’m technically on vacation, so I am going to summarize and leave it as an exercise to the reader to investigate the code and understand the underlying logic</p>

<p><strong>These versions are all basic MAB approaches, not BwK specific.</strong></p>

<table>
  <thead>
    <tr>
      <th style="text-align: center">Code Version</th>
      <th style="text-align: center">Method</th>
      <th style="text-align: center">Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td style="text-align: center">V1</td>
      <td style="text-align: center">Larkin Intuition</td>
      <td style="text-align: center">We still model things as a Beta distribution.<br /> We have a <code class="language-plaintext highlighter-rouge">DISCOVER_LIMIT</code>. While we’re in <code class="language-plaintext highlighter-rouge">DISCOVER_MODE</code>, we select the arm / server with highest beta variance, and fire off attempts to that server. If that fails, we re-evaluate. We continue until we fail. After the discover limit, then we statically pick the best server to send requests to.</td>
    </tr>
    <tr>
      <td style="text-align: center">V2</td>
      <td style="text-align: center">Vanilla UCB</td>
      <td style="text-align: center">This is the UCB method described above. We first prioritize any untried servers (since technically they have an infinite UCB score). Then for each server, we calculate the UCB score using the formula:<br /> \(UCB = \text{success_rate} + \sqrt{\frac{2 \ln(\text{total_requests})}{\text{num_attempts}}}\)</td>
    </tr>
    <tr>
      <td style="text-align: center">V3</td>
      <td style="text-align: center">Adjusted UCB</td>
      <td style="text-align: center">Very similar to the above however this type we play games with our exploration constant. It’s no longer $\sqrt{2}$, it’s 3 (chosen arbitrarily, just bigger than $\sqrt{2}$) for the first three attempts and then 1 after that when we’re starting to get penalized.</td>
    </tr>
    <tr>
      <td style="text-align: center">V4</td>
      <td style="text-align: center">Vanilla Thompson Sampling</td>
      <td style="text-align: center">What we described above, we pick the server with the highest $p$ and then we go from there. Either way if it’s a success or a failure, we update our $\alpha$ and $\beta$.</td>
    </tr>
    <tr>
      <td style="text-align: center">V5</td>
      <td style="text-align: center">Modified Thompson Sampling</td>
      <td style="text-align: center">In a somewhat similar game to the modified UCB, we scale alpha and beta based on the number of requests to encourage exploration. We use an exponential decay and if we’re at 3 attempts or more, we do not scale at all and just revert back to normal TS. Our <code class="language-plaintext highlighter-rouge">scale_factor</code> then becomes <code class="language-plaintext highlighter-rouge">max(2, total/variance_scale) / total</code> where <code class="language-plaintext highlighter-rouge">total = alpha + beta</code>. We then multiply $\alpha$ and $\beta$ by those coefficients.</td>
    </tr>
  </tbody>
</table>

<p><strong>These approaches in honesty were CC generated, but are rate limited aware and targeted at BwK approaches.</strong></p>

<table>
  <thead>
    <tr>
      <th style="text-align: center">Code Version</th>
      <th style="text-align: center">Method</th>
      <th style="text-align: center">Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td style="text-align: center">V6</td>
      <td style="text-align: center">Thompson Masked</td>
      <td style="text-align: center">A slight discrepancy from the original Thompson Sampling. Here <code class="language-plaintext highlighter-rouge">429</code>s which indicate that we have been rate limited. We exclude rate limited servers from the selection pool. Note, we also indicate a server as being rate-limited if we’ve gotten a 429 in the past second. The big notion is that 429s are treated as different than failures. We do not update $\beta$ when we get one, we instead just indicate it’s been rate limited. If all of our servers are rate limited, we get the server that is most likely to expire soon. This is probably best for Config Type T2.</td>
    </tr>
    <tr>
      <td style="text-align: center">V7</td>
      <td style="text-align: center">Sliding Window</td>
      <td style="text-align: center">Here given that we have the notion of temporal and dynamic rate limiting, we only remember a set amount of requests / history. I chose 30 basically arbitrarily. Again, perhaps ideally we could learn the rate limits and dynamically adapt this. Our $\alpha$ and $\beta$ params are only updated based on the set history.</td>
    </tr>
    <tr>
      <td style="text-align: center">V8</td>
      <td style="text-align: center">Blocking Bandit</td>
      <td style="text-align: center">And here is the adaptive cooldown / blocking that <code class="language-plaintext highlighter-rouge">V7</code> was lacking. The difference is now if we hit a 429 we start to exponentially increase the wait time to block the incoming requests from going to a server that we know is rate-limited.</td>
    </tr>
  </tbody>
</table>

<h3 id="simulation-harness">Simulation Harness</h3>

<p>The simulation harness is almost entirely vibe-coded but basically sends requests to our load balancer at the prescribed rate of 10 RPS. For more information, I would check the <code class="language-plaintext highlighter-rouge">flaky-load-balancer/flaky_load_balancer/harness.py</code> file out. It’s on GH <a href="https://github.com/johnlarkin1/multi-armed-bandit/blob/main/flaky-load-balancer/flaky_load_balancer/harness.py">here</a>.</p>

<h3 id="dashboard">Dashboard</h3>

<p>The dashboard was a fun vibe coded application that is a NextJS app. There’s a decent amount of functionality here, so I’ll cover some of the highlights. This NextJS project is meant to summarize and compare the results from various strategies (<code class="language-plaintext highlighter-rouge">V1</code>-<code class="language-plaintext highlighter-rouge">V8</code>) against the various config types (<code class="language-plaintext highlighter-rouge">T1</code>-<code class="language-plaintext highlighter-rouge">T3</code>). It also has a comparison route that compares all of them for a given run.</p>

<p>It connects and listens to the FastAPI server (basically to our load balancer) so that we get SSE streams for things like the heartbeat, metrics, and connected state. So what I would suggest is running <code class="language-plaintext highlighter-rouge">make harness</code> and that will start your FastAPI load balancer, start the dashboard, start the downstream <code class="language-plaintext highlighter-rouge">flakyservers</code> binary, and then start firing off requests.</p>

<p>Here is a demo:</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="/videos/multi-armed-bandit/multi-armed-bandit.mp4" type="video/mp4" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<p>And furthermore, here are some screenshots from the comparison page:</p>

<p><img src="/images/multi-armed-bandit/comparison.png" alt="compare" class="center lightbox-image" /></p>

<p><img src="/images/multi-armed-bandit/compare-viz.png" alt="compare-viz" class="center lightbox-image" /></p>

<h1 id="conclusion">Conclusion</h1>

<p>So! What were the results?</p>

<p><img src="/images/multi-armed-bandit/results.png" alt="results" class="center lightbox-image" /></p>

<p>Unsurprisingly, our Thompson Modified seemed to do the best on <code class="language-plaintext highlighter-rouge">T1</code>, the Sliding Window somewhat surprisingly did the best on <code class="language-plaintext highlighter-rouge">T2</code> (probably because the underlying binary is sinusoidal and there was some benefit about the cadence and the window being used). Finally, for <code class="language-plaintext highlighter-rouge">T3</code> the Blocking Bandit or Thompson Masked seemed to do the best.</p>

<hr />

<p>There’s a lot more I could talk about here, but this has already spilled over on the time budgeting so I will end here. If interested, feel free to reach out!</p>]]></content><author><name>johnlarkin1</name></author><category term="Development" /><category term="AI" /><summary type="html"><![CDATA[&lt;!– This was meant to be for a take-home... I submitted some version of the first draft, but then couldn't stop and here we are. If parts trail off, it's because I shouldn't have even gone this deep into other more pressing matter]]></summary></entry><entry><title type="html">iMessage Data Foundry</title><link href="https://johnlarkin1.github.io/2026/imessage-data-foundry/" rel="alternate" type="text/html" title="iMessage Data Foundry" /><published>2026-01-18T00:00:00+00:00</published><updated>2026-01-18T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2026/imessage-data-foundry</id><content type="html" xml:base="https://johnlarkin1.github.io/2026/imessage-data-foundry/"><![CDATA[<div class="project-registry">
  <a href="https://github.com/johnlarkin1/imessage-data-foundry" target="_blank" rel="noopener" class="registry-card github">
    <span class="lang-icon">🐙</span>
    <span class="lang-badge">Source</span>
    <span class="registry-name">GitHub</span>
  </a>
  <a href="https://pypi.org/project/imessage-data-foundry/" target="_blank" rel="noopener" class="registry-card python">
    <span class="lang-icon">🐍</span>
    <span class="lang-badge">Python</span>
    <span class="registry-name">PyPI</span>
  </a>
</div>

<h1 id="context">Context</h1>

<p>Recently for one of my projects, I needed synthetic data generated to match a MacOS compatible iMessage <code class="language-plaintext highlighter-rouge">chat.db</code> as well as the <code class="language-plaintext highlighter-rouge">AddressBook.db</code>.</p>

<p>This is going to be a short post, because the <a href="https://github.com/johnlarkin1/imessage-data-foundry">Github’s</a><code class="language-plaintext highlighter-rouge">README.md</code> has a lot more information. So check that out.</p>

<p>Alternatively, feel free to watch this demo video:</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <iframe src="https://www.youtube.com/embed/t_6QnWvlkCI" title="iMessage Data Foundry Demo" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen="">
    </iframe>
  </div>
</div>

<h1 id="installation">Installation</h1>

<p>There are many ways:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>$ uv tool install imessage-data-foundry
$ uvx imessage-data-foundry
$ pip install imessage-data-foundry
$ pipx install imessage-data-foundry
</code></pre></div></div>

<h1 id="conclusion">Conclusion</h1>

<p>Thanks! Feel free to check out the GH repo or reach out if there are any questions / concerns. Also it’s open source so feel free to submit issues / PRs.</p>]]></content><author><name>johnlarkin1</name></author><category term="Development" /><category term="AI" /><summary type="html"><![CDATA[🐙 Source GitHub 🐍 Python PyPI]]></summary></entry><entry><title type="html">larkin-mcp</title><link href="https://johnlarkin1.github.io/2025/larkin-mcp/" rel="alternate" type="text/html" title="larkin-mcp" /><published>2025-12-14T00:00:00+00:00</published><updated>2025-12-14T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2025/larkin-mcp</id><content type="html" xml:base="https://johnlarkin1.github.io/2025/larkin-mcp/"><![CDATA[<p><img src="/images/larkin-mcp/hero.png" alt="larkin-mcp" class="center-shrink lightbox-image" /></p>

<div class="image-caption">You can either interact with mine, or clone the repo <a href="https://github.com/johnlarkin1/yourname-mcp">here</a>, and get started with the second one.</div>
<p><br /></p>

<div class="project-registry">
  <a href="https://pypi.org/project/larkin-mcp/" target="_blank" rel="noopener" class="registry-card python">
    <span class="lang-icon">🐍</span>
    <span class="lang-badge">Python</span>
    <span class="registry-name">PyPI</span>
  </a>
  <a href="https://www.npmjs.com/package/@johnlarkin1/larkin-mcp" target="_blank" rel="noopener" class="registry-card typescript">
    <span class="lang-icon">📘</span>
    <span class="lang-badge">TypeScript</span>
    <span class="registry-name">npm</span>
  </a>
  <a href="https://crates.io/crates/larkin-mcp" target="_blank" rel="noopener" class="registry-card rust">
    <span class="lang-icon">🦀</span>
    <span class="lang-badge">Rust</span>
    <span class="registry-name">crates.io</span>
  </a>
</div>

<div class="image-caption">Check out any of the links above for the various published packages. Note, Claude did the css here.</div>
<p><br /></p>
<div class="template-card-wrapper">
  <a href="https://github.com/johnlarkin1/yourname-mcp" target="_blank" rel="noopener" class="template-card">
    <span class="template-icon">📋</span>
    <span class="template-text">yourname-mcp template</span>
    <span class="template-arrow">→</span>
  </a>
</div>

<p><br /></p>

<p>I’m working on a much bigger project, but honestly, needed to take a break from that. It has been a grind. I have burned many early mornings on that.</p>

<p>So as a break, I have wanted to explore building my own MCP server and templatizing this to make it easier for others to install and set this up as well. This is not going to be a long post, but I’m hoping the repos speak for themselves, and this provides ample motivation.</p>

<!--
# Table of Contents

- [Table of Contents](#table-of-contents)
- [Motivation](#motivation)
  - [Personal Insights](#personal-insights)
  - [Interactive Timeline](#interactive-timeline)
    - [Example 1:](#example-1)
    - [Example 2:](#example-2)
  - [Personalized Study Guide](#personalized-study-guide)
- [Context](#context)
- [Why?](#why)
- [`yourname-mcp`](#yourname-mcp)
  - [Demo](#demo)
  - [Security](#security)
  - [Rust](#rust)
- [Conclusion](#conclusion)
-->

<h1 id="motivation">Motivation</h1>

<p>To provide some motivation (and perhaps earn a few stars on the template repo), here are practical examples of what you can do with this specific MCP server.</p>

<h2 id="personal-insights">Personal Insights</h2>

<blockquote>
  <p>What do you think was John Larkin’s hardest tennis match?</p>
</blockquote>

<p><strong>Result:</strong></p>

<p><img src="/images/larkin-mcp/hardest-match.png" alt="larkin-mcp" class="center-shrink lightbox-image" /></p>

<p><strong>Rude</strong>!! Hallucination. I didn’t get <em>bageled</em>, I got <em>breadsticked</em>. In other words, it was 1-6 not 0-6. But yes, shoutout to Phillip Locklear…</p>

<h2 id="interactive-timeline">Interactive Timeline</h2>

<h3 id="example-1">Example 1:</h3>

<p><strong>Prompt:</strong></p>

<blockquote>
  <p>Can you give me John’s experience’s as a beautiful timeline? Please create a html file with that visualization</p>
</blockquote>

<p><strong>Result:</strong> <a href="/assets/html/larkin-mcp/john-larkin-profile.html" target="_blank" rel="noopener">View the timeline <svg class="external-link-icon" width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" style="display:inline-block;vertical-align:middle;margin-left:2px;"><path d="M10.5 1.5L5.5 6.5M10.5 1.5H7M10.5 1.5V5M10.5 7V10C10.5 10.2761 10.2761 10.5 10 10.5H2C1.72386 10.5 1.5 10.2761 1.5 10V2C1.5 1.72386 1.72386 1.5 2 1.5H5" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round"></path></svg></a></p>

<div class="mcp-demo-iframe-wrapper">
  <iframe class="mcp-demo-iframe-container" src="/assets/html/larkin-mcp/john-larkin-profile.html" title="John Larkin Timeline" width="1120" height="630" allowfullscreen="">
  </iframe>
</div>
<p><br /></p>

<h3 id="example-2">Example 2:</h3>

<blockquote>
  <p>Can you use your frontend-design skill and build a beautiful interactive timeline of John’s work experience and personal project timeline as a single html file visualization?</p>
</blockquote>

<p><strong>Result:</strong> <a href="/assets/html/larkin-mcp/john-larkin-timeline.html" target="_blank" rel="noopener">View the timeline <svg class="external-link-icon" width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" style="display:inline-block;vertical-align:middle;margin-left:2px;"><path d="M10.5 1.5L5.5 6.5M10.5 1.5H7M10.5 1.5V5M10.5 7V10C10.5 10.2761 10.2761 10.5 10 10.5H2C1.72386 10.5 1.5 10.2761 1.5 10V2C1.5 1.72386 1.72386 1.5 2 1.5H5" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round"></path></svg></a></p>

<div class="mcp-demo-iframe-wrapper">
  <iframe class="mcp-demo-iframe-container" src="/assets/html/larkin-mcp/john-larkin-timeline.html" title="John Larkin Experience Timeline" width="1120" height="630" allowfullscreen="">
  </iframe>
</div>
<p><br /></p>

<p>Honestly, the second one is pretty slick although it’s a bit… vapid of personality I guess.</p>

<p>fwiw, here is the usage in CC:</p>

<p><img src="/images/larkin-mcp/claude-code-example.png" alt="larkin-mcp" class="center-super-shrink lightbox-image" /></p>

<h2 id="personalized-study-guide">Personalized Study Guide</h2>

<p><strong>Prompt:</strong></p>

<blockquote>
  <p>Can you help John Larkin prepare for an Anthropic interview given his resume and past experience? Please search and find open roles and then prepare a study guide for his various gaps.</p>
</blockquote>

<p><strong>Result:</strong></p>

<p>Not sharing the whole thing, but you can see this from Claude Desktop:</p>

<p><img src="/images/larkin-mcp/claude-desktop-example.png" alt="larkin-mcp" class="center-small lightbox-image" /></p>

<h1 id="context">Context</h1>

<p>I wanted to set up a local MCP server that you can install to ask questions about the user. There are two versions:</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">larkin-mcp</code> - my materialized repo that has details about myself (largely professional, markdown files are online, but I’m guessing in the age of the internet, this level of detail is fine).</li>
  <li><code class="language-plaintext highlighter-rouge">yourname-mcp</code> - the templated repo where you can clone this, run a script, and optionally publish (caution: the info that you put in your <code class="language-plaintext highlighter-rouge">resources/content</code> markdown files will then be indexable / probably ingested from some AI… but my theory is that most of that stuff is already going to be there)</li>
</ul>

<h1 id="why">Why?</h1>

<p>Yeah so this was something my PM girlfriend asked me almost immediately. Why do this? Can’t you just feed your resume into ChatGPT and it’ll basically be able to do the same? I think yes, partially, but (at least in my case), my resume is still missing a ton of context. So I think my response is mutli-fold:</p>

<blockquote>
  <p>Can’t you just feed your resume into ChatGPT and ask questions of that?</p>
</blockquote>

<ol>
  <li>Feeding in your resume as a <code class="language-plaintext highlighter-rouge">pdf</code> or <code class="language-plaintext highlighter-rouge">md</code> file is going to bloat your context window. MCP provides more selective invocations.</li>
  <li>I don’t want to do that everytime I need something with my context and personality</li>
  <li>It’s still missing a ton of context about who I am and some more ephemeral things about me. (note: i know that 90% of companies won’t care about that, and 99.9% of recruiters won’t care about it)</li>
  <li>I wanted to be able to distribute this. There’s a world I could imagine where recruiters just run <code class="language-plaintext highlighter-rouge">uvx larkin-mcp</code> and then ask questions to get a feel for my work and who I am</li>
  <li>I want to control the level of detail and insight that this MCP server has</li>
  <li>I wanted to build an MCP server… I hadn’t done it, even at work.</li>
  <li>I wanted to explore the tooling around it as well.</li>
  <li>I wanted to build an MCP server in Typescript and Rust explicitly, given I’m trying to work on my Rust skills and I’m less involved in those communities</li>
  <li>I thought it would be a useful thing to templatize and set up some infrastructure so less technical users could <code class="language-plaintext highlighter-rouge">git clone &lt;repo&gt; &amp;&amp; ./run-install.sh</code> and that would ask them a couple of questions, analyze their resume, convert it into markdown, they could write some markdown to provide more context, and then boom, they could also publish it and others could use it if they wanted.</li>
  <li>As stated previously, I needed a break from my other project.</li>
</ol>

<p>And if you’re thinking like <em>well, what about Claude memory or ChatGPT memory?</em>, I’m really not a fan of that. I don’t think Simon Willison is either. And I don’t trust it to not sycophant it up or pull information that perhaps I don’t want for the questions I’m asking.</p>

<p>Hopefully, that’s enough rationale for personal motivation.</p>

<h1 id="yourname-mcp"><code class="language-plaintext highlighter-rouge">yourname-mcp</code></h1>

<p>This is hopefully your template of interest. The point is that this has enough scaffolding that you can run the install script, populate a couple markdown files, upload to PYPI and then you’re off and running. There will be more info in the actuall repo <a href="https://github.com/johnlarkin1/yourname-mcp">here</a>.</p>

<h2 id="demo">Demo</h2>

<p>Here is a demo showcasing the functionality:</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="https://www.dropbox.com/scl/fi/v7ljkkxf3p8d24vk8wlpk/yourname-mcp-demo-lg.mp4?rlkey=95ha9lg6gpwngufkdvq9l2t0q&amp;st=dlhyft2u&amp;raw=1" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<h2 id="security">Security</h2>

<p>I - like basically every other engineer - am slightly cautious about MCP. There are going to be large amounts of attacks given the trust people are placing into MCP and utilizing binary executables (i.e. <code class="language-plaintext highlighter-rouge">bunx</code> or <code class="language-plaintext highlighter-rouge">uvx</code>).</p>

<p>This is from 6 days ago (at time of writing):</p>

<blockquote class="reddit-embed-bq" style="height:316px" data-embed-theme="dark" data-embed-height="316"><a href="https://www.reddit.com/r/MCPservers/comments/1poelh4/is_anyone_else_terrified_by_the_lack_of_security/">Is anyone else terrified by the lack of security in standard MCP?</a><br /> by<a href="https://www.reddit.com/user/RaceInteresting3814/">u/RaceInteresting3814</a> in<a href="https://www.reddit.com/r/MCPservers/">MCPservers</a></blockquote>
<script async="" src="https://embed.reddit.com/widgets.js" charset="UTF-8"></script>

<p><br /></p>

<p>Even with this project… while I utilize <code class="language-plaintext highlighter-rouge">uvx</code> and <code class="language-plaintext highlighter-rouge">bunx</code> for the convenience, I am 100% afraid about impersonations, security attacks, people injecting malicious code from poor distributors. This is obviously nuanced. I am a huge fan of making software easily disseminated but the increase in malicious code and actors (that are only exacerbated from the AI wave) is extremely alarming. I mean just look at npm in the <a href="https://semgrep.dev/blog/2025/chalk-debug-and-color-on-npm-compromised-in-new-supply-chain-attack/">past</a> <a href="https://securitylabs.datadoghq.com/articles/shai-hulud-2.0-npm-worm/">couple</a> <a href="https://www.crowdstrike.com/en-us/blog/crowdstrike-falcon-prevents-npm-package-supply-chain-attacks/?utm_source=chatgpt.com">months</a>?</p>

<h2 id="rust">Rust</h2>

<p>I could have used something like <a href="https://crates.io/crates/cargo-binstall"><code class="language-plaintext highlighter-rouge">cargo-binstall</code></a>, but didn’t quite get to it. As a result, if you want to set this up in Claude Code or Claude Desktop, you’ll need to do something like <code class="language-plaintext highlighter-rouge">cargo install larkin-mcp</code> and then point to that corresponding built binary:</p>

<div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>   Compiling larkin-mcp v1.0.2
    Finished <span class="sb">`</span>release<span class="sb">`</span> profile <span class="o">[</span>optimized] target<span class="o">(</span>s<span class="o">)</span> <span class="k">in </span>14.83s
  Installing /Users/johnlarkin/.cargo/bin/larkin-mcp
   Installed package <span class="sb">`</span>larkin-mcp v1.0.2<span class="sb">`</span> <span class="o">(</span>executable <span class="sb">`</span>larkin-mcp<span class="sb">`</span><span class="o">)</span>
</code></pre></div></div>

<p>Rust was my favorite to implement, although the code structure is perhaps not as Rust idiomatic as it should be. In my opinion, <code class="language-plaintext highlighter-rouge">rmcp</code> which is the canonical framework for Rust MCP servers is slightly less ergonomic. They match a lot of the Python decorators in terms of Rust macros but there’s some tricks about public traits and understanding what is actually going on given the function calls.</p>

<h1 id="conclusion">Conclusion</h1>

<p>If you like this, or think it will be useful, please check out the basically templated repo <code class="language-plaintext highlighter-rouge">yourname-mcp</code> where the <code class="language-plaintext highlighter-rouge">README.md</code> will walk you through what you need to do! Always feel free to email or leave comments if need be.</p>]]></content><author><name>johnlarkin1</name></author><category term="Development" /><category term="AI" /><summary type="html"><![CDATA[]]></summary></entry><entry><title type="html">Understanding Muon</title><link href="https://johnlarkin1.github.io/2025/understanding-muon/" rel="alternate" type="text/html" title="Understanding Muon" /><published>2025-10-28T00:00:00+00:00</published><updated>2025-10-28T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2025/understanding-muon</id><content type="html" xml:base="https://johnlarkin1.github.io/2025/understanding-muon/"><![CDATA[<div class="markdown-alert markdown-alert-note">
<p>So while I tried to mainly focus on optimizers, this post kinda splayed out some. It was my first time trying <b><a href="https://pyodide.org/en/stable/">Pyodide</a></b> and incorporating that logic into my blog. It was my first time using <b><a href="https://docs.manim.community/en/stable/">manim</a></b>, which was exciting because I'm a big fan of the 3Blue1Brown channel. I also introduced quizzes (see AdamW section) for more interactivity. All of this is open source though, so if you have any questions, I'd be flattered if you emailed, but obviously you can just ask ChatGPT / Claude. 
</p>
</div>

<p><br /></p>

<h1 id="motivating-visualization">Motivating Visualization</h1>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="https://www.dropbox.com/scl/fi/399366yvev1jq03cvmu5w/muon-overview.mp4?rlkey=w8sh3t2ucnvboo4l72apzfmdj&amp;st=vvvy6k3s&amp;raw=1" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<div class="image-caption">Read on to understand the above visualization. My manim skills aren't fantastic so the timing of above could be improved.</div>
<p><br /></p>

<p>Today, we’re going to try and understand as much of this animation as possible. We’ll cover optimizers as a construct, look at an example, take a walk through history (again high level) and then we’ll investigate Muon, which is a more recent optimizer that has been sweeping the community. Note, we will not cover Newton-Schulz iteration or approximation of the SVD calc, but I’m hoping to cover that in another blog post.</p>

<div class="markdown-alert markdown-alert-tip">
<p>Also if you're curious the visualization code (which is a bit of a mess) is <b><a href="https://github.com/johnlarkin1/understanding-muon">here.</a></b></p>
</div>

<h1 id="background">Background</h1>

<p><a href="https://github.com/karpathy/nanochat"><code class="language-plaintext highlighter-rouge">nanochat</code></a> just dropped a couple of weeks ago and one element that I was extremely interested in was <a href="https://kellerjordan.github.io/posts/muon/">muon</a>. It’s a pretty recent state of the art optimizer that has shown competitive performance in training speed challenges.</p>

<p>First of all, if you are not familiar with some of this, you should start with Keller Jordan’s blog that I linked above. He’s the creator of the approach and it’s pretty ingenious. Second of all, if you’re not familiar with linear algebra at all (which is ok), I’d recommend this <a href="https://little-book-of.github.io/linear-algebra/">Little Book of Linear Algebra</a>. I ran through it over the past couple weeks so that I could ensure a strong base / have a refresher for some of the concepts that I haven’t seen since college. You can check out the <a href="https://github.com/johnlarkin1/little-book-of-linalg">Jupyter notebooks here</a>.</p>

<p>This post is going to try and take you as close from $0 \to 1$ as possible (one huge benefit of running through the book + lab linked above is my latex got way better. Not going to help me land a job at Anthropic but c’est la vie).</p>

<!--
# Table of Contents

- [Motivating Visualization](#motivating-visualization)
- [Background](#background)
- [Table of Contents](#table-of-contents)
- [(optional) Reading + Videos](#optional-reading--videos)
- [Deep Learning (simplified)](#deep-learning-simplified)
- [Tour of Popular Optimizers](#tour-of-popular-optimizers)
  - [Loss Function](#loss-function)
    - [Visualization](#visualization)
  - [Stochastic Gradient Descent](#stochastic-gradient-descent)
  - [SGD with Momentum](#sgd-with-momentum)
    - [Computational Cost of Momentum](#computational-cost-of-momentum)
    - [Variations](#variations)
  - [Adaptive Learning Rates (AdaGrad / RMSProp)](#adaptive-learning-rates-adagrad--rmsprop)
    - [AdaGrad (2010)](#adagrad-2010)
      - [Variations](#variations-1)
    - [RMSProp (2012)](#rmsprop-2012)
      - [Variations](#variations-2)
  - [Bias Correction (finally meeting Adam Optimizer, 2015)](#bias-correction-finally-meeting-adam-optimizer-2015)
    - [Comparison so Far](#comparison-so-far)
    - [Plain English](#plain-english)
    - [Viz](#viz)
  - [Weight Decay Coupling (the "W" in AdamW, 2017)](#weight-decay-coupling-the-w-in-adamw-2017)
    - [L2 Regularization](#l2-regularization)
    - [Viz](#viz-1)
- [Muon (MomentUm Orthogonalized by Newton-Schulz) (2025)](#muon-momentum-orthogonalized-by-newton-schulz-2025)
  - [Theory](#theory)
    - [Odd Polynomial Matrix](#odd-polynomial-matrix)
    - [Newton-Schulz Iteration](#newton-schulz-iteration)
    - [Overview](#overview)
  - [Implementation](#implementation)
- [Conclusion](#conclusion)
-->

<h1 id="optional-reading--videos">(optional) Reading + Videos</h1>

<p>These are a couple of helpful resources for you all to get started. I would actually think that if you’re starting from close to scratch or near scratch (haven’t studied AdamW) then you should probably come back to these after my article.</p>

<ul>
  <li>Videos
    <ul>
      <li><a href="https://www.youtube.com/watch?v=bO5nvE289ec"><strong>This Simple Optimizer Is Revolutionizing How We Train AI (Muon)</strong></a> (p.s. god the amount of clickbaiting people do is just suffocating me… however, this is a good video)</li>
    </ul>
  </li>
  <li>Reading
    <ul>
      <li><a href="https://kellerjordan.github.io/posts/muon/"><strong>Muon: An optimizer for hidden layers in neural networks</strong></a> - <em>Keller Jordan</em></li>
      <li><a href="https://jeremybernste.in/writing/deriving-muon"><strong>Deriving Muon</strong></a> - <em>Jeremy Bernstein</em></li>
      <li><a href="https://www.lakernewhouse.com/writing/muon-1"><strong>Understanding Muon</strong></a> - <em>Laker Newhouse</em>
        <ul>
          <li>this series (after doing my own research and investigation) is hilariously written. lots of Matrix allusions</li>
        </ul>
      </li>
    </ul>
  </li>
</ul>

<h1 id="deep-learning-simplified">Deep Learning (simplified)</h1>

<p>I’m not going to take you from the very beginning, but the language of deep learning is basically just… linear algebra.</p>

<p>We have these “deep learning” models that are really neural networks. All that means is that they’re layers of parameters (weights and biases) that take various inputs and make predictions. They normally are <em>affine transformations</em> followed by a (usually) non-linear activation.</p>

<p>Generally, the flow for training in deep learning goes like this:</p>

<ol>
  <li>forward pass (feeding data in)</li>
  <li>loss function (so we know how we did)</li>
  <li>backward pass (so we know how to adapt)</li>
  <li>gradient descent (or flavors thereof… where we actually adjust our weights)</li>
</ol>

<p>There’s fascinating math at all points of this process. However, we’re going to spend the day focusing on step 4 - and specifically on the subset of <strong>optimizers</strong>. Modern optimizers modify gradients using momentum, adaptive learning rates, etc.</p>

<p>Here is a high level visualization of what’s happening:</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="https://www.dropbox.com/scl/fi/11h7n3gwa30gmo57yj0zo/ch1-ml-training-process.mp4?rlkey=z29nnmou3ab8zvvj5hliphi25&amp;st=nov43ilo&amp;raw=1" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<div class="image-caption">Courtesy of me and Claude hammering on manim</div>
<p><br /></p>

<p>Note, that $ \eta $ here is the learning rate.</p>

<h1 id="tour-of-popular-optimizers">Tour of Popular Optimizers</h1>

<p>Ok the canonical example with optimizers is that we’re basically trying to find the lowest point in a valley. This is assuming our search space is $\mathbb{R}^3$ really but that’s fine for now.</p>

<p>So like let’s take an actual example with the Grand Canyon. Imagine you’re standing on top of the Grand Canyon - how are you going to find the lowest point in the Grand Canyon?</p>

<p><img src="https://www.jasonweissphotography.com/images/960/grand-canyon-toroweap-sunrise.jpg" alt="top-of-grand-canyon" class="center-shrink lightbox-image" /></p>

<div class="image-caption"><a href="https://www.jasonweissphotography.com/photo/grand-canyon-sunrise-toroweap/">Kudos</a> to Jason Weiss</div>
<p><br /></p>

<p>Now, the optimizer is basically telling us <em>how</em> to walk down that space. It’s obviously a lot easier if we have a topographic map, but we certainly do not in deep learning, and even with the topographic map, it can be tough to search across.</p>

<p><img src="https://databayou.com/grand/images/grandcanyonelevation.webp" alt="grand-canyon-topo" class="center-shrink lightbox-image" /></p>

<div class="image-caption"><a href="https://databayou.com/grand/canyon.html">Kudos</a> to DataByYou</div>
<p><br /></p>

<p>In this analogy, elevation is basically how “wrong” we are. You can think of it as the output of our loss function $L(\hat{y}, y)$. So we compute gradients to determine which direction reduces that loss. However, we still don’t know how big each step would be (the $\eta$ mentioned above) or how to adjust over time or how to avoid getting caught in local minima, etc.</p>

<h2 id="loss-function">Loss Function</h2>

<h3 id="visualization">Visualization</h3>

<p>I don’t have a loss function that is equivalent to the Grand Canyon (sadly), but we are going to look at the <a href="https://www.sfu.ca/~ssurjano/stybtang.html">Styblinski Tang function</a> as our example loss function. This isn’t going to be accurate, but imagine that the loss function of our deep learning process is only in 3D and has a shape that can be described by a function. In 2D, the Styblinski Tang function looks like this:</p>

\[\begin{align}
f(x,y) &amp;= \frac{1}{2}\sum_{i=1}^{d} \big(x_i^4 - 16x_i^2 + 5x_i \big) \\
f(x,y) &amp;= \frac{1}{2}(x_i^4 - 16x_i^2 + 5x_i ) (y_i^4 - 16y_i^2 + 5y_i )
\end{align}\]

<p>Here’s a visualization of this function:</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="https://www.dropbox.com/scl/fi/evcgoniyxavkrqs8gajru/ch2-loss-function.mp4?rlkey=r74573378jt9njeltk19w5912&amp;st=rmyax9oh&amp;raw=1" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<div class="image-caption">Courtesy of me and Claude hammering on manim</div>
<p><br /></p>

<h2 id="stochastic-gradient-descent">Stochastic Gradient Descent</h2>

<p>Conceptually with standard stochastic gradient descent (SGD), we update our weights so that we move in the opposite direction of the gradient (given gradient points to highest uphill direction).</p>

<p>Mathematically speaking, this is:</p>

\[\theta_{t+1} = \theta_t - \eta \nabla_{\theta} L (\theta_t)\]

<p>SGD works pretty well but it’s far from the best. Think about it back to our Grand Canyon approach. Imagine there are steep stairs but they zig-zag back and forth down the grand canyon. Potentially there is a ramp that is less steep but still more directly gets us to the lowest point in the valley quicker. If our landscape is more dynamic than just a vanilla bowl,that path is almost certainly not straight, and therefore SGD isn’t the most <em>efficient</em>. This is basically what happens to SGD in ravines. There is high curvature in one dimension, but not in another.</p>

<p>Furthermore, this step size for the gradient descent isn’t dynamic enough. Having one step size doesn’t take into nuance the steps per model param / model param derivative that we need to adjust by, so we can overblow our targets.</p>

<p>Here’s an example of where SGD could get caught in a local minima.</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="https://www.dropbox.com/scl/fi/nx9o5mbx5kc206ksf2wc1/ch3-sgd-trap.mp4?rlkey=znd4b6bl69dg3roi1t7aox3sq&amp;st=21l0we8j&amp;raw=1" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<div class="image-caption">Courtesy of me and Claude hammering on manim</div>
<p><br /></p>

<p>And if 3D isn’t really your style (especially given my <code class="language-plaintext highlighter-rouge">manima</code> skills are pretty poor). Here’s some Python code that will visualize SGD as a topological 2D portion:</p>

<!-- prettier-ignore-start -->
<div class="interactive-python">
<pre><code class="language-python">
import numpy as np
import matplotlib.pyplot as plt
from itertools import product

def styblinski_tang_fn(x: float, y: float) -&gt; float:
    return 0.5 * ((x**4 - 16 * x**2 + 5 * x) + (y**4 - 16 * y**2 + 5 * y))

def styblinski_tang_grad(x: float, y: float) -&gt; np.ndarray:
    dfx = 2 * x**3 - 16 * x + 2.5
    dfy = 2 * y**3 - 16 * y + 2.5
    return np.array([dfx, dfy], dtype=float)

eta = 0.01
steps = 80
theta = np.array([3.5, -3.5], dtype=float)

"""SGD!!! This is the important part here. Implementing the exact math above."""
path = [theta.copy()]
for _ in range(steps):
    grad = styblinski_tang_grad(*theta)
    theta -= eta * grad
    path.append(theta.copy())
path = np.array(path)

"""find stationary points (we can just look at derivative because repeated)"""
roots = np.roots([2.0, 0.0, -16.0, 2.5])
roots = np.real(roots[np.isreal(roots)])          # keep real roots
"""this is basically using second derivative to determine minima"""
minima_1d = [r for r in roots if (6*r*r - 16) &gt; 0]  # two minima
local_minima_2d = np.array(list(product(minima_1d, repeat=2)), dtype=float)
vals = np.array([styblinski_tang_fn(x, y) for x, y in local_minima_2d])
gmin_idx = np.argmin(vals)
gmin_pt = local_minima_2d[gmin_idx]
gmin_val = vals[gmin_idx]

"""viz"""
x = y = np.linspace(-5, 5, 300)
X, Y = np.meshgrid(x, y)
Z = styblinski_tang_fn(X, Y)
plt.figure(figsize=(7, 6))
cs = plt.contour(X, Y, Z, levels=40, cmap="viridis", alpha=0.85)
plt.clabel(cs, inline=True, fmt="%.0f", fontsize=8)
plt.plot(path[:, 0], path[:, 1], 'r.-', label='GD Path', zorder=2)
plt.scatter(path[0, 0], path[0, 1], color='orange', s=80, label='Start', zorder=3)
plt.scatter(path[-1, 0], path[-1, 1], color='blue', s=80, label='End', zorder=3)
mask = np.ones(len(local_minima_2d), dtype=bool)
mask[gmin_idx] = False
if np.any(mask):
    plt.scatter(local_minima_2d[mask, 0], local_minima_2d[mask, 1],
                marker='v', s=120, edgecolor='k', facecolor='white',
                label='Local minima', zorder=4)
plt.scatter(gmin_pt[0], gmin_pt[1], marker='*', s=220, edgecolor='k',
            facecolor='gold', label=f'Global min ({gmin_pt[0]:.4f}, {gmin_pt[1]:.4f})\n f={gmin_val:.4f}', zorder=5)
plt.title("Gradient Descent on Styblinski–Tang: Local vs Global Minima")
plt.xlabel("x"); plt.ylabel("y"); plt.legend(loc='upper right'); plt.grid(alpha=0.3); plt.tight_layout();
plt.show()
</code></pre>
</div>
<!-- prettier-ignore-end-->

<h2 id="sgd-with-momentum">SGD with Momentum</h2>

<p>So the natural progression is how can we do better than normal SGD.</p>

<p>This idea has been around forever (1964) compared to Muon which is basically 2025. Boris Polyak introduced momentum with physical intuition. If you roll a heavy ball down a hill and there are valleys, it doesn’t get trapped in a local minima. It has momentum to carry it over local minimum which helps find a global min.</p>

<p>Mathematically, it’s a pretty simple extension from our previous. The general idea is that now we have two equations governing how we update our parameters:</p>

\[\begin{align}
v_{t+1} &amp;= \beta v_t - \eta \nabla_{\theta} L (\theta_{t}) \\
\theta_{t+1} &amp;= \theta_{t} + v_{t+1}
\end{align}\]

<p>We’ve got some new parameters, so let’s define those:</p>

<ul>
  <li>$v_{t}$ - is the “velocity”, it’s the accumulated gradient basically our physical momentum</li>
  <li>$\beta$ - is the “momentum coefficient”. controls how much history we remember and how much we want to propagate</li>
  <li>$\eta$ - is still our learning rate</li>
</ul>

<p>A key insight is that if you take $\beta \to 0$ and substitute $v_{t+1}$ then our whole thing falls back to SGD (which is good).</p>

<p>A core paradigm shift here was that this was the first time gradient descent carried with it the notion of memory. It’s a bit more stateful.</p>

<p>Once again, a 3D version, and a 2D version.</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="https://www.dropbox.com/scl/fi/ihq86vaebowf17g7dacql/ch4-sgd-mom.mp4?rlkey=1c5xshmajrfy66z7brnfq61hu&amp;st=oladf1ox&amp;raw=1" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<div class="image-caption">Courtesy of me and Claude hammering on manim</div>
<p><br /></p>

<p>And the 2D visualization:</p>

<div class="interactive-python">
<pre><code class="language-python">
import numpy as np
import matplotlib.pyplot as plt
from itertools import product

def styblinski_tang_fn(x: float, y: float) -&gt; float:
    return 0.5 * ((x**4 - 16 * x**2 + 5 * x) + (y**4 - 16 * y**2 + 5 * y))

def styblinski_tang_grad(x: float, y: float) -&gt; np.ndarray:
    dfx = 2 * x**3 - 16 * x + 2.5
    dfy = 2 * y**3 - 16 * y + 2.5
    return np.array([dfx, dfy], dtype=float)

def stationary_points_and_global_min():
    roots = np.roots([2.0, 0.0, -16.0, 2.5])
    roots = np.real(roots[np.isreal(roots)])
    minima_1d = [r for r in roots if (6*r*r - 16) &gt; 0]
    mins2d = np.array(list(product(minima_1d, repeat=2)), dtype=float)
    vals = np.array([styblinski_tang_fn(x, y) for x, y in mins2d])
    gidx = np.argmin(vals)
    return mins2d, mins2d[gidx], vals[gidx]

def run_sgd(theta0, eta=0.02, steps=1200):
    theta = np.array(theta0, float)
    path = [theta.copy()]
    for _ in range(steps):
        theta -= eta * styblinski_tang_grad(*theta)
        path.append(theta.copy())
    return np.array(path)

"""
again, re-call beta is our momentum coefficient
eta is still our learning rate
extension: Nesterov Momentum
"""
def run_momentum(theta0, eta=0.02, beta=0.90, steps=1200):
    theta = np.array(theta0, float)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        grad = styblinski_tang_grad(*theta)
        v = beta * v - eta * grad
        theta = theta + v
        path.append(theta.copy())
    return np.array(path)

"""params"""
theta_start = np.array([4.1, 4.5], dtype=float)
eta = 0.02
beta = 0.90
steps = 1200
use_nesterov = False  # flip to True to experiment

sgd_path = run_sgd(theta_start, eta=eta, steps=steps)
mom_path = run_momentum(theta_start, eta=eta, beta=beta, steps=steps)
mins2d, gmin_pt, gmin_val = stationary_points_and_global_min()

"""viz"""
x = y = np.linspace(-5, 5, 400)
X, Y = np.meshgrid(x, y)
Z = styblinski_tang_fn(X, Y)

plt.figure(figsize=(8, 7))
cs = plt.contour(X, Y, Z, levels=50, cmap="viridis", alpha=0.85)
plt.clabel(cs, inline=True, fmt="%.0f", fontsize=7)
plt.plot(sgd_path[:, 0], sgd_path[:, 1], 'r.-', lw=1.5, ms=3, label='SGD')
plt.plot(mom_path[:, 0], mom_path[:, 1], 'b.-', lw=1.5, ms=3,
         label=f'Momentum (β={beta}, nesterov={use_nesterov})')
plt.scatter(sgd_path[0, 0], sgd_path[0, 1], c='orange', s=80, label='Start', zorder=3)
plt.scatter(sgd_path[-1, 0], sgd_path[-1, 1], c='red', s=70, label='SGD End', zorder=3)
plt.scatter(mom_path[-1, 0], mom_path[-1, 1], c='blue', s=70, label='Momentum End', zorder=3)
vals = np.array([styblinski_tang_fn(x0, y0) for x0, y0 in mins2d])
mask = np.ones(len(mins2d), dtype=bool)
mask[np.argmin(vals)] = False
if np.any(mask):
    plt.scatter(mins2d[mask, 0], mins2d[mask, 1],
                marker='v', s=120, edgecolor='k', facecolor='white',
                label='Local minima', zorder=4)
plt.scatter(gmin_pt[0], gmin_pt[1], marker='*', s=220, edgecolor='k',
            facecolor='gold', label=f'Global min ({gmin_pt[0]:.4f}, {gmin_pt[1]:.4f})\n f={gmin_val:.4f}', zorder=5)

plt.title("SGD vs Momentum on Styblinski–Tang: Escaping a Local Minimum")
plt.xlabel("x"); plt.ylabel("y")
plt.legend(loc='lower right'); plt.grid(alpha=0.3); plt.tight_layout()
plt.show()
</code></pre>

</div>

<h3 id="computational-cost-of-momentum">Computational Cost of Momentum</h3>

<p>So while momentum is great and improves training, let’s look at the change. In SGD, we have:</p>

<ul>
  <li>for each parameter $\theta_i$:
    <ul>
      <li>parameter itself</li>
      <li>gradient $\nabla_{\theta_{i}} L$</li>
    </ul>
  </li>
</ul>

<p>But now that we’re carrying around velocity $v_i$ for each parameter, we store:</p>

<ul>
  <li>memory wise:
    <ul>
      <li>one extra tensor the same size as $\theta$ (i.e. the same number of parameters we have )</li>
    </ul>
  </li>
  <li>comp wise:
    <ul>
      <li>this is relatively inexpensive given it’s basically one more computation to make</li>
      <li>but a good callout is that it’s not free</li>
    </ul>
  </li>
</ul>

<p>Again, even if you didn’t look at all at the code above, and just ran it, see this part:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">run_momentum</span><span class="p">(</span><span class="n">theta0</span><span class="p">,</span> <span class="n">eta</span><span class="o">=</span><span class="mf">0.02</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.90</span><span class="p">,</span> <span class="n">steps</span><span class="o">=</span><span class="mi">1200</span><span class="p">):</span>
    <span class="n">theta</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">theta0</span><span class="p">,</span> <span class="nb">float</span><span class="p">)</span>
    <span class="n">v</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">theta</span><span class="p">)</span>
    <span class="n">path</span> <span class="o">=</span> <span class="p">[</span><span class="n">theta</span><span class="p">.</span><span class="n">copy</span><span class="p">()]</span>
    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">steps</span><span class="p">):</span>
        <span class="n">grad</span> <span class="o">=</span> <span class="n">styblinski_tang_grad</span><span class="p">(</span><span class="o">*</span><span class="n">theta</span><span class="p">)</span>
        <span class="n">v</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">v</span> <span class="o">-</span> <span class="n">eta</span> <span class="o">*</span> <span class="n">grad</span>
        <span class="n">theta</span> <span class="o">=</span> <span class="n">theta</span> <span class="o">+</span> <span class="n">v</span>
        <span class="n">path</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">theta</span><span class="p">.</span><span class="n">copy</span><span class="p">())</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
</code></pre></div></div>

<p>That <code class="language-plaintext highlighter-rouge">v</code> didn’t exist before with standard SGD.</p>

<p>This is a general tradeoff that we’ll need to think about optimizer design. We need to be thinking about this on a massive magnitude of training and that each operation has significant impact leading to real $ signs.</p>

<h3 id="variations">Variations</h3>

<p>I won’t go into these in detail, but as with everything, there are numerous variations.</p>

<ul>
  <li><a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Momentum">Nesterov momentum (a.k.a NAG)</a></li>
</ul>

<h2 id="adaptive-learning-rates-adagrad--rmsprop">Adaptive Learning Rates (AdaGrad / RMSProp)</h2>

<p>Great, so momentum is going to help us smooth learning. The next area of improvement was for people to focus on $\eta$. It sucks that it’s the same for every parameter, so the whole notion was that we want to have our learning rate be adaptive per parameter.</p>

<p>This section is where the math starts to get a bit more interesting.</p>

<h3 id="adagrad-2010">AdaGrad (2010)</h3>

<p>Adaptive gradient came first. <a href="https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf">Here’s the original paper</a> written in 2010 by John Duchi, Elad Hazan, and Yoram Singer.</p>

<p>The general idea is:</p>

<ul>
  <li>we keep track of how large each parameter’s past gradients have been</li>
  <li>we use the history to scale down updates for params that have seen a lot of gradient action</li>
</ul>

<p>So the core idea here is that we’re going to track the <strong>sum of each parameters squared gradients over time</strong>. And this helps a ton of things with things like vanishing and exploding gradients (which actually was also an annoyance with <a href="/2025/teaching-a-computer-to-write/"><strong>Teaching a Computer How to Write</strong></a>.</p>

<p>In other words,</p>

\[r_{t,i} = \sum_{k=1}^t g_{k,i}^2\]

<p>So basically $r_{t,i}$ for the $i$th parameter at time $t$ is going to tell you how much more or less “energy”.</p>

<p>Then our <strong>update rule</strong> rescales the learning rate for each param coordinates:</p>

\[\theta_{t+1, i} = \theta_{t,i} - \frac{\eta}{\sqrt{r_{t,i}} + \varepsilon} g_{t,i}\]

<p>This can be written in a vectorized format like:</p>

\[\theta_{t+1} = \theta_{t} - \eta D_{t}^{-1/2} g_{t}\]

<p>where $D_t = \text{diag}(r_t)$ and each diagonal element corresponds to one coordinate’s cumulative gradient magnitude. So we’re basically embedding the $i$ into the shape of the vectors and dimension.</p>

<p>Again, DL loves big matrices.</p>

<p>I am not going to try and do a 3D visualization given those take me awhile to get to an acceptable place.</p>

<div class="interactive-python">
<pre><code class="language-python">
import numpy as np
import matplotlib.pyplot as plt
from sympy import Matrix, pprint, init_printing
from itertools import product
from IPython.display import display
init_printing(use_unicode=True)

def styblinski_tang_fn(x: float, y: float) -&gt; float:
    return 0.5 * ((x**4 - 16 * x**2 + 5 * x) + (y**4 - 16 * y**2 + 5 * y))

def styblinski_tang_grad(x: float, y: float) -&gt; np.ndarray:
    dfx = 2 * x**3 - 16 * x + 2.5
    dfy = 2 * y**3 - 16 * y + 2.5
    return np.array([dfx, dfy], dtype=float)

def stationary_points_and_global_min():
    roots = np.roots([2.0, 0.0, -16.0, 2.5])
    roots = np.real(roots[np.isreal(roots)])
    minima_1d = [r for r in roots if (6*r*r - 16) &gt; 0]
    mins2d = np.array(list(product(minima_1d, repeat=2)), dtype=float)
    vals = np.array([styblinski_tang_fn(x, y) for x, y in mins2d])
    gidx = np.argmin(vals)
    return mins2d, mins2d[gidx], vals[gidx]

def run_sgd(theta0, eta=0.02, steps=1200):
    theta = np.array(theta0, float)
    path = [theta.copy()]
    for _ in range(steps):
        theta -= eta * styblinski_tang_grad(*theta)
        path.append(theta.copy())
    return np.array(path)

def run_momentum(theta0, eta=0.02, beta=0.90, steps=1200):
    theta = np.array(theta0, float)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        grad = styblinski_tang_grad(*theta)
        v = beta * v - eta * grad
        theta = theta + v
        path.append(theta.copy())
    return np.array(path)

def run_adagrad(theta0, eta=0.40, eps=1e-8, steps=1200):
    """
    r_t &lt;- r_{t-1} + g_t^2
    theta &lt;- theta - (eta / (sqrt(r_t) + eps)) * g_t
    """
    theta = np.array(theta0, float)
    r = np.zeros_like(theta)         
    path = [theta.copy()]
    for step in range(steps):
        g = styblinski_tang_grad(*theta)
        r = r + g * g
        lr = eta / (np.sqrt(r) + eps) # elementwise effective LR
        if step % 100 == 0 and step &lt; 600:
            D = np.diag(r)
            print(f"\nStep {step}:  Dt = diag(r_step)")
            display(Matrix(D))
        if step == steps - 1:
            D = np.diag(r)
            print(f"\nFinal Step {step}:  Dt = diag(r_step)")
            display(Matrix(D))
        theta = theta - lr * g
        path.append(theta.copy())
    return np.array(path)

"""params"""
theta_start = np.array([4.1, 4.5], dtype=float)
eta = 0.02
beta = 0.90
steps = 1200
eta_adagrad = 0.40
eps_adagrad = 1e-8

sgd_path = run_sgd(theta_start, eta=eta, steps=steps)
mom_path = run_momentum(theta_start, eta=eta, beta=beta, steps=steps)
ada_path = run_adagrad(theta_start, eta=eta_adagrad, eps=eps_adagrad, steps=steps)
mins2d, gmin_pt, gmin_val = stationary_points_and_global_min()

"""viz"""
x = y = np.linspace(-5, 5, 400)
X, Y = np.meshgrid(x, y)
Z = styblinski_tang_fn(X, Y)

plt.figure(figsize=(8, 7))
cs = plt.contour(X, Y, Z, levels=50, alpha=0.85)   # (kept close; removed explicit cmap for portability)
plt.clabel(cs, inline=True, fmt="%.0f", fontsize=7)

plt.plot(sgd_path[:, 0], sgd_path[:, 1], 'r.-', lw=1.5, ms=3, label='SGD')
plt.plot(mom_path[:, 0], mom_path[:, 1], 'b.-', lw=1.5, ms=3,
         label=f'Momentum (β={beta})')
plt.plot(ada_path[:, 0], ada_path[:, 1], 'g.-', lw=1.5, ms=3,
         label=f'AdaGrad (η₀={eta_adagrad})')

plt.scatter(sgd_path[0, 0], sgd_path[0, 1], c='orange', s=80, label='Start', zorder=3)
plt.scatter(sgd_path[-1, 0], sgd_path[-1, 1], c='red', s=70, label='SGD End', zorder=3)
plt.scatter(mom_path[-1, 0], mom_path[-1, 1], c='blue', s=70, label='Momentum End', zorder=3)
plt.scatter(ada_path[-1, 0], ada_path[-1, 1], c='green', s=70, label='AdaGrad End', zorder=3)

vals = np.array([styblinski_tang_fn(x0, y0) for x0, y0 in mins2d])
mask = np.ones(len(mins2d), dtype=bool)
mask[np.argmin(vals)] = False
if np.any(mask):
    plt.scatter(mins2d[mask, 0], mins2d[mask, 1],
                marker='v', s=120, edgecolor='k', facecolor='white',
                label='Local minima', zorder=4)
plt.scatter(gmin_pt[0], gmin_pt[1], marker='*', s=220, edgecolor='k',
            facecolor='gold', label=f'Global min ({gmin_pt[0]:.4f}, {gmin_pt[1]:.4f})\n f={gmin_val:.4f}', zorder=5)

plt.title("SGD vs Momentum vs AdaGrad on Styblinski–Tang")
plt.xlabel("x"); plt.ylabel("y")
plt.legend(loc='lower right'); plt.grid(alpha=0.3); plt.tight_layout()
plt.show()
</code></pre>

</div>

<h4 id="variations-1">Variations</h4>

<p>Arguably, RMSProp is a deviation of AdaGrad, but… i decided to split it out given how talked about RMSProp is.</p>

<p>However, similar to AdaGrad, there’s also</p>

<ul>
  <li><a href="https://optimization.cbe.cornell.edu/index.php?title=AdaGrad#AdaDelta">AdaDelta</a>
    <ul>
      <li>basically does an exponential weighted average</li>
    </ul>
  </li>
</ul>

<h3 id="rmsprop-2012">RMSProp (2012)</h3>

<p>RMSProp, or Root Mean Square Propagation, allows the effective learning rate to increase or decrease. It cuts away from the effeective LR monotonically shrinking.</p>

<p>Confusingly but importantly, RMSProp is identical to AdaDelta just withohut the running average for parameter updates.</p>

<p>The whole notion of RMSProp is that we keep an <strong>exponential weighted moving average</strong> (EMA) of recent gradients per parameter.</p>

<p>We scale the raw gradient by the inverse root of that EMA.</p>

<p>In other words,</p>

\[\begin{align}
s_t &amp;= \rho s_{t-1} + (1-\rho) g_t^2 \\
\theta_{t+1} &amp;= \theta_t - \eta \frac{g_t}{\sqrt{s_t} + \varepsilon}
\end{align}\]

<p>Sometimes people use $\beta$ instead of $\rho$. But here is what these mean:</p>

<ul>
  <li>$s_t$ - accumulated moving average of squared gradients at time $t$</li>
  <li>$\rho$ - the decay rate, typically between 0.9 and 0.99</li>
  <li>$g(t)$ - still represents our gradient at time $t$</li>
</ul>

<p>And once again, for matrix math, similar to AdaGrad we can play a similar game with vectorizing it:</p>

<p>\(\theta_{t+1} = \theta_{t} - \eta \tilde{D}_t^{-\frac{1}{2}} g_t\)
where 
\(\tilde{D}_t = \text{diag}(s_t + \varepsilon)\)</p>

<p>So the total result is that we have large, consistently-steep coords get downscaled, and quiet coords get a healthier step. By using a moving window, step sizes don’t vanish over time.</p>

<p>The EMA is meant to focus on recent gradients, and maintains steady effective learning rate while preventing premature decay. With AdaGrad, effective LR monotonically shrinks and can stall on long runs.</p>

<p>Again, I am not going to try and do a 3D visualization given those take me awhile to get to an acceptable place.</p>

<div class="interactive-python">
<pre><code class="language-python">
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from sympy import Matrix
from IPython.display import display

def styblinski_tang_fn(x: float, y: float) -&gt; float:
    return 0.5 * ((x**4 - 16 * x**2 + 5 * x) + (y**4 - 16 * y**2 + 5 * y))


def styblinski_tang_grad(x: float, y: float) -&gt; np.ndarray:
    dfx = 2 * x**3 - 16 * x + 2.5
    dfy = 2 * y**3 - 16 * y + 2.5
    return np.array([dfx, dfy], dtype=float)


def stationary_points_and_global_min():
    roots = np.roots([2.0, 0.0, -16.0, 2.5])
    roots = np.real(roots[np.isreal(roots)])
    minima_1d = [r for r in roots if (6 * r * r - 16) &gt; 0]
    mins2d = np.array(list(product(minima_1d, repeat=2)), dtype=float)
    vals = np.array([styblinski_tang_fn(x, y) for x, y in mins2d])
    gidx = np.argmin(vals)
    return mins2d, mins2d[gidx], vals[gidx]


def run_sgd(theta0, eta=0.02, steps=1200):
    theta = np.array(theta0, float)
    path = [theta.copy()]
    for _ in range(steps):
        theta -= eta * styblinski_tang_grad(*theta)
        path.append(theta.copy())
    return np.array(path)


def run_momentum(theta0, eta=0.02, beta=0.90, steps=1200):
    theta = np.array(theta0, float)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        v = beta * v - eta * g
        theta = theta + v
        path.append(theta.copy())
    return np.array(path)


def run_adagrad(theta0, eta=0.40, eps=1e-8, steps=1200):
    theta = np.array(theta0, float)
    r = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        r = r + g * g
        lr_eff = eta / (np.sqrt(r) + eps)
        theta = theta - lr_eff * g
        path.append(theta.copy())
    return np.array(path)

def run_rmsprop(theta0, eta=1e-2, rho=0.9, eps=1e-8, steps=1200):
    """
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    theta &lt;- theta - eta * g_t / (sqrt(s_t) + eps)
    """
    theta = np.array(theta0, float)
    s = np.zeros_like(theta)
    path = [theta.copy()]
    for step in range(steps):
        g = styblinski_tang_grad(*theta)
        s = rho * s + (1 - rho) * (g * g)
        if step % 100 == 0 and step &lt; 600:
            S = np.diag(s)
            print(f"\nStep {step}:  s_t (EMA of squared gradients)")
            display(Matrix(S))
        if step == steps - 1:
            S = np.diag(s)
            print(f"\nFinal Step {step}:  s_t (EMA of squared gradients)")
            display(Matrix(S))
        theta = theta - eta * g / (np.sqrt(s) + eps)
        path.append(theta.copy())
    return np.array(path)

def run_rmsprop_centered(theta0, eta=1e-2, rho=0.9, eps=1e-8, steps=1200):
    """
    m_t = rho * m_{t-1} + (1 - rho) * g_t
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    denom = sqrt(s_t - m_t^2) + eps   # variance-based
    """
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    s = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        m = rho * m + (1 - rho) * g
        s = rho * s + (1 - rho) * (g * g)
        denom = np.sqrt(np.maximum(s - m * m, 0.0)) + eps
        theta = theta - eta * g / denom
        path.append(theta.copy())
    return np.array(path)


theta_start = np.array([4.1, 4.5], dtype=float)
steps = 1200

eta_sgd = 0.02
eta_mom, beta = 0.02, 0.90
eta_adagrad = 0.40
eta_rms, rho, eps = 1e-2, 0.9, 1e-8
eta_rms_c = 1e-2

sgd_path = run_sgd(theta_start, eta=eta_sgd, steps=steps)
mom_path = run_momentum(theta_start, eta=eta_mom, beta=beta, steps=steps)
ada_path = run_adagrad(theta_start, eta=eta_adagrad, steps=steps)
rms_path = run_rmsprop(theta_start, eta=eta_rms, rho=rho, eps=eps, steps=steps)
rmsc_path = run_rmsprop_centered(theta_start, eta=eta_rms_c, rho=rho, eps=eps, steps=steps)

mins2d, gmin_pt, gmin_val = stationary_points_and_global_min()

x = y = np.linspace(-5, 5, 400)
X, Y = np.meshgrid(x, y)
Z = styblinski_tang_fn(X, Y)

plt.figure(figsize=(9, 8))
cs = plt.contour(X, Y, Z, levels=50, alpha=0.85)
plt.clabel(cs, inline=True, fmt="%.0f", fontsize=7)

plt.plot(sgd_path[:, 0], sgd_path[:, 1], '.-', lw=1.2, ms=3, label='SGD')
plt.plot(mom_path[:, 0], mom_path[:, 1], '.-', lw=1.2, ms=3, label=f'Momentum (β={beta})')
plt.plot(ada_path[:, 0], ada_path[:, 1], '.-', lw=1.2, ms=3, label='AdaGrad')
plt.plot(rms_path[:, 0], rms_path[:, 1], '.-', lw=1.2, ms=3, label=f'RMSProp (ρ={rho})')
plt.plot(rmsc_path[:, 0], rmsc_path[:, 1], '.-', lw=1.2, ms=3, label='RMSProp (centered)')

plt.scatter(sgd_path[0, 0], sgd_path[0, 1], s=80, label='Start', zorder=3)
plt.scatter(sgd_path[-1, 0], sgd_path[-1, 1], s=60, label='SGD End', zorder=3)
plt.scatter(mom_path[-1, 0], mom_path[-1, 1], s=60, label='Momentum End', zorder=3)
plt.scatter(ada_path[-1, 0], ada_path[-1, 1], s=60, label='AdaGrad End', zorder=3)
plt.scatter(rms_path[-1, 0], rms_path[-1, 1], s=60, label='RMSProp End', zorder=3)
plt.scatter(rmsc_path[-1, 0], rmsc_path[-1, 1], s=60, label='RMSProp (centered) End', zorder=3)

vals = np.array([styblinski_tang_fn(x0, y0) for x0, y0 in mins2d])
mask = np.ones(len(mins2d), dtype=bool)
mask[np.argmin(vals)] = False
if np.any(mask):
    plt.scatter(mins2d[mask, 0], mins2d[mask, 1],
                marker='v', s=120, edgecolor='k', facecolor='white',
                label='Local minima', zorder=4)
plt.scatter(gmin_pt[0], gmin_pt[1], marker='*', s=220, edgecolor='k',
            facecolor='gold', label=f'Global min ({gmin_pt[0]:.4f}, {gmin_pt[1]:.4f})\n f={gmin_val:.4f}', zorder=5)

plt.title("SGD vs Momentum vs AdaGrad vs RMSProp (and Centered) on Styblinski–Tang")
plt.xlabel("x")
plt.ylabel("y")
plt.legend(loc='lower right')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
</code></pre>

</div>

<p>Again, the important code part is here:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">run_rmsprop</span><span class="p">(</span><span class="n">theta0</span><span class="p">,</span> <span class="n">eta</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">rho</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span> <span class="n">steps</span><span class="o">=</span><span class="mi">1200</span><span class="p">):</span>
    <span class="s">"""
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    theta &lt;- theta - eta * g_t / (sqrt(s_t) + eps)
    """</span>
    <span class="n">theta</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">theta0</span><span class="p">,</span> <span class="nb">float</span><span class="p">)</span>
    <span class="n">s</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">theta</span><span class="p">)</span>
    <span class="n">path</span> <span class="o">=</span> <span class="p">[</span><span class="n">theta</span><span class="p">.</span><span class="n">copy</span><span class="p">()]</span>
    <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">steps</span><span class="p">):</span>
        <span class="n">g</span> <span class="o">=</span> <span class="n">styblinski_tang_grad</span><span class="p">(</span><span class="o">*</span><span class="n">theta</span><span class="p">)</span>
        <span class="n">s</span> <span class="o">=</span> <span class="n">rho</span> <span class="o">*</span> <span class="n">s</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">rho</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">g</span> <span class="o">*</span> <span class="n">g</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">100</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">step</span> <span class="o">&lt;</span> <span class="mi">600</span><span class="p">:</span>
            <span class="n">S</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
            <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="se">\n</span><span class="s">Step </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s">:  s_t (EMA of squared gradients)"</span><span class="p">)</span>
            <span class="n">display</span><span class="p">(</span><span class="n">Matrix</span><span class="p">(</span><span class="n">S</span><span class="p">))</span>
        <span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="n">steps</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">S</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
            <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="se">\n</span><span class="s">Final Step </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s">:  s_t (EMA of squared gradients)"</span><span class="p">)</span>
            <span class="n">display</span><span class="p">(</span><span class="n">Matrix</span><span class="p">(</span><span class="n">S</span><span class="p">))</span>
        <span class="n">theta</span> <span class="o">=</span> <span class="n">theta</span> <span class="o">-</span> <span class="n">eta</span> <span class="o">*</span> <span class="n">g</span> <span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">s</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
        <span class="n">path</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">theta</span><span class="p">.</span><span class="n">copy</span><span class="p">())</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
</code></pre></div></div>

<h4 id="variations-2">Variations</h4>

<ul>
  <li>RMSProp (centered)</li>
</ul>

<h2 id="bias-correction-finally-meeting-adam-optimizer-2015">Bias Correction (finally meeting Adam Optimizer, 2015)</h2>

<p>RMSProp is fantastic but still subject to getting caught in local minima.</p>

<p>Ok finally in 2015 people introduced <a href="https://arxiv.org/abs/1412.6980">Adam</a>. This is basically marrying the momentum portions along with the utilization of the first two moments from RMSProp / AdaGrad. However, a key introduced is bias-correcting the EMAs because they start at zero and are biased early. Our update uses the <strong>direction</strong> $\hat{m}_t$ and the <strong>scale</strong> $\sqrt{\hat{v_t}}$.</p>

<p>Mathematically, we now have:</p>

<ul>
  <li><strong>momentum</strong> part (exp avg of raw gradients, our first moment (i.e. understanding magnitude of gradient updates)) \(m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t\)</li>
  <li><strong>rms prop</strong> part (exp avg of squared gradients, our second moment (i.e. understanding energy / dispersion of gradient updates)) \(v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\)</li>
  <li><strong>bias correction</strong> part (new) - getting around the fact that both are starting from 0, so divde by $ 1 - \beta_i^t $ \(\hat{m}_t = \frac{m_t}{1- \beta_1^t} \qquad \hat{v}_t = \frac{v_t}{1-\beta_2^t}\)</li>
</ul>

<p>with our final update being:</p>

\[\theta_{t+1} = \theta_t - \eta \frac{\hat{m_t}}{\sqrt{\hat{v_t}}+\varepsilon}\]

<p>Again, same thing with the vectorization, we’re always just modifying our $D$ matrix:</p>

\[\theta_{t+1} = \theta_t - \eta D_{t}^{-\frac{1}{2}}\hat{m}_t, \quad D_t = \text{diag}(\hat{v}_t + \varepsilon)\]

<h3 id="comparison-so-far">Comparison so Far</h3>

<p>I had ChatGPT create this table which does a good job of understanding the nuances between:</p>

<table>
  <thead>
    <tr>
      <th>Optimizer</th>
      <th>Tracks mean of gradients?</th>
      <th>Tracks mean of squared gradients?</th>
      <th>Bias correction?</th>
      <th>Update uses</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>Momentum (Polyak)</strong></td>
      <td>✅ $m_t = \beta m_{t-1} + (1-\beta) g_t$</td>
      <td>❌</td>
      <td>❌</td>
      <td>$ \theta_{t+1} = \theta_t - \eta m_t $</td>
    </tr>
    <tr>
      <td><strong>RMSProp (Hinton)</strong></td>
      <td>❌</td>
      <td>✅ $s_t = \rho s_{t-1} + (1-\rho) g_t^2$</td>
      <td>❌</td>
      <td>$ \theta_{t+1} = \theta_t - \eta \dfrac{g_t}{\sqrt{s_t}+\varepsilon} $</td>
    </tr>
    <tr>
      <td><strong>Adam (Kingma &amp; Ba)</strong></td>
      <td>✅ $m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t$</td>
      <td>✅ $v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2$</td>
      <td>✅ divides by (1-\beta^t)</td>
      <td>$ \theta_{t+1} = \theta_t - \eta \dfrac{\hat m_t}{\sqrt{\hat v_t}+\varepsilon} $</td>
    </tr>
  </tbody>
</table>

<h3 id="plain-english">Plain English</h3>

<p>My understanding in plain english in how each step affects this:</p>

<ul>
  <li>SGD - size of gradient is taken into account</li>
  <li>SGD with momentum - adds smoothing with momentum (introduces $\hat{m}_t$)</li>
  <li>RMSProp - adds scaling for recent average squared gradients vs older ones</li>
  <li>Adam - bias correction to fix underestimation at early timesteps</li>
</ul>

<h3 id="viz">Viz</h3>
<div class="interactive-python">
<pre><code class="language-python">
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from sympy import Matrix
from IPython.display import display

def styblinski_tang_fn(x: float, y: float) -&gt; float:
    return 0.5 * ((x**4 - 16 * x**2 + 5 * x) + (y**4 - 16 * y**2 + 5 * y))


def styblinski_tang_grad(x: float, y: float) -&gt; np.ndarray:
    dfx = 2 * x**3 - 16 * x + 2.5
    dfy = 2 * y**3 - 16 * y + 2.5
    return np.array([dfx, dfy], dtype=float)


def stationary_points_and_global_min():
    roots = np.roots([2.0, 0.0, -16.0, 2.5])
    roots = np.real(roots[np.isreal(roots)])
    minima_1d = [r for r in roots if (6 * r * r - 16) &gt; 0]
    mins2d = np.array(list(product(minima_1d, repeat=2)), dtype=float)
    vals = np.array([styblinski_tang_fn(x, y) for x, y in mins2d])
    gidx = np.argmin(vals)
    return mins2d, mins2d[gidx], vals[gidx]


def run_sgd(theta0, eta=0.02, steps=1200):
    theta = np.array(theta0, float)
    path = [theta.copy()]
    for _ in range(steps):
        theta -= eta * styblinski_tang_grad(*theta)
        path.append(theta.copy())
    return np.array(path)


def run_momentum(theta0, eta=0.02, beta=0.90, steps=1200):
    theta = np.array(theta0, float)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        v = beta * v - eta * g
        theta = theta + v
        path.append(theta.copy())
    return np.array(path)


def run_adagrad(theta0, eta=0.40, eps=1e-8, steps=1200):
    theta = np.array(theta0, float)
    r = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        r = r + g * g
        lr_eff = eta / (np.sqrt(r) + eps)
        theta = theta - lr_eff * g
        path.append(theta.copy())
    return np.array(path)

def run_rmsprop(theta0, eta=1e-2, rho=0.9, eps=1e-8, steps=1200):
    """
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    theta &lt;- theta - eta * g_t / (sqrt(s_t) + eps)
    """
    theta = np.array(theta0, float)
    s = np.zeros_like(theta)
    path = [theta.copy()]
    for step in range(steps):
        g = styblinski_tang_grad(*theta)
        s = rho * s + (1 - rho) * (g * g)
        theta = theta - eta * g / (np.sqrt(s) + eps)
        path.append(theta.copy())
    return np.array(path)

def run_rmsprop_centered(theta0, eta=1e-2, rho=0.9, eps=1e-8, steps=1200):
    """
    m_t = rho * m_{t-1} + (1 - rho) * g_t
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    denom = sqrt(s_t - m_t^2) + eps   # variance-based
    """
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    s = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        m = rho * m + (1 - rho) * g
        s = rho * s + (1 - rho) * (g * g)
        denom = np.sqrt(np.maximum(s - m * m, 0.0)) + eps
        theta = theta - eta * g / denom
        path.append(theta.copy())
    return np.array(path)

def run_adam(theta0, eta=1e-2, beta1=0.9, beta2=0.999, eps=1e-8, steps=1200):
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for t in range(1, steps + 1):
        g = styblinski_tang_grad(*theta)
        m = beta1 * m + (1 - beta1) * g
        v = beta2 * v + (1 - beta2) * (g * g)
        m_hat = m / (1 - beta1**t)
        v_hat = v / (1 - beta2**t)
        theta = theta - eta * m_hat / (np.sqrt(v_hat) + eps)
        path.append(theta.copy())
    return np.array(path)

theta_start = np.array([4.1, 4.5], dtype=float)
steps = 1200

eta_sgd = 0.02
eta_mom, beta = 0.02, 0.90
eta_adagrad = 0.40
eta_rms, rho, eps = 1e-2, 0.9, 1e-8
eta_rms_c = 1e-2

sgd_path = run_sgd(theta_start, eta=eta_sgd, steps=steps)
mom_path = run_momentum(theta_start, eta=eta_mom, beta=beta, steps=steps)
ada_path = run_adagrad(theta_start, eta=eta_adagrad, steps=steps)
rms_path = run_rmsprop(theta_start, eta=eta_rms, rho=rho, eps=eps, steps=steps)
rmsc_path = run_rmsprop_centered(theta_start, eta=eta_rms_c, rho=rho, eps=eps, steps=steps)
adam_path = run_adam(theta_start)
mins2d, gmin_pt, gmin_val = stationary_points_and_global_min()

x = y = np.linspace(-5, 5, 400)
X, Y = np.meshgrid(x, y)
Z = styblinski_tang_fn(X, Y)

plt.figure(figsize=(9, 8))
cs = plt.contour(X, Y, Z, levels=50, alpha=0.85)
plt.clabel(cs, inline=True, fmt="%.0f", fontsize=7)

plt.plot(sgd_path[:, 0], sgd_path[:, 1], '.-', lw=1.2, ms=3, label='SGD')
plt.plot(mom_path[:, 0], mom_path[:, 1], '.-', lw=1.2, ms=3, label=f'Momentum (β={beta})')
plt.plot(ada_path[:, 0], ada_path[:, 1], '.-', lw=1.2, ms=3, label='AdaGrad')
plt.plot(rms_path[:, 0], rms_path[:, 1], '.-', lw=1.2, ms=3, label=f'RMSProp (ρ={rho})')
plt.plot(rmsc_path[:, 0], rmsc_path[:, 1], '.-', lw=1.2, ms=3, label='RMSProp (centered)')
plt.plot(adam_path[:, 0], adam_path[:, 1], '.-', lw=1.2, ms=3, label='Adam')

plt.scatter(sgd_path[0, 0], sgd_path[0, 1], s=80, label='Start', zorder=3)
plt.scatter(sgd_path[-1, 0], sgd_path[-1, 1], s=60, label='SGD End', zorder=3)
plt.scatter(mom_path[-1, 0], mom_path[-1, 1], s=60, label='Momentum End', zorder=3)
plt.scatter(ada_path[-1, 0], ada_path[-1, 1], s=60, label='AdaGrad End', zorder=3)
plt.scatter(rms_path[-1, 0], rms_path[-1, 1], s=60, label='RMSProp End', zorder=3)
plt.scatter(rmsc_path[-1, 0], rmsc_path[-1, 1], s=60, label='RMSProp (centered) End', zorder=3)
plt.scatter(adam_path[-1, 0], adam_path[-1, 1], s=60, label='Adam End', zorder=3)

vals = np.array([styblinski_tang_fn(x0, y0) for x0, y0 in mins2d])
mask = np.ones(len(mins2d), dtype=bool)
mask[np.argmin(vals)] = False
if np.any(mask):
    plt.scatter(mins2d[mask, 0], mins2d[mask, 1],
                marker='v', s=120, edgecolor='k', facecolor='white',
                label='Local minima', zorder=4)
plt.scatter(gmin_pt[0], gmin_pt[1], marker='*', s=220, edgecolor='k',
            facecolor='gold', label=f'Global min ({gmin_pt[0]:.4f}, {gmin_pt[1]:.4f})\n f={gmin_val:.4f}', zorder=5)

plt.title("SGD vs Momentum vs AdaGrad vs RMSProp (and Centered) on Styblinski–Tang")
plt.xlabel("x")
plt.ylabel("y")
plt.legend(loc='upper left')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
</code></pre>

</div>

<h2 id="weight-decay-coupling-the-w-in-adamw-2017">Weight Decay Coupling (the “W” in AdamW, 2017)</h2>

<p>A very slight distinction but this was a key change that led to better generalization. AdamW was utilized to train BERT, GPT, and others. Most modern frameworks (PyTorch, Tensorflow, JAX) now make it the default.</p>

<h3 id="l2-regularization">L2 Regularization</h3>

<p>Ok so before we get to AdamW cleanly, we have to discuss L2 regularization (also commonly noted as $\lambda | \theta | ^2 $). The idea with L2 regularization is that we have a penalty on our loss function so that the model doesn’t overfit. Basically saying we want to minimize both the loss and the size of the weights.</p>

<p>When you use Adam, you compute the gradient of your total loss. So basically from start to finish, walking through:</p>

\[\begin{align}

L_{total} ( \theta ) &amp;= L_{data} (\theta) + \frac{\lambda}{2} \| \theta \|^2 \\

\nabla_{\theta} L_{total} (\theta) &amp;= \nabla_{\theta} L_{data} (\theta) + \lambda \theta

\end{align}\]

<p>This means that every gradient update has two parts:</p>

<ol>
  <li>a data term</li>
  <li>a regularization term</li>
</ol>

<p>And again, so $\lambda$ is the regularization strength - basically how much we want to penalize large weights.</p>

<p>So now incorporating this into Adam. We normally compute the gradient of our total loss.</p>

\[g_t = \nabla_{\theta} L_{total} (\theta_t) = \nabla_{\theta} L_{data} (\theta_t) + \lambda \theta_t\]

<p>But the downside is that the $+ \lambda \theta_t$ term becomes part of the gradient update…. That’s an issue for us because Adam does its adaptive scaling magic:</p>

\[\theta_{t+1} = \theta_{t} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \varepsilon}\]

<p>but that adaptive scaling portion also now includes the $\lambda \theta_t$ portion meaning some weights get decayed more than others, all depending on their individual $v_t$ values.</p>

<div class="markdown-quiz" data-quiz-id="adam-w-vs-adam-pt1">
  <span class="markdown-quiz__eyebrow">
    Knowledge check
  </span>

  
    <h3 class="markdown-quiz__question">Let's say we're not applying L2 regularization as part of our loss function, is AdamW going to be different at all from Adam?</h3>
  

  
    <p class="markdown-quiz__prompt">Select the answer</p>
  

  
  <ol class="markdown-quiz__choices">
    
      
      
      
        
        
        
        <li class="markdown-quiz__option" data-correct="false" data-feedback="that&#39;s not quite right! Both the loss function and the optimizer can have L2 regularization. Even if it&#39;s not explicitly in our loss function, having the L2 regularization in the AdamW optimzier can help to penalize large weights">
          <label for="adam-w-vs-adam-pt1-choice-1">
            <input type="radio" name="adam-w-vs-adam-pt1" id="adam-w-vs-adam-pt1-choice-1" value="No, L2 regularization is linked to AdamW so if it isn&#39;t in our loss function, then we don&#39;t need AdamW" />
            <span>No, L2 regularization is linked to AdamW so if it isn't in our loss function, then we don't need AdamW</span>
          </label>
        </li>
      
    
      
      
      
        
        
        
        <li class="markdown-quiz__option" data-correct="true" data-feedback="Yes, there&#39;s almost multiple points of injection for L2 regularization, and having AdamW even if your loss function / training code doesn&#39;t have L2 regularization can still be beneficial.">
          <label for="adam-w-vs-adam-pt1-choice-2">
            <input type="radio" name="adam-w-vs-adam-pt1" id="adam-w-vs-adam-pt1-choice-2" value="Yes, it can still be beneficial" />
            <span>Yes, it can still be beneficial</span>
          </label>
        </li>
      
    
  </ol>

  <div class="markdown-quiz__actions">
    <button class="markdown-quiz__submit" type="button">
      Check answer
    </button>
    <button class="markdown-quiz__reset" type="button">
      Try again
    </button>
  </div>

  <div class="markdown-quiz__feedback" role="status" aria-live="polite"></div>

  
    <div class="markdown-quiz__explanation" hidden="">
      L2 regulariziation is a larger concept and whether we include it in the loss function or as part of the optimizer are distinct concepts.
    </div>
  
</div>

<p>Here’s another one. And look at how we can call the AdamW optimizer in pytorch:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">this_is_equiv_to_lambda</span><span class="p">)</span>
</code></pre></div></div>

<div class="markdown-quiz" data-quiz-id="adam-w-vs-adam-pt2">
  <span class="markdown-quiz__eyebrow">
    Knowledge check
  </span>

  
    <h3 class="markdown-quiz__question">If I explicitly wanted to use Adam instead of AdamW, can I modify (and if so, how could I) the above pytorch code to get that?</h3>
  

  
    <p class="markdown-quiz__prompt">Select the answer</p>
  

  
  <ol class="markdown-quiz__choices">
    
      
      
      
        
        
        
        <li class="markdown-quiz__option" data-correct="false" data-feedback="Revisit the mathematical formulas above">
          <label for="adam-w-vs-adam-pt2-choice-1">
            <input type="radio" name="adam-w-vs-adam-pt2" id="adam-w-vs-adam-pt2-choice-1" value="trick question! you can&#39;t" />
            <span>trick question! you can't</span>
          </label>
        </li>
      
    
      
      
      
        
        
        
        <li class="markdown-quiz__option" data-correct="false" data-feedback="Nope! No other parameters">
          <label for="adam-w-vs-adam-pt2-choice-2">
            <input type="radio" name="adam-w-vs-adam-pt2" id="adam-w-vs-adam-pt2-choice-2" value="yes, there&#39;s another AdamW parameter in pytorch that enables vanilla Adam" />
            <span>yes, there's another AdamW parameter in pytorch that enables vanilla Adam</span>
          </label>
        </li>
      
    
      
      
      
        
        
        
        <li class="markdown-quiz__option" data-correct="true" data-feedback="exactly, if we drop our lambda to 0 then the math falls out and Adam is equiv to AdamW">
          <label for="adam-w-vs-adam-pt2-choice-3">
            <input type="radio" name="adam-w-vs-adam-pt2" id="adam-w-vs-adam-pt2-choice-3" value="yes,we can change the weight_decay" />
            <span>yes,we can change the weight_decay</span>
          </label>
        </li>
      
    
  </ol>

  <div class="markdown-quiz__actions">
    <button class="markdown-quiz__submit" type="button">
      Check answer
    </button>
    <button class="markdown-quiz__reset" type="button">
      Try again
    </button>
  </div>

  <div class="markdown-quiz__feedback" role="status" aria-live="polite"></div>

  
    <div class="markdown-quiz__explanation" hidden="">
      While subtle, lambda and the weight penalty factor has resulted in massive wins for the AdamW optimizer
    </div>
  
</div>

<h3 id="viz-1">Viz</h3>

<div class="interactive-python">
<pre><code class="language-python">
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from sympy import Matrix
from IPython.display import display

def styblinski_tang_fn(x: float, y: float) -&gt; float:
    return 0.5 * ((x**4 - 16 * x**2 + 5 * x) + (y**4 - 16 * y**2 + 5 * y))

def styblinski_tang_grad(x: float, y: float) -&gt; np.ndarray:
    dfx = 2 * x**3 - 16 * x + 2.5
    dfy = 2 * y**3 - 16 * y + 2.5
    return np.array([dfx, dfy], dtype=float)

def stationary_points_and_global_min():
    roots = np.roots([2.0, 0.0, -16.0, 2.5])
    roots = np.real(roots[np.isreal(roots)])
    minima_1d = [r for r in roots if (6 * r * r - 16) &gt; 0]
    mins2d = np.array(list(product(minima_1d, repeat=2)), dtype=float)
    vals = np.array([styblinski_tang_fn(x, y) for x, y in mins2d])
    gidx = np.argmin(vals)
    return mins2d, mins2d[gidx], vals[gidx]

def run_sgd(theta0, eta=0.02, steps=1200):
    theta = np.array(theta0, float)
    path = [theta.copy()]
    for _ in range(steps):
        theta -= eta * styblinski_tang_grad(*theta)
        path.append(theta.copy())
    return np.array(path)

def run_momentum(theta0, eta=0.02, beta=0.90, steps=1200):
    theta = np.array(theta0, float)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        v = beta * v - eta * g
        theta = theta + v
        path.append(theta.copy())
    return np.array(path)

def run_adagrad(theta0, eta=0.40, eps=1e-8, steps=1200):
    theta = np.array(theta0, float)
    r = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        r = r + g * g
        lr_eff = eta / (np.sqrt(r) + eps)
        theta = theta - lr_eff * g
        path.append(theta.copy())
    return np.array(path)

def run_rmsprop(theta0, eta=1e-2, rho=0.9, eps=1e-8, steps=1200):
    """
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    theta &lt;- theta - eta * g_t / (sqrt(s_t) + eps)
    """
    theta = np.array(theta0, float)
    s = np.zeros_like(theta)
    path = [theta.copy()]
    for step in range(steps):
        g = styblinski_tang_grad(*theta)
        s = rho * s + (1 - rho) * (g * g)
        theta = theta - eta * g / (np.sqrt(s) + eps)
        path.append(theta.copy())
    return np.array(path)

def run_rmsprop_centered(theta0, eta=1e-2, rho=0.9, eps=1e-8, steps=1200):
    """
    m_t = rho * m_{t-1} + (1 - rho) * g_t
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    denom = sqrt(s_t - m_t^2) + eps   # variance-based
    """
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    s = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        m = rho * m + (1 - rho) * g
        s = rho * s + (1 - rho) * (g * g)
        denom = np.sqrt(np.maximum(s - m * m, 0.0)) + eps
        theta = theta - eta * g / denom
        path.append(theta.copy())
    return np.array(path)

def run_adam(theta0, eta=1e-2, beta1=0.9, beta2=0.999, eps=1e-8, steps=1200):
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for t in range(1, steps + 1):
        g = styblinski_tang_grad(*theta)
        m = beta1 * m + (1 - beta1) * g
        v = beta2 * v + (1 - beta2) * (g * g)
        m_hat = m / (1 - beta1**t)
        v_hat = v / (1 - beta2**t)
        theta = theta - eta * m_hat / (np.sqrt(v_hat) + eps)
        path.append(theta.copy())
    return np.array(path)

def run_adamw(theta0, eta=1e-2, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01, steps=1200):
    """
    AdamW: decoupled weight decay
      theta &lt;- theta - eta * ( m_hat / (sqrt(v_hat)+eps) )  # adaptive step
      theta &lt;- theta - eta * weight_decay * theta           # uniform shrink
    Note: setting weight_decay=0.0 makes AdamW identical to Adam.
    """
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for t in range(1, steps + 1):
        g = styblinski_tang_grad(*theta)
        m = beta1 * m + (1 - beta1) * g
        v = beta2 * v + (1 - beta2) * (g * g)
        m_hat = m / (1 - beta1**t)
        v_hat = v / (1 - beta2**t)

        # adaptive update
        theta = theta - eta * (m_hat / (np.sqrt(v_hat) + eps))
        # decoupled weight decay (uniform; not scaled by v_hat)
        theta = theta - eta * weight_decay * theta

        path.append(theta.copy())
    return np.array(path)

"""----- params -----"""
theta_start = np.array([4.1, 4.5], dtype=float)
steps = 1200

eta_sgd = 0.02
eta_mom, beta = 0.02, 0.90
eta_adagrad = 0.40
eta_rms, rho, eps = 1e-2, 0.9, 1e-8
eta_rms_c = 1e-2

"""Adam / AdamW hyperparams"""
eta_adam = 1e-2
beta1, beta2 = 0.9, 0.999
eps_adam = 1e-8
wd = 1e-2     # try 0.0 (Adam-equivalent) vs 1e-3 vs 1e-2

"""----- runs -----"""
sgd_path  = run_sgd(theta_start, eta=eta_sgd, steps=steps)
mom_path  = run_momentum(theta_start, eta=eta_mom, beta=beta, steps=steps)
ada_path  = run_adagrad(theta_start, eta=eta_adagrad, steps=steps)
rms_path  = run_rmsprop(theta_start, eta=eta_rms, rho=rho, eps=eps, steps=steps)
rmsc_path = run_rmsprop_centered(theta_start, eta=eta_rms_c, rho=rho, eps=eps, steps=steps)
adam_path = run_adam(theta_start, eta=eta_adam, beta1=beta1, beta2=beta2, eps=eps_adam, steps=steps)
adamw_path = run_adamw(theta_start, eta=eta_adam, beta1=beta1, beta2=beta2, eps=eps_adam, weight_decay=wd, steps=steps)

mins2d, gmin_pt, gmin_val = stationary_points_and_global_min()

"""----- viz -----"""
x = y = np.linspace(-5, 5, 400)
X, Y = np.meshgrid(x, y)
Z = styblinski_tang_fn(X, Y)

plt.figure(figsize=(9, 8))
cs = plt.contour(X, Y, Z, levels=50, alpha=0.85)
plt.clabel(cs, inline=True, fmt="%.0f", fontsize=7)

plt.plot(sgd_path[:, 0],   sgd_path[:, 1],   '.-', lw=1.2, ms=3, label='SGD')
plt.plot(mom_path[:, 0],   mom_path[:, 1],   '.-', lw=1.2, ms=3, label=f'Momentum (β={beta})')
plt.plot(ada_path[:, 0],   ada_path[:, 1],   '.-', lw=1.2, ms=3, label='AdaGrad')
plt.plot(rms_path[:, 0],   rms_path[:, 1],   '.-', lw=1.2, ms=3, label=f'RMSProp (ρ={rho})')
plt.plot(rmsc_path[:, 0],  rmsc_path[:, 1],  '.-', lw=1.2, ms=3, label='RMSProp (centered)')
plt.plot(adam_path[:, 0],  adam_path[:, 1],  '.-', lw=1.2, ms=3, label='Adam')
plt.plot(adamw_path[:, 0], adamw_path[:, 1], '.-', lw=1.2, ms=3, label=f'AdamW (wd={wd})')

plt.scatter(sgd_path[0, 0], sgd_path[0, 1], s=80, label='Start', zorder=3)
plt.scatter(sgd_path[-1, 0], sgd_path[-1, 1], s=60, label='SGD End', zorder=3)
plt.scatter(mom_path[-1, 0], mom_path[-1, 1], s=60, label='Momentum End', zorder=3)
plt.scatter(ada_path[-1, 0], ada_path[-1, 1], s=60, label='AdaGrad End', zorder=3)
plt.scatter(rms_path[-1, 0], rms_path[-1, 1], s=60, label='RMSProp End', zorder=3)
plt.scatter(rmsc_path[-1, 0], rmsc_path[-1, 1], s=60, label='RMSProp (centered) End', zorder=3)
plt.scatter(adam_path[-1, 0], adam_path[-1, 1], s=60, label='Adam End', zorder=3)
plt.scatter(adamw_path[-1, 0], adamw_path[-1, 1], s=60, label='AdamW End', zorder=3)

plt.scatter(gmin_pt[0], gmin_pt[1], marker='*', s=220, edgecolor='k',
            facecolor='gold', label=f'Global min ({gmin_pt[0]:.4f}, {gmin_pt[1]:.4f})\n f={gmin_val:.4f}', zorder=5)

vals = np.array([styblinski_tang_fn(x0, y0) for x0, y0 in mins2d])
mask = np.ones(len(mins2d), dtype=bool)
mask[np.argmin(vals)] = False
if np.any(mask):
    plt.scatter(mins2d[mask, 0], mins2d[mask, 1],
                marker='v', s=120, edgecolor='k', facecolor='white',
                label='Local minima', zorder=4)

plt.title("SGD, Momentum, AdaGrad, RMSProp (+centered), Adam, AdamW on Styblinski–Tang")
plt.xlabel("x"); plt.ylabel("y")
plt.legend(loc='upper left')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
</code></pre>

</div>

<p>You’ll note… before we get to Muon that this example is almost entirely contrived. While the Styblinski-Tang function is a good example of a loss function that would be easy to get caught in a local minima and is hard to find the global min, you’ll note that just because in my contrived examples the <code class="language-plaintext highlighter-rouge">SGD with Momentum</code> optimizer is finding the global min does not mean that it generalizes well. Generally AdamW has been the defacto winner.</p>

<h1 id="muon-momentum-orthogonalized-by-newton-schulz-2025">Muon (MomentUm Orthogonalized by Newton-Schulz) (2025)</h1>

<p>Alright jeez, all of the above was a bit accidental, but I wanted to give you all a build up / very very quick run through of the various optimziers that have evolved. Again, the space is definitely iterative (pun intended), but these optimizers all build off of each other. Muon is no different.</p>

<h2 id="theory">Theory</h2>

<p>The idea is that we’re still working with our momentum matrix. The momentum matrix can tend to become row-rank in practice, which means only a couple of directions dominate.</p>

<p>Muon tries to orthogonalize our momentum matrix. Rare directions are amplified by the orthogonalization. Again, recall from the <a href="https://little-book-of.github.io/linear-algebra/">Little Book of Linear Algebra</a> that this means:</p>

<p>\(\text{Ortho}(M) = \text{argmin}_O \{ \| O - M \| f \}\)
where $OO^T = I$ and $O^T O = I$</p>

<p>Ok while this is hard… what do we turn to besides our good friend - the swiss army knife of linalg - SVD (singular value decomposiiton)</p>

\[M = U S V^T\]

<p>So we would compute SVD and then we’d set the S matrix to be diag(1).</p>

<p>However, once again SVD is computationally expensive so this</p>

<h3 id="odd-polynomial-matrix">Odd Polynomial Matrix</h3>

<p>Odd polynomial matrices are:</p>

\[\rho (X) = aX + b(X X^T)X\]

<p>so we could do:</p>

\[\rho (M) = aM + b(MM^T)M\]

<p>So let’s go ahead and do some math where we substitute in $M$.</p>

\[\begin{aligned}
\rho (M) &amp;= aM + b(MM^T)M  \\
\rho (M) &amp;= (a + b(MM^T))M \\
\rho (M) &amp;= (a + b((USV^T)(VSU^T)))(USV^T) \\
\rho (M) &amp;= (a + b(USV^TVSU^T))(USV^T) \\ \\
&amp;\quad \text{because $V$ is orthonormal, $V^TV = I$} \\ \\
\rho (M) &amp;= (a + b(USSU^T))(USV^T) \\ \\
&amp;\quad \text{and $S$ is diagonal so $SS = S^2$} \\ \\
\rho (M) &amp;= (a + b(US^2U^T))(USV^T) \\
\rho (M) &amp;= a(USV^T) + b(US^2U^TUSV^T) \\ \\ 
&amp;\quad \text{because $U$ is orthonormal, $U^TU = I$} \\ \\
\rho (M) &amp;= a(USV^T) + b(US^2SV^T) \\
\rho (M) &amp;= a(USV^T) + b(US^3V^T) \\ \\
&amp;\quad \text{simplifying gives} \\ \\
\rho (M) &amp;= U(aS + bS^3)V^T
\end{aligned}\]

<p>So… **applying an odd polynomial matrix function to M acts on the singular values in the same way as applying the function to each singular value function individually and then reconstructing the original matrix from the functions).</p>

<p>This expands for odd polynomials so just take this for granted or derive it for yourself:</p>

\[\begin{align}
\rho (M) &amp;= aM + b(MM^T)M + c(MM^T)^2 M \\
\vdots \\
\rho (M) &amp;= U(aS + bS^3 + cS^5)V^T \\
\end{align}\]

<p>Again, we want S to be diag with 1s… So this now becomes an optimization problem within itself. We’re trying to pick the coefficients of $a, b, c$ so that we get <code class="language-plaintext highlighter-rouge">S = np.eye(S.shape[0])</code>.</p>

<p>So how do we pick out the best parameters that will help us do that….</p>

<h3 id="newton-schulz-iteration">Newton-Schulz Iteration</h3>

<p>Again, <a href="https://www.youtube.com/watch?v=bO5nvE289ec">this video</a> is fantastic. However, this part was a little too abstracted. We’ll turn back to <code class="language-plaintext highlighter-rouge">manim</code> here for some more helpful visualizations and understanding.</p>

<p>~So I’m going to dive into the derivation here.~ Actually, it’s very interesting and I’m going to cover in another blog post. I’ll link it here.</p>

<p>For now, assume that we have these params:</p>

<ul>
  <li>$a = 3.4445$</li>
  <li>$b = -4.7775$</li>
  <li>$c = 2.0315$</li>
</ul>

<p>and those are going to be the params of our newton-schulz iteration that help us converge to what we consider is a valid $S$ for the singular values part of the SVD that has eigenvalues close-ish to 1.</p>

<h3 id="overview">Overview</h3>

<p>So now we have:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>for step in steps:
    compute gradient 
    compute momentum
    normalize momentum matrix
    orthogonalization
    update parameters
</code></pre></div></div>

<p>Now there is also muon with a weight adjustement similar to what we did with AdamW.</p>

<p>So we have:</p>

\[\begin{align}
G_t &amp;\leftarrow \nabla L_t (\theta_{t-1}) \\
M_t &amp;\leftarrow \beta M_{t-1} + G_t \\
M'_t &amp;\leftarrow \frac{M_t}{\| M_t \|_F} \\
O_t &amp;\leftarrow \text{NewtonSchulz5}(M'_t) \\
\theta_t &amp;\leftarrow \theta_{t-1} - \alpha \left(0.2 \sqrt{\text{max}(n,m)} \cdot O_t + \lambda \theta_{t-1}\right)
\end{align}\]

<h2 id="implementation">Implementation</h2>

<p>I actually want to introduce this section by looking at <a href="https://docs.pytorch.org/docs/stable/generated/torch.optim.Muon.html">PyTorch’s documentation</a>. This was added recently, but let’s look here:</p>

<p><img src="/images/understanding-muon/muon-pytorch.png" alt="muon-pytorch" class="center-small lightbox-image" /></p>

<p>This should look super familiar to the code that we’ve been covering!! The only tricky part is the <code class="language-plaintext highlighter-rouge">AdjustLR</code> step which deviates slightly between what the video / I have above covers (which is Moonshot’s implementation) vs Jordan Keller’s original impl of $sqrt{\text{max}(1, \frac{B}{A})}$.</p>

<p>There are a couple of tricky parts with implementing this:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">newton_schulz_5</span><span class="p">(</span><span class="n">M_matrix</span><span class="p">,</span> <span class="n">steps</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-7</span><span class="p">):</span>
    <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="p">(</span><span class="mf">3.4445</span><span class="p">,</span> <span class="o">-</span><span class="mf">4.7750</span><span class="p">,</span> <span class="mf">2.0315</span><span class="p">)</span> <span class="c1"># from Keller Jordan
</span>    <span class="n">X</span> <span class="o">=</span> <span class="n">M_matrix</span><span class="p">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">copy</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> <span class="c1"># speed up in practice
</span>    
    <span class="k">if</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
        <span class="n">X</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>

    <span class="n">X</span> <span class="o">=</span> <span class="n">X</span> <span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">X</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span> <span class="c1"># frobenius norm by def
</span>    <span class="c1"># so this is tricky but we're looking here
</span>    <span class="c1"># \rho (M) &amp;= aM + b(MM^T)M + c(MM^T)^2 M
</span>    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">steps</span><span class="p">):</span>
        <span class="n">A</span> <span class="o">=</span> <span class="n">X</span> <span class="o">@</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>
        <span class="n">B</span> <span class="o">=</span> <span class="n">b</span> <span class="o">*</span> <span class="n">A</span> <span class="o">+</span> <span class="n">c</span> <span class="o">*</span> <span class="n">A</span> <span class="o">@</span> <span class="n">A</span>
        <span class="n">X</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="n">X</span> <span class="o">+</span> <span class="n">B</span> <span class="o">@</span> <span class="n">X</span>
    <span class="k">if</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
        <span class="n">X</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>
    <span class="k">return</span> <span class="n">X</span>


<span class="k">def</span> <span class="nf">run_muon_muonshot</span><span class="p">(</span><span class="n">theta0</span><span class="p">,</span> <span class="n">eta</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.95</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">steps</span><span class="o">=</span><span class="mi">1200</span><span class="p">,</span>
             <span class="n">ns_steps</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-7</span><span class="p">,</span> <span class="n">use_nesterov</span><span class="o">=</span><span class="bp">True</span><span class="p">):</span>
    <span class="n">theta</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">theta0</span><span class="p">,</span> <span class="nb">float</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">theta</span><span class="p">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
        <span class="n">theta</span> <span class="o">=</span> <span class="n">theta</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]</span>
    <span class="k">elif</span> <span class="n">theta</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">theta</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
        <span class="n">theta</span> <span class="o">=</span> <span class="n">theta</span><span class="p">.</span><span class="n">T</span>

    <span class="k">def</span> <span class="nf">adjust_lr</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">):</span>
        <span class="k">return</span> <span class="mf">0.2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">)))</span>

    <span class="n">A</span><span class="p">,</span> <span class="n">B</span> <span class="o">=</span> <span class="n">theta</span><span class="p">.</span><span class="n">shape</span>
    <span class="n">path</span> <span class="o">=</span> <span class="p">[</span><span class="n">theta</span><span class="p">.</span><span class="n">copy</span><span class="p">()]</span>
    <span class="n">B_momentum_buffer</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">theta</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">steps</span><span class="p">):</span>
        <span class="n">g</span> <span class="o">=</span> <span class="n">styblinski_tang_grad</span><span class="p">(</span><span class="o">*</span><span class="n">theta</span><span class="p">)</span>  
        <span class="n">B_momentum_buffer</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">B_momentum_buffer</span> <span class="o">+</span> <span class="n">g</span>
        <span class="c1"># didn't cover nestorv but pytorch has it
</span>        <span class="n">M_eff</span> <span class="o">=</span> <span class="n">g</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">B_momentum_buffer</span> <span class="k">if</span> <span class="n">use_nesterov</span> <span class="k">else</span> <span class="n">B_momentum_buffer</span>
        <span class="n">O</span> <span class="o">=</span> <span class="n">newton_schulz_5</span><span class="p">(</span><span class="n">M_eff</span><span class="p">,</span> <span class="n">steps</span><span class="o">=</span><span class="n">ns_steps</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">)</span>
        <span class="c1"># decoupled weight decay (uniform shrink)
</span>        <span class="n">theta</span> <span class="o">=</span> <span class="n">theta</span> <span class="o">-</span> <span class="n">eta</span> <span class="o">*</span> <span class="p">(</span><span class="n">adjust_lr</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">)</span> <span class="o">*</span> <span class="n">O</span> <span class="o">+</span> <span class="n">weight_decay</span> <span class="o">*</span> <span class="n">theta</span><span class="p">)</span>
        <span class="n">path</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">theta</span><span class="p">.</span><span class="n">copy</span><span class="p">())</span>

    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
</code></pre></div></div>

<p>Once again, here is visualization code:</p>

<div class="interactive-python">
<pre><code class="language-python">
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from sympy import Matrix
from IPython.display import display

def styblinski_tang_fn(x: float, y: float) -&gt; float:
    return 0.5 * ((x**4 - 16 * x**2 + 5 * x) + (y**4 - 16 * y**2 + 5 * y))

def styblinski_tang_grad(x: float, y: float) -&gt; np.ndarray:
    dfx = 2 * x**3 - 16 * x + 2.5
    dfy = 2 * y**3 - 16 * y + 2.5
    return np.array([dfx, dfy], dtype=float)

def stationary_points_and_global_min():
    roots = np.roots([2.0, 0.0, -16.0, 2.5])
    roots = np.real(roots[np.isreal(roots)])
    minima_1d = [r for r in roots if (6 * r * r - 16) &gt; 0]
    mins2d = np.array(list(product(minima_1d, repeat=2)), dtype=float)
    vals = np.array([styblinski_tang_fn(x, y) for x, y in mins2d])
    gidx = np.argmin(vals)
    return mins2d, mins2d[gidx], vals[gidx]

def run_sgd(theta0, eta=0.02, steps=1200):
    theta = np.array(theta0, float)
    path = [theta.copy()]
    for _ in range(steps):
        theta -= eta * styblinski_tang_grad(*theta)
        path.append(theta.copy())
    return np.array(path)

def run_momentum(theta0, eta=0.02, beta=0.90, steps=1200):
    theta = np.array(theta0, float)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        v = beta * v - eta * g
        theta = theta + v
        path.append(theta.copy())
    return np.array(path)

def run_adagrad(theta0, eta=0.40, eps=1e-8, steps=1200):
    theta = np.array(theta0, float)
    r = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        r = r + g * g
        lr_eff = eta / (np.sqrt(r) + eps)
        theta = theta - lr_eff * g
        path.append(theta.copy())
    return np.array(path)

def run_rmsprop(theta0, eta=1e-2, rho=0.9, eps=1e-8, steps=1200):
    """
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    theta &lt;- theta - eta * g_t / (sqrt(s_t) + eps)
    """
    theta = np.array(theta0, float)
    s = np.zeros_like(theta)
    path = [theta.copy()]
    for step in range(steps):
        g = styblinski_tang_grad(*theta)
        s = rho * s + (1 - rho) * (g * g)
        theta = theta - eta * g / (np.sqrt(s) + eps)
        path.append(theta.copy())
    return np.array(path)

def run_rmsprop_centered(theta0, eta=1e-2, rho=0.9, eps=1e-8, steps=1200):
    """
    m_t = rho * m_{t-1} + (1 - rho) * g_t
    s_t = rho * s_{t-1} + (1 - rho) * g_t^2
    denom = sqrt(s_t - m_t^2) + eps   # variance-based
    """
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    s = np.zeros_like(theta)
    path = [theta.copy()]
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)
        m = rho * m + (1 - rho) * g
        s = rho * s + (1 - rho) * (g * g)
        denom = np.sqrt(np.maximum(s - m * m, 0.0)) + eps
        theta = theta - eta * g / denom
        path.append(theta.copy())
    return np.array(path)

def run_adam(theta0, eta=1e-2, beta1=0.9, beta2=0.999, eps=1e-8, steps=1200):
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for t in range(1, steps + 1):
        g = styblinski_tang_grad(*theta)
        m = beta1 * m + (1 - beta1) * g
        v = beta2 * v + (1 - beta2) * (g * g)
        m_hat = m / (1 - beta1**t)
        v_hat = v / (1 - beta2**t)
        theta = theta - eta * m_hat / (np.sqrt(v_hat) + eps)
        path.append(theta.copy())
    return np.array(path)

def run_adamw(theta0, eta=1e-2, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01, steps=1200):
    """
    AdamW: decoupled weight decay
      theta &lt;- theta - eta * ( m_hat / (sqrt(v_hat)+eps) )  # adaptive step
      theta &lt;- theta - eta * weight_decay * theta           # uniform shrink
    Note: setting weight_decay=0.0 makes AdamW identical to Adam.
    """
    theta = np.array(theta0, float)
    m = np.zeros_like(theta)
    v = np.zeros_like(theta)
    path = [theta.copy()]
    for t in range(1, steps + 1):
        g = styblinski_tang_grad(*theta)
        m = beta1 * m + (1 - beta1) * g
        v = beta2 * v + (1 - beta2) * (g * g)
        m_hat = m / (1 - beta1**t)
        v_hat = v / (1 - beta2**t)

        # adaptive update
        theta = theta - eta * (m_hat / (np.sqrt(v_hat) + eps))
        # decoupled weight decay (uniform; not scaled by v_hat)
        theta = theta - eta * weight_decay * theta

        path.append(theta.copy())
    return np.array(path)

def newton_schulz_5(M_matrix, steps=5, eps=1e-7):
    # from Keller Jordan
    a, b, c = (3.4445, -4.7750, 2.0315)

    # speed up in practical
    X = M_matrix.astype(np.float32, copy=False)

    transposed = False
    if X.shape[0] &gt; X.shape[1]:
        X = X.T
        transposed = True

    # frobenius norm
    X = X / (np.linalg.norm(X) + eps)
    
    # so this is tricky but we're looking here
    # \rho (M) &amp;= aM + b(MM^T)M + c(MM^T)^2 M
    for _ in range(steps):
        A = X @ X.T
        B = b * A + c * A @ A
        X = a * X + B @ X

    if transposed:
        X = X.T
    return X


def run_muon_muonshot(theta0, eta=1e-2, beta=0.95, weight_decay=1e-2, steps=1200,
             ns_steps=5, eps=1e-7, use_nesterov=True):

    theta = np.array(theta0, float)
    if theta.ndim == 1:
        theta = theta[:, None]          # (n,) -&gt; (n,1)
    elif theta.shape[0] == 1 and theta.shape[1] &gt; 1:
        theta = theta.T                 # (1,n) -&gt; (n,1)

    def adjust_lr(A, B):
        return 0.2 * np.sqrt(float(max(A, B)))

    A, B = theta.shape

    path = [theta.copy()]
    B_momentum_buffer = np.zeros_like(theta)
    for _ in range(steps):
        g = styblinski_tang_grad(*theta)  
        B_momentum_buffer = beta * B_momentum_buffer + g
        # didn't cover nestorv but pytorch has it
        M_eff = g + beta * B_momentum_buffer if use_nesterov else B_momentum_buffer
        O = newton_schulz_5(M_eff, steps=ns_steps, eps=eps)
        # decoupled weight decay (uniform shrink)
        theta = theta - eta * (adjust_lr(A, B) * O + weight_decay * theta)
        path.append(theta.copy())

    return np.array(path)

theta_start = np.array([4.1, 4.5], dtype=float)
steps = 1200

eta_sgd = 0.02
eta_mom, beta = 0.02, 0.90
eta_adagrad = 0.40
eta_rms, rho, eps = 1e-2, 0.9, 1e-8
eta_rms_c = 1e-2

eta_adam = 1e-2
beta1, beta2 = 0.9, 0.999
eps_adam = 1e-8
wd = 1e-2     # try 0.0 (Adam-equivalent) vs 1e-3 vs 1e-2

sgd_path  = run_sgd(theta_start, eta=eta_sgd, steps=steps)
mom_path  = run_momentum(theta_start, eta=eta_mom, beta=beta, steps=steps)
ada_path  = run_adagrad(theta_start, eta=eta_adagrad, steps=steps)
rms_path  = run_rmsprop(theta_start, eta=eta_rms, rho=rho, eps=eps, steps=steps)
rmsc_path = run_rmsprop_centered(theta_start, eta=eta_rms_c, rho=rho, eps=eps, steps=steps)
adam_path = run_adam(theta_start, eta=eta_adam, beta1=beta1, beta2=beta2, eps=eps_adam, steps=steps)
adamw_path = run_adamw(theta_start, eta=eta_adam, beta1=beta1, beta2=beta2, eps=eps_adam, weight_decay=wd, steps=steps)

eta_muon = 1e-2
beta_mu = 0.95
wd_mu = 1e-2
ns_steps = 5
eps_ns = 1e-7
use_nesterov = True

muon_path_raw = run_muon_muonshot(
    theta_start,
    eta=eta_muon,
    beta=beta_mu,
    weight_decay=wd_mu,
    steps=steps,
    ns_steps=ns_steps,
    eps=eps_ns,
    use_nesterov=use_nesterov
)
muon_path = muon_path_raw.squeeze(-1) if muon_path_raw.ndim == 3 else muon_path_raw  # (T,2,1) -&gt; (T,2)

mins2d, gmin_pt, gmin_val = stationary_points_and_global_min()

x = y = np.linspace(-5, 5, 400)
X, Y = np.meshgrid(x, y)
Z = styblinski_tang_fn(X, Y)

plt.figure(figsize=(9, 8))
cs = plt.contour(X, Y, Z, levels=50, alpha=0.85)
plt.clabel(cs, inline=True, fmt="%.0f", fontsize=7)

plt.plot(sgd_path[:, 0],   sgd_path[:, 1],   '.-', lw=1.2, ms=3, label='SGD')
plt.plot(mom_path[:, 0],   mom_path[:, 1],   '.-', lw=1.2, ms=3, label=f'Momentum (β={beta})')
plt.plot(ada_path[:, 0],   ada_path[:, 1],   '.-', lw=1.2, ms=3, label='AdaGrad')
plt.plot(rms_path[:, 0],   rms_path[:, 1],   '.-', lw=1.2, ms=3, label=f'RMSProp (ρ={rho})')
plt.plot(rmsc_path[:, 0],  rmsc_path[:, 1],  '.-', lw=1.2, ms=3, label='RMSProp (centered)')
plt.plot(adam_path[:, 0],  adam_path[:, 1],  '.-', lw=1.2, ms=3, label='Adam')
plt.plot(adamw_path[:, 0], adamw_path[:, 1], '.-', lw=1.2, ms=3, label=f'AdamW (wd={wd})')

plt.plot(muon_path[:, 0],  muon_path[:, 1],  '.-', lw=1.4, ms=3, label=f'Muon (NS={ns_steps}, β={beta_mu})')
plt.scatter(muon_path[-1, 0],  muon_path[-1, 1],  s=60, label='Muon End', zorder=3)

plt.scatter(sgd_path[0, 0], sgd_path[0, 1], s=80, label='Start', zorder=3)
plt.scatter(sgd_path[-1, 0], sgd_path[-1, 1], s=60, label='SGD End', zorder=3)
plt.scatter(mom_path[-1, 0], mom_path[-1, 1], s=60, label='Momentum End', zorder=3)
plt.scatter(ada_path[-1, 0], ada_path[-1, 1], s=60, label='AdaGrad End', zorder=3)
plt.scatter(rms_path[-1, 0], rms_path[-1, 1], s=60, label='RMSProp End', zorder=3)
plt.scatter(rmsc_path[-1, 0], rmsc_path[-1, 1], s=60, label='RMSProp (centered) End', zorder=3)
plt.scatter(adam_path[-1, 0],  adam_path[-1, 1],  s=60, label='Adam End', zorder=3)
plt.scatter(adamw_path[-1, 0], adamw_path[-1, 1], s=60, label='AdamW End', zorder=3)

plt.scatter(gmin_pt[0], gmin_pt[1], marker='*', s=220, edgecolor='k',
            facecolor='gold', label=f'Global min ({gmin_pt[0]:.4f}, {gmin_pt[1]:.4f})\n f={gmin_val:.4f}', zorder=5)

vals = np.array([styblinski_tang_fn(x0, y0) for x0, y0 in mins2d])
mask = np.ones(len(mins2d), dtype=bool)
mask[np.argmin(vals)] = False
if np.any(mask):
    plt.scatter(mins2d[mask, 0], mins2d[mask, 1],
                marker='v', s=120, edgecolor='k', facecolor='white',
                label='Local minima', zorder=4)

plt.title("SGD, Momentum, AdaGrad, RMSProp (+centered), Adam, AdamW, Muon on Styblinski–Tang")
plt.xlabel("x"); plt.ylabel("y")
plt.legend(loc='upper left')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

</code></pre>
</div>

<h1 id="conclusion">Conclusion</h1>

<p>Ok! I hope you have learned something. There is obviously a ton more I could write about here, but I think getting into actually writing the code and understanding the paths that we’re taking and this very detailed stepthrough is helpful. Muon is very interesting and while it’s still pretty hotly debated if it’ll scale (despite Kimi being trained on it with 1T tokens), there will be more research that certainly goes into this area.</p>

<p>I’m hoping to dive more into the Newton-Schulz iteration and have some interesting visualizations there, but as always, this has burned more of my time that maybe I should have allocated.</p>

<p>Once again, <a href="https://github.com/johnlarkin1/understanding-muon">visualization code is here too</a> if you need.</p>]]></content><author><name>johnlarkin1</name></author><category term="Algorithms" /><category term="AI" /><category term="M.L." /><summary type="html"><![CDATA[So while I tried to mainly focus on optimizers, this post kinda splayed out some. It was my first time trying Pyodide and incorporating that logic into my blog. It was my first time using manim, which was exciting because I'm a big fan of the 3Blue1Brown channel. I also introduced quizzes (see AdamW section) for more interactivity. All of this is open source though, so if you have any questions, I'd be flattered if you emailed, but obviously you can just ask ChatGPT / Claude.]]></summary></entry><entry><title type="html">Disjoint Set Union</title><link href="https://johnlarkin1.github.io/2025/disjoint-set-union/" rel="alternate" type="text/html" title="Disjoint Set Union" /><published>2025-10-13T00:00:00+00:00</published><updated>2025-10-13T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2025/disjoint-set-union</id><content type="html" xml:base="https://johnlarkin1.github.io/2025/disjoint-set-union/"><![CDATA[<p>Recently, I did an interview. I got absolutely flamed and one of the reasons was I wasn’t familiar with a Disjoint Set Union (and I certainly couldn’t complete the C++ interview in time building this data structure naturally).</p>

<p>I figure I would go back to this and do a deeper dive because I wasn’t as familiar with it.</p>

<!--
# Table of Contents

- [Table of Contents](#table-of-contents)
- [Theory](#theory)
  - [What is a Disjoint Set?](#what-is-a-disjoint-set)
    - [Trees](#trees)
    - [Forests](#forests)
- [Example Usage](#example-usage)
  - [Kruskal's Algo for MST](#kruskals-algo-for-mst)
- [The Problem](#the-problem)
  - [Disjoint Set Operations](#disjoint-set-operations)
    - [Creating a new set](#creating-a-new-set)
    - [Find an item's representative](#find-an-items-representative)
    - [Union / merge subsets](#union--merge-subsets)
- [Visualization](#visualization)
- [Optimizations](#optimizations)
  - [Path Compression](#path-compression)
  - [Union by Rank](#union-by-rank)
- [Rust Implementation](#rust-implementation)
-->

<h1 id="theory">Theory</h1>

<h2 id="what-is-a-disjoint-set">What is a Disjoint Set?</h2>

<p>A <em>disjoint set</em> or <em>union find</em> or <em>disjoint set union</em> are all the same data structure. It is a data structure that is optimized for handling various set operations and mainly focuses on two methods: <code class="language-plaintext highlighter-rouge">union</code> and <code class="language-plaintext highlighter-rouge">find</code> (hence one of the names).</p>

<p>The whole goal is: <strong>detecting if a member is in a set, and if sets are connected in a fast and performant manner</strong>.</p>

<p>So we’ll mainly have a target set. We’re representing each subset as an inverted tree (i.e. all the child nodes are pointing back to the root).</p>

<h3 id="trees">Trees</h3>

<p>As a reminder, trees are a specific form of a graph where:</p>

<ul>
  <li>undirected</li>
  <li>at most 1 path between any 2 nodes</li>
  <li>acyclic</li>
</ul>

<p>Two types: <strong>out-tree</strong> and <strong>in-tree</strong>.</p>

<p><img src="/images/disjoint-union-set/tree-types.png" alt="tree-types" class="center-small lightbox-image" /></p>

<p>Out-trees are probably the most common, but we’re going to be focusing on an in-tree.</p>

<h3 id="forests">Forests</h3>

<p><strong>A forest is a collection of trees</strong>. It’s an undirected acyclic graph, where each connected component is a tree. It’s a disjoint union of trees.</p>

<h1 id="example-usage">Example Usage</h1>

<h2 id="kruskals-algo-for-mst">Kruskal’s Algo for MST</h2>

<p>Kruskal’s algorithm is a way of detecting a <a href="https://en.wikipedia.org/wiki/Minimum_spanning_tree">minimum-spanning-tree</a>. In a very basic phrasing, a minimum spanning tree is a subset of the edges of a connected, undirected graph that connects all the nodes. Basically, we want one “connected component” (that normally also minimizes cost).</p>

<p>Imagine building a road and we’re trying to build a road that hits all of our target cities but in the cheapest way (might not be best).</p>

<p>Kruskal’s is basically:</p>

<ol>
  <li>Sort edges (by weight)</li>
  <li>Pick cheapest edge (if no cycle created)</li>
  <li>Continue while MST is not complete</li>
</ol>

<p>This greedy algorithm utilizes DSUs when we need to see if we are going to have a cycle (this is done by a disjoint set .find call).</p>

<p>Basically, in very lightweight pseudocode:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">edge</span> <span class="ow">in</span> <span class="n">sorted_edges</span><span class="p">:</span>
    <span class="k">if</span> <span class="n">find</span><span class="p">(</span><span class="n">edge</span><span class="p">.</span><span class="n">u</span><span class="p">)</span> <span class="o">!=</span> <span class="n">find</span><span class="p">(</span><span class="n">edge</span><span class="p">.</span><span class="n">v</span><span class="p">):</span>
        <span class="n">mst</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">edge</span><span class="p">)</span>
        <span class="n">union</span><span class="p">(</span><span class="n">edge</span><span class="p">.</span><span class="n">u</span><span class="p">,</span> <span class="n">edge</span><span class="p">.</span><span class="n">v</span><span class="p">)</span>
</code></pre></div></div>

<h1 id="the-problem">The Problem</h1>

<p>Again, how do we check subset membership between $x$ and $y$ fast? The answer is obviously DSUs.</p>

<h2 id="disjoint-set-operations">Disjoint Set Operations</h2>

<p>And so with that, this data structure is going to have:</p>

<ol>
  <li>Create a new set</li>
  <li>Find an item’s set representative (basically like the root of the subset tree)</li>
  <li>Union, merge subsets</li>
</ol>

<h3 id="creating-a-new-set">Creating a new set</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">DisjointSet</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">parent</span> <span class="o">=</span> <span class="p">{}</span>

  <span class="k">def</span> <span class="nf">make_set</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span>
</code></pre></div></div>

<h3 id="find-an-items-representative">Find an item’s representative</h3>

<p>How can we rapidly check if two targets are in the same subset? This is the whole point of the data structure basically. This is where we climb up the tree. This allows for very fast access.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">DisjointSet</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">find</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
    <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">==</span> <span class="n">x</span><span class="p">:</span>
      <span class="k">return</span> <span class="n">x</span>
    <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">])</span>
</code></pre></div></div>

<h3 id="union--merge-subsets">Union / merge subsets</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">DisjointSet</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">union</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
    <span class="n">root_x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">find</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="n">root_y</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">find</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>

    <span class="c1"># only merge if x and y are not in the same set
</span>    <span class="k">if</span> <span class="n">root_x</span> <span class="o">!=</span> <span class="n">root_y</span><span class="p">:</span>
      <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">root_y</span><span class="p">]</span> <span class="o">=</span> <span class="n">root_x</span>

</code></pre></div></div>

<p>Basically just stitching these subsets together. We just reset the parent for $y$ or for $x$ and that’s how we get around it.</p>

<h1 id="visualization">Visualization</h1>

<p>I thought about having Claude spin up a visualizer, but didn’t seem worth it. There are lots of good resources. The best I’ve seen is <a href="https://visualgo.net/en/ufds">here at visualgo</a>. There’s the visualizations / slides on the DSU <a href="https://visualgo.net/en/ufds?slide=1">here</a>.</p>

<h1 id="optimizations">Optimizations</h1>

<p>There’s two big optimizations that people generally hammer for DSUs. They are <strong>path compression</strong> and <strong>union by rank</strong>.</p>

<h2 id="path-compression">Path Compression</h2>

<p>So this is a neat trick that is invoked on the <code class="language-plaintext highlighter-rouge">find</code> call. When we’re climbing back up the tree to roots, we “flatten” the tree along the way. We make each visited node point directly to the root. That way, the next time we do a <code class="language-plaintext highlighter-rouge">find</code>, it’ll take $\mathcal{O}(1)$ time. In pseudocode,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># original
</span><span class="k">def</span> <span class="nf">find</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
  <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="o">==</span> <span class="n">x</span><span class="p">:</span>
    <span class="k">return</span> <span class="n">x</span>
  <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">])</span>

<span class="c1"># with path compression
</span><span class="k">def</span> <span class="nf">find</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
  <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="o">!=</span> <span class="n">x</span><span class="p">:</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">])</span>
  <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">]</span>
</code></pre></div></div>

<p>This helps keep our tree flat and wide. So for example <code class="language-plaintext highlighter-rouge">find(5)</code> would potentially take 5 recursive calls if we had <code class="language-plaintext highlighter-rouge">5 -&gt; 4 -&gt; 3 -&gt; 2 -&gt; 1</code>, but the next <code class="language-plaintext highlighter-rouge">find(4)</code> would be $\mathcal{O}(1)$.</p>

<h2 id="union-by-rank">Union by Rank</h2>

<p>This is another cool trick. When we <code class="language-plaintext highlighter-rouge">union</code>, we attach the smaller tree under the larger one. That once again, keeps the trees shallow so that our <code class="language-plaintext highlighter-rouge">find</code> operations are fast.</p>

<p>To do this, we keep track of <code class="language-plaintext highlighter-rouge">rank</code> - a measure of the tree’s height. When performing <code class="language-plaintext highlighter-rouge">union</code>, we compare ranks and attach smaller rank under larger rank.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">DisjointSet</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">parent</span> <span class="o">=</span> <span class="p">{}</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">rank</span> <span class="o">=</span> <span class="p">{}</span>

  <span class="k">def</span> <span class="nf">make_set</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>

  <span class="c1"># find...
</span>  <span class="k">def</span> <span class="nf">union</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
    <span class="n">root_x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">find</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="n">root_y</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">find</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">root_x</span> <span class="o">=</span> <span class="n">root_y</span><span class="p">:</span>
      <span class="k">return</span>

    <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">[</span><span class="n">root_y</span><span class="p">]:</span>
      <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">=</span> <span class="n">root_y</span>
    <span class="k">elif</span> <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">[</span><span class="n">root_y</span><span class="p">]:</span>
      <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">root_y</span><span class="p">]</span> <span class="o">=</span> <span class="n">root_x</span>
    <span class="k">else</span><span class="p">:</span>
      <span class="bp">self</span><span class="p">.</span><span class="n">parent</span><span class="p">[</span><span class="n">root_y</span><span class="p">]</span> <span class="o">=</span> <span class="n">root_x</span>
      <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span>
</code></pre></div></div>

<p>So a super interesting note here though is that according to most references, rank is just used to make merging decisions. It doesn’t <strong>need</strong> to be accurate. I thought it was weird at first that we don’t bump the rank if we hit the <code class="language-plaintext highlighter-rouge">if</code> / <code class="language-plaintext highlighter-rouge">elif</code> branch…. but it is because we’re directly attaching it to the root so we know the new rank is still just going to be the bigger one. This is a very important point. Rank is almost loosely tracked but it’s a rough heuristic for the upper bound of the height.</p>

<h1 id="rust-implementation">Rust Implementation</h1>

<p>I’ve been trying to learn more Rust given it’s everyone’s favorite programming language. So I wanted to build this up again in Rust instead of Python for learning and better management. The code is pretty readable and clean (somewhat similar to Python) so yeah I won’t describe too much else. I also put the size of each set for debugging as well.</p>

<div class="language-rust highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nd">#[derive(Debug,</span> <span class="nd">Clone)]</span>
<span class="k">pub</span> <span class="k">struct</span> <span class="n">DisjointSetUnion</span> <span class="p">{</span>
    <span class="n">parent</span><span class="p">:</span> <span class="nb">Vec</span><span class="o">&lt;</span><span class="nb">usize</span><span class="o">&gt;</span><span class="p">,</span>
    <span class="n">rank</span><span class="p">:</span> <span class="nb">Vec</span><span class="o">&lt;</span><span class="nb">usize</span><span class="o">&gt;</span><span class="p">,</span>
    <span class="n">size</span><span class="p">:</span> <span class="nb">Vec</span><span class="o">&lt;</span><span class="nb">usize</span><span class="o">&gt;</span><span class="p">,</span>
    <span class="n">sets</span><span class="p">:</span> <span class="nb">usize</span><span class="p">,</span>
<span class="p">}</span>

<span class="k">impl</span> <span class="n">DisjointSetUnion</span> <span class="p">{</span>
    <span class="cd">/// create a new disjoint set union with n elements</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">new</span><span class="p">(</span><span class="n">n</span><span class="p">:</span> <span class="nb">usize</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="k">Self</span> <span class="p">{</span>
        <span class="k">Self</span> <span class="p">{</span>
            <span class="n">parent</span><span class="p">:</span> <span class="p">(</span><span class="mi">0</span><span class="o">..</span><span class="n">n</span><span class="p">)</span><span class="nf">.collect</span><span class="p">(),</span>
            <span class="n">rank</span><span class="p">:</span> <span class="nd">vec!</span><span class="p">[</span><span class="mi">0</span><span class="p">;</span> <span class="n">n</span><span class="p">],</span>
            <span class="n">size</span><span class="p">:</span> <span class="nd">vec!</span><span class="p">[</span><span class="mi">1</span><span class="p">;</span> <span class="n">n</span><span class="p">],</span>
            <span class="n">sets</span><span class="p">:</span> <span class="n">n</span><span class="p">,</span>
        <span class="p">}</span>
    <span class="p">}</span>

    <span class="cd">/// number of disjoint sets</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">num_disjoint_sets</span><span class="p">(</span><span class="o">&amp;</span><span class="k">self</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="nb">usize</span> <span class="p">{</span>
        <span class="k">self</span><span class="py">.sets</span>
    <span class="p">}</span>

    <span class="cd">/// find the root of the set containing x</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">find</span><span class="p">(</span><span class="o">&amp;</span><span class="k">mut</span> <span class="k">self</span><span class="p">,</span> <span class="k">mut</span> <span class="n">x</span><span class="p">:</span> <span class="nb">usize</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="nb">usize</span> <span class="p">{</span>
        <span class="k">while</span> <span class="k">self</span><span class="py">.parent</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="o">!=</span> <span class="n">x</span> <span class="p">{</span>
            <span class="k">let</span> <span class="n">parent</span> <span class="o">=</span> <span class="k">self</span><span class="py">.parent</span><span class="p">[</span><span class="n">x</span><span class="p">];</span>
            <span class="k">self</span><span class="py">.parent</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="k">self</span><span class="py">.parent</span><span class="p">[</span><span class="n">parent</span><span class="p">];</span>
            <span class="n">x</span> <span class="o">=</span> <span class="k">self</span><span class="py">.parent</span><span class="p">[</span><span class="n">parent</span><span class="p">];</span>
        <span class="p">}</span>
        <span class="n">x</span>
    <span class="p">}</span>

    <span class="cd">/// union the sets containing x and y</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="k">union</span><span class="p">(</span><span class="o">&amp;</span><span class="k">mut</span> <span class="k">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">usize</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="nb">usize</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="nb">usize</span> <span class="p">{</span>
        <span class="k">let</span> <span class="n">root_x</span> <span class="o">=</span> <span class="k">self</span><span class="nf">.find</span><span class="p">(</span><span class="n">x</span><span class="p">);</span>
        <span class="k">let</span> <span class="n">root_y</span> <span class="o">=</span> <span class="k">self</span><span class="nf">.find</span><span class="p">(</span><span class="n">y</span><span class="p">);</span>

        <span class="c1">// same component</span>
        <span class="k">if</span> <span class="n">root_x</span> <span class="o">==</span> <span class="n">root_y</span> <span class="p">{</span>
            <span class="k">return</span> <span class="n">root_x</span><span class="p">;</span>
        <span class="p">}</span>

        <span class="c1">// we want smaller rank tree under higher rank tree</span>
        <span class="c1">// to try and keep things as flat as possible</span>
        <span class="k">if</span> <span class="k">self</span><span class="py">.rank</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">&lt;</span> <span class="k">self</span><span class="py">.rank</span><span class="p">[</span><span class="n">root_y</span><span class="p">]</span> <span class="p">{</span>
            <span class="k">self</span><span class="py">.parent</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">=</span> <span class="n">root_y</span><span class="p">;</span>
            <span class="k">self</span><span class="py">.size</span><span class="p">[</span><span class="n">root_y</span><span class="p">]</span> <span class="o">+=</span> <span class="k">self</span><span class="py">.size</span><span class="p">[</span><span class="n">root_x</span><span class="p">];</span>
            <span class="k">self</span><span class="py">.sets</span> <span class="o">-=</span> <span class="mi">1</span><span class="p">;</span>
            <span class="k">return</span> <span class="n">root_y</span><span class="p">;</span>
        <span class="p">}</span> <span class="k">else</span> <span class="k">if</span> <span class="k">self</span><span class="py">.rank</span><span class="p">[</span><span class="n">root_y</span><span class="p">]</span> <span class="o">&gt;</span> <span class="k">self</span><span class="py">.rank</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="p">{</span>
            <span class="k">self</span><span class="py">.parent</span><span class="p">[</span><span class="n">root_y</span><span class="p">]</span> <span class="o">=</span> <span class="n">root_x</span><span class="p">;</span>
            <span class="k">self</span><span class="py">.size</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">+=</span> <span class="k">self</span><span class="py">.size</span><span class="p">[</span><span class="n">root_y</span><span class="p">];</span>
            <span class="k">self</span><span class="py">.sets</span> <span class="o">-=</span> <span class="mi">1</span><span class="p">;</span>
            <span class="k">return</span> <span class="n">root_x</span><span class="p">;</span>
        <span class="p">}</span>

        <span class="c1">//otherwise, they're equal</span>
        <span class="k">self</span><span class="py">.parent</span><span class="p">[</span><span class="n">root_y</span><span class="p">]</span> <span class="o">=</span> <span class="n">root_x</span><span class="p">;</span>
        <span class="k">self</span><span class="py">.rank</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span><span class="p">;</span>
        <span class="k">self</span><span class="py">.size</span><span class="p">[</span><span class="n">root_x</span><span class="p">]</span> <span class="o">+=</span> <span class="k">self</span><span class="py">.size</span><span class="p">[</span><span class="n">root_y</span><span class="p">];</span>
        <span class="k">self</span><span class="py">.sets</span> <span class="o">-=</span> <span class="mi">1</span><span class="p">;</span>
        <span class="k">return</span> <span class="n">root_x</span><span class="p">;</span>
    <span class="p">}</span>

    <span class="cd">/// check if x and y are in the same set</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">connected</span><span class="p">(</span><span class="o">&amp;</span><span class="k">mut</span> <span class="k">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">usize</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="nb">usize</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="nb">bool</span> <span class="p">{</span>
        <span class="k">self</span><span class="nf">.find</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">==</span> <span class="k">self</span><span class="nf">.find</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
    <span class="p">}</span>

    <span class="cd">/// size of the set containing x</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">size_of</span><span class="p">(</span><span class="o">&amp;</span><span class="k">mut</span> <span class="k">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">usize</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="nb">usize</span> <span class="p">{</span>
        <span class="c1">// ugh perils of Rust</span>
        <span class="c1">// i wanted to do: self.size[self.find(x)]</span>
        <span class="c1">// but because the borrow checker we cannot</span>
        <span class="c1">// indexing into self.size immutably borrows self.size</span>
        <span class="c1">// and thus self for the duration of the indexing expression</span>
        <span class="c1">// as a result, when we do self.find we need a MUTABLE borrow</span>
        <span class="c1">// of self - so this conflict causes the break</span>
        <span class="k">let</span> <span class="n">root</span> <span class="o">=</span> <span class="k">self</span><span class="nf">.find</span><span class="p">(</span><span class="n">x</span><span class="p">);</span>
        <span class="k">self</span><span class="py">.size</span><span class="p">[</span><span class="n">root</span><span class="p">]</span>
    <span class="p">}</span>

    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">rank_of</span><span class="p">(</span><span class="o">&amp;</span><span class="k">mut</span> <span class="k">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="nb">usize</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="nb">usize</span> <span class="p">{</span>
        <span class="k">let</span> <span class="n">root</span> <span class="o">=</span> <span class="k">self</span><span class="nf">.find</span><span class="p">(</span><span class="n">x</span><span class="p">);</span>
        <span class="k">self</span><span class="py">.rank</span><span class="p">[</span><span class="n">root</span><span class="p">]</span>
    <span class="p">}</span>

    <span class="c1">// Claude added these</span>
    <span class="cd">/// (solely for viz) - reference to the parent array</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">parent</span><span class="p">(</span><span class="o">&amp;</span><span class="k">self</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="o">&amp;</span><span class="p">[</span><span class="nb">usize</span><span class="p">]</span> <span class="p">{</span>
        <span class="o">&amp;</span><span class="k">self</span><span class="py">.parent</span>
    <span class="p">}</span>

    <span class="cd">/// (solely for viz) - reference to the rank array</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">rank</span><span class="p">(</span><span class="o">&amp;</span><span class="k">self</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="o">&amp;</span><span class="p">[</span><span class="nb">usize</span><span class="p">]</span> <span class="p">{</span>
        <span class="o">&amp;</span><span class="k">self</span><span class="py">.rank</span>
    <span class="p">}</span>

    <span class="cd">/// (solely for viz) - reference to the size array</span>
    <span class="k">pub</span> <span class="k">fn</span> <span class="nf">size</span><span class="p">(</span><span class="o">&amp;</span><span class="k">self</span><span class="p">)</span> <span class="k">-&gt;</span> <span class="o">&amp;</span><span class="p">[</span><span class="nb">usize</span><span class="p">]</span> <span class="p">{</span>
        <span class="o">&amp;</span><span class="k">self</span><span class="py">.size</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The visualization code was entirely autogenerated by Claude and then I used <a href="https://github.com/charmbracelet/vhs"><code class="language-plaintext highlighter-rouge">vhs</code></a> to create the animation. Here is the demo:</p>

<p><img src="/images/disjoint-union-set/demo.gif" alt="demo-gif" class="basic-center lightbox-image" /></p>

<p><a href="https://github.com/johnlarkin1/disjoint-set-union">Here is the code</a> if you want to check it out. I’m guessing most people will just deep dive with ChatGPT which is ok too!</p>]]></content><author><name>johnlarkin1</name></author><category term="Algorithms" /><category term="Rust" /><summary type="html"><![CDATA[Recently, I did an interview. I got absolutely flamed and one of the reasons was I wasn’t familiar with a Disjoint Set Union (and I certainly couldn’t complete the C++ interview in time building this data structure naturally).]]></summary></entry><entry><title type="html">Teaching a Computer How to Write</title><link href="https://johnlarkin1.github.io/2025/teaching-a-computer-to-write/" rel="alternate" type="text/html" title="Teaching a Computer How to Write" /><published>2025-10-01T00:00:00+00:00</published><updated>2025-10-01T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2025/teaching-a-computer-to-write</id><content type="html" xml:base="https://johnlarkin1.github.io/2025/teaching-a-computer-to-write/"><![CDATA[<h1 id="️-motivating-visualizations">✍️ Motivating Visualizations</h1>

<p>Today, we’re going to learn how to teach a computer to write. I don’t mean generating text (which would have been probably a better thing to study in college), I mean learning to write like a human learns how to write with a pen and paper. My results (eventually) were pretty good, here are some motivating visualizations.</p>

<p>Let’s look at one. My family used to have this hung over our kitchen sink when I was a kid. I ate breakfast every day looking at it.</p>

<div class="featured-quote">
  <p class="featured-quote__text">The heart has its reasons which reason knows nothing of</p>
  <p class="featured-quote__attribution">
    <span class="featured-quote__author">Blaise Pascal</span>, <span class="featured-quote__source">"Pensées"</span>
  </p>
</div>

<p><img src="/images/generative-handwriting/synth_outputs/heart_has_its_reason/writing_cleansed.png" alt="heart-writing-cleansed" class="basic-center lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/heart_has_its_reason/writing.gif" alt="heart-writing-gif" class="basic-center lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/heart_has_its_reason/mdn_aggregate.png" alt="heart-mdn-aggregate" class="basic-center lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/heart_has_its_reason/attention_combined.gif" alt="heart-attention-gif" class="basic-center lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/heart_has_its_reason/mdn.png" alt="heart-mdn" class="basic-center lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/heart_has_its_reason/sampling.gif" alt="heart-sampling-gif" class="basic-center lightbox-image" /></p>

<p>Again, I’d recommend jumping down to here: <a href="#synthesis-model-sampling">Synthesis Model Sampling</a>. Arguably, the best part of this post. I’ll discuss what all these visualizations mean in detail.</p>

<hr />

<p><br /></p>

<div class="markdown-alert markdown-alert-disclaimer">

<p>This is a relatively long post! I would encourage you if you're trying to learn from 0 -&gt; 1 to read the whole thing, but feel free to jump around as you so wish. I would say there's three main portions: concept, theory, and code.
</p>

<p>My purpose here was to build up from the basics and really understand the flow. I provide quite a couple of models so we can see the progression from a simple neural net to a basic LSTM to Peephole LSTM to a stacked cascade of Peephole LSTMs to Mixture Density Networks to Attention Mechanism to Attention RNN to the Handwriting Prediction Network to finally throwing it all together to the full Handwriting Synthesis Network that Graves originally wrote about.</p>

<p>There's other things that maybe I'll discuss in the future like the need to pickle JAX models because if they're XLA compatible then you can't run inference on your CPU and issues like that. Another thing I didn't discuss really was temperature and bias for sampling. I also (sadly) didn't cover priming. However, I spent far more time on this than I should have. If you have any questions - as always - feel free to reach out if curious. </p>

<p>Enjoy!</p>

</div>

<div class="markdown-alert markdown-alert-note">
<p>One thing that I would highly recommend - if you're interested in the theory of LSTMs and why sigmoid vs tanh activations were chosen, I would really encourage reading Chris Olah's <b><a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">Understanding LSTMs</a></b> blog post. It does a fantastic job.</p>
</div>

<p><br /></p>

<hr />

<!--
# Table of Contents

- [✍️ Motivating Visualizations](#️-motivating-visualizations)
- [Table of Contents](#table-of-contents)
- [🥅 Motivation](#-motivation)
- [👨‍🏫 History](#-history)
  - [Tom and My Engineering Thesis](#tom-and-my-engineering-thesis)
- [🙏 Acknowledgements](#-acknowledgements)
- [📝 Concept](#-concept)
- [👾 Software](#-software)
  - [Tensorflow](#tensorflow)
    - [Programming Paradigm](#programming-paradigm)
    - [Versions - How the times have changed](#versions---how-the-times-have-changed)
    - [Tensorboard](#tensorboard)
  - [Pytorch](#pytorch)
  - [JAX](#jax)
    - [Programming Paradigm](#programming-paradigm-1)
- [📊 Data](#-data)
- [🧠 Base Neural Network Theory](#-base-neural-network-theory)
  - [Lions, Bears, and Many Neural Networks, oh my](#lions-bears-and-many-neural-networks-oh-my)
  - [Basic Neural Network](#basic-neural-network)
    - [Hyper Parameters](#hyper-parameters)
  - [Feedforward Neural Network](#feedforward-neural-network)
    - [Backpropagation](#backpropagation)
  - [Recurrent Neural Network](#recurrent-neural-network)
  - [Long Short Term Memory Networks](#long-short-term-memory-networks)
    - [Understanding the LLM Structure](#understanding-the-llm-structure)
- [🧬 Concepts to Code](#-concepts-to-code)
  - [LSTM Cell with Peephole Connections](#lstm-cell-with-peephole-connections)
    - [Theory](#theory)
    - [Code](#code)
  - [Gaussian Mixture Models](#gaussian-mixture-models)
    - [Theory](#theory-1)
    - [Code](#code-1)
  - [Mixture Density Networks](#mixture-density-networks)
    - [Theory](#theory-2)
    - [Code](#code-2)
  - [Mixture Density Loss](#mixture-density-loss)
    - [Theory](#theory-3)
    - [Code](#code-3)
  - [Attention Mechanism](#attention-mechanism)
    - [Theory](#theory-4)
    - [Code](#code-4)
  - [Stacked LSTM](#stacked-lstm)
    - [Theory](#theory-5)
    - [Code](#code-5)
  - [Final Result](#final-result)
- [🏋️ Training Results](#️-training-results)
  - [Vast AI GPU Enabled Execution](#vast-ai-gpu-enabled-execution)
    - [Problem #1 - Gradient Explosion Problem](#problem-1---gradient-explosion-problem)
    - [Problem #2 - OOM Galore](#problem-2---oom-galore)
    - [Sanity Check - Validating Model Dimensions (with AI... so somewhat)](#sanity-check---validating-model-dimensions-with-ai-so-somewhat)
- [✍️ Visualizations](#️-visualizations)
  - [Learning with Dummy Data](#learning-with-dummy-data)
  - [Synthesis Model Sampling](#synthesis-model-sampling)
    - [Heart has its reasons](#heart-has-its-reasons)
    - [Loved and lost](#loved-and-lost)
    - [It has to be symphonic](#it-has-to-be-symphonic)
    - [Is a model a lie?](#is-a-model-a-lie)
    - [Fish folly](#fish-folly)
- [Conclusion](#conclusion)
-->

<h1 id="-motivation">🥅 Motivation</h1>

<p>This motivation is clear - this is something that I have wanted to find the time to do right since college. My engineering thesis was on this Graves paper. My senior year, I worked with my good friend (also he’s a brilliant engineer) <a href="https://www.linkedin.com/in/tom-wilmots-030781a6/">Tom Wilmots</a> to understand and dive into this paper.</p>

<p>I’m going to pull pieces of that, but the time has changed, and I wanted to revisit some of the work we did, hopefully clean it up, and finally put a nail in this (so my girlfriend / friends don’t have to keep hearing about it).</p>

<h1 id="‍-history">👨‍🏫 History</h1>

<p><a href="https://www.linkedin.com/in/tom-wilmots-030781a6/">Tom</a> and I were very interested in the concept of teaching a computer how to write in college. There is a very famous <a href="https://arxiv.org/abs/1308.0850">paper</a> that was published around 2013 from Canadian computer scientist <a href="https://en.wikipedia.org/wiki/Alex_Graves_(computer_scientist)">Alex Graves</a>, titled <a href="https://arxiv.org/abs/1308.0850"><em>Generating Sequences With Recurrent Neural Networks</em></a>. At <a href="http://www.swarthmore.edu/">Swarthmore</a>, you have to do Engineering thesis, called <a href="https://www.swarthmore.edu/engineering/e90-senior-design-project">E90s</a>. It’s basically a year (although I’d argue it’s more of a semester when it all shakes out) long project focused on doing a piece of work you’re proud of.</p>

<h2 id="tom-and-my-engineering-thesis">Tom and My Engineering Thesis</h2>

<p>For the actual paper that we wrote, check it out here:</p>

<div style="text-align: center;">
    <embed src="/pdfs/Handwriting-Synthesis-E90.pdf" width="500" height="375" type="application/pdf" />
</div>

<p>You can also check it out here: <a href="https://arxiv.org/abs/1308.0850"><strong>Application of Neural Networks with Handwriting Samples</strong></a>.</p>

<h1 id="-acknowledgements">🙏 Acknowledgements</h1>

<p>Before I dive in, I do want to make some acknowledgements just given this is a partial resumption of work.</p>

<ul>
  <li><strong><a href="https://www.linkedin.com/in/tom-wilmots-030781a6/">Tom Wilmots</a></strong> - One of the brightest and best engineers I’ve worked with. He was an Engineering and Economics double major from <a href="http://www.swarthmore.edu/">Swarthmore</a>. Pretty sure I would have failed my E90 thesis without him.</li>
  <li><strong><a href="https://mzucker.github.io/">Matt Zucker</a></strong> - One of my role models and constant inspirations, Matt was kind enough to be Tom and my academic advisor for this final engineering project. He is the best professor I’ve come across.</li>
  <li><strong><a href="https://en.wikipedia.org/wiki/Alex_Graves_(computer_scientist)">Alex Graves</a></strong> - A professor that both Tom and I had the pleasure of working with. <strong>He responded to our emails, which I’m still very appreciative of</strong>. You can see more about his work at the University of Toronto <a href="https://www.cs.toronto.edu/~graves/">here</a>). He is the author of <a href="https://arxiv.org/abs/1308.0850">this paper</a>, which Matt found for us and pretty much was the basis of our project. He’s also the creator of the <a href="https://arxiv.org/abs/1410.5401">Neural Turing Machine</a>, which peaked my interest after having taken <a href="https://www.cs.swarthmore.edu/~fontes/cs46/17s/index.php">Theory of Computation</a>, with my other fantastic professor <a href="https://www.cs.swarthmore.edu/~fontes/">Lila Fontes</a> and learning about <a href="https://en.wikipedia.org/wiki/Turing_machine">Turing machines</a>.</li>
  <li><strong><a href="https://www.linkedin.com/in/david-ha-168a012/">David Ha</a></strong> - Another brilliant scientist who we had the privilege of corresponding with. Check out his blog <a href="http://blog.otoro.net/">here</a>. It’s beautiful. He also is very prolific on <a href="https://arxiv.org/">ArXiv</a> which is always cool to see.</li>
</ul>

<h1 id="-concept">📝 Concept</h1>

<p>This section is going to be for non-technical people to understand what we were trying to do. It’s relatively simple. At a very high level, <strong>we are trying to teach a computer how to generate human looking handwriting</strong>. To do that, we are going to train a neural network. We are going to use a public dataset, called <a href="https://fki.tic.heia-fr.ch/databases/iam-on-line-handwriting-database">IAM Online Handwriting Database</a>. This dataset had a ton of people write on a tablet where the data was being recorded. It collected basically sets of <code class="language-plaintext highlighter-rouge">Stroke</code> data, which were tuples of $(x, y, t)$, where $(x, y)$ are the coordinates on the tablet, and $t$ is the timestamp. We’ll use this data to train a model so that across all of the participants we have this blended approach of how to write like a human.</p>

<h1 id="-software">👾 Software</h1>

<p>In college, we decided between <a href="https://www.tensorflow.org/">Tensorflow</a> and <a href="https://pytorch.org/">Pytorch</a>. In college, we used Tensorflow. However, given the times, I wanted to still resume our tensorflow approach with updated designs, but I also wanted to try and use <a href="https://github.com/jax-ml/jax">JAX</a>. <a href="https://github.com/jax-ml/jax">JAX</a> is… newer. But it’s gotten some hype online and I think there’s a solid amount of adoption across the bigger AI labs now. In my opinion, Tensorflow is dying, Pytorch is the new status quo, and JAX is the new kid on the block. However, I’m not an ML researcher clearing millions of dollars. So grain of salt. This <a href="https://neel04.github.io/my-website/blog/pytorch_rant/">clickbaity article</a> which declares <em>“Pytorch is dead. Long live JAX”</em> got a ton of flak online, but regardless… it piqued my interest enough to try it here.</p>

<p>I’ll cover all three here and yeah probably dive deepest into tensorflow… but feel free to skip this section.</p>

<h2 id="tensorflow">Tensorflow</h2>

<h3 id="programming-paradigm">Programming Paradigm</h3>

<p>Tensorflow has this interesting programming paradigm, where you are more or less creating a graph. You define <code class="language-plaintext highlighter-rouge">Tensor</code>s and then when you run your dependency graph, those things are actually translated.</p>

<p>I have this quote from the Tensorflow API:</p>

<blockquote>
  <p>There’s only two things that go into Tensorflow.</p>

  <ol>
    <li>Building your computational dependency graph.</li>
    <li>Running your dependency graph.</li>
  </ol>
</blockquote>

<p>This was the old way, but now that’s not totally true. Apparently, Tensorflow 2.0 helped out a lot with the computational model and the notion of eagerly executing, rather than building the graph and then having everything run at once.</p>

<h3 id="versions---how-the-times-have-changed">Versions - How the times have changed</h3>

<p>So - another fun fact - when we were doing this in college, we were on <strong>tensorflow version v0.11</strong>!!! They hadn’t even released a major version. Now, I’m doing this on Tensorflow <strong>2.16.1</strong>. So the times have definitely changed.</p>

<p><img src="/images/generative-handwriting/being_old.jpeg" alt="being-old" class="center-shrink" /></p>

<p>Definitely haven’t been able to keep up with all those changes.</p>

<h3 id="tensorboard">Tensorboard</h3>

<p>Another cool thing about <a href="https://www.tensorflow.org/">Tensorflow</a> that should be mentioned is the ability to utilize the <a href="https://www.tensorflow.org/get_started/summaries_and_tensorboard">Tensorboard</a>. This is a visualization suite that creates a local website where you can interactively and with a live stream visualize your dependency graph. You can do cool things like confirm that the error is actually decreasing over the epochs.</p>

<p>We used this a bit more in college. I didn’t get a real chance to dive into the updates made from this.</p>

<h2 id="pytorch">Pytorch</h2>

<p><a href="https://pytorch.org/">PyTorch</a> is now basically the defacto standard for most serious research labs and AI shops. To me, it seems like things are still somewhat ported to Tensorflow for production, but I’m not totally sure about convention.</p>

<p>Pytorch seems to thread the line between Tensorflow and JAX. Functions don’t necessarily need to be pure to be utilized. You can loop and mutate state in a <code class="language-plaintext highlighter-rouge">nn.Module</code> just fine.</p>

<p>I won’t be covering pytorch but I certainly will come back around to it in later projects.</p>

<h2 id="jax">JAX</h2>

<p>The new up and comer! I think it’s largely a crowd favorite for it’s speed. Documentation is obviously worse. One Redditor summarized it nicely:</p>

<blockquote class="reddit-embed-bq" data-embed-theme="dark" data-embed-height="396"><a href="https://www.reddit.com/r/MachineLearning/comments/1b08qv6/comment/ks6u1e2/">Comment</a><br /> by<a href="https://www.reddit.com/user/Few-Pomegranate4369/">u/Few-Pomegranate4369</a> from discussion<a href="https://www.reddit.com/r/MachineLearning/comments/1b08qv6/d_is_it_worth_switching_to_jax_from/"></a><br /> in<a href="https://www.reddit.com/r/MachineLearning/">MachineLearning</a></blockquote>
<script async="" src="https://embed.reddit.com/widgets.js" charset="UTF-8"></script>

<p><br /></p>

<p>I hit numerous roadblocks where functions weren’t actually pure and then the JIT compile portion basically failed on startup.</p>

<h3 id="programming-paradigm-1">Programming Paradigm</h3>

<p>JAX and Pytorch are definitely the most like traditional Python imperative flow. The restriction on JAX is largely around pure functions. Tensorflow is also gradually moving away from the compile your graph and then run it paradigm.</p>

<h1 id="-data">📊 Data</h1>

<p>We’re using the <a href="https://fki.tic.heia-fr.ch/databases/iam-on-line-handwriting-database">IAM Online Handwriting Database</a>. Specifically, I’m looking at <code class="language-plaintext highlighter-rouge">data/lineStrokes-all.tar.gz</code>, which is XML data that looks like this:</p>

<p><img src="/images/generative-handwriting/example_data.png" alt="data" class="center-super-shrink lightbox-image" /></p>

<div class="image-caption">Example Handwriting IAM Data</div>
<p><br /></p>

<p>There’s also this note:</p>

<blockquote>
  <p>The database is divided into 4 parts, a training set, a first validation set, a second validation set and a final test set. The training set may be used for training the recognition system, while the two validation sets may be used for optimizing some meta-parameters. The final test set must be left unseen until the final test is performed. Note that you are allowed to use also other data for training etc, but report all the changes when you publish your experimental results and let the test set unchanged (It contains 3859 sequences, i.e. XML-files - one for each text line).</p>
</blockquote>

<p>So that determines our training set, validation set, second validation set, and a final test set.</p>

<h1 id="-base-neural-network-theory">🧠 Base Neural Network Theory</h1>

<p>I am not going to dive into details as much as we did for our senior E90 thesis, but I do want to cover a couple of the building blocks.</p>

<h2 id="lions-bears-and-many-neural-networks-oh-my">Lions, Bears, and Many Neural Networks, oh my</h2>

<p>I would highly encourage you to check out this website: <a href="https://www.asimovinstitute.org/neural-network-zoo/">https://www.asimovinstitute.org/neural-network-zoo/</a>. I remember seeing it in college when working on this thesis and was stunned. If you’re too lazy to click, check out the fun picture:</p>

<p><img src="/images/generative-handwriting/neural_network_zoo.png" alt="neural-network-zoo" class="center-super-shrink lightbox-image" /></p>

<div class="image-caption">Courtesy of <a href="https://www.asimovinstitute.org/neural-network-zoo/">Asimov Institute</a></div>
<p><br /></p>

<p>We’re going to explore some of the zoo in a bit more detail, specifically, focusing on <a href="https://en.wikipedia.org/wiki/Long_short-term_memory">LSTMs</a>.</p>

<h2 id="basic-neural-network">Basic Neural Network</h2>

<p><img src="https://aiml.com/wp-content/uploads/2023/08/Illustration-of-a-neural-net-1024x594.png" alt="basic-nn" class="center-shrink lightbox-image" /></p>

<div class="image-caption">Courtesy of <a href="https://aiml.com/what-is-the-basic-architecture-of-an-artificial-neural-network-ann/">AI ML</a></div>
<p><br /></p>

<p>The core structure of a neural network is the connections between all of the neurons. Each connection carries an activation signal of varying strength. If the incoming signal to a neuron is strong enough, then the signal is permeated through the next stages of the network.</p>

<p>There is a input layer that feeds the data into the hidden layer. The outputs from the hidden layer are then passed to the output layer. Every connection between nodes carries a weight determining the amount of information that gets passed through.</p>

<h3 id="hyper-parameters">Hyper Parameters</h3>

<p>For a basic neural network, there are generally three <a href="https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)">hyperparameters</a>:</p>

<ul>
  <li>pattern of connections between all neurons</li>
  <li>weights of connections between neurons</li>
  <li>activation functions of the neurons</li>
</ul>

<p>In our project however, we focus on a specific class of neural networks called Recurrent Neural Networks (RNNs), and the more specific variation of RNNs called Long Short Term Memory networks (<a href="https://en.wikipedia.org/wiki/Long_short-term_memory">LSTMs</a>).</p>

<p>However, let’s give a bit more context. There’s really two broad types of neural networks:</p>

<ul>
  <li><strong><a href="https://en.wikipedia.org/wiki/Feedforward_neural_network">Feedforward Neural Network</a></strong></li>
  <li><strong><a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">Recurrent Neural Networks</a></strong></li>
</ul>

<h2 id="feedforward-neural-network">Feedforward Neural Network</h2>

<p>These neural networks channel information in <strong>one direction</strong>.</p>

<p>The figure above is showing a feedforward neural network because <strong>the connections do not allow for the same input data to be seen multiple times by the same node.</strong></p>

<p>These networks are generally very well used for mapping raw data to categories. For example, classifying a face from an image.</p>

<p>Every node outputs a numerical value that it then passes to all its successor nodes. In other words:</p>

\[\begin{align}
y_j = f(x_j)
\end{align}
\tag{1}\]

<p>where</p>

\[\begin{align}
x_j = \sum_{i \in P_j} w_{ij} y_i
\end{align}
\tag{2}\]

<p>where</p>

<ul>
  <li>$y_j$ is the output of node $j$</li>
  <li>$x_j$ is the total weighted input for node $j$</li>
  <li>$w_{ij}$ is the weight from node $i$ to node $j$</li>
  <li>$y_i$ is the output from node $i$</li>
  <li>$P_j$ represents the set of predecessor nodes to node $j$</li>
</ul>

<p>Also note, $f(x)$ should be a smooth non-linear activation function that maps outputs to a reasonable domain. Some common activation functions include $\tanh(x)$ or the <a href="https://en.wikipedia.org/wiki/Sigmoid_function">sigmoid function</a>. These complex functions are necessary because the neural network is <em>literally</em> trying to learn a non-linear pattern.</p>

<h3 id="backpropagation">Backpropagation</h3>

<p><a href="https://en.wikipedia.org/wiki/Backpropagation">Backpropagation</a> is the mechanism in which we pass the error back through the network starting at the output node. Generally, we minimize using [stochastic gradient descent][stoch-grad-desc]. Again, lots of different ways we can define our error, but we can use sum of squared residuals between our $k$ targets and the output of $k$ nodes of the network.</p>

\[\begin{align}
E = \frac{1}{2} \sum_{k}(t_k - y_k)^2
\end{align}
\tag{3}\]

<p>The gradient descent part comes in next. We generate the set of all gradients with respect to error and minimize these gradients. We’re minimizing this:</p>

\[\begin{align}
g_{ij} = - \frac{\delta E}{\delta w_{ij}}
\end{align}
\tag{4}\]

<p>So overall, we’re continually altering the weights and minimizing their individual effect oin the overall error of the outputs.</p>

<p>The major downfall of this simple network is that we don’t have full context. With sequences, there’s not enough information about the previous words, so the context is missing. And that leads us to our next structure.</p>

<h2 id="recurrent-neural-network">Recurrent Neural Network</h2>

<p><a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">Recurrent Neural Networks (RNNs)</a> have a capacity to remember. This memory stems from the fact that their input is not only the current input vector but also a variation of what they output at previous time steps.</p>

<p>This visualization from <a href="https://colah.github.io/">Christopher Olah</a> (who holy hell i just realized is a co-founder of <a href="https://www.anthropic.com/">Anthropic</a>, but who Tom and I used to follow closely in college) is a great visualization:</p>

<p><img src="/images/generative-handwriting/rnn_unrolled.png" alt="rnn-unrolled" class="center-shrink lightbox-image" /></p>

<div class="image-caption">Courtesy of Chris Olah's <a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">Understanding LSTMs</a></div>
<p><br /></p>

<p>This RNN module is being unrolled over multiple timestamps. Information is passed within a module at time step $t$ to the module at $t+1$.</p>

<p>Per Tom and my paper,</p>

<blockquote>
  <p>An ideal RNN would theoretically be able to remember as far back as was necessary in order to to make an accurate prediction. However, as with many things, the theory does not carry over to reality. RNNs have trouble learning long term dependencies due to the vanishing gradient problem. An example of such a long term dependency might be if we are trying to predict the last word in the following sentence ”My family originally comes from Belgium so my native language is PREDICTION”. A normal RNN would possibly be able to recognize that the prediction should be a language but it would need the earlier context of Belgium to be able to accurately predict DUTCH.</p>
</blockquote>

<p>Topically, this is why the craze around LLMs is so impressive. There’s a lot more going on with LLMs… which… I will not cover here.</p>

<p>The notion of <a href="https://en.wikipedia.org/wiki/Backpropagation">backpropagation</a> is basically the same just we also have the added dimension of time.</p>

<p>The crux of the issue is that RNNs have many layers and as we begin to push the derivatives to zero. The gradients become too small and cause underflow. In actual meaning, the networks then cease to be able to learn.</p>

<p>However, <a href="https://en.wikipedia.org/wiki/Sepp_Hochreiter">Sepp Hochreiter</a> and <a href="https://en.wikipedia.org/wiki/J%C3%BCrgen_Schmidhuber">Juergen Schmidhuber</a> developed the <a href="https://en.wikipedia.org/wiki/Long_short-term_memory">Long Short Term Memory (LSTM)</a> unit that solved this vanishing gradient problem.</p>

<h2 id="long-short-term-memory-networks">Long Short Term Memory Networks</h2>

<p><a href="https://en.wikipedia.org/wiki/Long_short-term_memory">Long Short Term Memory (LSTM)</a> networks are specifically designed to learn long term dependencies.</p>

<p>Every form of RNN has repeating modules that pass information across timesteps, and LSTMs are no different. Where they different is the inner structure of each module. While a standard RNN might have a single neural layer, LSTMs have four.</p>

<p><img src="/images/generative-handwriting/lstm.png" alt="lstm-viz" class="center-shrink lightbox-image" /></p>

<div class="image-caption">Courtesy of Chris Olah's <a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">Understanding LSTMs</a></div>
<p><br /></p>

<h3 id="understanding-the-llm-structure">Understanding the LLM Structure</h3>

<p>So let’s better understand the structure above. There’s a way more comprehensive walkthrough <a href="https://colah.github.io/">here</a>. I’d encourage you to check out that walkthrough.</p>

<p><img src="/images/generative-handwriting/single_lstm_module.png" alt="lstm-viz" class="center-super-shrink lightbox-image" /></p>

<div class="image-caption">Courtesy of Chris Olah's <a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">Understanding LSTMs</a></div>
<p><br /></p>

<p>The top line is key to the LSTM’s ability to remember. It is called the cell state. We’ll reference it as $C_t$.</p>

<p>The first neural network layer is a sigmoid function. It takes as input the concatenation between the current input $x_t$ and the output of the previous module $h_{t-1}$. This is coined as the forget gate. It is in control of what to forget for the cell state. The sigmoid function is a good architecture decision here because it basically outputs numbers between [0,1] indicating how much the layer should let through.</p>

<p>We piecewise multiply the output of the sigmoid layer $f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$, with the cell state from the previous module $C_{t-1}$, forgetting the things that it doesn’t see as important.</p>

<p>Then right in the center of the image above there are two neural network layers which make up the update gate. First, $x_t \cdot h_{t-1}$ is pushed through both a sigmoid ($\sigma$) layer and a $\tanh$ layer. The output of this sigmoid layer $i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_C)$ determines which values to use to update, and the output of the $\tanh$ layer $\hat{C} = \sigma (W_C \cdot [h_{t-1}, x_{t} + b_C$, proposes an entirely new cell state. These two results are then piecewise multiplied and added to the current cell state (which we just edited using the forget layer) outputting the new cell state $C_t$.</p>

<p>The final neural network layer is called the output gate. It determines the relevant portion of the cell state to output as $h_t$. Once again, we feed $x_t \cdot h_{t-1}$ through a sigmoid layer whose output, $o_t = \sigma (W_o \cdot [h_{t-1}, x_t] + b_o)$, we piecewise multiple with $\tanh(C_t)$. The result of the multiplication determines the output of the LSTM module. Note that the <span style="color:purple"><strong>purple</strong></span> $\tanh$ is not a neural network layer, but a piecewise multiplication intended to push the current cell state into a reasonable domain.</p>

<div class="markdown-alert markdown-alert-note">
<p><b>I'm serious... you guys should check out Olah's <a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">Understanding LSTMs</a>. Here he is back in 2015 strongly foreshadowing transformers given the focus on attention (which is truly the hardest part of all this) blog post.</b></p>
</div>

<p><img src="/images/generative-handwriting/olah-attention.png" alt="olah-attention" class="center-shrink" /></p>

<div class="image-caption">Courtesy of Chris Olah's <a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">Understanding LSTMs</a></div>
<p><br /></p>

<p><br /></p>

<h1 id="-concepts-to-code">🧬 Concepts to Code</h1>

<p>When very first starting this project, I kind of figured that I would be able to use some of my college code, but looking back. It’s quite a mess and I don’t think that’s the way to go about it.</p>

<p>I thought for awhile about how best to structure this part. Meaning the code, but also how to show this in my blog post. With all the buzz about JAX, I wanted to try that too, so I thought it’d be helpful to show a side be side translation of the tensorflow vs jax code. My hope is that we’ll walk through the concepts and have a good understanding of the theory, and then the code will make a bit more sense. One note, is that I was a bit burnt of this project by the end so the JAX code I was trying to use <a href="https://github.com/google-deepmind/optax"><code class="language-plaintext highlighter-rouge">optax</code> (link)</a> and <a href="https://github.com/google/flax"><code class="language-plaintext highlighter-rouge">flax</code> (link)</a> as much as possible to cut down on bulkiness of code.</p>

<p>So we’ll walk through the building blocks (in terms of code) and then show the code translations.</p>

<h2 id="lstm-cell-with-peephole-connections">LSTM Cell with Peephole Connections</h2>

<h3 id="theory">Theory</h3>

<p>The basic LSTM cell (<code class="language-plaintext highlighter-rouge">tf.keras.layers.LSTMCell</code>) does not actually have the notion of peephole connections.</p>

<p>According to the very functional code that <a href="https://github.com/sjvasquez">sjvasquez</a> wrote, I don’t think we actually need it, but I figured it would be fun to implement regardless. Back in the old days, when Tensorflow would support add-ons, there was some work around this <a href="https://www.tensorflow.org/addons/api_docs/python/tfa/rnn/PeepholeLSTMCell">here</a>, but that project was deprecated.</p>

<p>That being said…. the JAX / Flax code also doesn’t have LSTMs out of the gate with peepholes and so…. I just used the normal ones. The JAX model actually trained a bit better, but I think part of that was also just patience.</p>

<h3 id="code">Code</h3>

<!-- prettier-ignore-start -->

<div class="code-toggle">
  <div class="code-toggle__tabs">
    <button class="code-toggle__tab code-toggle__tab--active" data-tab="tensorflow">TensorFlow</button>
    <button class="code-toggle__tab" data-tab="jax">JAX</button>
  </div>
  <div class="code-toggle__content">
    <div class="code-toggle__pane code-toggle__pane--active" data-pane="tensorflow">

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">tf</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">tf</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">Tensor</span><span class="p">]):</span>
    <span class="s">"""
    This is basically implementing Graves's equations on page 5
    https://www.cs.toronto.edu/~graves/preprint.pdf
    equations 5-11.

    From the paper,
    * sigma is the logistic sigmoid function
    * i -&gt; input gate
    * f -&gt; forget gate
    * o -&gt; output gate
    * c -&gt; cell state
    * W_{hi} - hidden-input gate matrix
    * W_{xo} - input-output gate matrix
    * W_{ci} - are diagonal
        + so element m in each gate vector only receives input from
        + element m of the cell vector
    """</span>

    <span class="c1"># going to be shape (?, num_lstm_units)
</span>    <span class="n">h_tm1</span><span class="p">,</span> <span class="n">c_tm1</span> <span class="o">=</span> <span class="n">state</span>

    <span class="c1"># basically the meat of eq, 7, 8, 9, 10
</span>    <span class="n">z</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">kernel</span><span class="p">)</span> <span class="o">+</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">h_tm1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">recurrent_kernel</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">bias</span>
    <span class="n">i_lin</span><span class="p">,</span> <span class="n">f_lin</span><span class="p">,</span> <span class="n">g_lin</span><span class="p">,</span> <span class="n">o_lin</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">num_or_size_splits</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

    <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">should_apply_peephole</span><span class="p">:</span>
        <span class="n">pw_i</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">peephole_weights</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">pw_f</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">peephole_weights</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">i_lin</span> <span class="o">=</span> <span class="n">i_lin</span> <span class="o">+</span> <span class="n">c_tm1</span> <span class="o">*</span> <span class="n">pw_i</span>
        <span class="n">f_lin</span> <span class="o">=</span> <span class="n">f_lin</span> <span class="o">+</span> <span class="n">c_tm1</span> <span class="o">*</span> <span class="n">pw_f</span>

    <span class="c1"># apply activation functions! see Olah's blog
</span>    <span class="n">i</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">i_lin</span><span class="p">)</span>
    <span class="n">f</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">f_lin</span><span class="p">)</span>
    <span class="n">g</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">g_lin</span><span class="p">)</span>
    <span class="n">c</span> <span class="o">=</span> <span class="n">f</span> <span class="o">*</span> <span class="n">c_tm1</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="n">g</span>

    <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">should_apply_peephole</span><span class="p">:</span>
        <span class="n">pw_o</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">peephole_weights</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">o_lin</span> <span class="o">=</span> <span class="n">o_lin</span> <span class="o">+</span> <span class="n">c</span> <span class="o">*</span> <span class="n">pw_o</span>

    <span class="n">o</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">o_lin</span><span class="p">)</span>

    <span class="c1"># final hidden state -&gt; eq. 11
</span>    <span class="n">h</span> <span class="o">=</span> <span class="n">o</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">c</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">h</span><span class="p">,</span> <span class="p">[</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span></code></pre></figure>

</div>
<div class="code-toggle__pane" data-pane="jax">

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">class</span> <span class="nc">HandwritingModel</span><span class="p">(</span><span class="n">nnx</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
        <span class="n">rngs</span><span class="p">:</span> <span class="n">nnx</span><span class="p">.</span><span class="n">Rngs</span><span class="p">,</span>
        <span class="n">synthesis_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">False</span><span class="p">,</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span> <span class="o">=</span> <span class="n">synthesis_mode</span>

        <span class="c1"># rngs is basically a set of random keys / number generators
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cells</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_build_lstm_stack</span><span class="p">(</span><span class="n">rngs</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">synthesis_mode</span><span class="p">:</span>
            <span class="c1"># i mean we really only care about synthesis mode, but in
</span>            <span class="c1"># this case we can make it explicit that if we have it then we should add our
</span>            <span class="c1"># attention layer
</span>            <span class="bp">self</span><span class="p">.</span><span class="n">attention_layer</span> <span class="o">=</span> <span class="n">nnx</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span>
                <span class="n">config</span><span class="p">.</span><span class="n">hidden_size</span> <span class="o">+</span> <span class="n">config</span><span class="p">.</span><span class="n">alphabet_size</span> <span class="o">+</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">config</span><span class="p">.</span><span class="n">num_attention_gaussians</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="n">rngs</span>
            <span class="p">)</span>

        <span class="c1"># mdn portion
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">mdn_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_build_mdn_head</span><span class="p">(</span><span class="n">rngs</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_build_lstm_stack</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rngs</span><span class="p">:</span> <span class="n">nnx</span><span class="p">.</span><span class="n">Rngs</span><span class="p">):</span>
        <span class="n">cells</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">num_layers</span><span class="p">):</span>
            <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span><span class="p">:</span>
                    <span class="c1"># so if we're in synthesis mode, then we need to add the alphabet size
</span>                    <span class="c1"># and the 3 dimensions of the input stroke
</span>                    <span class="c1"># that's because our alphabet size is the number of characters in our alphabet
</span>                    <span class="c1"># and the 3 dimensions of the input stroke are the x, y, and eos values
</span>                    <span class="n">in_size</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">alphabet_size</span> <span class="o">+</span> <span class="mi">3</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="n">in_size</span> <span class="o">=</span> <span class="mi">3</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="c1"># similar in both (just in synthesis we only care if we need to expand by the alphabet size)
</span>                <span class="n">in_size</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">hidden_size</span> <span class="o">+</span> <span class="mi">3</span>
                <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span><span class="p">:</span>
                    <span class="n">in_size</span> <span class="o">+=</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">alphabet_size</span>

            <span class="c1"># ok... being lazy but this is just standard LSTM
</span>            <span class="n">cells</span><span class="p">.</span><span class="n">append</span><span class="p">(</span>
                <span class="p">{</span><span class="s">"linear"</span><span class="p">:</span> <span class="n">nnx</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_size</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="n">rngs</span><span class="p">)}</span>
            <span class="p">)</span>
        <span class="k">return</span> <span class="n">cells</span>

    <span class="k">def</span> <span class="nf">lstm_cell</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">c</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">layer_idx</span><span class="p">:</span> <span class="nb">int</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">]:</span>
        <span class="c1"># just think about this as grabbing the W and b for our matrix mults
</span>        <span class="n">linear</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cells</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">][</span><span class="s">"linear"</span><span class="p">]</span>

        <span class="n">combined</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">h</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">gates</span> <span class="o">=</span> <span class="n">linear</span><span class="p">(</span><span class="n">combined</span><span class="p">)</span>

        <span class="n">i</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">g</span><span class="p">,</span> <span class="n">o</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">gates</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="c1"># activations
</span>        <span class="n">i</span> <span class="o">=</span> <span class="n">nnx</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
        <span class="n">f</span> <span class="o">=</span> <span class="n">nnx</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
        <span class="n">g</span> <span class="o">=</span> <span class="n">nnx</span><span class="p">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">g</span><span class="p">)</span>
        <span class="n">o</span> <span class="o">=</span> <span class="n">nnx</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>

        <span class="c1"># get new LSTM cell state
</span>        <span class="n">c_new</span> <span class="o">=</span> <span class="n">f</span> <span class="o">*</span> <span class="n">c</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="n">g</span>
        <span class="n">h_new</span> <span class="o">=</span> <span class="n">o</span> <span class="o">*</span> <span class="n">nnx</span><span class="p">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">c_new</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">h_new</span><span class="p">,</span> <span class="n">c_new</span></code></pre></figure>

</div>

  </div>
</div>

<!-- prettier-ignore-end -->

<h2 id="gaussian-mixture-models">Gaussian Mixture Models</h2>

<h3 id="theory-1">Theory</h3>

<p><img src="https://miro.medium.com/v2/resize:fit:996/1*kJYirC6ewCqX1M6UiXmLHQ.gif" alt="gmm-viz" class="center-super-shrink lightbox-image" /></p>

<div class="image-caption"><a href="https://miro.medium.com/v2/resize:fit:996/1*kJYirC6ewCqX1M6UiXmLHQ.gif">reference</a></div>
<p><br /></p>

<p><a href="https://brilliant.org/wiki/gaussian-mixture-model/">Gaussian Mixture Models</a> are an unsupervised technique to learn an underlying probabilistic model.</p>

<p>Brilliant has an incredible explanation walking through the theory <a href="https://brilliant.org/wiki/gaussian-mixture-model/">here</a>. I’d encourage you to check it out, but at a very high level:</p>

<ol>
  <li>A number of Gaussians is specified by the user</li>
  <li>The algo learns various parameters that represent the data while maximizing the likelihood of seeing such data</li>
</ol>

<p>So if we have $k$ components, for a <em>multivariate</em> Gaussian mixture model, we’ll learn $k$ means, $k$ variances, $k$ mixture weights, $k$ correlations through expectation maximization.</p>

<p>From <a href="https://brilliant.org/">Brilliant</a>, there are really two steps for the EM step:</p>

<blockquote>
  <p>The first step, known as the expectation step or E step, consists of calculating the expectation of the component assignments $C_k$ for each data point $x_i \in X$ given the model parameters $\phi_k, \mu_k$ , and $\sigma_k$ .</p>

  <p>The second step is known as the maximization step or M step, which consists of maximizing the expectations calculated in the E step with respect to the model parameters. This step consists of updating the values $\phi_k, \mu_k$ , and $\sigma_k$ .</p>
</blockquote>

<h3 id="code-1">Code</h3>

<p>There’s actually not a whole lot of code to provide here. GMMs are more of the technique that we’ll combine with the output of a neural network. That leads us smoothly to our next section.</p>

<h2 id="mixture-density-networks">Mixture Density Networks</h2>

<h3 id="theory-2">Theory</h3>

<p><a href="https://www.katnoria.com/mdn/">Mixture Density Networks</a> are an extension of GMMs that predict the parameters of a mixture probability distribution.</p>

<p><img src="https://towardsdatascience.com/wp-content/uploads/2024/05/1UKuoYsGWis22cOV7KpLjVg.png" alt="mdn-viz" class="basic-center lightbox-image" /></p>

<div class="image-caption">Courtesy of <a href="https://towardsdatascience.com/wp-content/uploads/2024/05/1UKuoYsGWis22cOV7KpLjVg.png">Towards Data Science</a></div>
<p><br /></p>

<p>Per our paper:</p>

<blockquote>
  <p>The idea is relatively simple - we take the output from a neural network and parametrize the learned parameters of the GMM. The result is that we can infer probabilistic prediction from our learned parameters. If our neural network is reason- ably predicting where the next point might be, the GMM will then learn probabilistic parameters that model the distribution of the next point. This is different in a few key aspects. Namely, we now have target values because our data is sequential. Therefore, when we feed in our targets, we minimize the log likelihood based on those expectations, thus altering the GMM portion of the model to learn the predicted values.</p>
</blockquote>

<p>More or less though, the problem we’re trying to solve is predicting the next input given our output vector. Essentially, we’re asking for $\text{Pr}(x_{t+1} | y_t)$. I’m not going to show the proof (we didn’t in our paper right), but the equation for the conditional probability is shown below:</p>

\[\begin{align}
\text{Pr}(x_{t+1} | y_t) = \sum_{j=1}^{M} \pi_{j}^t \mathcal{N} (x_{t+1} \mid \mu_j^t, \sigma_j^t, \rho_j^t)
\end{align}
\tag{5}\]

<p>where</p>

\[\begin{align}
\mathcal{N}(x \mid \mu, \sigma, \rho) = \frac{1}{2\pi \sigma_1 \sigma_2 \sqrt[]{1-\rho^2}} \exp \left[\frac{-Z}{2(1-\rho^2)}\right]
\end{align}
\tag{6}\]

<p>and</p>

\[\begin{align}
Z = \frac{(x_1 - \mu_1)^2 }{\sigma_1^2} + \frac{(x_2 - \mu_2)^2}{\sigma_2^2} - \frac{2\rho (x_1 - \mu_1) (x_2 - \mu_2) }{\sigma_1 \sigma_2}
\end{align}
\tag{7}\]

<p>Now, there’s a slight variation here because we have a handwriting specific <em>end-of-stroke</em> parameter. So we modify our conditional probability formula to result in our final calculation of:</p>

\[\begin{align}
\textrm{Pr}(x_{t+1} \mid y_t ) = \sum\limits_{j=1}\limits^{M} \pi_j^t \; \mathcal{N} (x_{t+1} \mid \mu_j^t, \sigma_j^t, \rho_j^t)
\begin{cases}
e_t &amp; \textrm{if } (x_{t+1})_3 = 1 \\
1-e_t &amp; \textrm{otherwise}
\end{cases}
\end{align}
\tag{8}\]

<p>And that’s it! That’s our final probability output from the MDN. Once we have this, performing our expectation maximization is simple as our loss function that we choose to minimize is just:</p>

\[\begin{align}
\mathcal{L}(\mathbf{x}) = - \sum\limits_{t=1}^{T} \log \textrm{Pr}(x_{t+1} \mid y_t)
\end{align}
\tag{9}\]

<h3 id="code-2">Code</h3>

<p>Here’s the corresponding code section for my mixture density network.</p>

<!-- prettier-ignore-start -->

<div class="code-toggle">
  <div class="code-toggle__tabs">
    <button class="code-toggle__tab code-toggle__tab--active" data-tab="tensorflow">TensorFlow</button>
    <button class="code-toggle__tab" data-tab="jax">JAX</button>
  </div>
  <div class="code-toggle__content">
    <div class="code-toggle__pane code-toggle__pane--active" data-pane="tensorflow">

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">class</span> <span class="nc">MixtureDensityLayer</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Layer</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">num_components</span><span class="p">,</span>
        <span class="n">name</span><span class="o">=</span><span class="s">"mdn"</span><span class="p">,</span>
        <span class="n">temperature</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
        <span class="n">enable_regularization</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
        <span class="n">sigma_reg_weight</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
        <span class="n">rho_reg_weight</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
        <span class="n">entropy_reg_weight</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
        <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
    <span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">MixtureDensityLayer</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_components</span> <span class="o">=</span> <span class="n">num_components</span>
        <span class="c1"># The number of parameters per mixture component: 2 means, 2 standard deviations, 1 correlation, 1 weight , 1 for eos
</span>        <span class="c1"># so that's our constant num_mixture_components_per_component
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="n">num_components</span> <span class="o">*</span> <span class="n">NUM_MIXTURE_COMPONENTS_PER_COMPONENT</span> <span class="o">+</span> <span class="mi">1</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mod_name</span> <span class="o">=</span> <span class="n">name</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">temperature</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">enable_regularization</span> <span class="o">=</span> <span class="n">enable_regularization</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">sigma_reg_weight</span> <span class="o">=</span> <span class="n">sigma_reg_weight</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">rho_reg_weight</span> <span class="o">=</span> <span class="n">rho_reg_weight</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">entropy_reg_weight</span> <span class="o">=</span> <span class="n">entropy_reg_weight</span>

    <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span>
        <span class="n">graves_initializer</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">initializers</span><span class="p">.</span><span class="n">TruncatedNormal</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">stddev</span><span class="o">=</span><span class="mf">0.075</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">input_units</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
        <span class="c1"># weights
</span>        <span class="c1"># lots of weight initialization here... could simplify here too
</span>
        <span class="c1"># biases
</span>        <span class="c1"># lots of bias initialization here... could simplify this part by just doing a massive 
</span>        <span class="c1"># and splitting... see the code if you're curious
</span>        <span class="nb">super</span><span class="p">().</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">temperature</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">training</span> <span class="k">else</span> <span class="bp">self</span><span class="p">.</span><span class="n">temperature</span>

        <span class="n">pi_logits</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_pi</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">b_pi</span>
        <span class="n">pi</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">pi_logits</span> <span class="o">/</span> <span class="n">temperature</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># [B, T, K]
</span>        <span class="c1"># clipping here... I was getting cooked by NaN creep
</span>        <span class="n">pi</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">pi</span><span class="p">,</span> <span class="mf">1e-6</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>

        <span class="n">mu</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_mu</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">b_mu</span>  <span class="c1"># [B, T, 2K]
</span>        <span class="n">mu1</span><span class="p">,</span> <span class="n">mu2</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>

        <span class="n">log_sigma</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_sigma</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">b_sigma</span>  <span class="c1"># [B, T, 2K]
</span>        <span class="c1"># again, this might be overkill but seems realistic for clipping
</span>        <span class="n">log_sigma</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">log_sigma</span><span class="p">,</span> <span class="o">-</span><span class="mf">5.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
        <span class="n">sigma</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_sigma</span><span class="p">)</span>
        <span class="n">sigma1</span><span class="p">,</span> <span class="n">sigma2</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">sigma</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>

        <span class="n">rho_raw</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_rho</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">b_rho</span>
        <span class="n">rho</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">rho_raw</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.9</span>

        <span class="n">eos_logit</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_eos</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">b_eos</span>

        <span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">pi</span><span class="p">,</span> <span class="n">mu1</span><span class="p">,</span> <span class="n">mu2</span><span class="p">,</span> <span class="n">sigma1</span><span class="p">,</span> <span class="n">sigma2</span><span class="p">,</span> <span class="n">rho</span><span class="p">,</span> <span class="n">eos_logit</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></code></pre></figure>

</div>
<div class="code-toggle__pane" data-pane="jax">

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">class</span> <span class="nc">HandwritingModel</span><span class="p">(</span><span class="n">nnx</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
        <span class="n">rngs</span><span class="p">:</span> <span class="n">nnx</span><span class="p">.</span><span class="n">Rngs</span><span class="p">,</span>
        <span class="n">synthesis_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">False</span><span class="p">,</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span> <span class="o">=</span> <span class="n">synthesis_mode</span>

        <span class="c1"># rngs is basically a set of random keys / number generators
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cells</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_build_lstm_stack</span><span class="p">(</span><span class="n">rngs</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">synthesis_mode</span><span class="p">:</span>
            <span class="c1"># i mean we really only care about synthesis mode, but in
</span>            <span class="c1"># this case we can make it explicit that if we have it then we should add our
</span>            <span class="c1"># attention layer
</span>            <span class="bp">self</span><span class="p">.</span><span class="n">attention_layer</span> <span class="o">=</span> <span class="n">nnx</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span>
                <span class="n">config</span><span class="p">.</span><span class="n">hidden_size</span> <span class="o">+</span> <span class="n">config</span><span class="p">.</span><span class="n">alphabet_size</span> <span class="o">+</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">config</span><span class="p">.</span><span class="n">num_attention_gaussians</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="n">rngs</span>
            <span class="p">)</span>

        <span class="c1"># mdn portion
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">mdn_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_build_mdn_head</span><span class="p">(</span><span class="n">rngs</span><span class="p">)</span>

    <span class="c1">#....
</span>    
    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">inputs</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>
        <span class="n">char_seq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span>
        <span class="n">char_lens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span>
        <span class="n">initial_state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">RNNState</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span>
        <span class="n">return_state</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">False</span><span class="p">,</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">:</span>
        <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">.</span><span class="n">shape</span>

        <span class="k">if</span> <span class="n">initial_state</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">h</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">hidden_size</span><span class="p">),</span> <span class="n">inputs</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
            <span class="n">c</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
            <span class="n">kappa</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">num_attention_gaussians</span><span class="p">),</span> <span class="n">inputs</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
            <span class="n">window</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">alphabet_size</span><span class="p">),</span> <span class="n">inputs</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">initial_state</span><span class="p">.</span><span class="n">hidden</span><span class="p">,</span> <span class="n">initial_state</span><span class="p">.</span><span class="n">cell</span>
            <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span> <span class="o">=</span> <span class="n">initial_state</span><span class="p">.</span><span class="n">kappa</span><span class="p">,</span> <span class="n">initial_state</span><span class="p">.</span><span class="n">window</span>

        <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">x_t</span><span class="p">):</span>
            <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span> <span class="o">=</span> <span class="n">carry</span>
            <span class="n">h_layers</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="n">c_layers</span> <span class="o">=</span> <span class="p">[]</span>

            <span class="c1"># layer1
</span>            <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span><span class="p">:</span>
                <span class="n">layer1_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">window</span><span class="p">,</span> <span class="n">x_t</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">layer1_input</span> <span class="o">=</span> <span class="n">x_t</span>

            <span class="n">h1</span><span class="p">,</span> <span class="n">c1</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cell</span><span class="p">(</span><span class="n">layer1_input</span><span class="p">,</span> <span class="n">h</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">c</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
            <span class="n">h_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">h1</span><span class="p">)</span>
            <span class="n">c_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">c1</span><span class="p">)</span>

            <span class="c1"># layer1 -&gt; attention
</span>            <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span> <span class="ow">and</span> <span class="n">char_seq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="ow">and</span> <span class="n">char_lens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
                <span class="n">window</span><span class="p">,</span> <span class="n">kappa</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">compute_attention</span><span class="p">(</span><span class="n">h1</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="p">,</span> <span class="n">x_t</span><span class="p">,</span> <span class="n">char_seq</span><span class="p">,</span> <span class="n">char_lens</span><span class="p">)</span>

            <span class="c1"># attention -&gt; layer2 and layer3
</span>            <span class="k">for</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">num_layers</span><span class="p">):</span>
                <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span><span class="p">:</span>
                    <span class="n">layer_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">x_t</span><span class="p">,</span> <span class="n">h_layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">window</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="n">layer_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">x_t</span><span class="p">,</span> <span class="n">h_layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

                <span class="n">h_new</span><span class="p">,</span> <span class="n">c_new</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cell</span><span class="p">(</span><span class="n">layer_input</span><span class="p">,</span> <span class="n">h</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">],</span> <span class="n">c</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">],</span> <span class="n">layer_idx</span><span class="p">)</span>
                <span class="n">h_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">h_new</span><span class="p">)</span>
                <span class="n">c_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">c_new</span><span class="p">)</span>

            <span class="n">h_new</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">h_layers</span><span class="p">)</span>
            <span class="n">c_new</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">c_layers</span><span class="p">)</span>

            <span class="c1"># mdn output from final hidden state
</span>            <span class="n">mdn_out</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mdn_layer</span><span class="p">(</span><span class="n">h_layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>  <span class="c1"># [B, 6M+1]
</span>
            <span class="k">return</span> <span class="p">(</span><span class="n">h_new</span><span class="p">,</span> <span class="n">c_new</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="p">),</span> <span class="n">mdn_out</span>

        <span class="c1"># this was the major unlock for JAX performance
</span>        <span class="c1"># it allows us to vectorize the computation over the time dimension
</span>        <span class="c1"># transpose inputs from [B, T, 3] to [T, B, 3] for scan
</span>        <span class="n">inputs_transposed</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="p">),</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="p">),</span> <span class="n">inputs_transposed</span><span class="p">)</span>

        <span class="c1"># transpose back
</span>        <span class="n">outputs</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">return_state</span><span class="p">:</span>
            <span class="n">final_state</span> <span class="o">=</span> <span class="n">RNNState</span><span class="p">(</span><span class="n">hidden</span><span class="o">=</span><span class="n">h</span><span class="p">,</span> <span class="n">cell</span><span class="o">=</span><span class="n">c</span><span class="p">,</span> <span class="n">kappa</span><span class="o">=</span><span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="o">=</span><span class="n">window</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">final_state</span>

        <span class="k">return</span> <span class="n">outputs</span></code></pre></figure>

</div>

  </div>
</div>

<!-- prettier-ignore-end -->

<h2 id="mixture-density-loss">Mixture Density Loss</h2>

<h3 id="theory-3">Theory</h3>

<p>I already covered the theory above, so I won’t go into that here, but just figured it was easier to split out the code between network and calculating our loss. Note, there’s some pretty aggressive clipping going on just given I had some pretty high instability with JAX. I think partially because of the implementation and clipping but loss would just go to 0 rather than the program crashing. To be clear, loss going to zero was not desired.</p>

<h3 id="code-3">Code</h3>

<!-- prettier-ignore-start -->

<div class="code-toggle">
  <div class="code-toggle__tabs">
    <button class="code-toggle__tab code-toggle__tab--active" data-tab="tensorflow">TensorFlow</button>
    <button class="code-toggle__tab" data-tab="jax">JAX</button>
  </div>
  <div class="code-toggle__content">
    <div class="code-toggle__pane code-toggle__pane--active" data-pane="tensorflow">

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">@</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">register_keras_serializable</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">mdn_loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">stroke_lengths</span><span class="p">,</span> <span class="n">num_components</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">):</span>
    <span class="s">"""
    Mixture density negative log-likelihood computed fully in log-space.

    y_true: [B, T, 3]  -&gt; (x, y, eos ∈ {0,1})
    y_pred: [B, T, 6*K + 1] -&gt; (pi, mu1, mu2, sigma1, sigma2, rho, eos_logit)

    The log space change was because I was getting absolutely torched by the
    gradients when using the normal space.
    """</span>
    <span class="n">out_pi</span><span class="p">,</span> <span class="n">mu1</span><span class="p">,</span> <span class="n">mu2</span><span class="p">,</span> <span class="n">sigma1</span><span class="p">,</span> <span class="n">sigma2</span><span class="p">,</span> <span class="n">rho</span><span class="p">,</span> <span class="n">eos_logits</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span>
        <span class="n">y_pred</span><span class="p">,</span>
        <span class="p">[</span><span class="n">num_components</span><span class="p">]</span> <span class="o">*</span> <span class="mi">6</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span>
        <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
    <span class="p">)</span>

    <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">eos_targets</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">sigma1</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">sigma1</span><span class="p">,</span> <span class="mf">1e-2</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">)</span>
    <span class="n">sigma2</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">sigma2</span><span class="p">,</span> <span class="mf">1e-2</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">)</span>
    <span class="n">rho</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">rho</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">)</span>
    <span class="n">out_pi</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">out_pi</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>

    <span class="n">log_2pi</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">y_pred</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
    <span class="n">one_minus_rho2</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">tf</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">rho</span><span class="p">),</span> <span class="n">eps</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
    <span class="n">log_one_minus_rho2</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">one_minus_rho2</span><span class="p">)</span>
    <span class="n">z1</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mu1</span><span class="p">)</span> <span class="o">/</span> <span class="n">sigma1</span>
    <span class="n">z2</span> <span class="o">=</span> <span class="p">(</span><span class="n">y</span> <span class="o">-</span> <span class="n">mu2</span><span class="p">)</span> <span class="o">/</span> <span class="n">sigma2</span>

    <span class="n">quad</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">z1</span><span class="p">)</span> <span class="o">+</span> <span class="n">tf</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">z2</span><span class="p">)</span> <span class="o">-</span> <span class="mf">2.0</span> <span class="o">*</span> <span class="n">rho</span> <span class="o">*</span> <span class="n">z1</span> <span class="o">*</span> <span class="n">z2</span>
    <span class="n">quad</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">quad</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">100.0</span><span class="p">)</span>
    <span class="n">log_norm</span> <span class="o">=</span> <span class="o">-</span><span class="p">(</span><span class="n">log_2pi</span> <span class="o">+</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">sigma1</span><span class="p">)</span> <span class="o">+</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">sigma2</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">log_one_minus_rho2</span><span class="p">)</span>
    <span class="n">log_gauss</span> <span class="o">=</span> <span class="n">log_norm</span> <span class="o">-</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">quad</span> <span class="o">/</span> <span class="n">one_minus_rho2</span>  <span class="c1"># [B, T, K]
</span>
    <span class="c1"># log mixture via log-sum-exp
</span>    <span class="n">log_pi</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">out_pi</span><span class="p">)</span>  <span class="c1"># [B, T, K]
</span>    <span class="n">log_gmm</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_logsumexp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">+</span> <span class="n">log_gauss</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># [B, T]
</span>
    <span class="c1"># bce (bernoulli cross entropy) to help out with stability
</span>    <span class="n">eos_nll</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">sigmoid_cross_entropy_with_logits</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="n">eos_targets</span><span class="p">,</span> <span class="n">logits</span><span class="o">=</span><span class="n">eos_logits</span><span class="p">)</span>  <span class="c1"># [B, T, 1]
</span>    <span class="n">eos_nll</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">eos_nll</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># [B, T]
</span>
    <span class="n">nll</span> <span class="o">=</span> <span class="o">-</span><span class="n">log_gmm</span> <span class="o">+</span> <span class="n">eos_nll</span>  <span class="c1"># [B, T]
</span>    <span class="k">if</span> <span class="n">stroke_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
        <span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">sequence_mask</span><span class="p">(</span><span class="n">stroke_lengths</span><span class="p">,</span> <span class="n">maxlen</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">shape</span><span class="p">(</span><span class="n">y_true</span><span class="p">)[</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">nll</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="n">nll</span> <span class="o">=</span> <span class="n">nll</span> <span class="o">*</span> <span class="n">mask</span>
        <span class="n">denom</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">mask</span><span class="p">),</span> <span class="mf">1.0</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">nll</span><span class="p">)</span> <span class="o">/</span> <span class="n">denom</span>

    <span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">nll</span><span class="p">)</span></code></pre></figure>

</div>
<div class="code-toggle__pane" data-pane="jax">

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span>
    <span class="n">predictions</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>
    <span class="n">targets</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>
    <span class="n">lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span>
    <span class="n">num_mixtures</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">NUM_BIVARIATE_GAUSSIAN_MIXTURE_COMPONENTS</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">:</span>
    <span class="n">nc</span> <span class="o">=</span> <span class="n">num_mixtures</span>
    <span class="n">pi</span><span class="p">,</span> <span class="n">mu1</span><span class="p">,</span> <span class="n">mu2</span><span class="p">,</span> <span class="n">s1</span><span class="p">,</span> <span class="n">s2</span><span class="p">,</span> <span class="n">rho</span><span class="p">,</span> <span class="n">eos_pred</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="p">[</span><span class="n">nc</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">nc</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">nc</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">nc</span><span class="p">,</span> <span class="mi">5</span> <span class="o">*</span> <span class="n">nc</span><span class="p">,</span> <span class="mi">6</span> <span class="o">*</span> <span class="n">nc</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">pi</span> <span class="o">=</span> <span class="n">nnx</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">pi</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">s1</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">s1</span><span class="p">,</span> <span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
    <span class="n">s2</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">s2</span><span class="p">,</span> <span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
    <span class="n">rho</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">nnx</span><span class="p">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">rho</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.95</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">)</span>
    <span class="n">eos_pred</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">nnx</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">eos_pred</span><span class="p">),</span> <span class="mf">1e-8</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="mf">1e-8</span><span class="p">)</span>

    <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">eos</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># major change is we compute log probabilities with better numerical stability
</span>    <span class="n">rho_sq</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">rho</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mf">0.9025</span><span class="p">)</span>
    <span class="n">one_minus_rho_sq</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">rho_sq</span><span class="p">,</span> <span class="mf">1e-6</span><span class="p">)</span>
    <span class="n">norm</span> <span class="o">=</span> <span class="o">-</span><span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="p">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">-</span> <span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">s1</span><span class="p">)</span> <span class="o">-</span> <span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">s2</span><span class="p">)</span> <span class="o">-</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">one_minus_rho_sq</span><span class="p">)</span>

    <span class="n">z1</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mu1</span><span class="p">)</span> <span class="o">/</span> <span class="n">jnp</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">s1</span><span class="p">,</span> <span class="mf">1e-6</span><span class="p">)</span>
    <span class="n">z2</span> <span class="o">=</span> <span class="p">(</span><span class="n">y</span> <span class="o">-</span> <span class="n">mu2</span><span class="p">)</span> <span class="o">/</span> <span class="n">jnp</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">s2</span><span class="p">,</span> <span class="mf">1e-6</span><span class="p">)</span>

    <span class="n">exp_term</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">/</span> <span class="n">one_minus_rho_sq</span> <span class="o">*</span> <span class="p">(</span><span class="n">z1</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="n">z2</span><span class="o">**</span><span class="mi">2</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">rho</span> <span class="o">*</span> <span class="n">z1</span> <span class="o">*</span> <span class="n">z2</span><span class="p">)</span>
    <span class="n">exp_term</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">exp_term</span><span class="p">,</span> <span class="o">-</span><span class="mi">50</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    <span class="n">log_probs</span> <span class="o">=</span> <span class="n">norm</span> <span class="o">+</span> <span class="n">exp_term</span>
    <span class="n">log_pi</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">pi</span><span class="p">,</span> <span class="mf">1e-8</span><span class="p">))</span>
    <span class="n">log_mixture</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">logsumexp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">+</span> <span class="n">log_probs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">eos_loss</span> <span class="o">=</span> <span class="o">-</span><span class="n">jnp</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">eos</span> <span class="o">*</span> <span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">eos_pred</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">eos</span><span class="p">)</span> <span class="o">*</span> <span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">eos_pred</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">loss</span> <span class="o">=</span> <span class="o">-</span><span class="n">log_mixture</span> <span class="o">+</span> <span class="n">eos_loss</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="o">|</span> <span class="n">jnp</span><span class="p">.</span><span class="n">isinf</span><span class="p">(</span><span class="n">loss</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
        <span class="n">mask</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">predictions</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="o">&lt;</span> <span class="n">lengths</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span>
        <span class="n">total_loss</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="o">/</span> <span class="n">jnp</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">mask</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span> <span class="o">|</span> <span class="n">jnp</span><span class="p">.</span><span class="n">isinf</span><span class="p">(</span><span class="n">total_loss</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">total_loss</span><span class="p">)</span>

    <span class="n">mean_loss</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">mean_loss</span><span class="p">)</span> <span class="o">|</span> <span class="n">jnp</span><span class="p">.</span><span class="n">isinf</span><span class="p">(</span><span class="n">mean_loss</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">mean_loss</span><span class="p">)</span></code></pre></figure>

</div>

  </div>
</div>
<!-- prettier-ignore-end -->

<h2 id="attention-mechanism">Attention Mechanism</h2>

<h3 id="theory-4">Theory</h3>

<p>The attention mechanism really only comes into play with the Synthesis Network which sadly <a href="https://www.linkedin.com/in/tom-wilmots-030781a6/">Tom</a> and I never got to in college. The idea (similar to most attention notions) is that we need to tell our model more specifically where to focus. This isn’t like the transformer notion of attention from the famous “Attention is All You Need” paper, but it’s the idea that we have various Gaussians to indicate probabilistically where we should be focusing. We utilize one-hot encoding vectors over our input characters so that we can more clearly identify the numerical representation. So the question we’re basically answering is like “oh, i see a ‘w’ character, generally how far along do we need to write for that?” to help also answer the question of when do we need to terminate.</p>

<p>The mathematical representation is here:</p>

<blockquote>
  <p>Given a length $U$ character sequence $\mathbf{c}$ and a length $T$ data sequence $\mathbf{x}$, the soft window $w_t$ into $\mathbf{c}$ at timestep $t$ ($1 \leq t \leq T$) is defined by the following discrete convolution with a mixture of $K$ Gaussian functions</p>

\[\begin{align}
\phi(t, u) &amp;= \sum_{k=1}^K \alpha^k_t\exp\left(-\beta_t^k\left(\kappa_t^k-u\right)^2\right)\\
w_t &amp;= \sum_{u=1}^U \phi(t, u)c_u
\end{align}\]

  <p>where $\phi(t, u)$ is the \emph{window weight} of $c_u$ at timestep $t$.</p>

  <p>Intuitively, the $\kappa_t$ parameters control the location of the window, the $\beta_t$ parameters control the width of the window and the $\alpha_t$ parameters control the importance of the window within the mixture.</p>

  <p>The size of the soft window vectors is the same as the size of the character vectors $c_u$ (assuming a one-hot encoding, this will be the number of characters in the alphabet).</p>

  <p>Note that the window mixture is not normalised and hence does not determine a probability distribution; however the window weight $\phi(t, u)$ can be loosely interpreted as the network’s belief that it is writing character $c_u$ at time $t$.</p>
</blockquote>

<h3 id="code-4">Code</h3>

<!-- prettier-ignore-start -->

<div class="code-toggle">
  <div class="code-toggle__tabs">
    <button class="code-toggle__tab code-toggle__tab--active" data-tab="tensorflow">TensorFlow</button>
    <button class="code-toggle__tab" data-tab="jax">JAX</button>
  </div>
  <div class="code-toggle__content">
    <div class="code-toggle__pane code-toggle__pane--active" data-pane="tensorflow">

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">@</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">register_keras_serializable</span><span class="p">()</span>
<span class="k">class</span> <span class="nc">AttentionMechanism</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Layer</span><span class="p">):</span>
    <span class="s">"""
    Attention mechanism for the handwriting synthesis model.
    This is a version of the attention mechanism used in
    the original paper by Alex Graves. It uses a Gaussian
    window to focus on different parts of the character sequence
    at each time step.

    See section: 5.0 / 5.1
    """</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_gaussians</span><span class="p">,</span> <span class="n">num_chars</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"attention"</span><span class="p">,</span> <span class="n">debug</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">AttentionMechanism</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_gaussians</span> <span class="o">=</span> <span class="n">num_gaussians</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_chars</span> <span class="o">=</span> <span class="n">num_chars</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">name_mod</span> <span class="o">=</span> <span class="n">name</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">debug</span> <span class="o">=</span> <span class="n">debug</span>

    <span class="k">def</span> <span class="nf">call</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">inputs</span><span class="p">,</span>  <span class="c1"># shape: [batch_size, num_gaussians, 3]
</span>        <span class="n">prev_kappa</span><span class="p">,</span>  <span class="c1"># shape: [batch_size, num_gaussians]
</span>        <span class="n">char_seq_one_hot</span><span class="p">,</span>  <span class="c1"># shape: [batch_size, char_len, num_chars]
</span>        <span class="n">sequence_lengths</span><span class="p">,</span>  <span class="c1"># shape: [batch_size]
</span>    <span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">[</span><span class="n">tf</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">Tensor</span><span class="p">]:</span>
        <span class="n">raw</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">attention_kernel</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">attention_bias</span>
        <span class="n">alpha_hat</span><span class="p">,</span> <span class="n">beta_hat</span><span class="p">,</span> <span class="n">kappa_hat</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">raw</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># shape: [batch_size, num_gaussians, 1]
</span>
        <span class="n">eps</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">constant</span><span class="p">(</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">inputs</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="n">scaling</span> <span class="o">=</span> <span class="mf">0.1</span>  <span class="c1"># Gentler activation
</span>        <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">alpha_hat</span> <span class="o">*</span> <span class="n">scaling</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps</span>  <span class="c1"># [B, G]
</span>        <span class="n">beta</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">beta_hat</span> <span class="o">*</span> <span class="n">scaling</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps</span>  <span class="c1"># [B, G]
</span>        <span class="n">dkap</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">kappa_hat</span> <span class="o">*</span> <span class="n">scaling</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps</span>

        <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">)</span>
        <span class="n">beta</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">beta</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">)</span>
        <span class="n">dkap</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">dkap</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>

        <span class="n">kappa</span> <span class="o">=</span> <span class="n">prev_kappa</span> <span class="o">+</span> <span class="n">dkap</span>
        <span class="n">kappa</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">kappa</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">30.0</span><span class="p">)</span>

        <span class="n">char_len</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">shape</span><span class="p">(</span><span class="n">char_seq_one_hot</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">u</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">char_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
        <span class="n">u</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>  <span class="c1"># shape: [1, 1, char_len]
</span>        <span class="n">u</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">tile</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_gaussians</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>  <span class="c1"># shape: [batch_size, num_gaussians, char_len]
</span>
        <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># shape: [batch_size, num_gaussians, 1]
</span>        <span class="n">beta</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">beta</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># shape: [batch_size, num_gaussians, 1]
</span>        <span class="n">kappa</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">kappa</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># shape: [batch_size, num_gaussians, 1]
</span>
        <span class="n">exponent</span> <span class="o">=</span> <span class="o">-</span><span class="n">beta</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">kappa</span> <span class="o">-</span> <span class="n">u</span><span class="p">)</span>
        <span class="n">exponent</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">exponent</span><span class="p">,</span> <span class="o">-</span><span class="mf">50.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span>
        <span class="n">phi</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">exponent</span><span class="p">)</span>  <span class="c1"># shape: [batch_size, num_gaussians, char_len]
</span>        <span class="n">phi</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">phi</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># Sum over gaussians: [B, L]
</span>
        <span class="n">sequence_mask</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">sequence_mask</span><span class="p">(</span><span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">maxlen</span><span class="o">=</span><span class="n">char_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
        <span class="n">phi</span> <span class="o">=</span> <span class="n">phi</span> <span class="o">*</span> <span class="n">sequence_mask</span>  <span class="c1"># mask paddings
</span>
        <span class="n">phi</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">is_finite</span><span class="p">(</span><span class="n">phi</span><span class="p">),</span> <span class="n">phi</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">phi</span><span class="p">))</span>
        <span class="c1"># we don't normalize here - Graves calls that out specifically!
</span>        <span class="c1"># &gt; Note that the window mixture is not normalised
</span>        <span class="c1"># &gt; and hence does not determine a probability distribution; however the window
</span>        <span class="c1"># &gt; weight φ(t,u) can be loosely interpreted as the network's belief that it is writ-
</span>        <span class="c1"># &gt; ing character cu at time t.
</span>        <span class="c1"># still section 5.1
</span>
        <span class="c1"># window vec
</span>        <span class="n">phi</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">phi</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># shape: [batch_size, char_len, 1]
</span>        <span class="n">w</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">phi</span> <span class="o">*</span> <span class="n">char_seq_one_hot</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># shape: [batch_size, num_chars]
</span>
        <span class="n">w</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">is_finite</span><span class="p">(</span><span class="n">w</span><span class="p">),</span> <span class="n">w</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">w</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">w</span><span class="p">,</span> <span class="n">kappa</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span></code></pre></figure>

</div>
<div class="code-toggle__pane" data-pane="jax">

<figure class="highlight"><pre><code class="language-python" data-lang="python">    <span class="k">def</span> <span class="nf">compute_attention</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">h</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>  <span class="c1"># [B, H]
</span>        <span class="n">prev_kappa</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>  <span class="c1"># [B, G]
</span>        <span class="n">window</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>  <span class="c1"># [B, A]
</span>        <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>  <span class="c1"># [B, 3]
</span>        <span class="n">char_seq</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>  <span class="c1"># [B, U, A] one-hot
</span>        <span class="n">char_lens</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>  <span class="c1"># [B] lengths
</span>    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">]:</span>
        <span class="s">"""Compute Gaussian window attention over character sequence."""</span>

        <span class="n">attention_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">window</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">h</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">params</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">attention_layer</span><span class="p">(</span><span class="n">attention_input</span><span class="p">)</span>  <span class="c1"># [B, 3G]
</span>        <span class="n">params</span> <span class="o">=</span> <span class="n">nnx</span><span class="p">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
        <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">,</span> <span class="n">kappa_inc</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="c1"># again... probably sliiiiightly overkill
</span>        <span class="n">alpha</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="mf">1e-4</span><span class="p">)</span>
        <span class="n">beta</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">beta</span><span class="p">,</span> <span class="mf">1e-4</span><span class="p">)</span>
        <span class="n">kappa_inc</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">kappa_inc</span><span class="p">,</span> <span class="mf">1e-4</span><span class="p">)</span>

        <span class="c1"># ok this was a trick from svasquez - the dividing by 25.0
</span>        <span class="c1"># is to help kappa learn given that 25 is roughly the average
</span>        <span class="c1"># number of strokes per sequence
</span>        <span class="n">kappa</span> <span class="o">=</span> <span class="n">prev_kappa</span> <span class="o">+</span> <span class="n">kappa_inc</span> <span class="o">/</span> <span class="mf">25.0</span>

        <span class="n">U</span> <span class="o">=</span> <span class="n">char_seq</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">positions</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">U</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">jnp</span><span class="p">.</span><span class="n">float32</span><span class="p">)[</span><span class="bp">None</span><span class="p">,</span> <span class="bp">None</span><span class="p">,</span> <span class="p">:]</span>  <span class="c1"># [1, 1, U]
</span>        <span class="n">kappa_exp</span> <span class="o">=</span> <span class="n">kappa</span><span class="p">[:,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">]</span>  <span class="c1"># [B, G, 1]
</span>        <span class="n">alpha_exp</span> <span class="o">=</span> <span class="n">alpha</span><span class="p">[:,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">]</span>  <span class="c1"># [B, G, 1]
</span>        <span class="n">beta_exp</span> <span class="o">=</span> <span class="n">beta</span><span class="p">[:,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">]</span>  <span class="c1"># [B, G, 1]
</span>
        <span class="c1"># gaussian window
</span>        <span class="n">phi</span> <span class="o">=</span> <span class="n">alpha_exp</span> <span class="o">*</span> <span class="n">jnp</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">beta_exp</span> <span class="o">*</span> <span class="p">(</span><span class="n">kappa_exp</span> <span class="o">-</span> <span class="n">positions</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>  <span class="c1"># [B, G, U]
</span>        <span class="n">phi</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">phi</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

        <span class="c1"># mask out positions beyond char_lens
</span>        <span class="n">mask</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">U</span><span class="p">)[</span><span class="bp">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">&lt;</span> <span class="n">char_lens</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]</span>  <span class="c1"># [B, U]
</span>        <span class="n">phi</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">phi</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span>

        <span class="c1"># so Graves said that
</span>        <span class="n">phi</span> <span class="o">=</span> <span class="n">phi</span> <span class="o">/</span> <span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">phi</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span>

        <span class="c1"># Apply to character sequence
</span>        <span class="c1"># window: [B, A] = sum_u phi[b,u]*char_seq[b,u,:]
</span>        <span class="n">window_new</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">einsum</span><span class="p">(</span><span class="s">"bu,bua-&gt;ba"</span><span class="p">,</span> <span class="n">phi</span><span class="p">,</span> <span class="n">char_seq</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">window_new</span><span class="p">,</span> <span class="n">kappa</span></code></pre></figure>

</div>

  </div>
</div>

<!-- prettier-ignore-end -->

<h2 id="stacked-lstm">Stacked LSTM</h2>

<h3 id="theory-5">Theory</h3>

<p>The one distinction between Graves’s setup and a standard LSTM is that Graves uses a <em>cascade</em> of LSTMs. So we use the MDN to generate a probabilistic prediction however our neural network is the cascade of LSTMs.</p>

<p>Per our paper:</p>

<blockquote>
  <p>The LSTM cascade buys us a few different things. As Graves aptly points out, it mitigates the vanishing gradient problem even more greatly than a single LSTM could. This is because of the skip-connections. All hidden layers have access to the input and all hidden layers are also directly connected to the output node. As a result, there are less processing steps from the bottom of the network to the top.</p>
</blockquote>

<p>So it looks something like this:</p>

<p><img src="/images/generative-handwriting/graves_stacked_lstm.png" alt="graves-stacked-lstm" class="center-super-shrink lightbox-image" /></p>

<div class="image-caption">Courtesy of Alex Graves's <a href="https://arxiv.org/abs/1308.0850">paper</a></div>
<p><br /></p>

<p>The one thing to note is that there is a dimensionality increase given we now have these hidden layers. Tom and I broke this down in our paper here:</p>

<blockquote>
  <p>Let’s observe the $x_{t-1}$ input. $h_{t-1}^1$ only has $x_{t-1}$ as its input which is in $\mathbb{R}^3$ because $(x, y, eos)$. However, we also pass our input $x_{t-1}$ into $h_{t-1}^2$. We assume that we simply concatenate the original input and the output of the first hidden layer. Because LSTMs do not scale dimensionality, we know the output is going to be in $\mathbb{R}^3$ as well. Therefore, after this concatenation, the input into the second hidden layer will be in $\mathbb{R}^6$. We can follow this process through and see that, the input to the third hidden layer will be in $\mathbb{R}^9$. Finally, we concatenate all of the LSTM cells (i.e. the hidden layers) together, thus getting a final dimension of $\mathbb{R}^{18}$ fed into our MDN. Note, this is for $m=3$ hidden layers, but more generally, we can observe the relation as</p>

\[\begin{align} \textrm{final dimension} = k \frac{m(m+1)}{2} \end{align}\]
</blockquote>

<p>Here’s my take is that I actually like how I constructed the Tensorflow version more from a composability perspective. I think the code is cleaner. However, c’est la vie.</p>

<h3 id="code-5">Code</h3>

<p>This is where the various <code class="language-plaintext highlighter-rouge">cell</code> vs <code class="language-plaintext highlighter-rouge">layer</code> concept in Tensorflow was very nice.</p>

<p>You can see here how the parts all come together smoothly. The custom RNN cell takes the lstm_cells (which are stacked), and then can basically abstract out and operate on the individual time steps without having to worry about actually introducing another <code class="language-plaintext highlighter-rouge">for</code> loop. This is beneficial because of the batching and GPU win we can get when it eventually becomes time.</p>

<!-- prettier-ignore-start -->

<div class="code-toggle">
  <div class="code-toggle__tabs">
    <button class="code-toggle__tab code-toggle__tab--active" data-tab="tensorflow">TensorFlow</button>
    <button class="code-toggle__tab" data-tab="jax">JAX</button>
  </div>
  <div class="code-toggle__content">
    <div class="code-toggle__pane code-toggle__pane--active" data-pane="tensorflow">

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">@</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">register_keras_serializable</span><span class="p">()</span>
<span class="k">class</span> <span class="nc">DeepHandwritingSynthesisModel</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">Model</span><span class="p">):</span>
    <span class="s">"""
    A similar implementation to the previous model,
    but now we're throwing the good old attention mechanism back into the mix.
    """</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">units</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">NUM_LSTM_CELLS_PER_HIDDEN_LAYER</span><span class="p">,</span>
        <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">NUM_LSTM_HIDDEN_LAYERS</span><span class="p">,</span>
        <span class="n">num_mixture_components</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">NUM_BIVARIATE_GAUSSIAN_MIXTURE_COMPONENTS</span><span class="p">,</span>
        <span class="n">num_chars</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">ALPHABET_SIZE</span><span class="p">,</span>
        <span class="n">num_attention_gaussians</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">NUM_ATTENTION_GAUSSIAN_COMPONENTS</span><span class="p">,</span>
        <span class="n">gradient_clip_value</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="n">GRADIENT_CLIP_VALUE</span><span class="p">,</span>
        <span class="n">enable_mdn_regularization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">False</span><span class="p">,</span>
        <span class="n">debug</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
        <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
    <span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_layers</span> <span class="o">=</span> <span class="n">num_layers</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_mixture_components</span> <span class="o">=</span> <span class="n">num_mixture_components</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_chars</span> <span class="o">=</span> <span class="n">num_chars</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_attention_gaussians</span> <span class="o">=</span> <span class="n">num_attention_gaussians</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">gradient_clip_value</span> <span class="o">=</span> <span class="n">gradient_clip_value</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">enable_mdn_regularization</span> <span class="o">=</span> <span class="n">enable_mdn_regularization</span>
        <span class="c1"># Store LSTM cells as tracked attributes instead of list
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cells</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span>
            <span class="n">cell</span> <span class="o">=</span> <span class="n">LSTMPeepholeCell</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">idx</span><span class="p">)</span>
            <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s">'lstm_cell_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s">'</span><span class="p">,</span> <span class="n">cell</span><span class="p">)</span>  <span class="c1"># Register as tracked attribute
</span>            <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cells</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cell</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">attention_mechanism</span> <span class="o">=</span> <span class="n">AttentionMechanism</span><span class="p">(</span><span class="n">num_gaussians</span><span class="o">=</span><span class="n">num_attention_gaussians</span><span class="p">,</span> <span class="n">num_chars</span><span class="o">=</span><span class="n">num_chars</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">attention_rnn_cell</span> <span class="o">=</span> <span class="n">AttentionRNNCell</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">lstm_cells</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">attention_mechanism</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_chars</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">rnn_layer</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">RNN</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">attention_rnn_cell</span><span class="p">,</span> <span class="n">return_sequences</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mdn_layer</span> <span class="o">=</span> <span class="n">MixtureDensityLayer</span><span class="p">(</span><span class="n">num_mixture_components</span><span class="p">,</span> <span class="n">enable_regularization</span><span class="o">=</span><span class="n">enable_mdn_regularization</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">debug</span> <span class="o">=</span> <span class="n">debug</span>

        <span class="c1"># metrics
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">loss_tracker</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">metrics</span><span class="p">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s">"loss"</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">nll_tracker</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">metrics</span><span class="p">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s">"nll"</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eos_accuracy_tracker</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">metrics</span><span class="p">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s">"eos_accuracy"</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eos_prob_tracker</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">metrics</span><span class="p">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s">"eos_prob"</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">call</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">training</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">tf</span><span class="p">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">tf</span><span class="p">.</span><span class="n">Tensor</span><span class="p">:</span>
        <span class="n">input_strokes</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s">"input_strokes"</span><span class="p">]</span>
        <span class="n">input_chars</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s">"input_chars"</span><span class="p">]</span>
        <span class="n">input_char_lens</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s">"input_char_lens"</span><span class="p">]</span>

        <span class="c1"># one-hot encode the character sequence and set RNN cell attributes
</span>        <span class="n">char_seq_one_hot</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">input_chars</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">num_chars</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">attention_rnn_cell</span><span class="p">.</span><span class="n">char_seq_one_hot</span> <span class="o">=</span> <span class="n">char_seq_one_hot</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">attention_rnn_cell</span><span class="p">.</span><span class="n">char_seq_len</span> <span class="o">=</span> <span class="n">input_char_lens</span>

        <span class="c1"># initial states
</span>        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">shape</span><span class="p">(</span><span class="n">input_strokes</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">initial_states</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">attention_rnn_cell</span><span class="p">.</span><span class="n">get_initial_state</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">input_strokes</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="n">initial_states_list</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">initial_states</span><span class="p">[</span><span class="s">"lstm_0_h"</span><span class="p">],</span>
            <span class="n">initial_states</span><span class="p">[</span><span class="s">"lstm_0_c"</span><span class="p">],</span>
            <span class="n">initial_states</span><span class="p">[</span><span class="s">"lstm_1_h"</span><span class="p">],</span>
            <span class="n">initial_states</span><span class="p">[</span><span class="s">"lstm_1_c"</span><span class="p">],</span>
            <span class="n">initial_states</span><span class="p">[</span><span class="s">"lstm_2_h"</span><span class="p">],</span>
            <span class="n">initial_states</span><span class="p">[</span><span class="s">"lstm_2_c"</span><span class="p">],</span>
            <span class="n">initial_states</span><span class="p">[</span><span class="s">"kappa"</span><span class="p">],</span>
            <span class="n">initial_states</span><span class="p">[</span><span class="s">"w"</span><span class="p">],</span>
        <span class="p">]</span>

        <span class="c1"># then through our RNN (which wraps stacked LSTM cells + attention mechanism)
</span>        <span class="c1"># and then through our MDN layer
</span>        <span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">rnn_layer</span><span class="p">(</span><span class="n">input_strokes</span><span class="p">,</span> <span class="n">initial_state</span><span class="o">=</span><span class="n">initial_states_list</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="n">training</span><span class="p">)</span>
        <span class="n">final_output</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mdn_layer</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">final_output</span></code></pre></figure>

</div>
<div class="code-toggle__pane" data-pane="jax">

<figure class="highlight"><pre><code class="language-python" data-lang="python">    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">inputs</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">,</span>
        <span class="n">char_seq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span>
        <span class="n">char_lens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span>
        <span class="n">initial_state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">RNNState</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span>
        <span class="n">return_state</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">False</span><span class="p">,</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ndarray</span><span class="p">:</span>
        <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">.</span><span class="n">shape</span>

        <span class="k">if</span> <span class="n">initial_state</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">h</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">hidden_size</span><span class="p">),</span> <span class="n">inputs</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
            <span class="n">c</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
            <span class="n">kappa</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">num_attention_gaussians</span><span class="p">),</span> <span class="n">inputs</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
            <span class="n">window</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">alphabet_size</span><span class="p">),</span> <span class="n">inputs</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">initial_state</span><span class="p">.</span><span class="n">hidden</span><span class="p">,</span> <span class="n">initial_state</span><span class="p">.</span><span class="n">cell</span>
            <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span> <span class="o">=</span> <span class="n">initial_state</span><span class="p">.</span><span class="n">kappa</span><span class="p">,</span> <span class="n">initial_state</span><span class="p">.</span><span class="n">window</span>

        <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">x_t</span><span class="p">):</span>
            <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span> <span class="o">=</span> <span class="n">carry</span>
            <span class="n">h_layers</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="n">c_layers</span> <span class="o">=</span> <span class="p">[]</span>

            <span class="c1"># layer1
</span>            <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span><span class="p">:</span>
                <span class="n">layer1_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">window</span><span class="p">,</span> <span class="n">x_t</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">layer1_input</span> <span class="o">=</span> <span class="n">x_t</span>

            <span class="n">h1</span><span class="p">,</span> <span class="n">c1</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cell</span><span class="p">(</span><span class="n">layer1_input</span><span class="p">,</span> <span class="n">h</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">c</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
            <span class="n">h_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">h1</span><span class="p">)</span>
            <span class="n">c_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">c1</span><span class="p">)</span>

            <span class="c1"># layer1 -&gt; attention
</span>            <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span> <span class="ow">and</span> <span class="n">char_seq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="ow">and</span> <span class="n">char_lens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
                <span class="n">window</span><span class="p">,</span> <span class="n">kappa</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">compute_attention</span><span class="p">(</span><span class="n">h1</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="p">,</span> <span class="n">x_t</span><span class="p">,</span> <span class="n">char_seq</span><span class="p">,</span> <span class="n">char_lens</span><span class="p">)</span>

            <span class="c1"># attention -&gt; layer2 and layer3
</span>            <span class="k">for</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">num_layers</span><span class="p">):</span>
                <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">synthesis_mode</span><span class="p">:</span>
                    <span class="n">layer_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">x_t</span><span class="p">,</span> <span class="n">h_layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">window</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="n">layer_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">x_t</span><span class="p">,</span> <span class="n">h_layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

                <span class="n">h_new</span><span class="p">,</span> <span class="n">c_new</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">lstm_cell</span><span class="p">(</span><span class="n">layer_input</span><span class="p">,</span> <span class="n">h</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">],</span> <span class="n">c</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">],</span> <span class="n">layer_idx</span><span class="p">)</span>
                <span class="n">h_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">h_new</span><span class="p">)</span>
                <span class="n">c_layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">c_new</span><span class="p">)</span>

            <span class="n">h_new</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">h_layers</span><span class="p">)</span>
            <span class="n">c_new</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">c_layers</span><span class="p">)</span>

            <span class="c1"># mdn output from final hidden state
</span>            <span class="n">mdn_out</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mdn_layer</span><span class="p">(</span><span class="n">h_layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>  <span class="c1"># [B, 6M+1]
</span>
            <span class="k">return</span> <span class="p">(</span><span class="n">h_new</span><span class="p">,</span> <span class="n">c_new</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="p">),</span> <span class="n">mdn_out</span>

        <span class="c1"># this was the major unlock for JAX performance
</span>        <span class="c1"># it allows us to vectorize the computation over the time dimension
</span>        <span class="c1"># transpose inputs from [B, T, 3] to [T, B, 3] for scan
</span>        <span class="n">inputs_transposed</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="p">),</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="p">),</span> <span class="n">inputs_transposed</span><span class="p">)</span>

        <span class="c1"># transpose back
</span>        <span class="n">outputs</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">return_state</span><span class="p">:</span>
            <span class="n">final_state</span> <span class="o">=</span> <span class="n">RNNState</span><span class="p">(</span><span class="n">hidden</span><span class="o">=</span><span class="n">h</span><span class="p">,</span> <span class="n">cell</span><span class="o">=</span><span class="n">c</span><span class="p">,</span> <span class="n">kappa</span><span class="o">=</span><span class="n">kappa</span><span class="p">,</span> <span class="n">window</span><span class="o">=</span><span class="n">window</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">final_state</span>

        <span class="k">return</span> <span class="n">outputs</span></code></pre></figure>

</div>

  </div>
</div>

<!-- prettier-ignore-end -->

<h2 id="final-result">Final Result</h2>

<p>Alright finally! So what do we have, and what can we do now?</p>

<p>We now are going to feed the output from our LSTM cascade into the GMM in order to build a probabilistic prediction model for the next stroke. The GMM will then be fed the actual next point, in order to create some idea of the deviation os that the loss can be properly minimized.</p>

<h1 id="️-training-results">🏋️ Training Results</h1>

<h2 id="vast-ai-gpu-enabled-execution">Vast AI GPU Enabled Execution</h2>

<details>
  <summary style="background-color: #d4edda; padding: 10px; border-radius: 5px; cursor: pointer; color: #155724; font-weight: bold;">
    Vast AI GPU Enabled Running
  </summary>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>

2024-04-21 19:01:02.183969: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
train data found. Loading...
test data found. Loading...
valid2 data found. Loading...
valid1 data found. Loading...
2024-04-21 19:01:04.798925: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22455 MB memory: -&gt; device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:82:00.0, compute capability: 8.6
2024-04-21 19:01:05.887036: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-04-21 19:01:05.887070: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.
2024-04-21 19:01:05.887164: I external/local_xla/xla/backends/profiler/gpu/cupti_tracer.cc:1239] Profiler found 1 GPUs
2024-04-21 19:01:05.917572: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.
2024-04-21 19:01:05.917763: I external/local_xla/xla/backends/profiler/gpu/cupti_tracer.cc:1364] CUPTI activity buffer flushed
Epoch 1/10000
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1713726072.109654 2329 service.cc:145] XLA service 0x7ad5bc004600 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1713726072.109731 2329 service.cc:153] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2024-04-21 19:01:12.346749: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
W0000 00:00:1713726072.691839 2329 assert_op.cc:38] Ignoring Assert operator assert_greater/Assert/AssertGuard/Assert
W0000 00:00:1713726072.694098 2329 assert_op.cc:38] Ignoring Assert operator assert_greater_1/Assert/AssertGuard/Assert
W0000 00:00:1713726072.696267 2329 assert_op.cc:38] Ignoring Assert operator assert_near/Assert/AssertGuard/Assert
2024-04-21 19:01:13.095183: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8906
2024-04-21 19:01:14.883021: W external/local_xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below 17.97GiB (19297974672 bytes) by rematerialization; only reduced to 20.51GiB (22027581828 bytes), down from 20.67GiB (22193496744 bytes) originally
I0000 00:00:1713726076.329853 2329 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
167/168 ━━━━━━━━━━━━━━━━━━━━ 0s 685ms/step - loss: 2.8113W0000 00:00:1713726191.427557 2333 assert_op.cc:38] Ignoring Assert operator assert_greater/Assert/AssertGuard/Assert
W0000 00:00:1713726191.429182 2333 assert_op.cc:38] Ignoring Assert operator assert_greater_1/Assert/AssertGuard/Assert
W0000 00:00:1713726191.430622 2333 assert_op.cc:38] Ignoring Assert operator assert_near/Assert/AssertGuard/Assert
2024-04-21 19:03:13.488256: W external/local_xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below 17.97GiB (19298282069 bytes) by rematerialization; only reduced to 19.75GiB (21203023676 bytes), down from 19.87GiB (21340423652 bytes) originally
168/168 ━━━━━━━━━━━━━━━━━━━━ 0s 709ms/step - loss: 2.8097
Epoch 1: Saving model.

Epoch 1: Loss improved from None to 0.0, saving model.
Model parameters after the 1st epoch:
Model: "deep_handwriting_synthesis_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ lstm_peephole_cell │ ? │ 764,400 │
│ (LSTMPeepholeCell) │ │ │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ lstm_peephole_cell_1 │ ? │ 1,404,400 │
│ (LSTMPeepholeCell) │ │ │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ lstm_peephole_cell_2 │ ? │ 1,404,400 │
│ (LSTMPeepholeCell) │ │ │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ attention (AttentionMechanism) │ ? │ 14,310 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ attention_rnn_cell │ ? │ 3,587,510 │
│ (AttentionRNNCell) │ │ │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ rnn (RNN) │ ? │ 3,587,510 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ mdn (MixtureDensityLayer) │ ? │ 48,521 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 7,272,064 (27.74 MB)
Trainable params: 3,636,031 (13.87 MB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 3,636,033 (13.87 MB)

All parameters:

[[lstm_peephole_kernel1]] shape: (76, 1600)
[[lstm_peephole_recurrent_kernel1]] shape: (400, 1600)
[[lstm_peephole_weights1]] shape: (400, 3)
[[lstm_peephole_bias1]] shape: (1600,)
[[lstm_peephole_kernel2]] shape: (476, 1600)
[[lstm_peephole_recurrent_kernel2]] shape: (400, 1600)
[[lstm_peephole_weights2]] shape: (400, 3)
[[lstm_peephole_bias2]] shape: (1600,)
[[lstm_peephole_kernel3]] shape: (476, 1600)
[[lstm_peephole_recurrent_kernel3]] shape: (400, 1600)
[[lstm_peephole_weights3]] shape: (400, 3)
[[lstm_peephole_bias3]] shape: (1600,)
[[kernel]] shape: (476, 30)
[[bias]] shape: (30,)
[[mdn_W_pi]] shape: (400, 20)
[[mdn_W_mu]] shape: (400, 40)
[[mdn_W_sigma]] shape: (400, 40)
[[mdn_W_rho]] shape: (400, 20)
[[mdn_W_eos]] shape: (400, 1)
[[mdn_b_pi]] shape: (20,)
[[mdn_b_mu]] shape: (40,)
[[mdn_b_sigma]] shape: (40,)
[[mdn_b_rho]] shape: (20,)
[[mdn_b_eos]] shape: (1,)

Trainable parameters:

(same here)

Trainable parameter count:

3636031
168/168 ━━━━━━━━━━━━━━━━━━━━ 133s 728ms/step - loss: 2.7931
Epoch 2/10000
60/168 ━━━━━━━━━━━━━━━━━━━━ 1:14 686ms/step - loss: 2.4870

</code></pre></div></div>

</details>

<p><br /></p>

<p>Ok so that’s all well and good and some fun math and neural network construction, but the meat of this project is about what we’re actually building with this theory. So let’s lay out our to-do list.</p>

<h3 id="problem-1---gradient-explosion-problem">Problem #1 - Gradient Explosion Problem</h3>

<p>Somehow on my first run through of this, I was still getting explodient gradients in the later stages of training my model.</p>

<p>As a result, I chose the laborious and time consuming process to run the training model on CPU so that I could print out debugging information and then run <code class="language-plaintext highlighter-rouge">tensorboard</code>’s Debugger model so I could inspect which gradients were exploding to <code class="language-plaintext highlighter-rouge">nan</code> or dreaded <code class="language-plaintext highlighter-rouge">inf</code>.</p>

<p>Here’s an example of what that looked like:</p>

<p><img src="/images/generative-handwriting/tensorboard-debugging.png" alt="tensorboard" class="center-shrink lightbox-image" /></p>

<p>Which was even more annoying because of this: <a href="https://github.com/tensorflow/tensorflow/issues/59215">https://github.com/tensorflow/tensorflow/issues/59215</a> issue.</p>

<h3 id="problem-2---oom-galore">Problem #2 - OOM Galore</h3>

<p>Uh oh, looks like the <code class="language-plaintext highlighter-rouge">vast.ai</code> instance I utilized didn’t have enough memory. Here is an example of one of the errors I ran into:</p>

<details>
  <summary style="background-color: #f8d7da; padding: 10px; border-radius: 5px; cursor: pointer; color: #721c24; font-weight: bold;">
    Out of memory error here
  </summary>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Out of memory while trying to allocate 22271409880 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   29.54MiB
              constant allocation:         4B
        maybe_live_out allocation:   27.74MiB
     preallocated temp allocation:   20.74GiB
                 total allocation:   20.77GiB
Peak buffers:
        Buffer 1:
                Size: 3.40GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,476,1600]
                ==========================

        Buffer 2:
                Size: 3.40GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,476,1600]
                ==========================

        Buffer 3:
                Size: 2.86GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul_1/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,400,1600]
                ==========================

        Buffer 4:
                Size: 2.86GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul_1/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,400,1600]
                ==========================

        Buffer 5:
                Size: 2.86GiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_2_1/MatMul_1/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,400,1600]
                ==========================

        Buffer 6:
                Size: 556.64MiB
                Operator: op_type="EmptyTensorList" op_name="gradient_tape/deep_handwriting_synthesis_model_1/rnn_1/while/deep_handwriting_synthesis_model_1/rnn_1/while/attention_rnn_cell_1/lstm_peephole_cell_1/MatMul/ReadVariableOp_0/accumulator" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,76,1600]
                ==========================

        Buffer 7:
                Size: 219.73MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,10,75]
                ==========================

        Buffer 8:
                Size: 219.73MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,10,75]
                ==========================

        Buffer 9:
                Size: 219.73MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,10,75]
                ==========================

        Buffer 10:
                Size: 139.45MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,476]
                ==========================

        Buffer 11:
                Size: 139.45MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,476]
                ==========================

        Buffer 12:
                Size: 139.45MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,476]
                ==========================

        Buffer 13:
                Size: 117.19MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,400]
                ==========================

        Buffer 14:
                Size: 117.19MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,400]
                ==========================

        Buffer 15:
                Size: 117.19MiB
                Operator: op_type="While" op_name="deep_handwriting_synthesis_model_1/rnn_1/while" source_file="/root/code/venv/lib/python3.11/site-packages/tensorflow/python/framework/ops.py" source_line=1177
                XLA Label: fusion
                Shape: f32[1200,64,400]
                ==========================


         [[]]

</code></pre></div></div>

</details>

<p><br /></p>

<h3 id="sanity-check---validating-model-dimensions-with-ai-so-somewhat">Sanity Check - Validating Model Dimensions (with AI… so somewhat)</h3>

<p>So where does AI using AI come in? I wanted to validate that the shapes of my training parameters all looked good. Again, I print out the shapes on the very first epoch to get some more details and do a quick dimensionality alignment check. I then fed that into ChatGPT along with the paper and asked it to double check all my params. Here was it’s output:</p>

<details>
  <summary style="background-color: #d4edda; padding: 10px; border-radius: 5px; cursor: pointer; color: #155724; font-weight: bold;">
    Model Summary
  </summary>
<blockquote><h3 id="model-summary">Model Summary</h3><p><strong>Model Name:</strong> deep_handwriting_synthesis_model<br /> <strong>Total Parameters:</strong> 7,272,063 (27.74 MB)<br /> <strong>Trainable Parameters:</strong> 3,636,031 (13.87 MB)<br /> <strong>Non-Trainable Parameters:</strong> 0 (0.00 B)<br /> <strong>Optimizer Parameters:</strong> 3,636,032 (13.87 MB)</p><hr /><h3 id="layer-wise-breakdown">Layer-wise Breakdown</h3><h4 id="1-lstm-peephole-cells">1. LSTM Peephole Cells</h4><p>You have three LSTM peephole cells:</p><ul><li><code class="language-plaintext highlighter-rouge">lstm_peephole_cell</code></li><li><code class="language-plaintext highlighter-rouge">lstm_peephole_cell_1</code></li><li><code class="language-plaintext highlighter-rouge">lstm_peephole_cell_2</code></li></ul><h5 id="parameters-and-shapes">Parameters and Shapes</h5><p>For each LSTM peephole cell:</p><ul><li><strong>Kernel Shape:</strong> <code class="language-plaintext highlighter-rouge">(input_dim, 4 * units)</code></li><li><strong>Recurrent Kernel Shape:</strong> <code class="language-plaintext highlighter-rouge">(units, 4 * units)</code></li><li><strong>Peephole Weights Shape:</strong> <code class="language-plaintext highlighter-rouge">(units, 3)</code></li><li><strong>Bias Shape:</strong> <code class="language-plaintext highlighter-rouge">(4 * units,)</code></li></ul><h5 id="shapes">Shapes:</h5><p><strong>First LSTM Peephole Cell (lstm_peephole_cell):</strong></p><ul><li><strong>Input Dimension:</strong> 76</li><li><strong>Units:</strong> 400</li><li><strong>Kernel Shape:</strong> <code class="language-plaintext highlighter-rouge">(76, 1600)</code> &nbsp;&nbsp;&nbsp;(76, 4 * 400)</li><li><strong>Recurrent Kernel Shape:</strong> <code class="language-plaintext highlighter-rouge">(400, 1600)</code> &nbsp;&nbsp;&nbsp;(400, 4 * 400)</li><li><strong>Peephole Weights Shape:</strong> <code class="language-plaintext highlighter-rouge">(400, 3)</code> &nbsp;&nbsp;&nbsp;(400 units, 3 gates)</li><li><strong>Bias Shape:</strong> <code class="language-plaintext highlighter-rouge">(1600,)</code> &nbsp;&nbsp;&nbsp;(4 * 400,)</li></ul><p><strong>Second and Third LSTM Peephole Cells (lstm_peephole_cell_1 and lstm_peephole_cell_2):</strong></p><ul><li><strong>Input Dimension:</strong> 476<br /> This includes concatenated inputs from previous layers and attention outputs.</li><li><strong>Units:</strong> 400</li><li><strong>Kernel Shape:</strong> <code class="language-plaintext highlighter-rouge">(476, 1600)</code> &nbsp;&nbsp;&nbsp;(476, 4 * 400)</li><li><strong>Recurrent Kernel Shape:</strong> <code class="language-plaintext highlighter-rouge">(400, 1600)</code> &nbsp;&nbsp;&nbsp;(400, 4 * 400)</li><li><strong>Peephole Weights Shape:</strong> <code class="language-plaintext highlighter-rouge">(400, 3)</code> &nbsp;&nbsp;&nbsp;(400 units, 3 gates)</li><li><strong>Bias Shape:</strong> <code class="language-plaintext highlighter-rouge">(1600,)</code> &nbsp;&nbsp;&nbsp;(4 * 400,)</li></ul><h5 id="conclusion">Conclusion</h5><p>The shapes of the LSTM peephole cells’ parameters are correct and align with standard LSTM architectures with peephole connections.</p><hr /><h4 id="2-attention-mechanism">2. Attention Mechanism</h4><ul><li><strong>Layer:</strong> <code class="language-plaintext highlighter-rouge">attention</code> (AttentionMechanism)</li></ul><h5 id="parameters-and-shapes-1">Parameters and Shapes</h5><ul><li><strong>Kernel Shape:</strong> <code class="language-plaintext highlighter-rouge">(476, 30)</code></li><li><strong>Bias Shape:</strong> <code class="language-plaintext highlighter-rouge">(30,)</code></li></ul><p><strong>Explanation:</strong></p><ul><li><strong>Input Dimension:</strong> 476<br /> This likely comes from concatenating the window vector <code class="language-plaintext highlighter-rouge">w</code>, the current input <code class="language-plaintext highlighter-rouge">x_t</code>, and the output from the first LSTM layer <code class="language-plaintext highlighter-rouge">s1_out</code>.</li><li><strong>Output Dimension:</strong> 30<br /> Corresponds to parameters needed for the attention mechanism, such as the attention weights.</li></ul><h5 id="conclusion-1">Conclusion</h5><p>The attention mechanism’s parameter shapes are appropriate.</p><hr /><h4 id="3-mixture-density-network-mdn-layer">3. Mixture Density Network (MDN) Layer</h4><ul><li><strong>Layer:</strong> <code class="language-plaintext highlighter-rouge">mdn</code> (MixtureDensityLayer)</li></ul><h5 id="parameters-and-shapes-2">Parameters and Shapes</h5><ul><li><strong>mdn_W_pi:</strong> <code class="language-plaintext highlighter-rouge">(400, 20)</code></li><li><strong>mdn_b_pi:</strong> <code class="language-plaintext highlighter-rouge">(20,)</code></li><li><strong>mdn_W_mu:</strong> <code class="language-plaintext highlighter-rouge">(400, 40)</code></li><li><strong>mdn_b_mu:</strong> <code class="language-plaintext highlighter-rouge">(40,)</code></li><li><strong>mdn_W_sigma:</strong> <code class="language-plaintext highlighter-rouge">(400, 40)</code></li><li><strong>mdn_b_sigma:</strong> <code class="language-plaintext highlighter-rouge">(40,)</code></li><li><strong>mdn_W_rho:</strong> <code class="language-plaintext highlighter-rouge">(400, 20)</code></li><li><strong>mdn_b_rho:</strong> <code class="language-plaintext highlighter-rouge">(20,)</code></li><li><strong>mdn_W_eos:</strong> <code class="language-plaintext highlighter-rouge">(400, 1)</code></li><li><strong>mdn_b_eos:</strong> <code class="language-plaintext highlighter-rouge">(1,)</code></li></ul><p><strong>Explanation:</strong></p><ul><li><strong>Hidden Units from Last LSTM Layer:</strong> 400</li><li><strong>Number of Mixture Components:</strong> 20</li></ul><h5 id="mdn-outputs">MDN Outputs:</h5><ul><li><strong>pi</strong> (mixture weights): 20 components</li><li><strong>mu</strong> (means): 2 coordinates per component * 20 components = 40 outputs</li><li><strong>sigma</strong> (standard deviations): 2 per component * 20 components = 40 outputs</li><li><strong>rho</strong> (correlation coefficients): 20 components</li><li><strong>eos</strong> (end-of-stroke probability): 1 output</li></ul><h5 id="parameter-shapes">Parameter Shapes:</h5><ul><li><strong>Weights:</strong></li><li><code class="language-plaintext highlighter-rouge">mdn_W_pi</code>: <code class="language-plaintext highlighter-rouge">(400, 20)</code></li><li><code class="language-plaintext highlighter-rouge">mdn_W_mu</code>: <code class="language-plaintext highlighter-rouge">(400, 40)</code></li><li><code class="language-plaintext highlighter-rouge">mdn_W_sigma</code>: <code class="language-plaintext highlighter-rouge">(400, 40)</code></li><li><code class="language-plaintext highlighter-rouge">mdn_W_rho</code>: <code class="language-plaintext highlighter-rouge">(400, 20)</code></li><li><code class="language-plaintext highlighter-rouge">mdn_W_eos</code>: <code class="language-plaintext highlighter-rouge">(400, 1)</code></li><li><strong>Biases:</strong></li><li><code class="language-plaintext highlighter-rouge">mdn_b_pi</code>: <code class="language-plaintext highlighter-rouge">(20,)</code></li><li><code class="language-plaintext highlighter-rouge">mdn_b_mu</code>: <code class="language-plaintext highlighter-rouge">(40,)</code></li><li><code class="language-plaintext highlighter-rouge">mdn_b_sigma</code>: <code class="language-plaintext highlighter-rouge">(40,)</code></li><li><code class="language-plaintext highlighter-rouge">mdn_b_rho</code>: <code class="language-plaintext highlighter-rouge">(20,)</code></li><li><code class="language-plaintext highlighter-rouge">mdn_b_eos</code>: <code class="language-plaintext highlighter-rouge">(1,)</code></li></ul><h5 id="conclusion-2">Conclusion</h5><p>The MDN layer’s parameter shapes correctly correspond to the required outputs for a mixture of 20 bivariate Gaussian components.</p><hr /><h4 id="4-total-parameters-and-trainable-parameters">4. Total Parameters and Trainable Parameters</h4><ul><li><strong>Total Parameters:</strong> 7,272,063</li><li><strong>Trainable Parameters:</strong> 3,636,031</li></ul><p><strong>Observation:</strong></p><ul><li>The total parameters are approximately double the trainable parameters.</li><li>This may be due to counting both the weights and optimizer states (like moments in the Adam optimizer).</li><li><strong>Optimizer Parameters:</strong> 3,636,032 (slightly off by one, which could be due to rounding or reporting).</li></ul><h5 id="conclusion-3">Conclusion</h5><p>The parameter counts make sense, and the discrepancy is acceptable.</p><hr /><h3 id="your-models-parameters">Your Model’s Parameters</h3><h4 id="all-parameters">All Parameters</h4><p>All parameters listed are correctly shaped and correspond to their respective layers.</p><h4 id="trainable-parameters">Trainable Parameters</h4><p>All trainable parameters match the parameters in the model and have the correct shapes.</p><hr /><h3 id="answering-your-question">Answering Your Question</h3><p>Yes, based on the model summary and the parameter shapes you’ve provided, everything appears to be in order. The shapes align with the expected dimensions &gt; given your model’s architecture:</p><ul><li><strong>LSTM Peephole Cells:</strong></li><li>Input dimensions and units are consistent.</li><li>Kernels and recurrent kernels have correct shapes.</li><li><p>Peephole weights and biases are appropriately sized.</p></li><li><strong>Attention Mechanism:</strong></li><li><p>The input and output dimensions match the concatenated inputs and expected output size.</p></li><li><strong>MDN Layer:</strong></li><li><p>The weights and biases correspond correctly to the number of mixture components and the outputs required for the mixture density network.</p></li><li><strong>Parameter Counts:</strong></li><li>The total number of parameters is reasonable for a model of this complexity.</li><li>The trainable parameters are correctly identified.</li></ul><hr /><h3 id="final-thoughts">Final Thoughts</h3><ul><li><p><strong>Proceed with Training:</strong> You can confidently continue running your training code.</p></li><li><strong>Monitoring:</strong> Keep an eye on training metrics such as loss, especially since you have confirmed that negative loss values are acceptable in your MDN context.</li><li><strong>Validation:</strong> Consider using a validation set to monitor overfitting and generalization performance.</li><li><strong>Performance Checks:</strong> If you encounter any unexpected behavior during training, revisit the data preprocessing steps and ensure that the inputs to your model are correctly shaped and normalized.</li></ul><hr /><h3 id="additional-recommendations">Additional Recommendations</h3><ul><li><strong>Logging:</strong> Utilize TensorBoard or similar tools to visualize training progress.</li><li><strong>Batch Size and Sequence Length:</strong> Ensure that your batch sizes and sequence lengths are manageable to prevent memory issues, as discussed earlier.</li><li><strong>Model Saving:</strong> Regularly save model checkpoints to avoid losing progress in case of interruptions.</li></ul></blockquote>
</details>

<p><br /></p>

<h1 id="️-visualizations">✍️ Visualizations</h1>

<h2 id="learning-with-dummy-data">Learning with Dummy Data</h2>

<p>Again, we used dummy data to start with to ensure our various components were learning and converging correctly.</p>

<p>I’m not going to burn too many pixels with these visualizations given I think they’re less interesting.</p>

<p>Here is our entire network and just sampling from the means (not showing the mixture densities) across the entire example datasets. One thing to note here if you can see how the LSTMs can still handle this type of larger contexts. Again, it pales in comparison to modern day transformer context, but still impressive.</p>

<p><img src="/images/generative-handwriting/viz/loop_epoch200_mixtures5.gif" alt="handwriting_loop_lstm_simple" class="center-shrink" /></p>

<p><img src="/images/generative-handwriting/viz/zigzag_epoch200_mixtures5.gif" alt="handwriting_zig_lstm_simple" class="center-shrink" /></p>

<h2 id="synthesis-model-sampling">Synthesis Model Sampling</h2>

<p>So again, given the above information, $\phi(t, u)$ represents the networks belief that it’s writing character $c_u$ at time $t$. It’s monotonically increasing (which makes sense and is enforced mathematically) and we can see its pretty stepwise increasing.</p>

<p>One of my favorite portions of these visualizations is the <strong>mixture components weights</strong>. You can see the various Gaussians activating for different parts of the synthesis network. For example, for end of stroke signals, we have separate Gaussians owning that portion of the model.</p>

<p>Most of these were generated like so:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="err">╭─</span><span class="n">johnlarkin</span><span class="o">@</span><span class="n">Mac</span> <span class="o">~/</span><span class="n">Documents</span><span class="o">/</span><span class="n">coding</span><span class="o">/</span><span class="n">generative</span><span class="o">-</span><span class="n">handwriting</span><span class="o">-</span><span class="n">jax</span> <span class="err">‹</span><span class="n">main</span><span class="o">*</span><span class="err">›</span>
<span class="err">╰─➤</span>  <span class="n">uv</span> <span class="n">run</span> <span class="n">python</span> <span class="n">generative_handwriting</span><span class="o">/</span><span class="n">generate</span><span class="o">/</span><span class="n">generate_handwriting_cpu</span><span class="p">.</span><span class="n">py</span> \
    <span class="o">--</span><span class="n">checkpoint</span> <span class="s">"checkpoints_saved/synthesis/loss_-2.59/checkpoint_216_cpu.pkl"</span> \
    <span class="o">--</span><span class="n">text</span> <span class="s">"It has to be symphonic"</span> \
    <span class="o">--</span><span class="n">bias</span> <span class="s">"0.75"</span> \
    <span class="o">--</span><span class="n">temperature</span> <span class="s">"0.75"</span> \
    <span class="o">--</span><span class="n">fps</span> <span class="s">"60"</span> \
    <span class="o">--</span><span class="n">formats</span> <span class="s">"all"</span> \
    <span class="o">--</span><span class="n">seed</span> <span class="s">"42"</span>
</code></pre></div></div>

<p>Another note is… my termination condition logic probably could be improved. Remember, we’re doing one-hot encoding which includes the null terminator. So the null term should be at len(line_text). Attention spans the full sequence. Specifically, $\phi$ has shape <code class="language-plaintext highlighter-rouge">[batch, char_seq_length]</code> so we can get our single sample (i.e. batch of 0), and then look at the char sequence length to basically see where our attention is at. In code speak, here’s what I’m doing:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>        # char_seq includes null terminator at index len(line_text)
        if phi is not None and t &gt;= len(line_text) * 2:
            char_idx = int(jnp.argmax(phi[0]))
            sampled_eos = stroke[2] &gt; 0.5

            # we can stop when:
            # 1. attention has reached the null terminator (char_idx == len(line_text)) AND we sampled EOS
            # 2. attention weight on null terminator is dominant (&gt; 0.5)
            # 3. we're well past the text and sampled EOS multiple times
            null_attention = float(phi[0][len(line_text)]) if len(phi[0]) &gt; len(line_text) else 0.0

            if char_idx == len(line_text) and sampled_eos:
                # this hits most
                break
            elif null_attention &gt; 0.5:
                # attention strongly focused on null terminator
                break
            elif char_idx &gt;= len(line_text) and t &gt; len(line_text) * 10:
                # failsafe: past text and generated way too much
                break
</code></pre></div></div>

<p>Finally, on the visualization front, I’m generating everything with bias 0.75 and temperature 0.75. I’m not going to discuss those, but the original paper goes into more detail.</p>

<hr />

<h3 id="heart-has-its-reasons">Heart has its reasons</h3>

<div class="featured-quote">
  <p class="featured-quote__text">The heart has its reasons which reason knows nothing of</p>
  <p class="featured-quote__attribution">
    <span class="featured-quote__author">Blaise Pascal</span>, <span class="featured-quote__source">Pensées</span>
  </p>
</div>

<p>One thing to note is that we are still constrained by line length. For example, if we try to specify this as a single line, we start to lose our attention and the context starts to fail. Part of this is that if we exceed the line length that we trained on (in terms of stroke sequences or input text length), then we start to flail.</p>

<p>So note the discrepancy between these two when we introduce a line break:</p>

<p><img src="/images/generative-handwriting/synth_outputs/heart_has_its_reason/mdn_aggregate.png" alt="heart-mdn-aggregate" class="center-small lightbox-image" /></p>

<p>vs</p>

<p><img src="/images/generative-handwriting/synth_outputs/heart_has_its_reason_single/mdn_aggregate.png" alt="heart-oneliner-mdn-aggregate" class="center-small lightbox-image" /></p>

<p>You can see how the model is less trained given the higher deviations towards the end of the line. Note, these MDN heatmap graphs on the bottom are created by showing the three highest weighted $\pi$ components per timestamp and then aggregating them across all timestamps.</p>

<p>Furthermore, the <code class="language-plaintext highlighter-rouge">eos</code> signals generally have the highest uncertainty and most spread out sigmas which makes sense given it’s the highest variable point.</p>

<hr />

<h3 id="loved-and-lost">Loved and lost</h3>

<div class="featured-quote">
  <p class="featured-quote__text">Better to have loved and lost than never to have loved at all</p>
  <p class="featured-quote__attribution">
    <span class="featured-quote__author">Lord Alfred Tennyson</span>, <span class="featured-quote__source">In Memoriam A. H. H.</span>
  </p>
</div>

<p><img src="/images/generative-handwriting/synth_outputs/loved_and_lost/writing_colored.gif" alt="better-to-have-loved-writing" class="center-small lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/loved_and_lost/writing_cleansed.png" alt="better-to-have-loved-mdn" class="basic-center lightbox-image" /></p>

<hr />

<h3 id="it-has-to-be-symphonic">It has to be symphonic</h3>

<div class="featured-quote">
  <p class="featured-quote__text">It has to be symphonic</p>
  <p class="featured-quote__attribution">
    <span class="featured-quote__author">Andrew Zimmern</span>, <span class="featured-quote__source">Takeaway (The Potash Twins ft. Andrew Zimmern)</span>
  </p>
</div>

<p><img src="/images/generative-handwriting/synth_outputs/symphonic/writing_colored.gif" alt="symphonic-writing" class="basic-center lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/symphonic/sampling.gif" alt="symphonic-sampling" class="center-small lightbox-image" /></p>

<hr />

<h3 id="is-a-model-a-lie">Is a model a lie?</h3>

<div class="featured-quote">
  <p class="featured-quote__text">A model is a lie that helps you see the truth</p>
  <p class="featured-quote__attribution">
    <span class="featured-quote__author">Howard Skipper</span>, <span class="featured-quote__source">requoted by Siddhartha Mukherjee in "The Emperor of All Maladies"</span>
  </p>
</div>

<p><img src="/images/generative-handwriting/synth_outputs/model_lie/writing_cleansed.png" alt="model-lie-writing" class="basic-center lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/model_lie/writing_colored.gif" alt="model-lie-writing-colored" class="center-small lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/model_lie/mdn_aggregate.png" alt="model-lie-mdn" class="center-small lightbox-image" /></p>

<hr />

<h3 id="fish-folly">Fish folly</h3>

<p>Hadn’t ever heard of this one but it’s my best friend’s favorite quote.</p>

<div class="featured-quote">
  <p class="featured-quote__text">Folly to love a fish. Or anyone who might leave us. But oh, what a gift.</p>
  <p class="featured-quote__attribution">
    <span class="featured-quote__author">Ann V. Klotz</span>, <span class="featured-quote__source">I Think on Thee, Dear Friend</span>
  </p>
</div>

<p><img src="/images/generative-handwriting/synth_outputs/fish_folly/writing_cleansed.png" alt="fish-folly-attention" class="basic-center lightbox-image" /></p>

<p><img src="/images/generative-handwriting/synth_outputs/fish_folly/attention.png" alt="fish-folly-attention" class="center-small lightbox-image" /></p>

<h1 id="conclusion">Conclusion</h1>

<p>This - again - was a bit of a bear of a project. It was maybe not my best use of time, but it was a labor of love.</p>

<p>I don’t think I’ll embark on a project of this nature in awhile (sadly). However, I hoped the reader has enjoyed. And feel free to pull the code and dive in yourself.</p>

<p>When re-reading my old draft blog post, I liked the way I ended things. So here it is:</p>

<blockquote>
  <p>Finally, I want to leave with a quote from our academic advisor <a href="https://mzucker.github.io/">Matt Zucker</a>. When I asked him when we know that our model is good enough, he responded with the following.</p>

  <blockquote>
    <p>“Learning never stops.”</p>
  </blockquote>
</blockquote>]]></content><author><name>johnlarkin1</name></author><category term="⭐️ Favorites" /><category term="Algorithms" /><category term="Development" /><category term="AI" /><category term="M.L." /><summary type="html"><![CDATA[✍️ Motivating Visualizations]]></summary></entry><entry><title type="html">Launching Scrollz</title><link href="https://johnlarkin1.github.io/2025/launching-scrollz/" rel="alternate" type="text/html" title="Launching Scrollz" /><published>2025-09-16T00:00:00+00:00</published><updated>2025-09-16T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2025/launching-scrollz</id><content type="html" xml:base="https://johnlarkin1.github.io/2025/launching-scrollz/"><![CDATA[<p>I’m happy to announce that as of <strong>September 16th, 2025</strong>, our application <a href="https://apps.apple.com/us/app/scrollz-app/id6745718779">Scrollz App</a> is officially in the iOS app store (and the Android store). With that, it’s out to 175 countries.</p>

<h1 id="context">Context</h1>

<p>We built <a href="https://www.scrollz.co/">Scrollz</a> because we were sick of scrolling through and cluttering our inboxes everyday with various newsletters. We wanted better search, more sophisticated note-taking, sharing, social features, etc. Lots of that we’re still working on. It’s a cool app, a couple of us put a solid chunk of time and energy into it. I’d appreciate a download and a review. Selfishly, I obviously needed to hit my new years resolution of being accepted onto the app store.</p>

<p><img src="/images/scrollz/main.jpg" alt="scrollz" class="center-image lightbox-image" /></p>

<p>Anyways! Enjoy. Open for thoughts, comments, feedback, etc.</p>]]></content><author><name>johnlarkin1</name></author><category term="Scrollz" /><summary type="html"><![CDATA[I’m happy to announce that as of September 16th, 2025, our application Scrollz App is officially in the iOS app store (and the Android store). With that, it’s out to 175 countries.]]></summary></entry><entry><title type="html">Tennis Scorigami</title><link href="https://johnlarkin1.github.io/2025/tennis-scorigami/" rel="alternate" type="text/html" title="Tennis Scorigami" /><published>2025-06-11T00:00:00+00:00</published><updated>2025-06-11T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2025/tennis-scorigami</id><content type="html" xml:base="https://johnlarkin1.github.io/2025/tennis-scorigami/"><![CDATA[<p>This post is going to be focused on discussing how we built our <a href="https://www.tennis-scorigami.com/">Tennis Scorigami</a> project from a technical standpoint. I’ll discuss the current architecture, some of the design decisions I made, and where I want the project to go next.</p>

<p>If you haven’t yet checked out the main site, feel free to here:</p>

<div class="tennis-scorigami-unfurl">
  <a href="https://www.tennis-scorigami.com/" target="_blank" rel="noopener noreferrer">
    <div class="unfurl-container">
      <video autoplay="" loop="" muted="" playsinline="" class="unfurl-video">
        <source src="/videos/tennis-scorigami/hero-section.mp4" type="video/mp4" />
      </video>
      <div class="unfurl-overlay">
        <div class="unfurl-content">
          <h3 class="unfurl-title">Tennis Scorigami</h3>
          <p class="unfurl-domain">tennis-scorigami.com</p>
        </div>
      </div>
    </div>
  </a>
</div>

<style>
.tennis-scorigami-unfurl {
  margin: 2rem 0;
  border-radius: 12px;
  overflow: hidden;
  box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
  transition: transform 0.2s ease, box-shadow 0.2s ease;
}

.tennis-scorigami-unfurl:hover {
  transform: translateY(-2px);
  box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05);
}

.tennis-scorigami-unfurl a {
  text-decoration: none;
  color: inherit;
}

.unfurl-container {
  position: relative;
  width: 100%;
  background: #1a1a2e;
  cursor: pointer;
}

.unfurl-video {
  width: 100%;
  height: auto;
  display: block;
  object-fit: cover;
  max-height: 400px;
}

.unfurl-overlay {
  position: absolute;
  bottom: 0;
  left: 0;
  right: 0;
  background: linear-gradient(to top, rgba(0, 0, 0, 0.9) 0%, rgba(0, 0, 0, 0.7) 50%, transparent 100%);
  padding: 3rem 2rem 1.5rem;
}

.unfurl-content {
  color: white;
  max-width: 600px;
}

.unfurl-title {
  font-size: 1.75rem;
  font-weight: 700;
  margin: 0 0 0.25rem 0;
  color: #ffffff;
}

.unfurl-domain {
  font-size: 1rem;
  margin: 0 0 0.5rem 0;
  color: #a8b2d1;
  font-weight: 400;
}

@media (max-width: 768px) {
  .unfurl-video {
    max-height: 250px;
  }
  
  .unfurl-overlay {
    padding: 2rem 1.5rem 1rem;
  }
  
  .unfurl-title {
    font-size: 1.5rem;
  }
  
  .unfurl-domain {
    font-size: 0.875rem;
  }
  
  .unfurl-date {
    right: 1.5rem;
  }
}
</style>

<!--
# Table of Contents

- [Table of Contents](#table-of-contents)
- [Motivation](#motivation)
  - [More Specific Motivation](#more-specific-motivation)
- [Demo](#demo)
- [Features](#features)
- [Challenges](#challenges)
  - [Data Consolidation](#data-consolidation)
  - [Being Cheap](#being-cheap)
  - [Fetching 108k Nodes](#fetching-108k-nodes)
  - [Rendering 108k Nodes](#rendering-108k-nodes)
    - [2D Sigma Graph](#2d-sigma-graph)
    - [3D Force Graph](#3d-force-graph)
    - [Streaming + NDJSON](#streaming--ndjson)
  - [Unfurl Previews](#unfurl-previews)
- [Surprises](#surprises)
- [Tech Stack](#tech-stack)
- [Engineering + Design](#engineering--design)
- [Other Fun Visualizations](#other-fun-visualizations)
  - [Player Rank History](#player-rank-history)
- [Conclusion](#conclusion)
-->

<h1 id="motivation">Motivation</h1>

<p>Given once again, the impending collapse of my profession due to automation, my lack of time, and the fact that there’s details on the main website, I will try not to repeat myself.</p>

<p>Our motivation here was largely love of tennis, data, and friendship. <a href="https://www.linkedin.com/in/sebastian-hoar-a71a5b112/">Sebastian</a> and <a href="https://www.linkedin.com/in/jebhenryhead/">Henry</a> were chatting in the groupchat about football scorigami, and Seb asked, I wonder if tennis has any scorigamis. And so, began an interesting conversation, and the groupchat started to explore where we could get data, if anyone had done this and all of that.</p>

<h2 id="more-specific-motivation">More Specific Motivation</h2>

<p>More specifically, besides my craving to “wow” my friends, the real pie in the sky goal for us (read: <a href="https://www.linkedin.com/in/johnlarkin/">me</a>), was to get <a href="https://en.wikipedia.org/wiki/Andy_Roddick">Andy Roddick</a> to re-tweet / view this project.</p>

<p><img src="/images/tennis-scorigami/driving-motivation-pt1.png" alt="driving-motivation" class="center-super-shrink" /></p>

<p>Andy Roddick has been one of our favorite tennis players since growing up, and <a href="https://www.linkedin.com/in/jebhenryhead/">Henry</a> and I watched as many of his matches as we could get a hold of. So that target was a stretch New Years Resolution of mine.</p>

<h1 id="demo">Demo</h1>

<p>If you’re too lazy to visit the <a href="https://www.tennis-scorigami.com/explore">website and explore</a>, here’s a demo:</p>

<div class="video-container">
  <div class="video-wrapper-dark">
    <video src="https://www.dropbox.com/scl/fi/96xzymtsl5j7zkbgdh2gu/tennis-scorigami-demo-smaller.mp4?rlkey=da9r4uyuj45wpakl4daffkyqw&amp;st=7txvjr91&amp;raw=1" muted="" autoplay="" loop="" controls="" style="width: 100%; height: auto;">
    </video>
  </div>
</div>

<h1 id="features">Features</h1>

<p>There’s numerous features here that we’re proud of. I’m going to list some of them, and then I’ll discuss them in further detail below:</p>

<ul>
  <li><strong>Graph-Based Score Sequence Modeling</strong>
    <ul>
      <li>this was originally <a href="https://www.linkedin.com/in/jebhenryhead/">Henry’s</a> idea but given how many nodes we have in this search space, it’s not as feasible to just arrange it as a grid like with the NFL</li>
      <li>we thought a graph as a novel (and visually appealing approach)</li>
      <li>what this means technically is that we pre-computed all permutations and did some background processing so that we could store this information in as close to frontend ready format for fast visualization and processing… that gets into our next point</li>
    </ul>
  </li>
  <li><strong>Performance-Optimized Materialized Views</strong>
    <ul>
      <li>we built out specific materialized views to help with the performance (given our existing FE filters) so that we can ensure latency is not noticeable</li>
    </ul>
  </li>
  <li><strong>Streaming Visualizations</strong>
    <ul>
      <li>Still kind of working on setting this up ideally for 5-graph nodes. There’s 125,062 nodes that we need to render for 5 set match permutations.</li>
      <li>I had to turn to <a href="https://apidog.com/blog/ndjson/">NDJSON</a> (basically just newline json chunked) which I had never used before</li>
      <li>This helped reduce both the latency and the incremental memory on receiving that information and parsing it on the FE</li>
    </ul>
  </li>
  <li><strong><a href="https://orm.drizzle.team/">Drizzle</a></strong>
    <ul>
      <li>Ok fair fine you got me. Technically I used both SQLAlchemy and then ported over to Drizzle so that was a little bit of a mess, but I have heard great things about drizzle and liked that experience a lot. <code class="language-plaintext highlighter-rouge">drizzle-kit</code> is very sleek and the latest major release for drizzle is fantastic.</li>
    </ul>
  </li>
  <li><strong>(somewhat?) Decent Testing</strong>
    <ul>
      <li>Yeah, obviously I wouldn’t say I went crazy here, but I did set up <a href="https://jestjs.io/"><code class="language-plaintext highlighter-rouge">jest</code></a>, <a href="https://playwright.dev/"><code class="language-plaintext highlighter-rouge">playwright</code></a>, and <a href="https://storybook.js.org/"><code class="language-plaintext highlighter-rouge">storybook</code></a></li>
      <li>These are all things that the very talented <a href="https://www.linkedin.com/in/barlock/">Michael Barlock</a> first introduced to the team.</li>
      <li>I learned a lot from him and since then, yeah I’ve been trying to incorporate / adopt these a bit more</li>
    </ul>
  </li>
  <li><strong>Unfurl Link Coverage</strong>
    <ul>
      <li>Perhaps trivial, but try sending <code class="language-plaintext highlighter-rouge">https://tennis-scorigami.com/</code> over iMessage? Yup, it uses the mp4 video with autoplay. What about over Discord? Falls to the native sub-5MB gif that plays. LinkedIn? That same gif? Twitter - just uses a static image as a fallback, and Slack? Also has coverage with the gif / animation support.</li>
      <li>I spent a non-trivial amount on this because it’s the little details that might help Andy Roddick actually click retweet (although yeah I should probably fix that Twitter share link then)</li>
    </ul>
  </li>
</ul>

<h1 id="challenges">Challenges</h1>

<h2 id="data-consolidation">Data Consolidation</h2>

<div class="markdown-alert markdown-alert-disclaimer">
<p>There's a couple important notes I want to make here.</p>

<p>Our data cutoff is then (partially) hinged on <a href="https://github.com/jeffsackmann">Jeff Sackmann</a>. Currently it's good until the end of 2024. I hate that. I'm not running LLMs, I shouldn't have a data cutoff. I'm planning on building a web scraper for ATP results and setting up my own data feeds because why the f not, and it's 2025 and you can roll (most) software if push really comes to shove. I can go on a full blown rant on my Substack or something, but I am aware of this limitation, and I dislike it more than you.</p>

<p>Secondly, consolidating all of this data from disparate sources, always prevents a challenge. There's the age old problem of same logical data, but different source ingestion. I have pulled some of the 2024 matches with SportRadar, and with RapidApi, and with Sackmann, so consolidating that was definitely a bit of elbow grease. Obviously, I used LLMs for parts of this project, but that part was probably the most hands on and driven. Yes, don't worry - I set up my <a href="https://www.npmjs.com/package/@modelcontextprotocol/server-postgres">Postgres MCP</a> and after porting to Neon I set up my <a href="https://neon.com/docs/ai/neon-mcp-server">Neon MCP</a>, but man yeah... still early days I suppose.</p>

</div>
<p><br /></p>

<p>Again, as referenced <a href="https://www.tennis-scorigami.com/about#data-collection">here</a>, tennis data is a commodity. It is insanely annoying and hard to get clean data. I am adamant that another side project that will spin out of this is a publicly available free API for people to query and get tennis data from.</p>

<p>I tried numerous things:</p>

<ul>
  <li><a href="https://sportradar.com/media-tech/data-content/sports-data-api/">SportRadar</a>
    <ul>
      <li>they’re one of the world’s best data providers</li>
      <li>however, they are absurdly expensive. there’s more info in the Reddit below but they don’t have set plans. as of 2 years ago, for a small time project, they were $1250 a month</li>
      <li>I tried their free trial, ripped as much as I could (of recent tournaments), and then got rate limited, and my trial expired</li>
      <li>Needless to say that was a bit of a miss</li>
    </ul>
  </li>
</ul>

<blockquote class="reddit-embed-bq" style="height:316px" data-embed-theme="dark" data-embed-height="356"><a href="https://www.reddit.com/r/Sportradar/comments/s9j4tl/api_pricing/">API Pricing?</a><br /> by<a href=""></a> in<a href="https://www.reddit.com/r/Sportradar/">Sportradar</a></blockquote>
<script async="" src="https://embed.reddit.com/widgets.js" charset="UTF-8"></script>

<ul>
  <li>(continuation)
    <ul>
      <li>one nice thing about SportRadar though was that they publish an <code class="language-plaintext highlighter-rouge">OpenAPI</code> (API not AI 🙄) spec of their endpoints.</li>
      <li>You can see their <a href="https://api.sportradar.com/tennis/production/v3/openapi/swagger/index.html#/">Swagger docs here</a>, and download the <code class="language-plaintext highlighter-rouge">openapi.yaml</code> file <a href="https://api.sportradar.com/tennis/production/v3/openapi/openapi.yaml">here</a></li>
      <li>This made generating a Python client for interacting with it very easy.</li>
      <li>I even created a SportRadar specific Python client for this.</li>
      <li>That repo was (more or less) entirely created with: <code class="language-plaintext highlighter-rouge">openapi-python-client generate --path ~/Downloads/openapi.yaml --config tennis-config.yaml</code> with my <code class="language-plaintext highlighter-rouge">tennis-config.yaml</code> being as simple as:</li>
    </ul>
  </li>
</ul>

<div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="na">project_name_override</span><span class="pi">:</span> <span class="s">sportradar-tennis-v3</span>
<span class="na">package_name_override</span><span class="pi">:</span> <span class="s">sportradar_tennis_v3</span>
</code></pre></div></div>

<p><a href="https://github.com/johnlarkin1/sportradar-tennis-v3"><strong>Check out the GH repo for the SportRadar client here</strong></a></p>

<div class="github-repo-card" data-repo="johnlarkin1/sportradar-tennis-v3">
  <div class="github-repo-loading">
    <div class="loading-spinner"></div>
    <p>Loading repository data...</p>
  </div>
  <div class="github-repo-content" style="display: none">
    <div class="github-repo-header">
      <svg class="github-icon" height="20" width="20" viewBox="0 0 16 16" version="1.1" aria-hidden="true">
        <path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"></path>
      </svg>
      <h3 class="github-repo-name">
        <a href="" target="_blank" rel="noopener noreferrer"></a>
      </h3>
    </div>
    <div class="github-repo-meta">
      <span class="github-repo-author">by <a href="" target="_blank" rel="noopener noreferrer"></a></span>
    </div>
    <p class="github-repo-description"></p>
    <div class="github-repo-stats">
      <div class="github-repo-stat github-repo-language" style="display: none">
        <span class="language-color"></span>
        <span class="language-name"></span>
      </div>
      <div class="github-repo-stat">
        <svg class="github-stat-icon" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true">
          <path fill-rule="evenodd" d="M8 .25a.75.75 0 01.673.418l1.882 3.815 4.21.612a.75.75 0 01.416 1.279l-3.046 2.97.719 4.192a.75.75 0 01-1.088.791L8 12.347l-3.766 1.98a.75.75 0 01-1.088-.79l.72-4.194L.818 6.374a.75.75 0 01.416-1.28l4.21-.611L7.327.668A.75.75 0 018 .25zm0 2.445L6.615 5.5a.75.75 0 01-.564.41l-3.097.45 2.24 2.184a.75.75 0 01.216.664l-.528 3.084 2.769-1.456a.75.75 0 01.698 0l2.77 1.456-.53-3.084a.75.75 0 01.216-.664l2.24-2.183-3.096-.45a.75.75 0 01-.564-.41L8 2.694v.001z"></path>
        </svg>
        <span class="stars-count">0</span>
      </div>
      <div class="github-repo-stat">
        <svg class="github-stat-icon" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true">
          <path fill-rule="evenodd" d="M5 3.25a.75.75 0 11-1.5 0 .75.75 0 011.5 0zm0 2.122a2.25 2.25 0 10-1.5 0v.878A2.25 2.25 0 005.75 8.5h1.5v2.128a2.251 2.251 0 101.5 0V8.5h1.5a2.25 2.25 0 002.25-2.25v-.878a2.25 2.25 0 10-1.5 0v.878a.75.75 0 01-.75.75h-4.5A.75.75 0 015 6.25v-.878zm3.75 7.378a.75.75 0 11-1.5 0 .75.75 0 011.5 0zm3-8.75a.75.75 0 100-1.5.75.75 0 000 1.5z"></path>
        </svg>
        <span class="forks-count">0</span>
      </div>
    </div>
    <div class="github-repo-topics" style="display: none"></div>
  </div>
  <div class="github-repo-error" style="display: none">
    <p>Unable to load repository data</p>
  </div>
</div>

<ul>
  <li><a href="https://rapidapi.com/">RapidApi</a>
    <ul>
      <li>Ok but I moved on rapidly (pun intended), because I wasn’t about to pay $1250 a month (unless I was running betting strats)</li>
      <li>I switched to <a href="https://rapidapi.com/">RapidAPI</a> given their generous free plans</li>
      <li>The data quality also suffered here, and a lot of the APIs had pretty stringent API limits (2k calls <em>per month</em>)</li>
      <li>Given this I eventually turned away after pulling what I could</li>
    </ul>
  </li>
  <li><a href="https://github.com/jeffsackmann">Sackmann</a>
    <ul>
      <li>I seriously need to buy Jeff Sackmann a beer</li>
      <li>He’s consolidated <strong>years</strong> of tennis data into a decently well organized format</li>
      <li>Sure there’s duplicate players, strange score formats, partial data, conflicts with accents, etc</li>
      <li>Lots of the traditional data quality things, but it’s easier to go from an excess and cleanse than to pull the data out of thin air</li>
      <li>So we stuck with Jeff for our historical data</li>
    </ul>
  </li>
</ul>

<h2 id="being-cheap">Being Cheap</h2>

<p>This one will be quick, but another challenge of this was just my desire to not spend money. Specifically for a hosted database service. I started with <a href="https://supabase.com/">Supabase</a> which I love and use for many other projects, but then pivoted to <a href="https://aiven.io/">Aiven</a> which I honestly liked anymore. However, they didn’t have connection pooling and I figured if this did go viral that I would get burned and people would say I was a bad engineer if we were throttling on Aiven’s free plan of like 10 open db transactions. So finally, I ended up <a href="https://neon.com/">Neon</a> because I’ve been wanting to try them and they’re slightly cheaper than <a href="https://supabase.com/">Supabase</a>. Here’s an AI generated table summarizing the pros and cons, and decisions:</p>

<div class="markdown-alert markdown-alert-ai">
<p>This table was generated using AI.</p>
</div>

<table>
  <thead>
    <tr>
      <th>Provider</th>
      <th>Pros</th>
      <th>Cons</th>
      <th>Decision Rationale</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><a href="https://supabase.com/">Supabase</a></td>
      <td>- Great developer experience<br />- Integrated auth + storage<br />- Rich UI/dashboard</td>
      <td>- Slightly pricier for scaling<br />- Overhead from extra services if only DB is needed</td>
      <td>Preferred for full-stack apps, but overkill + cost when only DB needed</td>
    </tr>
    <tr>
      <td><a href="https://aiven.io/">Aiven</a></td>
      <td>- Excellent reliability<br />- Flexible managed Postgres<br />- Simple CLI/tools</td>
      <td>- No built-in connection pooling<br />- Free tier limit: ~10 connections</td>
      <td>Risky for viral spikes; free tier would throttle app &amp; reflect poorly on eng.</td>
    </tr>
    <tr>
      <td><a href="https://neon.com/">Neon</a></td>
      <td>- Built-in connection pooling<br />- Autoscaling<br />- Cheaper than Supabase<br />- Separation of storage &amp; compute</td>
      <td>- Newer platform, less mature<br />- Limited ecosystem/integrations compared to Supabase</td>
      <td>Chosen for price/perf tradeoff; avoids pooling issues; good opportunity to test</td>
    </tr>
  </tbody>
</table>

<h2 id="fetching-108k-nodes">Fetching 108k Nodes</h2>

<p>This was actually pretty fine truth be told in terms of backend performances. The large win of utilizing <a href="https://en.wikipedia.org/wiki/Materialized_view">materialized views</a> was that I could pull the pertinent information, and shape it into the appropriate information that my various FE graph information would require.</p>

<p>Once we get to live data, I’ll set up cronjobs or triggers to refresh these and build out my data pipeline a bit more. However, for the moment, these materialized views were sufficient.</p>

<p>Here’s one as an example:</p>

<div class="language-sql highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">CREATE</span> <span class="n">MATERIALIZED</span> <span class="k">VIEW</span> <span class="k">public</span><span class="p">.</span><span class="n">mv_slug_stats_3_men</span>
<span class="n">TABLESPACE</span> <span class="n">pg_default</span>
<span class="k">AS</span> <span class="k">SELECT</span> <span class="n">s</span><span class="p">.</span><span class="n">event_id</span><span class="p">,</span>
    <span class="n">ss</span><span class="p">.</span><span class="n">sequence_id</span> <span class="k">AS</span> <span class="n">id</span><span class="p">,</span>
    <span class="n">ss</span><span class="p">.</span><span class="n">slug</span><span class="p">,</span>
    <span class="n">ss</span><span class="p">.</span><span class="n">depth</span><span class="p">,</span>
    <span class="n">ss</span><span class="p">.</span><span class="n">winner_sets</span><span class="p">,</span>
    <span class="n">ss</span><span class="p">.</span><span class="n">loser_sets</span><span class="p">,</span>
    <span class="n">is_terminal_3</span><span class="p">(</span><span class="n">ss</span><span class="p">.</span><span class="n">winner_sets</span><span class="p">,</span> <span class="n">ss</span><span class="p">.</span><span class="n">loser_sets</span><span class="p">)</span> <span class="k">AS</span> <span class="n">is_terminal</span><span class="p">,</span>
    <span class="mi">3</span> <span class="k">AS</span> <span class="n">best_of</span><span class="p">,</span>
    <span class="n">COALESCE</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">played</span><span class="p">,</span> <span class="k">false</span><span class="p">)</span> <span class="k">AS</span> <span class="n">played</span><span class="p">,</span>
    <span class="n">COALESCE</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">occurrences</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">AS</span> <span class="n">occurrences</span>
   <span class="k">FROM</span> <span class="n">score_sequence</span> <span class="n">ss</span>
     <span class="k">LEFT</span> <span class="k">JOIN</span> <span class="n">mv_sequence_stats_3_men</span> <span class="n">s</span> <span class="k">ON</span> <span class="n">s</span><span class="p">.</span><span class="n">sequence_id</span> <span class="o">=</span> <span class="n">ss</span><span class="p">.</span><span class="n">sequence_id</span>
  <span class="k">WHERE</span> <span class="n">ss</span><span class="p">.</span><span class="n">best_of</span> <span class="o">&lt;=</span> <span class="mi">3</span>
<span class="k">WITH</span> <span class="k">DATA</span><span class="p">;</span>

<span class="c1">-- this is so that we can filter by event_id which is what happens on the frontend when</span>
<span class="c1">-- a user selects either a tournament or a year (more or less)</span>
<span class="k">CREATE</span> <span class="k">INDEX</span> <span class="n">idx_mv_slug_stats_3_men_event</span> <span class="k">ON</span> <span class="k">public</span><span class="p">.</span><span class="n">mv_slug_stats_3_men</span> <span class="k">USING</span> <span class="n">btree</span> <span class="p">(</span><span class="n">event_id</span><span class="p">);</span>
</code></pre></div></div>

<h2 id="rendering-108k-nodes">Rendering 108k Nodes</h2>

<p>There are numerous repos (some of which I’m using) to help process and render beautiful graphics. I didn’t roll my own physics engine or the force-graph library. I used these:</p>

<h3 id="2d-sigma-graph">2D Sigma Graph</h3>

<div class="github-repo-card" data-repo="jacomyal/sigma.js">
  <div class="github-repo-loading">
    <div class="loading-spinner"></div>
    <p>Loading repository data...</p>
  </div>
  <div class="github-repo-content" style="display: none">
    <div class="github-repo-header">
      <svg class="github-icon" height="20" width="20" viewBox="0 0 16 16" version="1.1" aria-hidden="true">
        <path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"></path>
      </svg>
      <h3 class="github-repo-name">
        <a href="" target="_blank" rel="noopener noreferrer"></a>
      </h3>
    </div>
    <div class="github-repo-meta">
      <span class="github-repo-author">by <a href="" target="_blank" rel="noopener noreferrer"></a></span>
    </div>
    <p class="github-repo-description"></p>
    <div class="github-repo-stats">
      <div class="github-repo-stat github-repo-language" style="display: none">
        <span class="language-color"></span>
        <span class="language-name"></span>
      </div>
      <div class="github-repo-stat">
        <svg class="github-stat-icon" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true">
          <path fill-rule="evenodd" d="M8 .25a.75.75 0 01.673.418l1.882 3.815 4.21.612a.75.75 0 01.416 1.279l-3.046 2.97.719 4.192a.75.75 0 01-1.088.791L8 12.347l-3.766 1.98a.75.75 0 01-1.088-.79l.72-4.194L.818 6.374a.75.75 0 01.416-1.28l4.21-.611L7.327.668A.75.75 0 018 .25zm0 2.445L6.615 5.5a.75.75 0 01-.564.41l-3.097.45 2.24 2.184a.75.75 0 01.216.664l-.528 3.084 2.769-1.456a.75.75 0 01.698 0l2.77 1.456-.53-3.084a.75.75 0 01.216-.664l2.24-2.183-3.096-.45a.75.75 0 01-.564-.41L8 2.694v.001z"></path>
        </svg>
        <span class="stars-count">0</span>
      </div>
      <div class="github-repo-stat">
        <svg class="github-stat-icon" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true">
          <path fill-rule="evenodd" d="M5 3.25a.75.75 0 11-1.5 0 .75.75 0 011.5 0zm0 2.122a2.25 2.25 0 10-1.5 0v.878A2.25 2.25 0 005.75 8.5h1.5v2.128a2.251 2.251 0 101.5 0V8.5h1.5a2.25 2.25 0 002.25-2.25v-.878a2.25 2.25 0 10-1.5 0v.878a.75.75 0 01-.75.75h-4.5A.75.75 0 015 6.25v-.878zm3.75 7.378a.75.75 0 11-1.5 0 .75.75 0 011.5 0zm3-8.75a.75.75 0 100-1.5.75.75 0 000 1.5z"></path>
        </svg>
        <span class="forks-count">0</span>
      </div>
    </div>
    <div class="github-repo-topics" style="display: none"></div>
  </div>
  <div class="github-repo-error" style="display: none">
    <p>Unable to load repository data</p>
  </div>
</div>

<h3 id="3d-force-graph">3D Force Graph</h3>

<div class="github-repo-card" data-repo="vasturiano/react-force-graph">
  <div class="github-repo-loading">
    <div class="loading-spinner"></div>
    <p>Loading repository data...</p>
  </div>
  <div class="github-repo-content" style="display: none">
    <div class="github-repo-header">
      <svg class="github-icon" height="20" width="20" viewBox="0 0 16 16" version="1.1" aria-hidden="true">
        <path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"></path>
      </svg>
      <h3 class="github-repo-name">
        <a href="" target="_blank" rel="noopener noreferrer"></a>
      </h3>
    </div>
    <div class="github-repo-meta">
      <span class="github-repo-author">by <a href="" target="_blank" rel="noopener noreferrer"></a></span>
    </div>
    <p class="github-repo-description"></p>
    <div class="github-repo-stats">
      <div class="github-repo-stat github-repo-language" style="display: none">
        <span class="language-color"></span>
        <span class="language-name"></span>
      </div>
      <div class="github-repo-stat">
        <svg class="github-stat-icon" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true">
          <path fill-rule="evenodd" d="M8 .25a.75.75 0 01.673.418l1.882 3.815 4.21.612a.75.75 0 01.416 1.279l-3.046 2.97.719 4.192a.75.75 0 01-1.088.791L8 12.347l-3.766 1.98a.75.75 0 01-1.088-.79l.72-4.194L.818 6.374a.75.75 0 01.416-1.28l4.21-.611L7.327.668A.75.75 0 018 .25zm0 2.445L6.615 5.5a.75.75 0 01-.564.41l-3.097.45 2.24 2.184a.75.75 0 01.216.664l-.528 3.084 2.769-1.456a.75.75 0 01.698 0l2.77 1.456-.53-3.084a.75.75 0 01.216-.664l2.24-2.183-3.096-.45a.75.75 0 01-.564-.41L8 2.694v.001z"></path>
        </svg>
        <span class="stars-count">0</span>
      </div>
      <div class="github-repo-stat">
        <svg class="github-stat-icon" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true">
          <path fill-rule="evenodd" d="M5 3.25a.75.75 0 11-1.5 0 .75.75 0 011.5 0zm0 2.122a2.25 2.25 0 10-1.5 0v.878A2.25 2.25 0 005.75 8.5h1.5v2.128a2.251 2.251 0 101.5 0V8.5h1.5a2.25 2.25 0 002.25-2.25v-.878a2.25 2.25 0 10-1.5 0v.878a.75.75 0 01-.75.75h-4.5A.75.75 0 015 6.25v-.878zm3.75 7.378a.75.75 0 11-1.5 0 .75.75 0 011.5 0zm3-8.75a.75.75 0 100-1.5.75.75 0 000 1.5z"></path>
        </svg>
        <span class="forks-count">0</span>
      </div>
    </div>
    <div class="github-repo-topics" style="display: none"></div>
  </div>
  <div class="github-repo-error" style="display: none">
    <p>Unable to load repository data</p>
  </div>
</div>

<p>Regardless, I wanted a very slick and performant frontend to render all these nodes. The 5 set 3d force graph is not yet built out. You’ll note that <code class="language-plaintext highlighter-rouge">vasturiano</code> has a large graph demo <a href="https://vasturiano.github.io/react-force-graph/example/large-graph/">here</a>:</p>

<div style="text-align: center;">
<iframe src="https://vasturiano.github.io/react-force-graph/example/large-graph/" width="800px" height="600px" frameborder="0"></iframe>
</div>

<p>And that looks fantastic. However, if you look at the source code…</p>

<div class="language-javascript highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// https://github.com/vasturiano/react-force-graph/blob/master/example/large-graph/index.html</span>
<span class="o">&lt;</span><span class="nx">head</span><span class="o">&gt;</span>
  <span class="o">&lt;</span><span class="nx">style</span><span class="o">&gt;</span> <span class="nx">body</span> <span class="p">{</span> <span class="nl">margin</span><span class="p">:</span> <span class="mi">0</span><span class="p">;</span> <span class="p">}</span> <span class="o">&lt;</span><span class="sr">/style</span><span class="err">&gt;
</span>
  <span class="o">&lt;</span><span class="nx">script</span> <span class="nx">type</span><span class="o">=</span><span class="dl">"</span><span class="s2">importmap</span><span class="dl">"</span><span class="o">&gt;</span><span class="p">{</span> <span class="dl">"</span><span class="s2">imports</span><span class="dl">"</span><span class="p">:</span> <span class="p">{</span>
    <span class="dl">"</span><span class="s2">react</span><span class="dl">"</span><span class="p">:</span> <span class="dl">"</span><span class="s2">https://esm.sh/react</span><span class="dl">"</span><span class="p">,</span>
    <span class="dl">"</span><span class="s2">react-dom</span><span class="dl">"</span><span class="p">:</span> <span class="dl">"</span><span class="s2">https://esm.sh/react-dom/client</span><span class="dl">"</span>
  <span class="p">}}</span><span class="o">&lt;</span><span class="sr">/script</span><span class="err">&gt;
</span>
<span class="c">&lt;!--</span>  <span class="o">&lt;</span><span class="nx">script</span> <span class="nx">type</span><span class="o">=</span><span class="dl">"</span><span class="s2">module</span><span class="dl">"</span><span class="o">&gt;</span><span class="k">import</span> <span class="o">*</span> <span class="k">as</span> <span class="nx">React</span> <span class="k">from</span> <span class="dl">'</span><span class="s1">react</span><span class="dl">'</span><span class="p">;</span> <span class="nb">window</span><span class="p">.</span><span class="nx">React</span> <span class="o">=</span> <span class="nx">React</span><span class="p">;</span><span class="o">&lt;</span><span class="sr">/script&gt;--</span><span class="err">&gt;
</span><span class="c">&lt;!--</span>  <span class="o">&lt;</span><span class="nx">script</span> <span class="nx">src</span><span class="o">=</span><span class="dl">"</span><span class="s2">../../src/packages/react-force-graph-3d/dist/react-force-graph-3d.js</span><span class="dl">"</span> <span class="nx">defer</span><span class="o">&gt;&lt;</span><span class="sr">/script&gt;--</span><span class="err">&gt;
</span><span class="o">&lt;</span><span class="sr">/head</span><span class="err">&gt;
</span>
<span class="o">&lt;</span><span class="nx">body</span><span class="o">&gt;</span>
  <span class="o">&lt;</span><span class="nx">div</span> <span class="nx">id</span><span class="o">=</span><span class="dl">"</span><span class="s2">graph</span><span class="dl">"</span><span class="o">&gt;&lt;</span><span class="sr">/div</span><span class="err">&gt;
</span>
  <span class="o">&lt;</span><span class="nx">script</span> <span class="nx">src</span><span class="o">=</span><span class="dl">"</span><span class="s2">//cdn.jsdelivr.net/npm/@babel/standalone</span><span class="dl">"</span><span class="o">&gt;&lt;</span><span class="sr">/script</span><span class="err">&gt;
</span>  <span class="o">&lt;</span><span class="nx">script</span> <span class="nx">type</span><span class="o">=</span><span class="dl">"</span><span class="s2">text/jsx</span><span class="dl">"</span> <span class="nx">data</span><span class="o">-</span><span class="nx">type</span><span class="o">=</span><span class="dl">"</span><span class="s2">module</span><span class="dl">"</span><span class="o">&gt;</span>
    <span class="k">import</span> <span class="nx">ForceGraph3D</span> <span class="k">from</span> <span class="dl">'</span><span class="s1">https://esm.sh/react-force-graph-3d?external=react</span><span class="dl">'</span><span class="p">;</span>
    <span class="k">import</span> <span class="nx">React</span> <span class="k">from</span> <span class="dl">'</span><span class="s1">react</span><span class="dl">'</span><span class="p">;</span>
    <span class="k">import</span> <span class="p">{</span> <span class="nx">createRoot</span> <span class="p">}</span> <span class="k">from</span> <span class="dl">'</span><span class="s1">react-dom</span><span class="dl">'</span><span class="p">;</span>

    <span class="nx">fetch</span><span class="p">(</span><span class="dl">'</span><span class="s1">../datasets/blocks.json</span><span class="dl">'</span><span class="p">).</span><span class="nx">then</span><span class="p">(</span><span class="nx">res</span> <span class="o">=&gt;</span> <span class="nx">res</span><span class="p">.</span><span class="nx">json</span><span class="p">()).</span><span class="nx">then</span><span class="p">(</span><span class="nx">data</span> <span class="o">=&gt;</span> <span class="p">{</span>
      <span class="nx">createRoot</span><span class="p">(</span><span class="nb">document</span><span class="p">.</span><span class="nx">getElementById</span><span class="p">(</span><span class="dl">'</span><span class="s1">graph</span><span class="dl">'</span><span class="p">)).</span><span class="nx">render</span><span class="p">(</span>
        <span class="o">&lt;</span><span class="nx">ForceGraph3D</span>
          <span class="nx">graphData</span><span class="o">=</span><span class="p">{</span><span class="nx">data</span><span class="p">}</span>
          <span class="nx">nodeLabel</span><span class="o">=</span><span class="p">{</span><span class="nx">node</span> <span class="o">=&gt;</span> <span class="o">&lt;</span><span class="nx">div</span><span class="o">&gt;&lt;</span><span class="nx">b</span><span class="o">&gt;</span><span class="p">{</span><span class="nx">node</span><span class="p">.</span><span class="nx">user</span><span class="p">}</span><span class="o">&lt;</span><span class="sr">/b&gt;: {node.description}&lt;/</span><span class="nx">div</span><span class="o">&gt;</span><span class="p">}</span>
          <span class="nx">nodeAutoColorBy</span><span class="o">=</span><span class="dl">"</span><span class="s2">user</span><span class="dl">"</span>
          <span class="nx">linkDirectionalParticles</span><span class="o">=</span><span class="p">{</span><span class="mi">1</span><span class="p">}</span>
        <span class="sr">/</span><span class="err">&gt;
</span>      <span class="p">);</span>
    <span class="p">});</span>
  <span class="o">&lt;</span><span class="sr">/script</span><span class="err">&gt;
</span><span class="o">&lt;</span><span class="sr">/body</span><span class="err">&gt;
</span></code></pre></div></div>

<p>And then look at the underlying <code class="language-plaintext highlighter-rouge">../datasets/blocks.json</code> and then do a little bit of Python / json handling… you’ll note:</p>

<div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>╭─johnlarkin@Mac ~/Documents/coding/tennis-scorigami ‹feature/john/example-blog-post<span class="k">*</span>›
╰─➤  python                                                                                                                          127 ↵
Python 3.11.7 <span class="o">(</span>v3.11.7:fa7a6f2303, Dec  4 2023, 15:22:56<span class="o">)</span> <span class="o">[</span>Clang 13.0.0 <span class="o">(</span>clang-1300.0.29.30<span class="o">)]</span> on darwin
Type <span class="s2">"help"</span>, <span class="s2">"copyright"</span>, <span class="s2">"credits"</span> or <span class="s2">"license"</span> <span class="k">for </span>more information.
<span class="o">&gt;&gt;&gt;</span> import json
<span class="o">&gt;&gt;&gt;</span> import sys
<span class="o">&gt;&gt;&gt;</span> from pathlib import Path
<span class="o">&gt;&gt;&gt;</span> raw_path <span class="o">=</span> Path<span class="o">(</span><span class="s1">'/Users/johnlarkin/Downloads/blocks.json'</span><span class="o">)</span>
<span class="o">&gt;&gt;&gt;</span> raw_content <span class="o">=</span> raw_path.read_text<span class="o">()</span>
<span class="o">&gt;&gt;&gt;</span> data <span class="o">=</span> json.loads<span class="o">(</span>raw_content<span class="o">)</span>
<span class="o">&gt;&gt;&gt;</span> data[<span class="s1">'nodes'</span><span class="o">][</span>:2]
<span class="o">[{</span><span class="s1">'id'</span>: <span class="s1">'4062045'</span>, <span class="s1">'user'</span>: <span class="s1">'mbostock'</span>, <span class="s1">'description'</span>: <span class="s1">'Force-Directed Graph'</span><span class="o">}</span>, <span class="o">{</span><span class="s1">'id'</span>: <span class="s1">'1341021'</span>, <span class="s1">'user'</span>: <span class="s1">'mbostock'</span>, <span class="s1">'description'</span>: <span class="s1">'Parallel Coordinates'</span><span class="o">}]</span>
<span class="o">&gt;&gt;&gt;</span> len<span class="o">(</span>data[<span class="s1">'nodes'</span><span class="o">])</span>
1238 <span class="c"># ... so... not actually that large</span>
<span class="o">&gt;&gt;&gt;</span> len<span class="o">(</span>data[<span class="s1">'links'</span><span class="o">])</span>
2602
</code></pre></div></div>

<p>So… compared to 108k nodes and a similar number of edges… not exactly a port and shift.</p>

<hr />

<p>However, the 2d sigma graph (with a bit of streaming) can handle this with relative ease. Note, we are playing a bit of a visual game with doing some edge reduction for the farther out layers for 5 graphs, basically following this logic:</p>

<div class="language-typescript highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nx">edgesToRender</span> <span class="o">=</span> <span class="nx">data</span><span class="p">.</span><span class="nx">edges</span><span class="p">.</span><span class="nx">filter</span><span class="p">((</span><span class="nx">edge</span><span class="p">)</span> <span class="o">=&gt;</span> <span class="p">{</span>
  <span class="kd">const</span> <span class="nx">fromNode</span> <span class="o">=</span> <span class="nx">nodeMap</span><span class="p">.</span><span class="kd">get</span><span class="p">(</span><span class="nx">edge</span><span class="p">.</span><span class="nx">frm</span><span class="p">);</span>
  <span class="kd">const</span> <span class="nx">toNode</span> <span class="o">=</span> <span class="nx">nodeMap</span><span class="p">.</span><span class="kd">get</span><span class="p">(</span><span class="nx">edge</span><span class="p">.</span><span class="nx">to</span><span class="p">);</span>

  <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="nx">fromNode</span> <span class="o">||</span> <span class="o">!</span><span class="nx">toNode</span><span class="p">)</span> <span class="k">return</span> <span class="kc">false</span><span class="p">;</span>

  <span class="c1">// Always keep early depth edges (structure)</span>
  <span class="k">if</span> <span class="p">(</span><span class="nb">Math</span><span class="p">.</span><span class="nx">max</span><span class="p">(</span><span class="nx">fromNode</span><span class="p">.</span><span class="nx">depth</span><span class="p">,</span> <span class="nx">toNode</span><span class="p">.</span><span class="nx">depth</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">)</span> <span class="k">return</span> <span class="kc">true</span><span class="p">;</span>

  <span class="c1">// Keep edges to/from unscored nodes (discovery)</span>
  <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="nx">fromNode</span><span class="p">.</span><span class="nx">played</span> <span class="o">||</span> <span class="o">!</span><span class="nx">toNode</span><span class="p">.</span><span class="nx">played</span><span class="p">)</span> <span class="k">return</span> <span class="kc">true</span><span class="p">;</span>

  <span class="c1">// Keep edges with high occurrence nodes</span>
  <span class="k">if</span> <span class="p">(</span><span class="nx">fromNode</span><span class="p">.</span><span class="nx">occurrences</span> <span class="o">&gt;</span> <span class="mi">100</span> <span class="o">||</span> <span class="nx">toNode</span><span class="p">.</span><span class="nx">occurrences</span> <span class="o">&gt;</span> <span class="mi">100</span><span class="p">)</span> <span class="k">return</span> <span class="kc">true</span><span class="p">;</span>

  <span class="c1">// For deeper levels, only keep a sample</span>
  <span class="k">return</span> <span class="nb">Math</span><span class="p">.</span><span class="nx">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mf">0.1</span><span class="p">;</span> <span class="c1">// Keep 10% of remaining edges</span>
<span class="p">});</span>
</code></pre></div></div>

<h3 id="streaming--ndjson">Streaming + NDJSON</h3>

<p>However, the crux of keeping the frontend still lightweight for the 5 set rendering was to use NDJSON.</p>

<p>I had heard of NDJSON as a streaming mechanism but hadn’t used it in a production workflow. However, <a href="https://www.anthropic.com/claude-code">Claude</a> obviously was a huge help here and took care of implementing that part. The crux of the code looks like this:</p>

<div class="language-typescript highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kd">const</span> <span class="nx">response</span> <span class="o">=</span> <span class="k">await</span> <span class="nx">fetchGraphStream</span><span class="p">({</span>
  <span class="cm">/* {...filters} */</span>
  <span class="na">maxEdgesPerDepth</span><span class="p">:</span> <span class="nx">GRAPH_CONFIG</span><span class="p">.</span><span class="nx">maxEdgesPerDepth</span><span class="p">,</span>
  <span class="na">minOccurrences</span><span class="p">:</span> <span class="nx">GRAPH_CONFIG</span><span class="p">.</span><span class="nx">minOccurrences</span><span class="p">,</span>
  <span class="na">signal</span><span class="p">:</span> <span class="nx">abortController</span><span class="p">.</span><span class="nx">signal</span><span class="p">,</span>
<span class="p">});</span>

<span class="kd">const</span> <span class="nx">reader</span> <span class="o">=</span> <span class="nx">response</span><span class="p">.</span><span class="nx">body</span><span class="p">.</span><span class="nx">getReader</span><span class="p">();</span>
<span class="kd">const</span> <span class="nx">decoder</span> <span class="o">=</span> <span class="k">new</span> <span class="nx">TextDecoder</span><span class="p">();</span>
<span class="kd">let</span> <span class="nx">buffer</span> <span class="o">=</span> <span class="dl">""</span><span class="p">;</span>

<span class="cm">/* more code here */</span>

<span class="k">while</span> <span class="p">(</span><span class="kc">true</span><span class="p">)</span> <span class="p">{</span>
  <span class="kd">const</span> <span class="p">{</span> <span class="nx">done</span><span class="p">,</span> <span class="nx">value</span> <span class="p">}</span> <span class="o">=</span> <span class="k">await</span> <span class="nx">reader</span><span class="p">.</span><span class="nx">read</span><span class="p">();</span>
  <span class="k">if</span> <span class="p">(</span><span class="nx">done</span><span class="p">)</span> <span class="k">break</span><span class="p">;</span>

  <span class="nx">buffer</span> <span class="o">+=</span> <span class="nx">decoder</span><span class="p">.</span><span class="nx">decode</span><span class="p">(</span><span class="nx">value</span><span class="p">,</span> <span class="p">{</span> <span class="na">stream</span><span class="p">:</span> <span class="kc">true</span> <span class="p">});</span>
  <span class="kd">const</span> <span class="nx">lines</span> <span class="o">=</span> <span class="nx">buffer</span><span class="p">.</span><span class="nx">split</span><span class="p">(</span><span class="dl">"</span><span class="se">\n</span><span class="dl">"</span><span class="p">);</span>
  <span class="nx">buffer</span> <span class="o">=</span> <span class="nx">lines</span><span class="p">.</span><span class="nx">pop</span><span class="p">()</span> <span class="o">||</span> <span class="dl">""</span><span class="p">;</span>

  <span class="k">for</span> <span class="p">(</span><span class="kd">const</span> <span class="nx">line</span> <span class="k">of</span> <span class="nx">lines</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="nx">line</span><span class="p">.</span><span class="nx">trim</span><span class="p">())</span> <span class="k">continue</span><span class="p">;</span>

    <span class="k">try</span> <span class="p">{</span>
      <span class="kd">const</span> <span class="na">message</span><span class="p">:</span> <span class="nx">StreamMessage</span> <span class="o">=</span> <span class="nx">JSON</span><span class="p">.</span><span class="nx">parse</span><span class="p">(</span><span class="nx">line</span><span class="p">);</span>

      <span class="k">switch</span> <span class="p">(</span><span class="nx">message</span><span class="p">.</span><span class="kd">type</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">case</span> <span class="dl">"</span><span class="s2">meta</span><span class="dl">"</span><span class="p">:</span>
          <span class="nx">streamTotalNodes</span> <span class="o">=</span> <span class="nx">message</span><span class="p">.</span><span class="nx">totalNodes</span><span class="p">;</span>
          <span class="nx">streamTotalEdges</span> <span class="o">=</span> <span class="nx">message</span><span class="p">.</span><span class="nx">totalEdges</span><span class="p">;</span>
          <span class="nx">setTotalNodes</span><span class="p">(</span><span class="nx">streamTotalNodes</span><span class="p">);</span>
          <span class="nx">setTotalEdges</span><span class="p">(</span><span class="nx">streamTotalEdges</span><span class="p">);</span>
          <span class="cm">/* more code here */</span>
          <span class="k">break</span><span class="p">;</span>

        <span class="k">case</span> <span class="dl">"</span><span class="s2">nodes</span><span class="dl">"</span><span class="p">:</span>
          <span class="nx">tempNodes</span><span class="p">.</span><span class="nx">push</span><span class="p">(...</span><span class="nx">message</span><span class="p">.</span><span class="nx">data</span><span class="p">);</span>
          <span class="nx">setLoadedNodes</span><span class="p">(</span><span class="nx">tempNodes</span><span class="p">.</span><span class="nx">length</span><span class="p">);</span>
          <span class="cm">/* more code here */</span>
          <span class="k">break</span><span class="p">;</span>

        <span class="k">case</span> <span class="dl">"</span><span class="s2">edges</span><span class="dl">"</span><span class="p">:</span>
          <span class="nx">tempEdges</span><span class="p">.</span><span class="nx">push</span><span class="p">(...</span><span class="nx">message</span><span class="p">.</span><span class="nx">data</span><span class="p">);</span>
          <span class="nx">setLoadedEdges</span><span class="p">(</span><span class="nx">tempEdges</span><span class="p">.</span><span class="nx">length</span><span class="p">);</span>
          <span class="cm">/* more code here */</span>
          <span class="k">break</span><span class="p">;</span>

        <span class="k">case</span> <span class="dl">"</span><span class="s2">complete</span><span class="dl">"</span><span class="p">:</span>
          <span class="nx">setData</span><span class="p">({</span> <span class="na">nodes</span><span class="p">:</span> <span class="nx">tempNodes</span><span class="p">,</span> <span class="na">edges</span><span class="p">:</span> <span class="nx">tempEdges</span> <span class="p">});</span>
          <span class="cm">/* more code here */</span>
          <span class="k">break</span><span class="p">;</span>
      <span class="p">}</span>
    <span class="p">}</span> <span class="k">catch</span> <span class="p">(</span><span class="nx">e</span><span class="p">)</span> <span class="p">{</span>
      <span class="nx">console</span><span class="p">.</span><span class="nx">error</span><span class="p">(</span><span class="dl">"</span><span class="s2">Failed to parse stream message:</span><span class="dl">"</span><span class="p">,</span> <span class="nx">e</span><span class="p">);</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>How this works is that my backend API route is creating a <code class="language-plaintext highlighter-rouge">ReadableStream</code>, basically doing something like this:</p>

<div class="language-typescript highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kd">const</span> <span class="nx">stream</span> <span class="o">=</span> <span class="k">new</span> <span class="nx">ReadableStream</span><span class="p">({...});</span>
<span class="k">return</span> <span class="k">new</span> <span class="nx">Response</span><span class="p">(</span><span class="nx">stream</span><span class="p">,</span> <span class="p">{...});</span>
</code></pre></div></div>

<p>Then on the frontend this <code class="language-plaintext highlighter-rouge">response.body</code> is that <code class="language-plaintext highlighter-rouge">ReadableStream</code> bit, which has the built in <code class="language-plaintext highlighter-rouge">getReader()</code> (which returns basically a <code class="language-plaintext highlighter-rouge">ReadableStreamDefaultReader</code>).</p>

<p>Then to use that <code class="language-plaintext highlighter-rouge">reader</code>, it’s as simple as: <code class="language-plaintext highlighter-rouge">const { done, value } = await reader.read();</code>.</p>

<p>This helped a lot for less memory load on the frontend as well as a more interactive and non-blocking call for the UI.</p>

<h2 id="unfurl-previews">Unfurl Previews</h2>

<p>Just a quick note, but the various nuances between different platforms is an absolute headache. Someone needs to sort that out soon. There was basically this decision pipeline for me everytime:</p>

<p><img src="/images/tennis-scorigami/unfurl-pipeline.png" alt="unfurl-decision-tree" class="center-small lightbox-image" /></p>

<h1 id="surprises">Surprises</h1>

<div class="markdown-alert markdown-alert-note">
<p>I was very surprised (and satisfied) when we saw the 8.5 number. 8.5% seemed shockingly low for 5set matches that had been played (tiebreak in the fifth note). So part of it is that 6-0, 0-6 is very uncommon and that opens up a huge part of the search tree to never have been played before. I chatted with ChatGPT or Claude about this as well, and thought its response was interesting (and partially helped me be convinced).</p>
<p>That being said if the was a data quality or data ingestion issue, don't kill the messenger.</p>
</div>

<p>There’s numerous reasons for this though. Here are some of them:</p>

<ul>
  <li>there are 108k possible final score outcomes</li>
  <li>there have only been 40k 5 set matches in the Open Era
    <ul>
      <li>that’s a rough estimate but I have only loaded in 40k</li>
    </ul>
  </li>
  <li>we’re excluding 2025
    <ul>
      <li>this is exciting, it means there’s lots of scorigamis</li>
    </ul>
  </li>
  <li>concentrated probability distribution
    <ul>
      <li>consider this valid but unlikely scoreline: <code class="language-plaintext highlighter-rouge">6-7, 7-6, 6-0, 1-6, 7-5</code></li>
      <li>that would require:
        <ul>
          <li>losing a tiebreak</li>
          <li>then winning a tiebreak</li>
          <li>then bageling opponent</li>
          <li>then getting destroyed 1-6</li>
          <li>then battle in the final</li>
        </ul>
      </li>
      <li>this is obviously possible but unlikely at the pro level</li>
    </ul>
  </li>
  <li>order matters
    <ul>
      <li>permutations not combinations! so that increases the space</li>
    </ul>
  </li>
</ul>

<p>Here’s an excerpt from the conversation i had with ChatGPT about this:</p>

<blockquote>
  <p>Below is a concise, self-contained recap of the <strong>math</strong> followed by a list of reasons the realised coverage is only ≈ 8 %.</p>

  <hr />

  <h4 id="the-counting">The counting</h4>

  <table>
    <thead>
      <tr>
        <th>Step</th>
        <th>What we are counting</th>
        <th>Result</th>
      </tr>
    </thead>
    <tbody>
      <tr>
        <td><strong>1. Set-level possibilities</strong></td>
        <td>A player can <strong>win</strong> a set by any of seven legal scores (6-0 … 7-6). From the winner’s viewpoint those 7 options are all we need for every set.</td>
        <td>7 outcomes per set</td>
      </tr>
      <tr>
        <td><strong>2. Match-length patterns</strong></td>
        <td>Arrange “W” (winner’s set) and “L” (loser’s set) with the <strong>last set forced to be W</strong>: <br />• 3-0 → <code class="language-plaintext highlighter-rouge">WWW</code> → 1 pattern <br />• 3-1 → choose where the lone <code class="language-plaintext highlighter-rouge">L</code> goes among the first 3 slots ⇒ C(3,1)=3 <br />• 3-2 → choose where the two <code class="language-plaintext highlighter-rouge">L</code>’s go among the first 4 slots ⇒ C(4,2)=6</td>
        <td>1 + 3 + 6 = <strong>10 patterns</strong></td>
      </tr>
      <tr>
        <td><strong>3. Fill each pattern with scores</strong></td>
        <td>Every set in a k-set pattern has 7 choices ⇒ $7^{k}$. <br />• 3-set patterns: $7^{3}=343$ <br />• 4-set patterns: $7^{4}=2 401$ <br />• 5-set patterns: $7^{5}=16 807$</td>
        <td>multiply per pattern</td>
      </tr>
      <tr>
        <td><strong>4. Grand total</strong></td>
        <td>$\sum \text{patterns}\times 7^{k}$</td>
        <td>$1×343 \;+\; 3×2 401 \;+\; 6×16 807 = \boxed{108\,388}$ distinct best-of-5 finals</td>
      </tr>
    </tbody>
  </table>

  <hr />

  <h4 id="why-only-8--have-shown-up-on-court">Why only ~8 % have shown up on court</h4>

  <ol>
    <li>
      <p><strong>Matches &lt; boxes</strong>
The Open-era database has only <strong>≈ 45 000</strong> men’s best-of-5 matches. Even if every one were unique we could fill at most 45 000 / 108 388 ≈ 41 % of the grid. Coverage must start below that ceiling.</p>
    </li>
    <li>
      <p><strong>Uneven probability mass</strong>
Tennis scoring is <em>not</em> uniform. Straight-set wins with mundane margins (6-3, 6-4, 6-4) occur constantly, while exotic mixtures such as 6-7, 7-6, 6-0, 1-6, 7-5 require a domination-collapse-rebound sequence that is extraordinarily rare. Most of the 108 388 boxes sit under microscopic probability.</p>
    </li>
    <li>
      <p><strong>Five-set inflation</strong>
Five-setters account for <strong>93 %</strong> of the theoretical space (16 807 × 6 patterns), yet only about <strong>15 %</strong> of real matches go the full distance. That mismatch alone wipes out a huge swath of cells.</p>
    </li>
    <li>
      <p><strong>Correlation between sets</strong>
Set scores are not independent draws: if one player is dominating (say 6-1, 6-2) the conditional chance of the next set being a nail-biter (7-6) is lower than the unconditional 1/7. Correlation compresses the practical outcome space far below the naïve $7^{k}$.</p>
    </li>
  </ol>

  <p>Put differently:</p>

\[\text{Observed coverage}\;=\;\frac{\text{≈ 9 200 distinct finals}}{108 388} \;\approx\; 8.5\%.\]

  <p>Given the <strong>combinatorial overhang</strong>, the <strong>skewed probability distribution</strong>, and the <strong>shortfall of five-set matches</strong>, landing in the single-digit-percent range is not a surprise—it is the mathematically natural result.</p>
</blockquote>

<h1 id="tech-stack">Tech Stack</h1>

<p>Rather than just word vomit what I used here, I’m actually going to just defer to <a href="https://www.anthropic.com/claude-code">Claude</a> and then summarize with a <a href="https://www.mermaidchart.com/">Mermaid</a> mindmap. What a sentence.</p>

<pre><code class="language-mermaid">mindmap
  root((Tennis Scorigami))
    Frontend
      Next.js 14
        App Router
        RSC
        API Routes
      React 19
      TypeScript
      Tailwind CSS
        Custom Tennis Theme
        Dark Mode
      UI Libraries
        Radix UI
        Shadcn
      Visualization
        SigmaJS
        3D Force Graph
        Dynamic Imports
      State
        Jotai
        React Query
    Backend
      Python
        SQLAlchemy
        Alembic
        Pydantic
      Data Sources
        Sackmann Datasets
        SportRadar API
    Database
      PostgreSQL
        Neon
          Neon MCP
          Connection Pooling
        Supabase
        Aiven
      Drizzle ORM
        Type Safety
        Migrations
      Optimization
        Materialized Views
        Strategic Indexes
        Graph Structure
    Infrastructure
      Deployment
        Vercel
        Edge Functions
      Monitoring
        PostHog Analytics
        Error Tracking
      Performance
        CDN Caching
        SSL/TLS
        Turbopack
</code></pre>

<div class="image-caption">Technology stack mindmap</div>
<p><br /></p>

<p>Here’s a fun example of Neon MCP plugging away with Claude Code (feel free to click on the image to magnify):</p>

<p><img src="/images/tennis-scorigami/neon-mcp-example.png" alt="neon-mcp-example" class="center-small lightbox-image" /></p>

<h1 id="engineering--design">Engineering + Design</h1>

<p>Again, I’m burnt on time and this blog post (always willing to chat more about it), so I’m just going to summarize with a Claude generated mermaid diagram as well. I know it’s a bit tough to see, so feel free to zoom in.</p>

<pre><code class="language-mermaid">graph TB
    %%=== Node Classes for Better Readability ===%%
    classDef source fill:#4F46E5,color:#fff,stroke:#4338CA,stroke-width:2px
    classDef db fill:#059669,color:#fff,stroke:#047857,stroke-width:2px
    classDef pipeline fill:#EA580C,color:#fff,stroke:#C2410C,stroke-width:2px
    classDef api fill:#DC2626,color:#fff,stroke:#B91C1C,stroke-width:2px
    classDef frontend fill:#7C3AED,color:#fff,stroke:#6D28D9,stroke-width:2px
    classDef infra fill:#0891B2,color:#fff,stroke:#0E7490,stroke-width:2px
    classDef viz fill:#DB2777,color:#fff,stroke:#BE185D,stroke-width:2px
    classDef analytics fill:#16A34A,color:#fff,stroke:#15803D,stroke-width:2px
    classDef neutral fill:#6B7280,color:#fff,stroke:#4B5563,stroke-width:2px
    classDef faded fill:#E5E7EB,color:#374151,stroke:#D1D5DB,stroke-width:1px

    %%=== Data Sources ===%%
    subgraph "Data Sources"
        DS1[Jeff Sackmann Datasets]
        DS2[SportRadar API]
        class DS1,DS2 source
    end

    %%=== Data Pipeline ===%%
    subgraph "Data Pipeline (Python)"
        PY1[SQLAlchemy Models]
        PY2[Data Ingestion Scripts]
        PY3[Alembic Migrations]
        PY4[Batch Processing]
        class PY1,PY2,PY3,PY4 pipeline

        DS1 --&gt; PY2
        DS2 --&gt; PY2
        PY2 --&gt; PY1
        PY1 --&gt; PY4
        PY3 --&gt; PY1
    end

    %%=== Database ===%%
    subgraph "Database Layer (PostgreSQL on Neon)"
        DB1[(Main Tables)]
        DB2[(Score Sequences)]
        DB3[(Materialized Views)]
        class DB1,DB2,DB3 db

        DB1 --&gt; |"players, matches, tournaments"| DB2
        DB2 --&gt; |"Pre-computed aggregations"| DB3

        subgraph "Materialized Views"
            MV1[mv_graph_sets_3_men]
            MV2[mv_graph_sets_5_men]
            MV3[mv_graph_sets_3_women]
            MV4[...more views]
            class MV1,MV2,MV3,MV4 db
        end

        DB3 --&gt; MV1
        DB3 --&gt; MV2
        DB3 --&gt; MV3
        DB3 --&gt; MV4
    end

    %%=== API Layer ===%%
    subgraph "API Layer (Next.js)"
        API1["/api/v1/matches"]
        API2["/api/v1/graph"]
        API3["/api/v1/search"]
        API4["/api/v1/filters"]
        API5["/api/v1/scores"]
        class API1,API2,API3,API4,API5 api

        CACHE[Cache Layer&lt;br/&gt;5 min revalidation]
        class CACHE neutral

        API1 --&gt; CACHE
        API2 --&gt; CACHE
        API3 --&gt; CACHE
        API4 --&gt; CACHE
        API5 --&gt; CACHE
    end

    %%=== Frontend ===%%
    subgraph "Frontend (Next.js App Router)"
        subgraph "Pages"
            P1[Home Page]
            P2[Explore Page]
            P3[Search Page]
            P4[About]
            class P1,P2,P3,P4 frontend
        end

        subgraph "State Management"
            J1[Jotai Atoms]
            J2[Graph Controls]
            J3[Filter State]
            class J1,J2,J3 frontend
        end

        subgraph "Visualization Components"
            V1[SigmaJS 2D Graph]
            V2[3D Force Graph]
            V3[Streaming Graph]
            class V1,V2,V3 viz
        end

        P2 --&gt; V1
        P2 --&gt; V2
        P2 --&gt; V3

        J1 --&gt; J2
        J1 --&gt; J3
        J2 --&gt; V1
        J2 --&gt; V2
        J3 --&gt; API2
    end

    %%=== Infrastructure ===%%
    subgraph "Infrastructure"
        I1[Drizzle ORM]
        I2[Connection Pool]
        I3[SSL/TLS]
        I4[PostHog Analytics]
        I5[Turbopack]
        class I1,I2,I3,I5 infra
        class I4 analytics
    end

    PY4 --&gt; DB1
    DB3 --&gt; I1
    I1 --&gt; I2
    I2 --&gt; API1
    I2 --&gt; API2
    I2 --&gt; API3
    I2 --&gt; API4
    I2 --&gt; API5

    CACHE --&gt; P1
    CACHE --&gt; P2
    CACHE --&gt; P3
    CACHE --&gt; P4

    P1 --&gt; I4
    P2 --&gt; I4
</code></pre>

<h1 id="other-fun-visualizations">Other Fun Visualizations</h1>

<h2 id="player-rank-history">Player Rank History</h2>

<p>I was intrigued when I saw that <a href="https://github.com/jeffsackmann">Sackmann</a> also had player rank history week over week. There’s more I want to do with the application to <a href="https://www.tennis-scorigami.com/">tennis-scorigami</a> but for now, I thought it was fun to create some of these visualizations:</p>

<p><img src="/images/tennis-scorigami/atp_number_one_timeline.png" alt="atp-number-one-timeline" class="center-image lightbox-image" /></p>

<div class="image-caption">ATP #1 Ranking Timeline</div>
<p><br /></p>

<p><img src="/images/tennis-scorigami/wta_number_one_timeline.png" alt="wta-number-one-timeline" class="center-image lightbox-image" /></p>

<div class="image-caption">WTA #1 Ranking Timeline</div>
<p><br /></p>

<h1 id="conclusion">Conclusion</h1>

<p>This was an awesome project to work on and I still think there’s a ton we could do here. If you find any data quality issues, please reach out. Some thoughts about where we could take this:</p>

<ul>
  <li>Richer match data leading -&gt; embeddings -&gt; vector search</li>
  <li>WebRTC for real time collaboration in some way</li>
  <li>Popularity of searches</li>
  <li>More exposure of conditional probabilities given a player and a score, what might happen next
    <ul>
      <li>Perhaps leverage this into some type of betting.</li>
    </ul>
  </li>
</ul>

<p>As always, feel free to reach out with any questions.</p>]]></content><author><name>johnlarkin1</name></author><category term="⭐️ Favorites" /><category term="Development" /><category term="Friends" /><summary type="html"><![CDATA[This post is going to be focused on discussing how we built our Tennis Scorigami project from a technical standpoint. I’ll discuss the current architecture, some of the design decisions I made, and where I want the project to go next.]]></summary></entry><entry><title type="html">Walk in the Parquet</title><link href="https://johnlarkin1.github.io/2025/walk-in-the-parquet/" rel="alternate" type="text/html" title="Walk in the Parquet" /><published>2025-03-31T00:00:00+00:00</published><updated>2025-03-31T00:00:00+00:00</updated><id>https://johnlarkin1.github.io/2025/walk-in-the-parquet</id><content type="html" xml:base="https://johnlarkin1.github.io/2025/walk-in-the-parquet/"><![CDATA[<p>At <a href="https://mojo.com/">Mojo</a>, we use <a href="https://parquet.apache.org/">Parquet</a> files to store some of our simulation data. I - however - have been increasingly frustrated by the lack of support on MacOS to natively view them. They are (normally) compressed through a <a href="https://en.wikipedia.org/wiki/Snappy_(compression)">snappy</a> algorithm, and Apple doesn’t have a native application to open them.</p>

<p>So I decided to build one - to help myself out, my teammates at work out, and hopefully some other random engineers out in the wild. In the very least, this blog post will detail how you can build your own desktop application, specifically in this case using <a href="https://tauri.app/">Tauri</a>.</p>

<p>There’s more information (i.e. lame marketing) here: <a href="https://www.walkintheparquet.com/">walkintheparquet.com</a>. Here’s an iframe if don’t want to leave this page:</p>

<div style="text-align: center;">
<div style="max-width: 800px; margin: 0 auto; box-shadow: 0 12px 28px rgba(0, 0, 0, 0.4), 0 0 0 1px rgba(255, 255, 255, 0.05); border-radius: 4px; overflow: hidden; background-color: #1a1a1a; padding: 15px;">
<iframe src="https://www.walkintheparquet.com/" width="100%" height="450px" frameborder="0" allowfullscreen=""></iframe>
</div>
</div>

<p><br /></p>

<div class="markdown-alert markdown-alert-tip">
<p>Also! If you have feature requests or issues, you can head over to the Canny board for this project and leave some notes. There's a link at the bottom of the main website (/ iframe above), or you can go <a href="https://walk-in-the-parquet.canny.io/">here</a>. Obviously feel free to email me too!</p>
</div>

<!--
# Table of Contents

- [Table of Contents](#table-of-contents)
- [Driving Motivation](#driving-motivation)
  - [Are there really no other solutions?](#are-there-really-no-other-solutions)
- [What is Parquet?](#what-is-parquet)
- [What's the Problem?](#whats-the-problem)
- [Engineering + Design](#engineering--design)
  - [Desktop Decisions](#desktop-decisions)
  - [Challenges and Conquests](#challenges-and-conquests)
    - [Documentation](#documentation)
    - [Supporting Structs](#supporting-structs)
    - [App Store Annoyance](#app-store-annoyance)
- [Conclusion](#conclusion)
  - [Kudos](#kudos)
-->

<h1 id="driving-motivation">Driving Motivation</h1>

<p>To inspire you a little bit, here’s what we’ve built. This is also available (again) <a href="https://www.walkintheparquet.com/">here</a>, but also available on the <a href="https://apps.apple.com/us/app/walk-in-the-parquet/id6743959514?mt=12">App Store</a>.</p>

<p><img src="/images/walk-in-the-parquet/slideshow.gif" alt="walk-in-the-parquet-show" class="center-image" /></p>

<p>And here it is in the <a href="https://apps.apple.com/us/app/walk-in-the-parquet/id6743959514?mt=12">App Store</a>:</p>

<p><img src="/images/walk-in-the-parquet/app-store.png" alt="app-store" class="center-small" /></p>

<p>Again, this blog post is going to talk a bit more about the actual building process, but if you want to see more about the product and download it, head over to the <a href="https://www.walkintheparquet.com/">main website</a>.</p>

<p>And before we go any deeper, I know there’s this question…</p>

<h2 id="are-there-really-no-other-solutions">Are there really no other solutions?</h2>

<p>Yeah I mean there’s this:</p>

<div style="display: flex; justify-content: center; align-items: center;">
    <div style="margin: 10px;">
        <img src="/images/walk-in-the-parquet/other-app-store-pt1.png" alt="Image 1" style="width: 300px; height: auto;" />
    </div>
    <div style="margin: 10px;">
        <img src="/images/walk-in-the-parquet/other-app-store-pt2.png" alt="Image 2" style="width: 600px; height: auto;" />
    </div>
</div>
<div class="image-caption">Can't wait until people are leaving me the same reviews. But hey mine is free. And it's a v1. </div>
<p><br /></p>

<p>Which…. yeah don’t love people trying to charge for that.</p>

<p>The one that I’ve seen the most is this: <a href="https://www.parquet-viewer.com/">https://www.parquet-viewer.com/</a>, which I use but they <em>also</em> have some dumb paywalling features. And to be honest, I don’t really love uploading potentially sensitive files to the web.</p>

<p>And finally, it seems like the third alternative is a VSCode Extension, but the downside there is it’s just <code class="language-plaintext highlighter-rouge">json</code> I believe. Again, totally fine - I won’t be upset if you want to do that. It’s not as smooth to query or see some top level analytics, but c’est la vie.</p>

<h1 id="what-is-parquet">What is Parquet?</h1>

<p>Let’s back up a little bit for those not even familiar with Parquet. I won’t go into too much detail because there’s enough information out there on the web, but I’ll give a high level overview.</p>

<p><a href="https://parquet.apache.org/">Parquet</a> is a file type that is optimized and used prevalently for data processing frameworks. It was introduced by Apache and has numerous benefits, some of which I’ll get into below.</p>

<p>The big distinguisher for Parquet is that it’s a columnar storage format. This has some big wins especially in terms of compression. For example, if you were storing numerous of the same types of player-ids then you can have a vastly higher compression rate given the column is going to have a lot of redundancy.</p>

<p>These columns are then going to be split into row-groups. Row groups are just logical partitions of the data into separate rows.</p>

<p>These row groups are incredibly clutch when it comes to parallel processing because it lets them be read in parallel. Furthermore, the optimization is done so that only <em>relevant</em> row groups are read.</p>

<p>So this then becomes a bit of a hyperparamter performance optimization question right? What’s the ideal number to set for your row-groups? Well… yeah this is a bit of experimentation. There’s surprisingly little documentation around what is best, but generally you want a trade off between compression and performance. Generally, I’ve seen people suggesting you want your rowgroup to be around 128MB to 512MB. AWS seems to default it to be 128MB <a href="https://docs.aws.amazon.com/prescriptive-guidance/latest/apache-iceberg-on-aws/best-practices-read.html">here</a>.</p>

<p>You can think about it like so:</p>

<table>
  <thead>
    <tr>
      <th>Row Group Size</th>
      <th>Pros</th>
      <th>Cons</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>Larger Row Groups</strong></td>
      <td>- Better read performance (fewer metadata reads, more sequential IO)</td>
      <td>- Higher memory usage during write</td>
    </tr>
    <tr>
      <td> </td>
      <td>- Better compression (larger chunks compress more efficiently)</td>
      <td>- Slower write performance if memory is constrained</td>
    </tr>
    <tr>
      <td><strong>Smaller Row Groups</strong></td>
      <td>- Lower memory usage during write</td>
      <td>- Slower reads (more metadata overhead and disk seeks)</td>
    </tr>
    <tr>
      <td> </td>
      <td>- Faster writes in streaming or frequent-flush scenarios</td>
      <td>- Worse compression</td>
    </tr>
    <tr>
      <td> </td>
      <td> </td>
      <td>- Less effective filtering (min/max stats less meaningful)</td>
    </tr>
  </tbody>
</table>

<p>There’s also other fields like column chunks and pages. The best overview I’ve seen is actually from <a href="https://celerdata.com/">CelerData</a> <a href="https://celerdata.com/glossary/parquet-file-format">here</a>.</p>

<p>This image from CelerData does a good job breaking out the different parts of the underlying structure:</p>

<p><img src="https://parquet.apache.org/images/FileLayout.gif" alt="parquet-file-layout" class="center-shrink" /></p>

<div class="image-caption">Full credit to CelerData for the image</div>
<p><br /></p>

<p>But! If you don’t like that one, noooo worries. Databricks has $62B valuation and they also wrote about it <a href="https://www.databricks.com/glossary/what-is-parquet">here</a>. So feel free to check out some other links.</p>

<h1 id="whats-the-problem">What’s the Problem?</h1>

<p>Well, the problem that I wanted to address is that there’s not a great way to open these files. I discussed some alternatives and their downsides above, but it’s dumb that I couldn’t have everything local (excluding a paywalled App Store app) or I’d have to upload things to the web and some dude’s random server.</p>

<p><img src="/images/walk-in-the-parquet/no-default-application.png" alt="app-store" class="center-small" /></p>

<p>The other problem? I haven’t worked with Rust in awhile, and I still desperately want to get better at it, so that was the selfish motivation. It’s a borderline smooth transition into the next section.</p>

<h1 id="engineering--design">Engineering + Design</h1>

<h2 id="desktop-decisions">Desktop Decisions</h2>

<p>Ah the desktop application game - what a question.</p>

<p>Now I’ve worked with <a href="https://www.electronjs.org/">Electron</a> at <a href="https://www.dropbox.com/">Dropbox</a> so I was familiar with generally that architecture and paradigm. It has been a minute since I’ve dealt with <a href="https://www.electronjs.org/docs/latest/tutorial/process-model#preload-scripts">preload scripts</a> or the <a href="https://www.electronjs.org/docs/latest/tutorial/ipc">ipcMain vs ipcRenderer distinction</a>.</p>

<p>The downside (in this case) and why I didn’t choose Electron was because I didn’t really want an all Typescript backend.</p>

<p>Truthfully, I really wanted a Python backend, both because that’s what I’m best at, but also because I wanted to use <a href="https://duckdb.org/"><code class="language-plaintext highlighter-rouge">duckdb</code></a> for loading in the files and doing analysis quickly, on-disk, and keeping things lightweight. I haven’t loaded Parquet files in Typescript before, and I’ve also been seeing more about <a href="https://tauri.app/">Tauri</a> and figured that it was a better use case.</p>

<p>Additionally, I know from googling Apache Parquet documentation (we integrate in <a href="https://go.dev/">Golang</a>, <a href="https://cplusplus.com/">C++</a>, and <a href="https://www.python.org/">Python</a> all at work) that they DO have Rust support. I know this because I personally think that most of the documentation put out from Apache blows. The other noticeable benefit of using Rust and Tauri is that Tauri is a lot lighter weight of a desktop application.</p>

<p><a href="https://www.coditation.com/">Coditation</a> sums it up pretty well below (<a href="https://www.coditation.com/blog/electron-vs-tauri#:~:text=and%20CPU%20resources.-,Tauri%20is%20designed%20to%20be%20more%20lightweight%20and%20faster%20than,run%20more%20efficiently%20than%20Electron.">ref</a>):</p>

<blockquote>
  <p>Tauri is designed to be more lightweight and faster than Electron, as it uses less memory and CPU resources, which means that Tauri is designed to run more efficiently than Electron.</p>

  <p>Tauri uses Rust as a native layer instead of JavaScript and web technologies, which results in lower memory usage and CPU usage compared to Electron. Additionally, Tauri is also designed to be more lightweight overall, which means that it has less overhead and a smaller binary size than Electron.</p>
</blockquote>

<p>In other words,</p>

<table>
  <thead>
    <tr>
      <th>Feature</th>
      <th>Electron</th>
      <th>Tauri (what I picked)</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Backend Lang</td>
      <td>JavaScript/TS</td>
      <td>Rust</td>
    </tr>
    <tr>
      <td>Binary Size</td>
      <td>Large</td>
      <td>Small</td>
    </tr>
    <tr>
      <td>Memory Usage</td>
      <td>Higher</td>
      <td>Lower</td>
    </tr>
  </tbody>
</table>

<h2 id="challenges-and-conquests">Challenges and Conquests</h2>

<h3 id="documentation">Documentation</h3>

<p>So this is a core part of it, but recently, I have been one of the many to get hit with the “yoooo how much did you vibe code”. There have been <a href="https://www.reddit.com/r/ProgrammerHumor/comments/1jcjrzf/vibecoding/?utm_source=share&amp;utm_medium=web3x&amp;utm_name=web3xcss&amp;utm_term=1&amp;utm_content=share_button">many</a> <a href="https://preview.redd.it/viberagingnow-v0-1vjd0a87owpe1.jpeg?auto=webp&amp;s=00830b4959b1426e6280068dd59b528257aa8c3b">good</a> <a href="https://preview.redd.it/vibe-coding-v0-hwsv07yperre1.jpeg?auto=webp&amp;s=004dcaea56ead53a5c453efa24d93d174865fa57">memes</a> about this.</p>

<p>The thing is <strong>I basically did vibecode the entire website. NextJS, simple lightweight static frontend is a perfect use case for it</strong>. I’m not at all a frontend designer and so yeah, of course I’m not going to be ripping that manually or going into Figma first or anything like that. So that was lovely. Way faster and way quicker to ship.</p>

<p>The interesting part (at least for me) was how best to architect this with Tauri and have that handoff. The challenges were about that design, as well as the blatant lack of LLMs that are trained on Tauriv2 and the parquet versions I was using in Rust.</p>

<p>Specifically,</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>arrow = "54.3.0"
arrow-schema = "54.3.0"
parquet = "54.3.0"
</code></pre></div></div>

<p>these crates had virtually no LLM support (what a breath of fresh air).</p>

<h3 id="supporting-structs">Supporting Structs</h3>

<p>As a result, it meant using the documentation and figuring out exactly why some of my string data was being parsed as a <code class="language-plaintext highlighter-rouge">Utf8View</code> vs a <code class="language-plaintext highlighter-rouge">Utf8</code>.</p>

<p>In terms of code, it meant that I had in my <code class="language-plaintext highlighter-rouge">sql.rs</code> parsing engine, a match statement (one of Rust’s best features imo) like this:</p>

<div class="language-rust highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// ... many more types before</span>

        <span class="nn">DataType</span><span class="p">::</span><span class="n">Int64</span> <span class="k">=&gt;</span> <span class="p">{</span>
            <span class="k">let</span> <span class="n">array</span> <span class="o">=</span> <span class="n">column</span>
                <span class="nf">.as_any</span><span class="p">()</span>
                <span class="py">.downcast_ref</span><span class="p">::</span><span class="o">&lt;</span><span class="n">Int64Array</span><span class="o">&gt;</span><span class="p">()</span>
                <span class="nf">.ok_or_else</span><span class="p">(||</span> <span class="nn">QueryError</span><span class="p">::</span><span class="nf">Other</span><span class="p">(</span><span class="s">"Failed to downcast to Int64Array"</span><span class="nf">.to_string</span><span class="p">()))</span><span class="o">?</span><span class="p">;</span>
            <span class="nn">serde_json</span><span class="p">::</span><span class="nn">Value</span><span class="p">::</span><span class="nf">Number</span><span class="p">(</span><span class="nn">serde_json</span><span class="p">::</span><span class="nn">Number</span><span class="p">::</span><span class="nf">from</span><span class="p">(</span><span class="n">array</span><span class="nf">.value</span><span class="p">(</span><span class="n">row_idx</span><span class="p">)))</span>
        <span class="p">}</span>

        <span class="nn">DataType</span><span class="p">::</span><span class="n">UInt8</span> <span class="k">=&gt;</span> <span class="p">{</span>
            <span class="k">let</span> <span class="n">array</span> <span class="o">=</span> <span class="n">column</span>
                <span class="nf">.as_any</span><span class="p">()</span>
                <span class="py">.downcast_ref</span><span class="p">::</span><span class="o">&lt;</span><span class="n">UInt8Array</span><span class="o">&gt;</span><span class="p">()</span>
                <span class="nf">.ok_or_else</span><span class="p">(||</span> <span class="nn">QueryError</span><span class="p">::</span><span class="nf">Other</span><span class="p">(</span><span class="s">"Failed to downcast to UInt8Array"</span><span class="nf">.to_string</span><span class="p">()))</span><span class="o">?</span><span class="p">;</span>
            <span class="nn">serde_json</span><span class="p">::</span><span class="nn">Value</span><span class="p">::</span><span class="nf">Number</span><span class="p">(</span><span class="nn">serde_json</span><span class="p">::</span><span class="nn">Number</span><span class="p">::</span><span class="nf">from</span><span class="p">(</span><span class="n">array</span><span class="nf">.value</span><span class="p">(</span><span class="n">row_idx</span><span class="p">)</span> <span class="k">as</span> <span class="nb">u64</span><span class="p">))</span>
        <span class="p">}</span>

        <span class="nn">DataType</span><span class="p">::</span><span class="n">UInt16</span> <span class="k">=&gt;</span> <span class="p">{</span>
            <span class="k">let</span> <span class="n">array</span> <span class="o">=</span> <span class="n">column</span>
                <span class="nf">.as_any</span><span class="p">()</span>
                <span class="py">.downcast_ref</span><span class="p">::</span><span class="o">&lt;</span><span class="n">UInt16Array</span><span class="o">&gt;</span><span class="p">()</span>
                <span class="nf">.ok_or_else</span><span class="p">(||</span> <span class="p">{</span>
                    <span class="nn">QueryError</span><span class="p">::</span><span class="nf">Other</span><span class="p">(</span><span class="s">"Failed to downcast to UInt16Array"</span><span class="nf">.to_string</span><span class="p">())</span>
                <span class="p">})</span><span class="o">?</span><span class="p">;</span>
            <span class="nn">serde_json</span><span class="p">::</span><span class="nn">Value</span><span class="p">::</span><span class="nf">Number</span><span class="p">(</span><span class="nn">serde_json</span><span class="p">::</span><span class="nn">Number</span><span class="p">::</span><span class="nf">from</span><span class="p">(</span><span class="n">array</span><span class="nf">.value</span><span class="p">(</span><span class="n">row_idx</span><span class="p">)</span> <span class="k">as</span> <span class="nb">u64</span><span class="p">))</span>
        <span class="p">}</span>

<span class="c1">// ... many more types after</span>
</code></pre></div></div>

<p>There are numerous <code class="language-plaintext highlighter-rouge">DataType</code>s that get pulled in with <code class="language-plaintext highlighter-rouge">use datafusion::arrow::datatypes::*;</code>. I tried to handle most, but yeah of course Parquet files can be increasingly complex so as a <code class="language-plaintext highlighter-rouge">v1.0.0</code> I am not promising to have entire support. There is basic support for nested structures as seen here:</p>

<p><img src="/images/walk-in-the-parquet/nested-structure.png" alt="app-store" class="center-shrink" /></p>

<p>However, handling this recursively is a bit of a challenge. I am expecting there to be some corner cases that I missed.</p>

<h3 id="app-store-annoyance">App Store Annoyance</h3>

<p>By far however, the biggest learning I had was about bundling up a package for the App Store and the numerous steps to get that going.</p>

<p>There’s already quite a bit out there about <a href="https://developer.apple.com/documentation/security/notarizing-macos-software-before-distribution">notarization</a> and <a href="https://support.apple.com/guide/security/app-code-signing-process-sec7c917bf14/web">code-signing</a>, but I think the most helpful thing was putting this all in <code class="language-plaintext highlighter-rouge">post-build.sh</code> script.</p>

<p>So basically after running this:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>╭─johnlarkin@Mac ~/Documents/coding/walk-in-the-parquet ‹main›
╰─➤  npm run tauri build <span class="nt">--</span> <span class="nt">--target</span> universal-apple-darwin
</code></pre></div></div>

<p>You can get a <a href="https://v1.tauri.app/v1/guides/building/macos#binary-targets">fat client</a> from Tauri that will support both Apple silicon and Intel-based archs.</p>

<p>After doing that, then what I finally got set up was this type of setup (thanks Claude for some of the emojis):</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c">#!/bin/bash</span>
<span class="nb">set</span> <span class="nt">-euo</span> pipefail

<span class="nv">APP_NAME</span><span class="o">=</span><span class="s2">"Walk in the Parquet"</span>
<span class="nv">ENTITLEMENTS_PATH</span><span class="o">=</span><span class="s2">"src-tauri/entitlements.plist"</span>
<span class="nv">PKG_NAME</span><span class="o">=</span><span class="s2">"WalkInTheParquet.pkg"</span>

<span class="c"># I was trying to automatically detect the app-path / dmg-path but what was happening</span>
<span class="c"># was I was occassionalyl picking the wrong app / dmg and then yeah i was too lazy to fix this</span>
<span class="c"># APP_PATH=$(find src-tauri/target/universal-apple-darwin -type d -name "$APP_NAME.app" | head -n 1)</span>
<span class="nv">APP_PATH</span><span class="o">=</span><span class="s2">"src-tauri/target/universal-apple-darwin/release/bundle/macos/Walk in the Parquet.app"</span>
<span class="nv">DMG_PATH</span><span class="o">=</span><span class="s2">"src-tauri/target/universal-apple-darwin/release/bundle/dmg/Walk in the Parquet_1.0.0_universal.dmg"</span>

<span class="c"># you basically need `APPLE_ISSUER_ID`, `APPLE_PARQUET_KEY_ID`</span>
<span class="c"># set up in your env</span>
<span class="k">if</span> <span class="o">[[</span> <span class="nt">-z</span> <span class="s2">"</span><span class="k">${</span><span class="nv">APPLE_ISSUER_ID</span><span class="k">:-}</span><span class="s2">"</span> <span class="o">]]</span><span class="p">;</span> <span class="k">then
  </span><span class="nb">echo</span> <span class="s2">"🚨 Error: Environment variable APPLE_ISSUER_ID is not set"</span>
  <span class="nb">exit </span>1
<span class="k">else
  </span><span class="nb">echo</span> <span class="s2">"✅ Environment variable APPLE_ISSUER_ID is set"</span>
<span class="k">fi

if</span> <span class="o">[[</span> <span class="nt">-z</span> <span class="s2">"</span><span class="k">${</span><span class="nv">APPLE_PARQUET_KEY_ID</span><span class="k">:-}</span><span class="s2">"</span> <span class="o">]]</span><span class="p">;</span> <span class="k">then
  </span><span class="nb">echo</span> <span class="s2">"🚨 Error: Environment variable APPLE_PARQUET_KEY_ID is not set"</span>
  <span class="nb">exit </span>1
<span class="k">else
  </span><span class="nb">echo</span> <span class="s2">"✅ Environment variable APPLE_PARQUET_KEY_ID is set"</span>
<span class="k">fi
</span><span class="nb">echo</span> <span class="s2">"🔑 All required environment variables are set"</span>

<span class="k">if</span> <span class="o">[[</span> <span class="o">!</span> <span class="nt">-d</span> <span class="s2">"</span><span class="nv">$APP_PATH</span><span class="s2">"</span> <span class="o">]]</span><span class="p">;</span> <span class="k">then
  </span><span class="nb">echo</span> <span class="s2">"🚨 Error: .app bundle not found at expected path: </span><span class="nv">$APP_PATH</span><span class="s2">"</span>
  <span class="nb">exit </span>1
<span class="k">else
  </span><span class="nb">echo</span> <span class="s2">"✅ .app bundle found at: </span><span class="nv">$APP_PATH</span><span class="s2">"</span>
<span class="k">fi
if</span> <span class="o">[[</span> <span class="nt">-z</span> <span class="s2">"</span><span class="nv">$DMG_PATH</span><span class="s2">"</span> <span class="o">]]</span><span class="p">;</span> <span class="k">then
  </span><span class="nb">echo</span> <span class="s2">"🚨 Error: DMG file not found!"</span>
  <span class="nb">exit </span>1
<span class="k">else
  </span><span class="nb">echo</span> <span class="s2">"✅ DMG file found at: </span><span class="nv">$DMG_PATH</span><span class="s2">"</span>
<span class="k">fi

</span><span class="nb">echo</span> <span class="s2">"🔐 Re-signing .app with entitlements using 3rd Party Application cert..."</span>
codesign <span class="nt">--entitlements</span> <span class="s2">"</span><span class="nv">$ENTITLEMENTS_PATH</span><span class="s2">"</span> <span class="nt">--deep</span> <span class="nt">--force</span> <span class="nt">--options</span> runtime <span class="se">\</span>
  <span class="nt">--sign</span> &lt;REDACTED-BUT-PUT-YOUR-KEYCHAIN-NAME-HERE&gt; <span class="s2">"</span><span class="nv">$APP_PATH</span><span class="s2">"</span>

<span class="nb">echo</span> <span class="s2">"🧳 Rebuilding and signing .pkg with 3rd Party Installer cert..."</span>
productbuild <span class="se">\</span>
  <span class="nt">--component</span> <span class="s2">"</span><span class="nv">$APP_PATH</span><span class="s2">"</span> /Applications <span class="se">\</span>
  <span class="nt">--sign</span> &lt;REDACTED-BUT-PUT-YOUR-KEYCHAIN-NAME-HERE&gt; <span class="se">\</span>
  <span class="s2">"</span><span class="nv">$PKG_NAME</span><span class="s2">"</span>

<span class="nb">echo</span> <span class="s2">"🚀 Submitting DMG to notarization..."</span>
xcrun notarytool submit <span class="s2">"</span><span class="nv">$DMG_PATH</span><span class="s2">"</span> <span class="se">\</span>
  <span class="nt">--key</span> &lt;THIS IS THE PATH TO YOUR p8 KEY you downloaded from APPLE&gt; <span class="se">\</span>
  <span class="nt">--key-id</span> <span class="s2">"</span><span class="nv">$APPLE_PARQUET_KEY_ID</span><span class="s2">"</span> <span class="se">\</span>
  <span class="nt">--issuer</span> <span class="s2">"</span><span class="nv">$APPLE_ISSUER_ID</span><span class="s2">"</span> <span class="se">\</span>
  <span class="nt">--keychain-profile</span> <span class="s2">"notarytool-password"</span> <span class="se">\</span>
  <span class="nt">--wait</span>

<span class="c"># this is check the arch, staple, validate staple steps</span>
lipo <span class="nt">-info</span> <span class="s2">"</span><span class="nv">$APP_PATH</span><span class="s2">/Contents/MacOS/walk-in-the-parquet"</span>
xcrun stapler staple <span class="s2">"</span><span class="nv">$DMG_PATH</span><span class="s2">"</span>
xcrun stapler validate <span class="s2">"</span><span class="nv">$DMG_PATH</span><span class="s2">"</span>
hdiutil imageinfo <span class="s2">"</span><span class="nv">$DMG_PATH</span><span class="s2">"</span> | <span class="nb">grep </span>Format

<span class="nb">echo</span> <span class="s2">"📦 .pkg ready to be uploaded via Transporter:"</span>
<span class="nb">echo</span> <span class="s2">"   -&gt; </span><span class="nv">$PKG_NAME</span><span class="s2">"</span>
<span class="nb">echo</span> <span class="s2">""</span>
<span class="nb">echo</span> <span class="s2">"🚀 Open Transporter and upload the package manually if needed."</span>
</code></pre></div></div>

<p>The key parts are that you’ll want your <code class="language-plaintext highlighter-rouge">--sign</code> argument to be <code class="language-plaintext highlighter-rouge">3rd Party Mac Developer Application &lt;etc&gt;</code>. That is your 3rd party developer application that you can use for signing</p>

<h1 id="conclusion">Conclusion</h1>

<p>Anyway, I have sunk more time than allocated into this, but it was a fun project, and I’m looking forward to working on this in the future. If you have issues or feedback requests, feel free to blow up that Canny board.</p>

<p>Enjoy the application and I hope I’ve helped some random stranger out there.</p>

<h2 id="kudos">Kudos</h2>

<p>Oh also thank you to my girlfriend for coming up with the name. Better than what I could have thought of.</p>]]></content><author><name>johnlarkin1</name></author><category term="⭐️ Favorites" /><category term="Development" /><summary type="html"><![CDATA[At Mojo, we use Parquet files to store some of our simulation data. I - however - have been increasingly frustrated by the lack of support on MacOS to natively view them. They are (normally) compressed through a snappy algorithm, and Apple doesn’t have a native application to open them.]]></summary></entry></feed>