<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom"><title>Eli Bendersky's website - Machine Learning</title><link href="https://eli.thegreenplace.net/" rel="alternate"></link><link href="https://eli.thegreenplace.net/feeds/machine-learning.atom.xml" rel="self"></link><id>https://eli.thegreenplace.net/</id><updated>2026-02-05T03:38:39-08:00</updated><entry><title>Rewriting pycparser with the help of an LLM</title><link href="https://eli.thegreenplace.net/2026/rewriting-pycparser-with-the-help-of-an-llm/" rel="alternate"></link><published>2026-02-04T19:35:00-08:00</published><updated>2026-02-05T03:38:39-08:00</updated><author><name>Eli Bendersky</name></author><id>tag:eli.thegreenplace.net,2026-02-04:/2026/rewriting-pycparser-with-the-help-of-an-llm/</id><summary type="html">&lt;p&gt;&lt;a class="reference external" href="https://github.com/eliben/pycparser"&gt;pycparser&lt;/a&gt; is my most widely used open
source project (with ~20M daily downloads from PyPI &lt;a class="footnote-reference" href="#footnote-1" id="footnote-reference-1"&gt;[1]&lt;/a&gt;). It's a pure-Python
parser for the C programming language, producing ASTs inspired by &lt;a class="reference external" href="https://docs.python.org/3/library/ast.html"&gt;Python's
own&lt;/a&gt;. Until very recently, it's
been using &lt;a class="reference external" href="https://www.dabeaz.com/ply/ply.html"&gt;PLY: Python Lex-Yacc&lt;/a&gt; for
the core parsing.&lt;/p&gt;
&lt;p&gt;In this post, I'll describe how …&lt;/p&gt;</summary><content type="html">&lt;p&gt;&lt;a class="reference external" href="https://github.com/eliben/pycparser"&gt;pycparser&lt;/a&gt; is my most widely used open
source project (with ~20M daily downloads from PyPI &lt;a class="footnote-reference" href="#footnote-1" id="footnote-reference-1"&gt;[1]&lt;/a&gt;). It's a pure-Python
parser for the C programming language, producing ASTs inspired by &lt;a class="reference external" href="https://docs.python.org/3/library/ast.html"&gt;Python's
own&lt;/a&gt;. Until very recently, it's
been using &lt;a class="reference external" href="https://www.dabeaz.com/ply/ply.html"&gt;PLY: Python Lex-Yacc&lt;/a&gt; for
the core parsing.&lt;/p&gt;
&lt;p&gt;In this post, I'll describe how I collaborated with an LLM coding agent (Codex)
to help me rewrite pycparser to use a hand-written recursive-descent parser and
remove the dependency on PLY. This has been an interesting experience and the
post contains lots of information and is therefore quite long; if you're just
interested in the final result, check out the latest code of pycparser - the
&lt;tt class="docutils literal"&gt;main&lt;/tt&gt; branch already has the new implementation.&lt;/p&gt;
&lt;img alt="meme picture saying &amp;quot;can't come to bed because my AI agent produced something slightly wrong&amp;quot;" class="align-center" src="https://eli.thegreenplace.net/images/2026/cantcometobed.png" /&gt;
&lt;div class="section" id="the-issues-with-the-existing-parser-implementation"&gt;
&lt;h2&gt;The issues with the existing parser implementation&lt;/h2&gt;
&lt;p&gt;While pycparser has been working well overall, there were a number of nagging
issues that persisted over years.&lt;/p&gt;
&lt;div class="section" id="parsing-strategy-yacc-vs-hand-written-recursive-descent"&gt;
&lt;h3&gt;Parsing strategy: YACC vs. hand-written recursive descent&lt;/h3&gt;
&lt;p&gt;I began working on pycparser in 2008, and back then using a YACC-based approach
for parsing a whole language like C seemed like a no-brainer to me. Isn't this
what everyone does when writing a serious parser? Besides, the K&amp;amp;R2 book
famously carries the entire grammar of the C99 language in an appendix - so it
seemed like a simple matter of translating that to PLY-yacc syntax.&lt;/p&gt;
&lt;p&gt;And indeed, it wasn't &lt;em&gt;too&lt;/em&gt; hard, though there definitely were some complications
in building the ASTs for declarations (C's &lt;a class="reference external" href="https://eli.thegreenplace.net/2008/10/18/implementing-cdecl-with-pycparser"&gt;gnarliest part&lt;/a&gt;).&lt;/p&gt;
&lt;p&gt;Shortly after completing pycparser, I got more and more interested in compilation
and started learning about the different kinds of parsers more seriously. Over
time, I grew convinced that &lt;a class="reference external" href="https://eli.thegreenplace.net/tag/recursive-descent-parsing"&gt;recursive descent&lt;/a&gt; is the way to
go - producing parsers that are easier to understand and maintain (and are often
faster!).&lt;/p&gt;
&lt;p&gt;It all ties in to the &lt;a class="reference external" href="https://eli.thegreenplace.net/2017/benefits-of-dependencies-in-software-projects-as-a-function-of-effort/"&gt;benefits of dependencies in software projects as a
function of effort&lt;/a&gt;.
Using parser generators is a heavy &lt;em&gt;conceptual&lt;/em&gt; dependency: it's really nice
when you have to churn out many parsers for small languages. But when you have
to maintain a single, very complex parser, as part of a large project - the
benefits quickly dissipate and you're left with a substantial dependency that
you constantly grapple with.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="the-other-issue-with-dependencies"&gt;
&lt;h3&gt;The other issue with dependencies&lt;/h3&gt;
&lt;p&gt;And then there are the usual problems with dependencies; dependencies get
abandoned, and they may also develop security issues. Sometimes, both of these
become true.&lt;/p&gt;
&lt;p&gt;Many years ago, pycparser forked and started vendoring its own version of PLY.
This was part of transitioning pycparser to a dual Python 2/3 code base when PLY
was slower to adapt. I believe this was the right decision, since PLY &amp;quot;just
worked&amp;quot; and I didn't have to deal with active (and very tedious in the Python
ecosystem, where packaging tools are replaced faster than dirty socks)
dependency management.&lt;/p&gt;
&lt;p&gt;A couple of weeks ago &lt;a class="reference external" href="https://github.com/eliben/pycparser/issues/588"&gt;this issue&lt;/a&gt;
was opened for pycparser. It turns out the some old PLY code triggers security
checks used by some Linux distributions; while this code was fixed in a later
commit of PLY, PLY itself was apparently abandoned and archived in late 2025.
And guess what? That happened in the middle of a large rewrite of the package,
so re-vendoring the pre-archiving commit seemed like a risky proposition.&lt;/p&gt;
&lt;p&gt;On the issue it was suggested that &amp;quot;hopefully the dependent packages move on to
a non-abandoned parser or implement their own&amp;quot;; I originally laughed this idea
off, but then it got me thinking... which is what this post is all about.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="growing-complexity-of-parsing-a-messy-language"&gt;
&lt;h3&gt;Growing complexity of parsing a messy language&lt;/h3&gt;
&lt;p&gt;The original K&amp;amp;R2 grammar for C99 had - famously - a single shift-reduce
conflict having to do with dangling &lt;tt class="docutils literal"&gt;else&lt;/tt&gt;s belonging to the most recent
&lt;tt class="docutils literal"&gt;if&lt;/tt&gt; statement. And indeed, other than the famous &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Lexer_hack"&gt;lexer hack&lt;/a&gt;
used to deal with &lt;a class="reference external" href="https://eli.thegreenplace.net/2011/05/02/the-context-sensitivity-of-cs-grammar-revisited"&gt;C's type name / ID ambiguity&lt;/a&gt;,
pycparser only had this single shift-reduce conflict.&lt;/p&gt;
&lt;p&gt;But things got more complicated. Over the years, features were added that
weren't strictly in the standard but were supported by all the industrial
compilers. The more advanced C11 and C23 standards weren't beholden to the
promises of conflict-free YACC parsing (since almost no industrial-strength
compilers use YACC at this point), so all caution went out of the window.&lt;/p&gt;
&lt;p&gt;The latest (PLY-based) release of pycparser has many reduce-reduce conflicts
&lt;a class="footnote-reference" href="#footnote-2" id="footnote-reference-2"&gt;[2]&lt;/a&gt;; these are a severe maintenance hazard because it means the parsing rules
essentially have to be tie-broken by order of appearance in the code. This is
very brittle; pycparser has only managed to maintain its stability and quality
through its comprehensive test suite. Over time, it became harder and harder to
extend, because YACC parsing rules have all kinds of spooky-action-at-a-distance
effects. The straw that broke the camel's back was &lt;a class="reference external" href="https://github.com/eliben/pycparser/pull/590"&gt;this PR&lt;/a&gt; which again proposed to
increase the number of reduce-reduce conflicts &lt;a class="footnote-reference" href="#footnote-3" id="footnote-reference-3"&gt;[3]&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;This - again - prompted me to think &amp;quot;what if I just dump YACC and switch to
a hand-written recursive descent parser&amp;quot;, and here we are.&lt;/p&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;div class="section" id="the-mental-roadblock"&gt;
&lt;h2&gt;The mental roadblock&lt;/h2&gt;
&lt;p&gt;None of the challenges described above are new; I've been pondering them for
many years now, and yet biting the bullet and rewriting the parser didn't feel
like something I'd like to get into. By my private estimates it'd take at least
a week of deep heads-down work to port the gritty 2000 lines of YACC grammar
rules to a recursive descent parser &lt;a class="footnote-reference" href="#footnote-4" id="footnote-reference-4"&gt;[4]&lt;/a&gt;. Moreover, it wouldn't be a
particularly &lt;em&gt;fun&lt;/em&gt; project either - I didn't feel like I'd learn much new and
my interests have shifted away from this project. In short, the &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Potential_well"&gt;Potential well&lt;/a&gt; was just too deep.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="why-would-this-even-work-tests"&gt;
&lt;h2&gt;Why would this even work? Tests&lt;/h2&gt;
&lt;p&gt;I've definitely noticed the improvement in capabilities of LLM coding
agents in the past few months, and many reputable people online rave about using
them for increasingly larger projects. That said, would an LLM agent really be
able to accomplish such a complex project on its own? This isn't just a toy,
it's thousands of lines of dense parsing code.&lt;/p&gt;
&lt;p&gt;What gave me hope is the concept of &lt;a class="reference external" href="https://simonwillison.net/2025/Dec/31/the-year-in-llms/#the-year-of-conformance-suites"&gt;conformance suites mentioned by
Simon Willison&lt;/a&gt;.
Agents seem to do well when there's a very clear and rigid
goal function - such as a large, high-coverage conformance test suite.&lt;/p&gt;
&lt;p&gt;And pycparser has an &lt;a class="reference external" href="https://github.com/eliben/pycparser/blob/main/tests/test_c_parser.py"&gt;very extensive one&lt;/a&gt;.
Over 2500 lines of test code parsing various C snippets to ASTs with expected
results, grown over a decade and a half of real issues and bugs reported by
users.&lt;/p&gt;
&lt;p&gt;I figured the LLM can either succeed or fail and throw its hands up in despair,
but it's quite unlikely to produce a &lt;em&gt;wrong&lt;/em&gt; port that would still pass all
the tests. So I set it to run.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="the-initial-port"&gt;
&lt;h2&gt;The initial port&lt;/h2&gt;
&lt;p&gt;I fired up Codex in pycparser's repository, and wrote this prompt just to
make sure it understands me and can run the tests:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;This repository implements a Python parser producing an AST. It has a very
extensive test suite; can you first verify you successfully run all the tests?
(with python3 -m unittest discover, from the root dir)
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Codex figured it out (I gave it the exact command, after all!); my next
prompt was the real thing &lt;a class="footnote-reference" href="#footnote-5" id="footnote-reference-5"&gt;[5]&lt;/a&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;pycparser uses the PLY library for its lexer and parser, and its own AST
package for the ASTs created by the parser. I want to explore replacing the
Parser component (leaving the Lexer alone for now) with a hand written
recursive descent parser that will pass all the same tests. Can you try that?
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Here Codex went to work and churned &lt;em&gt;for over an hour&lt;/em&gt;. Having never
observed an agent work for nearly this long, I kind of assumed it went off
the rails and will fail sooner or later. So I was rather surprised and skeptical
when it eventually came back with:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;Replaced the PLY parse path with a handwritten recursive‑descent parser while
keeping the existing lexer and AST helpers intact; CParser now points to the
new RD implementation, with the old PLY version preserved as PLYCParser in
pycparser/c_parser.py. The RD parser includes a token stream, declarator/type
parsing, expression precedence handling, pragmas, and statement parsing to
match the previous AST output and coordinate behavior.
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;It took me a while to poke around the code and run it until I was convinced - it
had actually done it! It wrote a new recursive descent parser with only
ancillary dependencies on PLY, and that parser passed the test suite. After a
few more prompts, we've removed the ancillary dependencies and made the
structure clearer. I hadn't looked too deeply into code quality at this point,
but at least on the functional level - it succeeded. This was very impressive!&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="a-quick-note-on-reviews-and-branches"&gt;
&lt;h2&gt;A quick note on reviews and branches&lt;/h2&gt;
&lt;p&gt;A change like the one described above is impossible to code-review as one PR in
any meaningful way; so I used a different strategy. Before embarking on this
path, I created a new branch and once Codex finished the initial rewrite, I
committed this change, knowing that I will review it in detail, piece-by-piece
later on.&lt;/p&gt;
&lt;p&gt;Even though coding agents have their own notion of history and can &amp;quot;revert&amp;quot;
certain changes, I felt much safer relying on Git. In the worst case if all of
this goes south, I can nuke the branch and it's as if nothing ever happened.
I was determined to only merge this branch onto &lt;tt class="docutils literal"&gt;main&lt;/tt&gt; once I was fully
satisfied with the code. In what follows, I had to &lt;tt class="docutils literal"&gt;git reset&lt;/tt&gt; several times
when I didn't like the direction in which Codex was going. In hindsight, doing
this work in a branch was absolutely the right choice.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="the-long-tail-of-goofs"&gt;
&lt;h2&gt;The long tail of goofs&lt;/h2&gt;
&lt;p&gt;Once I've sufficiently convinced myself that the new parser is actually working,
I used Codex to similarly rewrite the lexer and get rid of the PLY dependency
entirely, deleting it from the repository. Then, I started looking more deeply
into code quality - reading the code created by Codex and trying to wrap my head
around it.&lt;/p&gt;
&lt;p&gt;And - oh my - this was quite the journey. Much has been written about the code
produced by agents, and much of it seems to be true. Maybe it's a setting I'm
missing (I'm not using my own custom &lt;tt class="docutils literal"&gt;AGENTS.md&lt;/tt&gt; yet, for instance), but
Codex seems to be that eager programmer that wants to get from A to B whatever
the cost. Readability, minimalism and code clarity are very much secondary
goals.&lt;/p&gt;
&lt;p&gt;Using &lt;tt class="docutils literal"&gt;&lt;span class="pre"&gt;raise...except&lt;/span&gt;&lt;/tt&gt; for control flow? Yep. Abusing Python's weak typing
(like having &lt;tt class="docutils literal"&gt;None&lt;/tt&gt;, &lt;tt class="docutils literal"&gt;false&lt;/tt&gt; and other values all mean different things
for a given variable)? For sure. Spreading the logic of a complex function
all over the place instead of putting all the key parts in a single switch
statement? You bet.&lt;/p&gt;
&lt;p&gt;Moreover, the agent is hilariously &lt;em&gt;lazy&lt;/em&gt;. More than once I had to convince it
to do something it initially said is impossible, and even insisted again in
follow-up messages. The anthropomorphization here is mildly concerning, to be
honest. I could never imagine I would be writing something like the following to
a computer, and yet - here we are: &amp;quot;Remember how we moved X to Y before? You
can do it again for Z, definitely. Just try&amp;quot;.&lt;/p&gt;
&lt;p&gt;My process was to see how I can instruct Codex to fix things, and intervene
myself (by rewriting code) as little as possible. I've &lt;em&gt;mostly&lt;/em&gt; succeeded in
this, and did maybe 20% of the work myself.&lt;/p&gt;
&lt;p&gt;My branch grew &lt;em&gt;dozens&lt;/em&gt; of commits, falling into roughly these categories:&lt;/p&gt;
&lt;ol class="arabic simple"&gt;
&lt;li&gt;The code in X is too complex; why can't we do Y instead?&lt;/li&gt;
&lt;li&gt;The use of X is needlessly convoluted; change Y to Z, and T to V in all
instances.&lt;/li&gt;
&lt;li&gt;The code in X is unclear; please add a detailed comment - with examples - to
explain what it does.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;Interestingly, after doing (3), the agent was often more effective in giving
the code a &amp;quot;fresh look&amp;quot; and succeeding in either (1) or (2).&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="the-end-result"&gt;
&lt;h2&gt;The end result&lt;/h2&gt;
&lt;p&gt;Eventually, after many hours spent in this process, I was reasonably pleased
with the code. It's far from perfect, of course, but taking the essential
complexities into account, it's something I could see myself maintaining (with
or without the help of an agent). I'm sure I'll find more ways to improve it
in the future, but I have a reasonable degree of confidence that this will be
doable.&lt;/p&gt;
&lt;p&gt;It passes all the tests, so I've been able to release a new version (3.00)
without major issues so far. The only issue I've discovered is that some of
CFFI's tests are overly precise about the phrasing of errors reported by
pycparser; this was &lt;a class="reference external" href="https://github.com/python-cffi/cffi/pull/224"&gt;an easy fix&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;The new parser is also faster, by about 30% based on my benchmarks! This is
typical of recursive descent when compared with YACC-generated parsers, in my
experience. After reviewing the initial rewrite of the lexer, I've spent a while
instructing Codex on how to make it faster, and it worked reasonably well.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="followup-static-typing"&gt;
&lt;h2&gt;Followup - static typing&lt;/h2&gt;
&lt;p&gt;While working on this, it became quite obvious that static typing would make the
process easier. LLM coding agents really benefit from closed loops with strict
guardrails (e.g. a test suite to pass), and type-annotations act as such.
For example, had pycparser already been type annotated, Codex would probably not
have overloaded values to multiple types (like &lt;tt class="docutils literal"&gt;None&lt;/tt&gt; vs. &lt;tt class="docutils literal"&gt;False&lt;/tt&gt; vs.
others).&lt;/p&gt;
&lt;p&gt;In a followup, I asked Codex to type-annotate pycparser (running checks using
&lt;tt class="docutils literal"&gt;ty&lt;/tt&gt;), and this was also a back-and-forth because the process exposed some
issues that needed to be refactored. Time will tell, but hopefully it will make
further changes in the project simpler for the agent.&lt;/p&gt;
&lt;p&gt;Based on this experience, I'd bet that coding agents will be somewhat more
effective in strongly typed languages like Go, TypeScript and especially Rust.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="conclusions"&gt;
&lt;h2&gt;Conclusions&lt;/h2&gt;
&lt;p&gt;Overall, this project has been a really good experience, and I'm impressed with
what modern LLM coding agents can do! While there's no reason to expect that
progress in this domain will stop, even if it does - these are already very
useful tools that can significantly improve programmer productivity.&lt;/p&gt;
&lt;p&gt;Could I have done this myself, without an agent's help? Sure. But it would have
taken me &lt;em&gt;much&lt;/em&gt; longer, assuming that I could even muster the will and
concentration to engage in this project. I estimate it would take me at least
a week of full-time work (so 30-40 hours) spread over who knows how long to
accomplish. With Codex, I put in an order of magnitude less work into this
(around 4-5 hours, I'd estimate) and I'm happy with the result.&lt;/p&gt;
&lt;p&gt;It was also &lt;em&gt;fun&lt;/em&gt;. At least in one sense, my professional life can be described
as the pursuit of focus, deep work and &lt;em&gt;flow&lt;/em&gt;. It's not easy for me to get into
this state, but when I do I'm highly productive and find it very enjoyable.
Agents really help me here. When I know I need to write some code and it's
hard to get started, asking an agent to write a prototype is a great catalyst
for my motivation. Hence the meme at the beginning of the post.&lt;/p&gt;
&lt;div class="section" id="does-code-quality-even-matter"&gt;
&lt;h3&gt;Does code quality even matter?&lt;/h3&gt;
&lt;p&gt;One can't avoid a nagging question - does the quality of the code produced
by agents even matter? Clearly, the agents themselves can understand it (if not
today's agent, then at least next year's). Why worry about future
maintainability if the agent can maintain it? In other words, does it make sense
to just go full vibe-coding?&lt;/p&gt;
&lt;p&gt;This is a fair question, and one I don't have an answer to. Right now, for
projects I maintain and &lt;em&gt;stand behind&lt;/em&gt;, it seems obvious to me that the code
should be fully understandable and accepted by me, and the agent is just a tool
helping me get to that state more efficiently. It's hard to say what the future
holds here; it's going to interesting, for sure.&lt;/p&gt;
&lt;hr class="docutils" /&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-1" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-1"&gt;[1]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;pycparser has a fair number of &lt;a class="reference external" href="https://deps.dev/pypi/pycparser/3.0.0/dependents"&gt;direct dependents&lt;/a&gt;,
but the majority of downloads comes through &lt;a class="reference external" href="https://github.com/python-cffi/cffi"&gt;CFFI&lt;/a&gt;,
which itself is a major building block for much of the Python ecosystem.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-2" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-2"&gt;[2]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;The table-building report says 177, but that's certainly an
over-dramatization because it's common for a single conflict to
manifest in several ways.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-3" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-3"&gt;[3]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;It didn't help the PR's case that it was almost certainly vibe coded.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-4" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-4"&gt;[4]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;&lt;p class="first"&gt;There was also the lexer to consider, but this seemed like a much
simpler job. My impression is that in the early days of computing,
&lt;tt class="docutils literal"&gt;lex&lt;/tt&gt; gained prominence because of strong regexp support which wasn't
very common yet. These days, with excellent regexp libraries
existing for pretty much every language, the added value of &lt;tt class="docutils literal"&gt;lex&lt;/tt&gt; over
a &lt;a class="reference external" href="https://eli.thegreenplace.net/2013/06/25/regex-based-lexical-analysis-in-python-and-javascript"&gt;custom regexp-based lexer&lt;/a&gt;
isn't very high.&lt;/p&gt;
&lt;p class="last"&gt;That said, it wouldn't make much sense to embark on a journey to rewrite
&lt;em&gt;just&lt;/em&gt; the lexer; the dependency on PLY would still remain, and besides,
PLY's lexer and parser are designed to work well together. So it wouldn't
help me much without tackling the parser beast.&lt;/p&gt;
&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-5" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-5"&gt;[5]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;I've decided to ask it to the port the parser first, leaving the lexer
alone. This was to split the work into reasonable chunks. Besides, I
figured that the parser is the hard job anyway - if it succeeds in that,
the lexer should be easy. That assumption turned out to be correct.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;
</content><category term="misc"></category><category term="Python"></category><category term="Machine Learning"></category><category term="Compilation"></category><category term="Recursive descent parsing"></category></entry><entry><title>Sparsely-gated Mixture Of Experts (MoE)</title><link href="https://eli.thegreenplace.net/2025/sparsely-gated-mixture-of-experts-moe/" rel="alternate"></link><published>2025-04-18T09:33:00-07:00</published><updated>2025-04-18T16:33:37-07:00</updated><author><name>Eli Bendersky</name></author><id>tag:eli.thegreenplace.net,2025-04-18:/2025/sparsely-gated-mixture-of-experts-moe/</id><summary type="html">&lt;p&gt;In &lt;a class="reference external" href="https://arxiv.org/pdf/1706.03762"&gt;transformer models&lt;/a&gt;, the
&lt;a class="reference external" href="https://eli.thegreenplace.net/2025/notes-on-implementing-attention/"&gt;attention block&lt;/a&gt;
is typically followed by a &lt;em&gt;feed forward&lt;/em&gt; layer (FF), which is a simple fully-connected
NN with a hidden layer and nonlinearity. Here's the code for such a block that
uses ReLU:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;feed_forward_relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Feed-forward layer with ReLU activation.&lt;/span&gt;

&lt;span class="sd"&gt;    Args:&lt;/span&gt;
&lt;span class="sd"&gt;        x: Input …&lt;/span&gt;&lt;/pre&gt;&lt;/div&gt;</summary><content type="html">&lt;p&gt;In &lt;a class="reference external" href="https://arxiv.org/pdf/1706.03762"&gt;transformer models&lt;/a&gt;, the
&lt;a class="reference external" href="https://eli.thegreenplace.net/2025/notes-on-implementing-attention/"&gt;attention block&lt;/a&gt;
is typically followed by a &lt;em&gt;feed forward&lt;/em&gt; layer (FF), which is a simple fully-connected
NN with a hidden layer and nonlinearity. Here's the code for such a block that
uses ReLU:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;feed_forward_relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Feed-forward layer with ReLU activation.&lt;/span&gt;

&lt;span class="sd"&gt;    Args:&lt;/span&gt;
&lt;span class="sd"&gt;        x: Input tensor (B, N, D).&lt;/span&gt;
&lt;span class="sd"&gt;        Wh: Weights for the hidden layer (D, DH).&lt;/span&gt;
&lt;span class="sd"&gt;        Wo: Weights for the output layer (DH, D).&lt;/span&gt;

&lt;span class="sd"&gt;    Returns:&lt;/span&gt;
&lt;span class="sd"&gt;        Output tensor (B, N, D).&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;  &lt;span class="c1"&gt;# hidden layer (B, N, DH)&lt;/span&gt;
    &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;maximum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# ReLU activation (B, N, DH)&lt;/span&gt;
    &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt;  &lt;span class="c1"&gt;# output layer (B, N, D)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;This layer typically holds most of the weights in the transformer, because the
hidden dimension (DH in this post, &lt;tt class="docutils literal"&gt;hidden_dim&lt;/tt&gt; in some papers) is large - 4x
the embedding depth D is common. Intuitively, this makes sense because this
layer does the majority of the heavy lifting; while the attention block mixes
the embeddings of tokens together to express their relationships to one another,
the FF block does the actual &amp;quot;reasoning&amp;quot; on these tokens.&lt;/p&gt;
&lt;p&gt;Transformer blocks are repeated dozens of times in a model, so the total size
of these layers becomes problematic.
One approach for improving efficiency that became very popular is called
&lt;em&gt;sparsely-gated mixture of experts&lt;/em&gt; (&lt;a class="reference external" href="https://arxiv.org/pdf/1701.06538"&gt;paper&lt;/a&gt;).&lt;/p&gt;
&lt;div class="section" id="mixture-of-experts-architecture-moe"&gt;
&lt;h2&gt;Mixture of Experts architecture (MoE)&lt;/h2&gt;
&lt;p&gt;The basic idea of MoE is this:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;The large FF layer is split into a number (NEXP) of blocks called &amp;quot;experts&amp;quot;. Each
expert is still a FF layer. It takes a vector of size D and transforms it
to another vector of size D.&lt;/li&gt;
&lt;li&gt;There's an additional piece called &amp;quot;router&amp;quot; or &amp;quot;gate&amp;quot;. This is just a
fully-connected layer (D, NEXP) that takes a token and produces a score for
each expert. The router is learned by the model, along with the experts
themselves.&lt;/li&gt;
&lt;li&gt;K experts with the highest scores are selected for each
token, and the token is only fed through these experts.&lt;/li&gt;
&lt;li&gt;The scores are also used to calculate a weighted average from the experts'
outputs, eventually
producing an answer of size D.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Here's a diagram for a single token, assuming &lt;tt class="docutils literal"&gt;NEXP=8&lt;/tt&gt; and &lt;tt class="docutils literal"&gt;TOPK=2&lt;/tt&gt; (the two
highest scoring experts are selected for each token, out of a total of eight):&lt;/p&gt;
&lt;img alt="mixture of experts diagram" class="align-center" src="https://eli.thegreenplace.net/images/2025/moe.png" /&gt;
&lt;p&gt;Notes:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;Experts #1 and #5 are selected because the router produced the highest
scores for them among all the experts. The input token is routed to these
experts, but not to the others.&lt;/li&gt;
&lt;li&gt;The output of each expert is element-wise multiplied by a corresponding
&lt;em&gt;weight&lt;/em&gt;, calculated from the scores of the selected experts using a softmax
function (to ensure balanced weighting across multiple tokens and experts).&lt;/li&gt;
&lt;li&gt;The weighted expert outputs are added up for the final output of the layer.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;The key point to understand about this architecture is: the experts that were
not among the top K for a token &lt;em&gt;aren't used at all&lt;/em&gt; - the computation required
to propagate the token through these experts is eschewed (both on the forward and
backward passes).&lt;/p&gt;
&lt;p&gt;This is the goal of the MoE architecture - we increase the overall model size,
but keep the computational cost in check by only using a portion of the
parameters for every single token.
This is also reflected in the models' names; for example, the &lt;a class="reference external" href="https://arxiv.org/pdf/2401.04088"&gt;Mixtral model&lt;/a&gt; has size 8x7B; it has 8 experts, and it
would be incorrect to just multiply the size of each expert by 8 because not
all these parameters participate in the calculation of every token &lt;a class="footnote-reference" href="#footnote-1" id="footnote-reference-1"&gt;[1]&lt;/a&gt;.
According to the Mixtral paper, the model only uses 13B active parameters
for each token.&lt;/p&gt;
&lt;p&gt;A summary of the idea behind MoE is:&lt;/p&gt;
&lt;blockquote&gt;
MoE increases the model's capacity without proportionally increasing its
computational cost.&lt;/blockquote&gt;
&lt;/div&gt;
&lt;div class="section" id="numpy-implementation"&gt;
&lt;h2&gt;Numpy implementation&lt;/h2&gt;
&lt;p&gt;Here's a well-commented implementation of the MoE layer using pure Numpy. First,
some parameters:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# Parameters for a feed-forward layer with a fixed activation function.&lt;/span&gt;
&lt;span class="nd"&gt;@dataclass&lt;/span&gt;
&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;FFParams&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;Wh&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;
    &lt;span class="n"&gt;Wo&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;


&lt;span class="c1"&gt;# Parameters for a Mixture of Experts (MoE) layer.&lt;/span&gt;
&lt;span class="nd"&gt;@dataclass&lt;/span&gt;
&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MoEParams&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Embedding dimension of each token (a.k.a. model dimension, Dmodel)&lt;/span&gt;
    &lt;span class="n"&gt;D&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;

    &lt;span class="c1"&gt;# Hidden dimension in FF layers&lt;/span&gt;
    &lt;span class="n"&gt;DH&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;

    &lt;span class="c1"&gt;# Total number of experts&lt;/span&gt;
    &lt;span class="n"&gt;NEXP&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;

    &lt;span class="c1"&gt;# K in the top-k selection of top experts per token&lt;/span&gt;
    &lt;span class="n"&gt;TOPK&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;

    &lt;span class="c1"&gt;# List of experts: each expert is a forward layer with FFParams.&lt;/span&gt;
    &lt;span class="n"&gt;ff_weights&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;List&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;FFParams&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="c1"&gt;# Router weights: a linear layer (D, NEXP) that maps input to expert scores.&lt;/span&gt;
    &lt;span class="n"&gt;router_weights&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;And now the implementation. Note that it takes a general (B, N, D) input, assuming
batch dimension D and sequence length N:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;moe&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;MoEParams&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Mixture of Experts (MoE) layer.&lt;/span&gt;

&lt;span class="sd"&gt;    Args:&lt;/span&gt;
&lt;span class="sd"&gt;        x: Input tensor (B, N, D).&lt;/span&gt;
&lt;span class="sd"&gt;        params: MoEParams.&lt;/span&gt;

&lt;span class="sd"&gt;    Returns:&lt;/span&gt;
&lt;span class="sd"&gt;        Output tensor (B, N, D).&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="c1"&gt;# Run input through router to get expert scores for each token.&lt;/span&gt;
    &lt;span class="n"&gt;expert_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;router_weights&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, NEXP)&lt;/span&gt;

    &lt;span class="c1"&gt;# Select the top-k expert scores and their indices for each token.&lt;/span&gt;
    &lt;span class="n"&gt;top_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;top_experts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;topk_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;expert_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;TOPK&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, TOPK)&lt;/span&gt;

    &lt;span class="c1"&gt;# Apply softmax to the top scores to get weights that sum to 1.&lt;/span&gt;
    &lt;span class="n"&gt;weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;top_scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, TOPK)&lt;/span&gt;

    &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
            &lt;span class="c1"&gt;# Unvectorized implementation: for each token in the batch and&lt;/span&gt;
            &lt;span class="c1"&gt;# sequence, select the top-k experts and apply them with the&lt;/span&gt;
            &lt;span class="c1"&gt;# calculated weights.&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;expert_idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weight&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;top_experts&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
                &lt;span class="n"&gt;expert&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ff_weights&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;expert_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
                &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;weight&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;feed_forward_relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;expert&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Wh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;expert&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Wo&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Calculating the experts themselves is not vectorized here - it is done token
by token. MoE is inherently sparse: different tokens in the same sequence (and
batch) may go through different sets of experts. Vectorizing this
efficiently is tricky in general
and depends on the HW we run the model on &lt;a class="footnote-reference" href="#footnote-2" id="footnote-reference-2"&gt;[2]&lt;/a&gt;.
For a popular approach on GPUs, see the
&lt;a class="reference external" href="https://arxiv.org/pdf/2211.15841"&gt;MegaBlocks paper&lt;/a&gt; from 2022. This remains
an active area of research.&lt;/p&gt;
&lt;p&gt;All that's left are some helper functions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;topk_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Get the top k elements and their indices.&lt;/span&gt;

&lt;span class="sd"&gt;    x is an arbitrary array with at least two dimensions. The returned&lt;/span&gt;
&lt;span class="sd"&gt;    array has the same shape as x, but its elements are the top k elements&lt;/span&gt;
&lt;span class="sd"&gt;    across the last dimension. The indices of the top k elements are also&lt;/span&gt;
&lt;span class="sd"&gt;    returned.&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argpartition&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;take_along_axis&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;softmax_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Compute softmax across last dimension of x.&lt;/span&gt;

&lt;span class="sd"&gt;    x is an arbitrary array with at least two dimensions. The returned array has&lt;/span&gt;
&lt;span class="sd"&gt;    the same shape as x, but its elements sum up to 1 across the last dimension.&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="c1"&gt;# Subtract the max for numerical stability&lt;/span&gt;
    &lt;span class="n"&gt;ex&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="c1"&gt;# Divide by sums across last dimension&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;ex&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ex&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;div class="section" id="additional-considerations"&gt;
&lt;h2&gt;Additional considerations&lt;/h2&gt;
&lt;p&gt;A major area of focus with MoE architectures is &lt;em&gt;load balancing&lt;/em&gt; among experts.
Without special provisions, the model may learn to prefer certain experts over
others and this leads to inefficient utilization of the model's weights. There
are various approaches to tackle this, for example:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;Adding noise to the top-k selection process to inject randomness&lt;/li&gt;
&lt;li&gt;Defining a special loss function during training that encourages experts
to receive a roughly equal number of training samples&lt;/li&gt;
&lt;/ul&gt;
&lt;/div&gt;
&lt;div class="section" id="code"&gt;
&lt;h2&gt;Code&lt;/h2&gt;
&lt;p&gt;The full code for this post is &lt;a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/main/transformer-attention/moe.py"&gt;available on GitHub&lt;/a&gt;.&lt;/p&gt;
&lt;hr class="docutils" /&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-1" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-1"&gt;[1]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;Another way to think about MoE is that each &amp;quot;expert&amp;quot; specializes in
a certain area of the model's capability. For example, one expert would
be good at math, another at prose, etc. This is a very rough
approximation, though, because transformer models consist of dozens of
repeating blocks, and all these different experts end up thoroughly
intermixed as tokens flow through the entire model.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-2" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-2"&gt;[2]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;&lt;p class="first"&gt;In the sparsely-gated mixture of experts paper, this is referred to as &lt;em&gt;The Shrinking Batch Problem&lt;/em&gt;:&lt;/p&gt;
&lt;p class="last"&gt;&lt;em&gt;&amp;quot;In modern CPUs and GPUs, large batch sizes are necessary for computational efficiency, so as
to amortize the overhead of parameter loads and updates. If the gating network chooses k out of
n experts for each example, then for a batch of b examples, each expert receives a much smaller
batch of approximately kb/n &amp;lt;&amp;lt; b examples. This causes a naive MoE implementation to become
very inefficient as the number of experts increases&amp;quot;&lt;/em&gt;&lt;/p&gt;
&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
</content><category term="misc"></category><category term="Math"></category><category term="Machine Learning"></category><category term="Python"></category></entry><entry><title>Cross-entropy and KL divergence</title><link href="https://eli.thegreenplace.net/2025/cross-entropy-and-kl-divergence/" rel="alternate"></link><published>2025-04-12T06:54:00-07:00</published><updated>2025-04-13T18:02:41-07:00</updated><author><name>Eli Bendersky</name></author><id>tag:eli.thegreenplace.net,2025-04-12:/2025/cross-entropy-and-kl-divergence/</id><summary type="html">&lt;p&gt;Cross-entropy is widely used in modern ML to compute the loss for classification
tasks. This post is a brief overview of the math behind it and a related
concept called Kullback-Leibler (KL) divergence.&lt;/p&gt;
&lt;div class="section" id="information-content-of-a-single-random-event"&gt;
&lt;h2&gt;Information content of a single random event&lt;/h2&gt;
&lt;p&gt;We'll start with a single event (&lt;em&gt;E&lt;/em&gt;) that has probability …&lt;/p&gt;&lt;/div&gt;</summary><content type="html">&lt;p&gt;Cross-entropy is widely used in modern ML to compute the loss for classification
tasks. This post is a brief overview of the math behind it and a related
concept called Kullback-Leibler (KL) divergence.&lt;/p&gt;
&lt;div class="section" id="information-content-of-a-single-random-event"&gt;
&lt;h2&gt;Information content of a single random event&lt;/h2&gt;
&lt;p&gt;We'll start with a single event (&lt;em&gt;E&lt;/em&gt;) that has probability &lt;em&gt;p&lt;/em&gt;. The information
content (or &amp;quot;degree of surprise&amp;quot;) of this event occurring is defined as:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/411b392d3cb3d6df381212ad075bb742324667f9.svg" style="height: 43px;" type="image/svg+xml"&gt;\[I(E) = \log_2 \left (\frac{1}{p} \right )\]&lt;/object&gt;
&lt;p&gt;The base 2 here is used so that we can count the information in units of &lt;em&gt;bits&lt;/em&gt;.
Thinking about this definition intuitively, imagine an event with probability
&lt;em&gt;p=1&lt;/em&gt;; using the formula, the information we gain by observing this event
occurring is 0, which makes sense. On the other extreme, as the probability
&lt;em&gt;p&lt;/em&gt; approaches 0, the information we gain is huge. An equivalent way to write
the formula is:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/3fff25a47de21db8fc5f51d5c8162c4a4ac70884.svg" style="height: 19px;" type="image/svg+xml"&gt;\[I(E) = -\log_2 p\]&lt;/object&gt;
&lt;p&gt;Some numeric examples: suppose we flip a fair coin and it comes out heads. The
probability of this event happening is 1/2, therefore:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/7617949f4d165a96deecb06720eaaa7a1535e63a.svg" style="height: 36px;" type="image/svg+xml"&gt;\[I(E_{heads})=-\log_2 \frac{1}{2} = 1\]&lt;/object&gt;
&lt;p&gt;Now suppose we roll a fair die and it lands on 4. The probability of this event
happening is 1/6, therefore:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/f01c001877cad81ba2f23415978b493f18a99941.svg" style="height: 36px;" type="image/svg+xml"&gt;\[I(E_4)=-\log_2 \frac{1}{6} = 2.58\]&lt;/object&gt;
&lt;p&gt;In other words, the degree of surprise for rolling a 4 is higher than the degree
of surprise for flipping to heads - which makes sense, given the probabilities
involved.&lt;/p&gt;
&lt;p&gt;Other than behaving correctly for boundary values, the logarithm function makes
sense for calculating the degree of surprise for another important reason: the
way it behaves for a combination of events.&lt;/p&gt;
&lt;p&gt;Consider this: we flip a fair coin and roll a fair die; the coin comes out
heads, and the die lands on 4. What is the probability of this event happening?
Because the two events are independent, the probability is the product of the
probabilities of the individual events, so 1/12, and then:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/dce1d3b5f70fdd0d2d208e83aa21218761a31ef8.svg" style="height: 36px;" type="image/svg+xml"&gt;\[I(E_{heads}\cap E_{4})=-\log_2 \frac{1}{12} = 3.58\]&lt;/object&gt;
&lt;p&gt;Note that the entropy is the precise &lt;em&gt;sum&lt;/em&gt; of the entropies of individual events.
This is to be expected - we need so many bits for one of the events, and so many
for the other; the total of the bits adds up. The logarithm function gives us
exactly this behavior for probabilities:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/52ad04456bdd7fb6c8e23f3d071219b42ef7e72e.svg" style="height: 19px;" type="image/svg+xml"&gt;\[\log(p_1 \cap p_2) = \log(p_1 \cdot p_2) = \log(p_1) + \log(p_2)\]&lt;/object&gt;
&lt;/div&gt;
&lt;div class="section" id="entropy"&gt;
&lt;h2&gt;Entropy&lt;/h2&gt;
&lt;p&gt;Given a random variable &lt;em&gt;X&lt;/em&gt; with values &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/d601c93a21050cc76c2120e759f794765487e037.svg" style="height: 11px;" type="image/svg+xml"&gt;x_1\dots x_n&lt;/object&gt; and associated
probabilities &lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/4f2e682f2f2e50ecdc8a8709afd7ea86cf5c9baa.svg" style="height: 12px;" type="image/svg+xml"&gt;p_1\dots p_n&lt;/object&gt;, the &lt;em&gt;entropy of X&lt;/em&gt; is defined as the
expected value of information for &lt;em&gt;X&lt;/em&gt;:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/aa60fc303ec00fbe8dd3e4dd6ec0574daaec661b.svg" style="height: 52px;" type="image/svg+xml"&gt;\[H(X)=-\sum_{j=1}^{n}p_j \log_2 p_j\]&lt;/object&gt;
&lt;p&gt;High entropy means high uncertainty, while low entropy means low uncertainty.
Let's look at a couple of examples:&lt;/p&gt;
&lt;img alt="distribution with single value at probability 1, others at 0" class="align-center" src="https://eli.thegreenplace.net/images/2025/distrib-1-0s.png" /&gt;
&lt;p&gt;This is a random variable with 5 distinct values; the probability of &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/593f4cff5d4210d46e140db57bafc4f692493f76.svg" style="height: 11px;" type="image/svg+xml"&gt;x_1&lt;/object&gt;
is 1, and the rest is 0. The entropy here is 0, because &lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/22dc559721b05f362dd835ea5e94678bb7c89f45.svg" style="height: 16px;" type="image/svg+xml"&gt;1\cdot \log 1 = 0&lt;/object&gt;
and also &lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/52fbfb13262ae7b2bd6f03a574408208896b0c9a.svg" style="height: 16px;" type="image/svg+xml"&gt;0\cdot \log 0 = 0&lt;/object&gt;  &lt;a class="footnote-reference" href="#footnote-1" id="footnote-reference-1"&gt;[1]&lt;/a&gt;. We gain no
information by observing an event sampled from this distribution, because we
knew ahead of time what would happen.&lt;/p&gt;
&lt;p&gt;Another example is a uniform distribution for the 5 possible outcomes:&lt;/p&gt;
&lt;img alt="distribution with uniform probabilities 0.2 per value" class="align-center" src="https://eli.thegreenplace.net/images/2025/distrib-uniform.png" /&gt;
&lt;p&gt;The entropy for this distribution is:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/40d3d1432da4d5a07dfb6440da7814c1cc4d72c9.svg" style="height: 56px;" type="image/svg+xml"&gt;\[H(X)=-\sum_{j=1}^{5}0.2 \log_2 0.2 = 2.32\]&lt;/object&gt;
&lt;p&gt;Intuitively: we have 5 different values with equal probabilities, so we'll need
&lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3304ed08b474e66efb2b577f4017f7a6c1114c2b.svg" style="height: 17px;" type="image/svg+xml"&gt;\log_{2} 5=2.32&lt;/object&gt; bits to represent that. Note that entropy is always
non-negative, because
&lt;object class="valign-m6" data="https://eli.thegreenplace.net/images/math/048f921e590e2392d2ffd293286dc52e74aa6371.svg" style="height: 18px;" type="image/svg+xml"&gt;0\leq p_j \leq 1&lt;/object&gt; and therefore &lt;object class="valign-m6" data="https://eli.thegreenplace.net/images/math/3ee0f682a5c1e670e20f57bc74961d6d491e7575.svg" style="height: 18px;" type="image/svg+xml"&gt;\log_2 p_j \leq 0&lt;/object&gt; for all &lt;em&gt;j&lt;/em&gt;
in a proper probability distribution.&lt;/p&gt;
&lt;p&gt;It's not hard to show that the maximum possible entropy for a random variable
occurs for a uniform distribution. In all other distributions, some values are
more represented than others which makes the result somewhat less surprising.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="cross-entropy"&gt;
&lt;h2&gt;Cross-entropy&lt;/h2&gt;
&lt;p&gt;Cross-entropy is an extension of the concept of entropy, when two different
probability distributions are present. The typical formulation useful for
machine learning is:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/b6e46340bd9f28350c89fb8ee386022ecc751dc2.svg" style="height: 52px;" type="image/svg+xml"&gt;\[H(P,Q)=-\sum_{j=1}^{n}p_j \log_2 q_j\]&lt;/object&gt;
&lt;p&gt;Where:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;&lt;em&gt;P&lt;/em&gt; is the actual observed data distribution&lt;/li&gt;
&lt;li&gt;&lt;em&gt;Q&lt;/em&gt; is the predicted data distribution&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Similarly to entropy, cross-entropy is non-negative; in fact, it collapses to
the entropy formula when &lt;em&gt;P&lt;/em&gt; and &lt;em&gt;Q&lt;/em&gt; are the same:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/6e1b6cb172af5f3c986755f0994d5bb41aa04a63.svg" style="height: 52px;" type="image/svg+xml"&gt;\[H(P,P)=-\sum_{j=1}^{n}p_j \log_2 p_j=H(P)\]&lt;/object&gt;
&lt;p&gt;An information-theoretic interpretation of cross-entropy is: the average number
of bits required to encode an actual probability distribution &lt;em&gt;P&lt;/em&gt;, when we
assumed the data follows &lt;em&gt;Q&lt;/em&gt; instead.&lt;/p&gt;
&lt;p&gt;Here's a numeric example:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Plotted:&lt;/p&gt;
&lt;img alt="plotting p vs q" class="align-center" src="https://eli.thegreenplace.net/images/2025/pq-n-vs-uniform.png" /&gt;
&lt;p&gt;The cross-entropy of these two distributions is 2.32&lt;/p&gt;
&lt;p&gt;Now let's try a &lt;em&gt;Q&lt;/em&gt; that's slightly closer to &lt;em&gt;P&lt;/em&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.175&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.35&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.175&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.15&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;img alt="plotting p vs q" class="align-center" src="https://eli.thegreenplace.net/images/2025/pq-n-vs-n2.png" /&gt;
&lt;p&gt;The cross-entropy in these distributions is somewhat lower, 2.16; this is
expected, because they're more similar. In other words, the outcome of measuring
&lt;em&gt;P&lt;/em&gt; when our model predicted &lt;em&gt;Q&lt;/em&gt; is less surprising.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="kl-divergence"&gt;
&lt;h2&gt;KL divergence&lt;/h2&gt;
&lt;p&gt;Cross-entropy is useful for tracking the training loss of a model (more on this
in the next section),
but it has some mathematical properties that make it less than ideal
as a statistical tool to compare two probability distributions. Specifically,
&lt;object class="valign-m5" data="https://eli.thegreenplace.net/images/math/d2d1a0f5ebf075a2675d7945869f40c18e3426c9.svg" style="height: 19px;" type="image/svg+xml"&gt;H(P,P)=H(P)&lt;/object&gt;, which isn't (usually) zero; this is the lowest value
possible for cross-entropy. In other words, cross-entropy
always retains the inherent uncertainty of &lt;em&gt;P&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;The KL divergence fixes this by subtracting &lt;object class="valign-m5" data="https://eli.thegreenplace.net/images/math/c1bafacb42854a23560796fb336e80c95c319031.svg" style="height: 19px;" type="image/svg+xml"&gt;H(P)&lt;/object&gt; from cross-entropy:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/44f211c9fc3917937101677e3b8f97a0983d2ac4.svg" style="height: 64px;" type="image/svg+xml"&gt;\[D_{KL}(P,Q)=H(P,Q)-H(P)=-\left (\sum_{j=1}^{n}p_j \log_2 q_j - \sum_{j=1}^{n}p_j \log_2 p_j \right )\]&lt;/object&gt;
&lt;p&gt;Manipulating the logarithms, we can also get these alternative formulations:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/4d4b8725d2df5a7f92fc22dfd10259e99a1ca5c9.svg" style="height: 52px;" type="image/svg+xml"&gt;\[D_{KL}(P,Q)=-\sum_{j=1}^{n}p_j \log_2 \frac{q_j}{p_j}=\sum_{j=1}^{n}p_j \log_2 \frac{p_j}{q_j}\]&lt;/object&gt;
&lt;p&gt;Thus, the KL divergence is more useful as a &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Divergence_(statistics)"&gt;measure of divergence&lt;/a&gt;
between
two probability distributions, since &lt;object class="valign-m5" data="https://eli.thegreenplace.net/images/math/07660ab74385d62ad0200d3248a2d2c35ed2b35c.svg" style="height: 19px;" type="image/svg+xml"&gt;D_{KL}(P,P)=0&lt;/object&gt;. Note, however, that
it's not a true &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Metric_space"&gt;distance metric&lt;/a&gt;
because it's not symmetric:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/64e352ee5152f703b40e74c72b72f20ac38d7175.svg" style="height: 19px;" type="image/svg+xml"&gt;\[D_{KL}(P,Q)\ne D_{KL}(Q,P)\]&lt;/object&gt;
&lt;/div&gt;
&lt;div class="section" id="uses-in-machine-learning"&gt;
&lt;h2&gt;Uses in machine learning&lt;/h2&gt;
&lt;p&gt;In ML, we often have a model that makes a prediction and a set of training data
which defines a real-world probability distribution. It's natural to define
a loss function in terms of the difference between the two distributions (the
model's prediction and the real data).&lt;/p&gt;
&lt;p&gt;Cross-entropy is very useful as a loss function because it's non-negative and
provides a single scalar number that's lower for similar distributions and
higher for dissimilar distributions. Moreover, if we think of cross-entropy
in terms of KL divergence:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/a31aa0b9de98d1aa1023600fa24f1397d517ef84.svg" style="height: 19px;" type="image/svg+xml"&gt;\[H(P,Q)=D_{KL}(P,Q)+H(P)\]&lt;/object&gt;
&lt;p&gt;We'll notice that &lt;object class="valign-m5" data="https://eli.thegreenplace.net/images/math/c1bafacb42854a23560796fb336e80c95c319031.svg" style="height: 19px;" type="image/svg+xml"&gt;H(P)&lt;/object&gt; - the entropy of the real-world distribution - does
not depend on the model at all. Therefore, optimizing cross-entropy is equivalent
to optimizing the KL divergence. I wrote about concrete uses of cross-entropy
as a loss function in previous posts:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;&lt;a class="reference external" href="https://eli.thegreenplace.net/2016/logistic-regression/"&gt;Logistic regression&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a class="reference external" href="https://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/"&gt;Softmax for multiclass classification&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;That said, the KL divergence is also sometimes useful more directly; for example
in the &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Evidence_lower_bound"&gt;evidence lower bound&lt;/a&gt;
used for &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Variational_autoencoder"&gt;Variational autoencoders&lt;/a&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="relation-to-maximum-likelihood-estimation"&gt;
&lt;h2&gt;Relation to Maximum Likelihood Estimation&lt;/h2&gt;
&lt;p&gt;There's an interesting relation between the concepts discussed in this post
and &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Maximum_likelihood_estimation"&gt;Maximum Likelihood Estimation&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;Suppose we have a true probability distribution &lt;em&gt;P&lt;/em&gt;, and a parameterized model
that predicts the probability distribution &lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e8cde2aa522fd55113d3c90a6e02ac752e6ad9de.svg" style="height: 16px;" type="image/svg+xml"&gt;Q_\theta&lt;/object&gt;. &lt;img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /&gt;
stands for all the parameters of our model (e.g. all the weights of a deep
learning network).&lt;/p&gt;
&lt;p&gt;The &lt;em&gt;likelihood&lt;/em&gt; of observing a set of samples &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/703980c1da69640a87feb2a4e0977836026442bb.svg" style="height: 11px;" type="image/svg+xml"&gt;x_1\cdots x_n&lt;/object&gt; drawn
from &lt;em&gt;P&lt;/em&gt; is:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/eae84dafb00d4e74b504fdb974b9c03925224af1.svg" style="height: 49px;" type="image/svg+xml"&gt;\[L=\prod ^{n}_{i=1}P(x_i)\]&lt;/object&gt;
&lt;p&gt;However, we don't really know &lt;em&gt;P&lt;/em&gt;; what we do know is &lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e8cde2aa522fd55113d3c90a6e02ac752e6ad9de.svg" style="height: 16px;" type="image/svg+xml"&gt;Q_\theta&lt;/object&gt;, so
we can calculate:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/dd18eafdbf58c511a91a4efb3ad973586161fb40.svg" style="height: 49px;" type="image/svg+xml"&gt;\[L(\theta)=\prod ^{n}_{i=1}Q_\theta(x_i)\]&lt;/object&gt;
&lt;p&gt;The idea is to find an optimal set of parameters &lt;object class="valign-0" data="https://eli.thegreenplace.net/images/math/5f8bf92383eafb1f6f5fbffb6dcb58d1bf1a9319.svg" style="height: 19px;" type="image/svg+xml"&gt;\widehat{\theta}&lt;/object&gt;
such that this likelihood is maximized; in other words:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/419c9a0618b474805144567b8c9d4f1d290848ca.svg" style="height: 49px;" type="image/svg+xml"&gt;\[\widehat{\theta}=\underset{\theta}{argmax}\ L(\theta)=\underset{\theta}{argmax}\ \prod ^{n}_{i=1}Q_\theta(x_i)\]&lt;/object&gt;
&lt;p&gt;Working with products is inconvenient, however, so a logarithm is used instead
to convert a product to a sum (since &lt;object class="valign-m5" data="https://eli.thegreenplace.net/images/math/53f77c87c1ab5f4c797a46613e0db513cc75dd67.svg" style="height: 19px;" type="image/svg+xml"&gt;log(f(x))&lt;/object&gt; is a monotonically
increasing function, maximizing it is akin to maximizing &lt;img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /&gt; itself):&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/7502cc60fad662058324ef62d8e7bf3e8e8bb7b7.svg" style="height: 49px;" type="image/svg+xml"&gt;\[\widehat{\theta}=\underset{\theta}{argmax}\ \log L(\theta)=\underset{\theta}{argmax}\ \sum ^{n}_{i=1}\log Q_\theta(x_i)\]&lt;/object&gt;
&lt;p&gt;This is the &lt;em&gt;maximal log-likelihood&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;Now a clever statistical trick is employed; first, we multiply the function
we're maximizing by the constant &lt;object class="valign-m6" data="https://eli.thegreenplace.net/images/math/da4e4caaf82d121438b1882f8b0a08baff2aee00.svg" style="height: 22px;" type="image/svg+xml"&gt;\frac{1}{n}&lt;/object&gt; - this doesn't affect the
maxima, of course:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/c2758a1be51374150e26bde87e58aa29adb2f152.svg" style="height: 49px;" type="image/svg+xml"&gt;\[\widehat{\theta}=\underset{\theta}{argmax}\ \frac{1}{n}\sum ^{n}_{i=1}\log Q_\theta(x_i)\]&lt;/object&gt;
&lt;p&gt;The function inside the &lt;em&gt;argmax&lt;/em&gt; is now the average across &lt;em&gt;n&lt;/em&gt; samples obtained
from the true probability distribution &lt;em&gt;P&lt;/em&gt;. The &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Law_of_large_numbers"&gt;Law of Large numbers&lt;/a&gt;
states that with a large enough &lt;em&gt;n&lt;/em&gt;, this average converges to the expected
value of drawing from this distribution:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/b9c22053a992846decc249254004e3b57fb9444c.svg" style="height: 49px;" type="image/svg+xml"&gt;\[\widehat{\theta}=\underset{\theta}{argmax}\ \sum ^{n}_{i=1}P(x_i)\log Q_\theta(x_i)\]&lt;/object&gt;
&lt;p&gt;This should start looking familiar; all that's left is to negate the sum and
minimize the negative instead:&lt;/p&gt;
&lt;object class="align-center" data="https://eli.thegreenplace.net/images/math/069d4d7ead9db053ea269d38c24deb3e003ecfe4.svg" style="height: 49px;" type="image/svg+xml"&gt;\[\widehat{\theta}=\underset{\theta}{argmin}\ -\sum ^{n}_{i=1}P(x_i)\log Q_\theta(x_i)\]&lt;/object&gt;
&lt;p&gt;The function we're now minimizing is the &lt;em&gt;cross-entropy&lt;/em&gt; between &lt;em&gt;P&lt;/em&gt; and
&lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e8cde2aa522fd55113d3c90a6e02ac752e6ad9de.svg" style="height: 16px;" type="image/svg+xml"&gt;Q_\theta&lt;/object&gt;. We've shown that maximum likelihood estimation is equivalent
to minimizing the cross-entropy between the true and and predicted data
distributions.&lt;/p&gt;
&lt;hr class="docutils" /&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-1" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-1"&gt;[1]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;This can be proven by taking the limit &lt;object class="valign-m6" data="https://eli.thegreenplace.net/images/math/cb7cbf90311ea6820f39588c1af3de46e31e3389.svg" style="height: 18px;" type="image/svg+xml"&gt;\lim_{p\to 0} p \log p&lt;/object&gt;
and using L'Hopital's rule to show that it goes to 0.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
</content><category term="misc"></category><category term="Math"></category><category term="Machine Learning"></category></entry><entry><title>Reproducing word2vec with JAX</title><link href="https://eli.thegreenplace.net/2025/reproducing-word2vec-with-jax/" rel="alternate"></link><published>2025-04-05T06:19:00-07:00</published><updated>2025-04-05T13:18:49-07:00</updated><author><name>Eli Bendersky</name></author><id>tag:eli.thegreenplace.net,2025-04-05:/2025/reproducing-word2vec-with-jax/</id><summary type="html">&lt;p&gt;The word2vec model was proposed in a 2013 paper by Google researchers called
&lt;a class="reference external" href="https://arxiv.org/pdf/1301.3781"&gt;&amp;quot;Efficient Estimation of Word Representations in Vector Space&amp;quot;&lt;/a&gt;,
and was further refined by additional papers from the same team. It kick-started
the modern use of &lt;em&gt;embeddings&lt;/em&gt; - dense vector representation of words (and later
tokens) for language models …&lt;/p&gt;</summary><content type="html">&lt;p&gt;The word2vec model was proposed in a 2013 paper by Google researchers called
&lt;a class="reference external" href="https://arxiv.org/pdf/1301.3781"&gt;&amp;quot;Efficient Estimation of Word Representations in Vector Space&amp;quot;&lt;/a&gt;,
and was further refined by additional papers from the same team. It kick-started
the modern use of &lt;em&gt;embeddings&lt;/em&gt; - dense vector representation of words (and later
tokens) for language models.&lt;/p&gt;
&lt;p&gt;Also, the code - with some instructions - &lt;a class="reference external" href="https://code.google.com/archive/p/word2vec/"&gt;was made available openly&lt;/a&gt;.
This post reproduces the word2vec results using JAX, and also talks about
reproducing it using the original C code (see the &lt;em&gt;Original word2vec code&lt;/em&gt;
section for that).&lt;/p&gt;
&lt;div class="section" id="embeddings"&gt;
&lt;h2&gt;Embeddings&lt;/h2&gt;
&lt;p&gt;First, a brief introduction to embeddings.
Wikipedia has a good definition:&lt;/p&gt;
&lt;blockquote&gt;
In natural language processing, a word embedding is a representation of a
word. The embedding is used in text analysis. Typically, the representation is a
real-valued vector that encodes the meaning of the word in such a way that the
words that are closer in the vector space are expected to be similar in meaning&lt;/blockquote&gt;
&lt;p&gt;Here's a framework that made sense to me when I was first learning about
embeddings many years ago:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;ML models and NNs specifically are all about vector math.&lt;/li&gt;
&lt;li&gt;Words in a human language (like English) are just sequences of characters
with no semantic meaning (there's nothing in the word &amp;quot;dog&amp;quot; that conveys
dog-ness any more than the same concept in other human languages). Also, words
have different lengths which isn't convenient.&lt;/li&gt;
&lt;li&gt;To represent words as vectors, we typically use indices into a vocabulary;
equivalently, this can be seen as a one-hot vector with the value at the
correct vocabulary index being 1, and the rest 0.&lt;/li&gt;
&lt;li&gt;This latter vector representation has no semantic meaning either, because
&amp;quot;Paris&amp;quot; and &amp;quot;France&amp;quot; will be as different from each other as &amp;quot;Paris&amp;quot; and
&amp;quot;Armadillo&amp;quot;. Also, these vectors are huge (a typical vocabulary can have
tens of thousands of words, just for a single language!)&lt;/li&gt;
&lt;li&gt;Therefore, we need some magic to convert words into vectors that carry
meaning.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Embeddings are that magic. They are dense vectors of floats - with typically
hundreds or thousands of elements, and serve as representations of these words
in high-dimensional space.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="the-word2vec-cbow-architecture"&gt;
&lt;h2&gt;The word2vec CBOW architecture&lt;/h2&gt;
&lt;p&gt;The word2vec paper proposed two related architectures: CBOW (Continuous Bag Of
Words) and Continuous Skip Gram. The two are fairly similar, and in this post
I'm going to focus on CBOW.&lt;/p&gt;
&lt;p&gt;The idea of the CBOW approach is to teach the model to predict a word from its
surrounding words. Here's an example with window size of four &lt;a class="footnote-reference" href="#footnote-1" id="footnote-reference-1"&gt;[1]&lt;/a&gt;:&lt;/p&gt;
&lt;img alt="CBOW - showing word in center of window, with context words around" class="align-center" src="https://eli.thegreenplace.net/images/2025/word2vec-cbow.png" /&gt;
&lt;p&gt;The goal here is to have the model predict that &amp;quot;liberty&amp;quot; should be the word
in the middle, given the context words in peach-colored boxes. This is an
&lt;em&gt;unsupervised&lt;/em&gt; model - it learns by consuming text, sliding its window word
by word over arbitrary amounts of (properly formatted and sanitized) input.&lt;/p&gt;
&lt;p&gt;Concretely, the following diagram shows the model architecture; here are
the dimensions involved:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;B: batch (for computational efficiency, whole batches are processed together)&lt;/li&gt;
&lt;li&gt;V: vocabulary size (the number of unique words in our vocabulary)&lt;/li&gt;
&lt;li&gt;D: model depth (the size of the dense embedding vectors we're trying to learn)&lt;/li&gt;
&lt;li&gt;W: window size&lt;/li&gt;
&lt;/ul&gt;
&lt;img alt="word2vec CBOW model architecture" class="align-center" src="https://eli.thegreenplace.net/images/2025/word2vec-arch.png" /&gt;
&lt;p&gt;Here's the flow of data in the forward pass:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;&lt;tt class="docutils literal"&gt;context&lt;/tt&gt; is the context words for a given position. For example, in the
sample diagram above the context would be of length 8. Each element is an
integer representation of a word (its index into the vocabulary). Since
we're processing batches, the shape of this array is (B,2W).&lt;/li&gt;
&lt;li&gt;The context &lt;em&gt;indexes&lt;/em&gt; into a projection matrix &lt;tt class="docutils literal"&gt;P&lt;/tt&gt;, which has the learned
embedding per row - one for each word in the vocabulary. The result is
&lt;tt class="docutils literal"&gt;projection&lt;/tt&gt; with shape (B,2W,D). The first two dimensions remain the same
(because we still have the same batch and window size), but every integer
is replaced with the word's embedding - so an extra dimension is added.&lt;/li&gt;
&lt;li&gt;Next, a &lt;em&gt;mean&lt;/em&gt; (arithmetic average) is taken across the window dimension.
The embeddings of all the words in the window are averaged together. The
result is (B,D) where each row is the average of the embeddings of 2W words.&lt;/li&gt;
&lt;li&gt;Finally, the &lt;em&gt;hidden&lt;/em&gt; layer matrix &lt;tt class="docutils literal"&gt;H&lt;/tt&gt; is used to map the dense
representation back into a sparse one &lt;a class="footnote-reference" href="#footnote-2" id="footnote-reference-2"&gt;[2]&lt;/a&gt; - this is the prediction of the middle
word. Recall that this tries to predict a one-hot encoding of the word's
vocabulary index.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;For training, the loss is calculated by comparing &lt;tt class="docutils literal"&gt;out&lt;/tt&gt; to the one-hot
encoding of the actual target word for this window, and the calculated gradient
is propagated backwards to train the model.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="jax-implementation"&gt;
&lt;h2&gt;JAX implementation&lt;/h2&gt;
&lt;p&gt;The JAX implementation of the model described above is clean and compact:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="nd"&gt;@jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jit&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;word2vec_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Forward pass of the word2Vec model.&lt;/span&gt;

&lt;span class="sd"&gt;    context is a (batch_size, 2*window_size) array of word IDs.&lt;/span&gt;

&lt;span class="sd"&gt;    V is the vocabulary size, D is the embedding dimension.&lt;/span&gt;
&lt;span class="sd"&gt;    params[&amp;quot;projection&amp;quot;] is a (V, D) matrix of word embeddings.&lt;/span&gt;
&lt;span class="sd"&gt;    params[&amp;quot;hidden&amp;quot;] is a (D, V) matrix of weights for the hidden layer.&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="c1"&gt;# Indexing into (V, D) matrix with a batch of IDs. The output shape&lt;/span&gt;
    &lt;span class="c1"&gt;# is (batch_size, 2*window_size, D).&lt;/span&gt;
    &lt;span class="n"&gt;projection&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;projection&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="c1"&gt;# Compute average across the context word. The output shape is&lt;/span&gt;
    &lt;span class="c1"&gt;# (batch_size, D).&lt;/span&gt;
    &lt;span class="n"&gt;avg_projection&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;projection&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# (batch_size, D) @ (D, V) -&amp;gt; (batch_size, V)&lt;/span&gt;
    &lt;span class="n"&gt;hidden&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;avg_projection&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;hidden&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;


&lt;span class="nd"&gt;@jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jit&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;word2vec_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Compute the loss of the word2Vec model.&amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;word2vec_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (batch_size, V)&lt;/span&gt;

    &lt;span class="n"&gt;target_onehot&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;one_hot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# (batch_size, V)&lt;/span&gt;
    &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;losses&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax_cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_onehot&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;div class="section" id="training"&gt;
&lt;h2&gt;Training&lt;/h2&gt;
&lt;p&gt;For training, I've been relying on the same dataset used by the original word2vec
code - a 100MB text file downloaded from &lt;a class="reference external" href="http://mattmahoney.net/dc/text8.zip"&gt;http://mattmahoney.net/dc/text8.zip&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;This file contains all-lowercase text with no punctuation, so it requires very
little cleaning and processing. What it &lt;em&gt;does&lt;/em&gt; require for higher-quality
training is &lt;em&gt;subsampling&lt;/em&gt;: throwing away some of the most common words (e.g.
&amp;quot;and&amp;quot;, &amp;quot;is&amp;quot;, &amp;quot;not&amp;quot; in English), since they appear so much in the text. Here's
my code for this:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;subsample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;threshold&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Subsample frequent words, return a new list of words.&lt;/span&gt;

&lt;span class="sd"&gt;    Follows the subsampling procedure described in the paper &amp;quot;Distributed&lt;/span&gt;
&lt;span class="sd"&gt;    Representations of Words and Phrases and their Compositionality&amp;quot; by&lt;/span&gt;
&lt;span class="sd"&gt;    Mikolov et al. (2013).&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="n"&gt;word_counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Counter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;total_count&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;freqs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;count&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;total_count&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;count&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()}&lt;/span&gt;

    &lt;span class="c1"&gt;# Common words (freq(word) &amp;gt; threshold) are kept with a computed&lt;/span&gt;
    &lt;span class="c1"&gt;# probability, while rare words are always kept.&lt;/span&gt;
    &lt;span class="n"&gt;p_keep&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
        &lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;threshold&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;freqs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;freqs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;threshold&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;word_counts&lt;/span&gt;
    &lt;span class="p"&gt;}&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;p_keep&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;We also have to create a vocabulary with some limited size:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;make_vocabulary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;top_k&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;20000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Creates a vocabulary from a list of words.&lt;/span&gt;

&lt;span class="sd"&gt;    Keeps the top_k most common words and assigns an index to each word. The&lt;/span&gt;
&lt;span class="sd"&gt;    index 0 is reserved for the &amp;quot;&amp;lt;unk&amp;gt;&amp;quot; token.&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="n"&gt;word_counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Counter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;vocab&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;&amp;lt;unk&amp;gt;&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;most_common&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;top_k&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The preprocessing step generates the list of subsampled words and the
vocabulary, and stores them in a pickle file for future reference. The
training loop uses these data to train a model from a random initialization.
Pay special attention to the hyper-parameters defined at the top of the
&lt;tt class="docutils literal"&gt;train&lt;/tt&gt; function. I set these to be as close as possible to the original
word2vec code:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;V&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;D&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;200&lt;/span&gt;
    &lt;span class="n"&gt;LEARNING_RATE&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1e-3&lt;/span&gt;
    &lt;span class="n"&gt;WINDOW_SIZE&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
    &lt;span class="n"&gt;BATCH_SIZE&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1024&lt;/span&gt;
    &lt;span class="n"&gt;EPOCHS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;25&lt;/span&gt;

    &lt;span class="n"&gt;initializer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;initializers&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;glorot_uniform&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
        &lt;span class="s2"&gt;&amp;quot;projection&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;initializer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;PRNGKey&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;501337&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;D&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
        &lt;span class="s2"&gt;&amp;quot;hidden&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;initializer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;PRNGKey&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;501337&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;D&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
    &lt;span class="p"&gt;}&lt;/span&gt;

    &lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;adam&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;LEARNING_RATE&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;opt_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;Approximate number of batches:&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;BATCH_SIZE&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;EPOCHS&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;=== Epoch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;epoch_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;target_batch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context_batch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;generate_train_vectors&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;train_data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;window_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;WINDOW_SIZE&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;BATCH_SIZE&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="c1"&gt;# Shuffle the batch.&lt;/span&gt;
            &lt;span class="n"&gt;indices&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;permutation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;target_batch&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
            &lt;span class="n"&gt;target_batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;target_batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;indices&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="n"&gt;context_batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context_batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;indices&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

            &lt;span class="c1"&gt;# Compute the loss and gradients; optimize.&lt;/span&gt;
            &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;grads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_and_grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;word2vec_loss&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;
                &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_batch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context_batch&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;updates&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;opt_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;grads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;opt_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;apply_updates&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;updates&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="n"&gt;epoch_loss&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;1000&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;Batch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;Epoch loss: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epoch_loss&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.2f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;checkpoint_filename&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;checkpoint-&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;03&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;.pickle&amp;quot;&lt;/span&gt;
        &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;Saving checkpoint to&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;checkpoint_filename&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nb"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;checkpoint_filename&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;quot;wb&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;file&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;pickle&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dump&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;file&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The only thing I'm not showing here is the &lt;tt class="docutils literal"&gt;generate_train_vectors&lt;/tt&gt; function,
as it's not particularly interesting; you can find it
&lt;a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/main/word2vec-jax/train.py"&gt;in the full code&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;I don't have a particularly powerful GPU, so on my machine training this model
for 25 epochs takes 20-30 minutes.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="extracting-embeddings-and-finding-word-similarities"&gt;
&lt;h2&gt;Extracting embeddings and finding word similarities&lt;/h2&gt;
&lt;p&gt;The result of the training is the &lt;tt class="docutils literal"&gt;P&lt;/tt&gt; and &lt;tt class="docutils literal"&gt;H&lt;/tt&gt; arrays with trained weights;
&lt;tt class="docutils literal"&gt;P&lt;/tt&gt; is exactly the embedding matrix we need! It maps vocabulary words to their
dense embedding representation. Using &lt;tt class="docutils literal"&gt;P&lt;/tt&gt;, we can create the fun word demos
that made word2vec famous. The full code has a script named &lt;tt class="docutils literal"&gt;&lt;span class="pre"&gt;similar-words.py&lt;/span&gt;&lt;/tt&gt;
that does this. Some examples:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;$ uv run similar-words.py -word paris \
      -checkpoint checkpoint.pickle \
      -traindata train-data.pickle
Words similar to &amp;#39;paris&amp;#39;:
paris           1.00
france          0.50
french          0.49
la              0.42
le              0.41
henri           0.40
toulouse        0.38
brussels        0.38
petit           0.38
les             0.38
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;And:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;$ uv run similar-words.py -analogy berlin,germany,tokyo \
      -checkpoint checkpoint.pickle \
      -traindata train-data.pickle
Analogies for &amp;#39;berlin is to germany as tokyo is to ?&amp;#39;:
tokyo           0.70
japan           0.45
japanese        0.44
osaka           0.40
china           0.36
germany         0.35
singapore       0.32
han             0.31
gu              0.31
kyushu          0.31
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;This brings us to the intuition for how word2vec works: the basic idea is that
semantically similar words will appear in the vicinity of roughly similar
context words, but also that words are generally related to words in the
context their appear in. This lets the model learn that some words are more
related than others; for example:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;$ uv run similar-words.py -sims soccer,basketball,chess,cat,bomb \
      -checkpoint checkpoint.pickle \
      -traindata train-data.pickle
Similarities for &amp;#39;soccer&amp;#39; with context words [&amp;#39;basketball&amp;#39;, &amp;#39;chess&amp;#39;, &amp;#39;cat&amp;#39;, &amp;#39;bomb&amp;#39;]:
basketball      0.40
chess           0.22
cat             0.14
bomb            0.13
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;div class="section" id="optimizations"&gt;
&lt;h2&gt;Optimizations&lt;/h2&gt;
&lt;p&gt;The word2vec model can be optimized in several ways, many of which are focused
on avoiding the giant matrix multiplication by &lt;tt class="docutils literal"&gt;H&lt;/tt&gt; at the very end. The
word2vec authors have a followup paper called &lt;a class="reference external" href="https://arxiv.org/pdf/1310.4546"&gt;&amp;quot;Distributed Representations of
Words and Phrases and their Compositionality&amp;quot;&lt;/a&gt;
where these are described; I'm leaving them out of my implementation, for
simplicity.&lt;/p&gt;
&lt;p&gt;Implementing these optimizations could help us improve the model's quality
considerably, by increasing the model depth (it's currently 200, which is
very low by modern LLM standards) and the amount of data we train on. That
said, these days word2vec is mostly of historical interest anyway; the
&lt;em&gt;Modern text embeddings&lt;/em&gt; section will have more to say on how embeddings are
trained as part of modern LLMs.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="original-word2vec-code"&gt;
&lt;h2&gt;Original word2vec code&lt;/h2&gt;
&lt;p&gt;As mentioned above, the original website for the word2vec model is available
on an &lt;a class="reference external" href="https://code.google.com/archive/p/word2vec/"&gt;archived version of Google Code&lt;/a&gt;.
That page is still useful reading, but the Subversion instructions to obtain
the actual code no longer work.&lt;/p&gt;
&lt;p&gt;I was able to find a GitHub mirror with a code export here: &lt;a class="reference external" href="https://github.com/tmikolov/word2vec"&gt;https://github.com/tmikolov/word2vec&lt;/a&gt;
(the username certainly checks out, though it's hard to know for sure!)&lt;/p&gt;
&lt;p&gt;The awesome thing is that this code still builds and runs perfectly, many years
later. Hurray to self-contained C programs with no dependencies; all I needed
was to run &lt;tt class="docutils literal"&gt;make&lt;/tt&gt;, and then use the included shell scripts to download the
data and run training. This code uses the CPU for training; it takes a while,
but I was able to reproduce the similarity / analogy results fairly easily.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="modern-text-embeddings"&gt;
&lt;h2&gt;Modern text embeddings&lt;/h2&gt;
&lt;p&gt;The word2vec model trains an embedding matrix; this &lt;em&gt;pre-trained&lt;/em&gt; matrix can
then be used as part of other ML models. This approach was used for a while,
but it's no longer popular.&lt;/p&gt;
&lt;p&gt;These days, an embedding matrix is trained as part of a larger model.
For example, GPT-type transformer-based LLMs have an embedding matrix as the
first layer in the model. This is basically just the &lt;tt class="docutils literal"&gt;P&lt;/tt&gt; matrix from the
diagram above &lt;a class="footnote-reference" href="#footnote-3" id="footnote-reference-3"&gt;[3]&lt;/a&gt;. LLMs learn both the
embeddings and their specific task (generating tokens from a given context)
at the same time. This makes some sense because:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;LLMs process enormous amounts of data, and consuming this data multiple times
to train embeddings separately is wasteful.&lt;/li&gt;
&lt;li&gt;Embeddings trained together with the LLM are inherently tuned to the LLM's
specific task and hyper-parameters (i.e. the kind of tokenizer used, the
model depth etc.)&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Specifically, modern embedding matrices differ from word2vec in two important
aspects:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;Instead of being &lt;em&gt;word&lt;/em&gt; embeddings, they are &lt;em&gt;token&lt;/em&gt; embeddings. I wrote much
more on &lt;a class="reference external" href="https://eli.thegreenplace.net/2024/tokens-for-llms-byte-pair-encoding-in-go/"&gt;tokens for LLMs here&lt;/a&gt;.&lt;/li&gt;
&lt;li&gt;The model depth (D) is &lt;em&gt;much&lt;/em&gt; larger; GPT-3 has D=12288, and in newer models
it's probably even larger. Deep embedding vectors help the models capture more
nuance and semantic meaning about tokens. Naturally, they also require much
more data to be trained effectively.&lt;/li&gt;
&lt;/ul&gt;
&lt;/div&gt;
&lt;div class="section" id="full-code"&gt;
&lt;h2&gt;Full code&lt;/h2&gt;
&lt;p&gt;The full code for this post is &lt;a class="reference external" href="https://github.com/eliben/deep-learning-samples/tree/main/word2vec-jax"&gt;available here&lt;/a&gt;.
If you want to reproduce the my word2vec results, check out the README file - it
contains full instructions on which scripts to run and in which order.&lt;/p&gt;
&lt;hr class="docutils" /&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-1" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-1"&gt;[1]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;The window size is
how many words to the left and right of the target word to take into account,
and it's a configurable hyper-parameter during training.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-2" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-2"&gt;[2]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;&lt;p class="first"&gt;The terms &lt;em&gt;dense&lt;/em&gt; and &lt;em&gt;sparse&lt;/em&gt; are used in the post in the following
sense:&lt;/p&gt;
&lt;p&gt;A sparse array is one where almost all entries are 0. This is true
for one-hot vectors representing vocabulary words (all entries are 0
except a single one that has the value 1).&lt;/p&gt;
&lt;p class="last"&gt;A dense array is filled with arbitrary floating-point
values. An embedding vector is dense in this sense - it's typically
short compared to the sparse vector (in the word2vec example used in
this post D=200, while V=20000), but full of data (hence &amp;quot;dense&amp;quot;). An
embedding matrix is dense since it consists of dense vectors (one per
word index).&lt;/p&gt;
&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-3" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-3"&gt;[3]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;The rest (mean calculation, hidden layer) isn't needed since
it's only there to &lt;em&gt;train&lt;/em&gt; the word2vec CBOW model.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
</content><category term="misc"></category><category term="Python"></category><category term="Math"></category><category term="Machine Learning"></category></entry><entry><title>Notes on implementing Attention</title><link href="https://eli.thegreenplace.net/2025/notes-on-implementing-attention/" rel="alternate"></link><published>2025-03-26T17:15:00-07:00</published><updated>2025-05-02T01:36:27-07:00</updated><author><name>Eli Bendersky</name></author><id>tag:eli.thegreenplace.net,2025-03-26:/2025/notes-on-implementing-attention/</id><summary type="html">&lt;p&gt;Some notes on implementing attention blocks in pure Python +
Numpy. The focus here is on the exact implementation in code, explaining all the
shapes throughout the process. The motivation for why attention works is not
covered here too deeply - there are plenty of excellent online resources
explaining it.&lt;/p&gt;
&lt;p&gt;Several papers …&lt;/p&gt;</summary><content type="html">&lt;p&gt;Some notes on implementing attention blocks in pure Python +
Numpy. The focus here is on the exact implementation in code, explaining all the
shapes throughout the process. The motivation for why attention works is not
covered here too deeply - there are plenty of excellent online resources
explaining it.&lt;/p&gt;
&lt;p&gt;Several papers are mentioned throughout the code; they are:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;AIAYN - &lt;a class="reference external" href="https://arxiv.org/abs/1706.03762"&gt;Attention Is All You Need&lt;/a&gt; by
Vaswani et al.&lt;/li&gt;
&lt;li&gt;GPT-3 - &lt;a class="reference external" href="https://arxiv.org/abs/2005.14165"&gt;Language Models are Few-Shot Learners&lt;/a&gt; by Brown et al.&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="section" id="basic-scaled-self-attention"&gt;
&lt;h2&gt;Basic scaled self-attention&lt;/h2&gt;
&lt;p&gt;We'll start with the most basic scaled dot product self-attention, working on a
single sequence of tokens, without masking.&lt;/p&gt;
&lt;p&gt;The input is a 2D array of shape (N, D). N is the length of the sequence (how
many tokens it contains) and D is the embedding depth - the length of the
embedding vector representing each token &lt;a class="footnote-reference" href="#footnote-1" id="footnote-reference-1"&gt;[1]&lt;/a&gt;. D could be something like
512, or more, depending on the model.&lt;/p&gt;
&lt;img alt="input array N by D" class="align-center" src="https://eli.thegreenplace.net/images/2025/nd-array.png" /&gt;
&lt;p&gt;A self-attention module is parameterized with three weight matrices, &lt;tt class="docutils literal"&gt;Wk&lt;/tt&gt;,
&lt;tt class="docutils literal"&gt;Wq&lt;/tt&gt; and &lt;tt class="docutils literal"&gt;Wv&lt;/tt&gt;. Some variants also have accompanying bias vectors, but the
AIAYN paper doesn't use them, so I'll skip them here. In the general case,
the shape of each weight matrix is (D, HS), where HS is some fraction of
D. HS stands for &amp;quot;head size&amp;quot; and we'll see what this means soon. This is a
diagram of a self-attention module (the diagram assumes
N=6, D is some large number and so is HS). In the diagram, &lt;tt class="docutils literal"&gt;&amp;#64;&lt;/tt&gt; stands for
matrix multiplication (Python/Numpy syntax):&lt;/p&gt;
&lt;img alt="schematic of a single attention head" class="align-center" src="https://eli.thegreenplace.net/images/2025/attention-head.png" /&gt;
&lt;p&gt;Here's a basic Numpy implementation of this:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# self_attention the way it happens in the Transformer model. No bias.&lt;/span&gt;
&lt;span class="c1"&gt;# D = model dimension/depth (length of embedding)&lt;/span&gt;
&lt;span class="c1"&gt;# N = input sequence length&lt;/span&gt;
&lt;span class="c1"&gt;# HS = head size&lt;/span&gt;
&lt;span class="c1"&gt;#&lt;/span&gt;
&lt;span class="c1"&gt;# x is the input (N, D), each token in a row.&lt;/span&gt;
&lt;span class="c1"&gt;# Each of W* is a weight matrix of shape (D, HS)&lt;/span&gt;
&lt;span class="c1"&gt;# The result is (N, HS)&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;self_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wv&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Each of these is (N, D) @ (D, HS) = (N, HS)&lt;/span&gt;
    &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wq&lt;/span&gt;
    &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wk&lt;/span&gt;
    &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wv&lt;/span&gt;

    &lt;span class="c1"&gt;# kq: (N, N) matrix of dot products between each pair of q and k vectors.&lt;/span&gt;
    &lt;span class="c1"&gt;# The division by sqrt(HS) is the scaling.&lt;/span&gt;
    &lt;span class="n"&gt;kq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

    &lt;span class="c1"&gt;# att: (N, N) attention matrix. The rows become the weights that sum&lt;/span&gt;
    &lt;span class="c1"&gt;# to 1 for each output vector.&lt;/span&gt;
    &lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;  &lt;span class="c1"&gt;# (N, HS)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The &amp;quot;scaled&amp;quot; part is just dividing &lt;tt class="docutils literal"&gt;kq&lt;/tt&gt; by the square root of &lt;tt class="docutils literal"&gt;HS&lt;/tt&gt;, which
is done to keep the values of the dot products manageable (otherwise they would
grow with the size of the contracted dimension).&lt;/p&gt;
&lt;p&gt;The only dependency is a function for &lt;a class="reference external" href="https://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/"&gt;calculating Softmax&lt;/a&gt;
across the last dimension of an input array:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;softmax_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Compute softmax across last dimension of x.&lt;/span&gt;

&lt;span class="sd"&gt;    x is an arbitrary array with at least two dimensions. The returned array has&lt;/span&gt;
&lt;span class="sd"&gt;    the same shape as x, but its elements sum up to 1 across the last dimension.&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="c1"&gt;# Subtract the max for numerical stability&lt;/span&gt;
    &lt;span class="n"&gt;ex&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="c1"&gt;# Divide by sums across last dimension&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;ex&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ex&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;When the input is 2D, the &amp;quot;last dimension&amp;quot; is the columns. Colloquially, this
Softmax function acts on each row of &lt;em&gt;x&lt;/em&gt; separately; it applies the Softmax
formula to the elements (columns) of the row, ending up with a row of numbers in
the range &lt;tt class="docutils literal"&gt;[0,1]&lt;/tt&gt; that all sum up to 1.&lt;/p&gt;
&lt;p&gt;Another note on the dimensions: it's possible for the &lt;tt class="docutils literal"&gt;Wv&lt;/tt&gt; matrix to have a
different second dimension from &lt;tt class="docutils literal"&gt;Wq&lt;/tt&gt; and &lt;tt class="docutils literal"&gt;Wk&lt;/tt&gt;. If you look at the diagram,
you can see this will work out, since the softmax produces (N, N), and whatever
the second dimension of V is, will be the second dimension of the output. The
AIAYN paper designates these dimensions as &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/85b78059c9068f7b482e10c9a50fad97ad8cdcf5.svg" style="height: 15px;" type="image/svg+xml"&gt;d_k&lt;/object&gt; and &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/c347663196aa1510668fbd932f8f62aa92e8b15d.svg" style="height: 15px;" type="image/svg+xml"&gt;d_v&lt;/object&gt;, but in
practice &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/8ee5144f6c9132f363cbc0a8a61b08a78b578968.svg" style="height: 15px;" type="image/svg+xml"&gt;d_k=d_v&lt;/object&gt; in all the variants it lists. I found that these
dimensions are typically the same in other papers as well. Therefore, for
simplicity I just made them all equal to D in this post; if desired, a variant
with different &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/85b78059c9068f7b482e10c9a50fad97ad8cdcf5.svg" style="height: 15px;" type="image/svg+xml"&gt;d_k&lt;/object&gt; and &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/c347663196aa1510668fbd932f8f62aa92e8b15d.svg" style="height: 15px;" type="image/svg+xml"&gt;d_v&lt;/object&gt; is a fairly trivial modification to
this code.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="batched-self-attention"&gt;
&lt;h2&gt;Batched self-attention&lt;/h2&gt;
&lt;p&gt;In the real world, the input array is unlikely to be 2D because models are
trained on &lt;em&gt;batches&lt;/em&gt; of input sequences. To leverage the parallelism of modern
hardware, whole batches are typically processed in the same operation.&lt;/p&gt;
&lt;img alt="input array (B, N, D)" class="align-center" src="https://eli.thegreenplace.net/images/2025/bnd-array.png" /&gt;
&lt;p&gt;The batched version of scaled self-attention is very similar to the non-batched
one, due to the magic of Numpy matrix multiplication and broadcasts. Now the
input shape is (B, N, D), where B is the batch dimension. The &lt;tt class="docutils literal"&gt;W*&lt;/tt&gt; matrices
are still (D, HS); multiplying a (B, N, D) array by (D, HS) performs contraction
between the last axis of the first array and the first axis of the second array,
resulting in (B, N, HS). Here's the code, with the dimensions annotated
for each operation:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# self_attention with inputs that have a batch dimension.&lt;/span&gt;
&lt;span class="c1"&gt;# x has shape (B, N, D)&lt;/span&gt;
&lt;span class="c1"&gt;# Each of W* has shape (D, D)&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;self_attention_batched&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wv&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wq&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, HS)&lt;/span&gt;
    &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wk&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, HS)&lt;/span&gt;
    &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wv&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, HS)&lt;/span&gt;

    &lt;span class="n"&gt;kq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;swapaxes&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, N)&lt;/span&gt;

    &lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, N)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, HS)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Note that the only difference between this and the non-batched version is the
line calculating &lt;tt class="docutils literal"&gt;kq&lt;/tt&gt;:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;Since &lt;tt class="docutils literal"&gt;k&lt;/tt&gt; is no longer 2D, the notion of &amp;quot;transpose&amp;quot; is ambiguous so we
explicitly ask to swap the last and the penultimate axis, leaving the first
axis (B) intact.&lt;/li&gt;
&lt;li&gt;When calculating the scaling factor we use &lt;tt class="docutils literal"&gt;&lt;span class="pre"&gt;k.shape[-1]&lt;/span&gt;&lt;/tt&gt; to select the
&lt;em&gt;last&lt;/em&gt; dimension of &lt;tt class="docutils literal"&gt;k&lt;/tt&gt;, instead of &lt;tt class="docutils literal"&gt;k.shape[1]&lt;/tt&gt; which only selects the
last dimension for 2D arrays.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;In fact, this function could also calculate the non-batched version! From now
on, we'll assume that all inputs are batched, and all operations are implicitly
batched. I'm not going to be using the &amp;quot;batched&amp;quot; prefix or suffix on functions
any more.&lt;/p&gt;
&lt;p&gt;The basic underlying idea of the attention module is to shift around the
multi-dimensional representations of tokens in the sequence towards a better
representation of the entire sequence. The tokens &lt;em&gt;attend to&lt;/em&gt; each other.
Specifically, the matrix produced by the Softmax operation is called the
&lt;em&gt;attention matrix&lt;/em&gt;. It's (N, N); for each token it specifies how much
information from every other token in the sequence should be taken into account.
For example, a higher number in cell (R, C) means that there's a stronger
relation of token at index R in the sequence to the token at index C.&lt;/p&gt;
&lt;p&gt;Here's a nice example from the AIAYN paper, showing a word sequence and the
weights produced by two attention heads (purple and brown) for a given position
in the input sequence:&lt;/p&gt;
&lt;img alt="attention paper screenshot showing learned attention" class="align-center" src="https://eli.thegreenplace.net/images/2025/aiayn-paper-screenshot.png" /&gt;
&lt;p&gt;This shows how the model is learning to resolve what the word &amp;quot;its&amp;quot; refers to
in the sentence. Let's take just the purple head as an example. The index of
token &amp;quot;its&amp;quot; in the sequence is 8, and the index of &amp;quot;Law&amp;quot; is 1. In the attention
matrix for this head, the value at index (8, 1) will be very high (close to 1),
with other values in the same row much lower.&lt;/p&gt;
&lt;p&gt;While this intuitive explanation isn't critical to understand how attention is
implemented, it will become more important when we talk about &lt;em&gt;masked&lt;/em&gt;
self-attention later on.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="multi-head-attention"&gt;
&lt;h2&gt;Multi-head attention&lt;/h2&gt;
&lt;p&gt;The attention mechanism we've seen so far has a single set of K, Q and V
matrices. This is called one &amp;quot;head&amp;quot; of attention. In today's models, there
are typically multiple heads. Each head does its attention job separately, and
in the end all these results are concatenated and feed through a linear layer.&lt;/p&gt;
&lt;p&gt;In what follows, NH is the number of heads and HS is the head size.
Typically, NH times HS would be D; for example, the AIAYN paper mentions
several configurations for D=512: NH=8 and HS=64, NH=32 and HS=16, and so on &lt;a class="footnote-reference" href="#footnote-2" id="footnote-reference-2"&gt;[2]&lt;/a&gt;.
However, the math works out even if this isn't the case, because the final linear
(&amp;quot;projection&amp;quot;) layer maps the output back to (N, D).&lt;/p&gt;
&lt;p&gt;Assuming the previous diagram showing a self-attention module is a single head
with input (N, D) and output (N, HS), this is how multiple heads are combined:&lt;/p&gt;
&lt;img alt="schematic of multiple attention heads" class="align-center" src="https://eli.thegreenplace.net/images/2025/multi-head-attention.png" /&gt;
&lt;p&gt;Each of the (NH) heads has its own parameter weights for Q, K and
V. Each attention head outputs a (N, HS) matrix; these are concatenated along
the last dimension to (N, NH * HS), which is passed through a final linear
projection.&lt;/p&gt;
&lt;p&gt;Here's a function implementing (batched) multi-head attention; for now, please
ignore the code inside &lt;tt class="docutils literal"&gt;do_mask&lt;/tt&gt; conditions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# x has shape (B, N, D)&lt;/span&gt;
&lt;span class="c1"&gt;# In what follows:&lt;/span&gt;
&lt;span class="c1"&gt;#   NH = number of heads&lt;/span&gt;
&lt;span class="c1"&gt;#   HS = head size&lt;/span&gt;
&lt;span class="c1"&gt;# Each W*s is a list of NH weight matrices of shape (D, HS).&lt;/span&gt;
&lt;span class="c1"&gt;# Wp is a weight matrix for the final linear projection, of shape (NH * HS, D)&lt;/span&gt;
&lt;span class="c1"&gt;# The result is (B, N, D)&lt;/span&gt;
&lt;span class="c1"&gt;# If do_mask is True, each attention head is masked from attending to future&lt;/span&gt;
&lt;span class="c1"&gt;# tokens.&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;multihead_attention_list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wqs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wvs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;do_mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Check shapes.&lt;/span&gt;
    &lt;span class="n"&gt;NH&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;HS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wqs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wvs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;Wqs&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;Wks&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;Wvs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;HS&lt;/span&gt;
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;Wp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;NH&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;HS&lt;/span&gt;

    &lt;span class="c1"&gt;# List of head outputs&lt;/span&gt;
    &lt;span class="n"&gt;head_outs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;do_mask&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="c1"&gt;# mask is a lower-triangular (N, N) matrix, with zeros above&lt;/span&gt;
        &lt;span class="c1"&gt;# the diagonal and ones on the diagonal and below.&lt;/span&gt;
        &lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tril&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;Wk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wv&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wqs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wvs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Calculate self attention for each head separately&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wq&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, HS)&lt;/span&gt;
        &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wk&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, HS)&lt;/span&gt;
        &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wv&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, HS)&lt;/span&gt;

        &lt;span class="n"&gt;kq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;swapaxes&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, N)&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;do_mask&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="c1"&gt;# Set the masked positions to -inf, to ensure that a token isn&amp;#39;t&lt;/span&gt;
            &lt;span class="c1"&gt;# affected by tokens that come after it in the softmax.&lt;/span&gt;
            &lt;span class="n"&gt;kq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;inf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, N)&lt;/span&gt;
        &lt;span class="n"&gt;head_outs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, HS)&lt;/span&gt;

    &lt;span class="c1"&gt;# Concatenate the head outputs and apply the final linear projection&lt;/span&gt;
    &lt;span class="n"&gt;all_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concatenate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;head_outs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, NH * HS)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;all_heads&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wp&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, D)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;It is possible to vectorize this code even further; you'll sometimes see the
heads laid out in a separate (4th) dimension instead of being a list. See
the &lt;em&gt;Vectorizing across the heads dimension&lt;/em&gt; section.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="masked-or-causal-self-attention"&gt;
&lt;h2&gt;Masked (or Causal) self-attention&lt;/h2&gt;
&lt;p&gt;Attention modules can be used in both &lt;em&gt;encoder&lt;/em&gt; and &lt;em&gt;decoder&lt;/em&gt; blocks. &lt;em&gt;Encoder&lt;/em&gt;
blocks are useful for things like language understanding or translation; for
these, it makes sense for each token to attend to all the other tokens in the
sequence.&lt;/p&gt;
&lt;p&gt;However, for generative models this presents a problem: if during training a
word attends to future words, the model will just &amp;quot;cheat&amp;quot; and not really learn
how to generate the next word from only past words. This is done in a &lt;em&gt;decoder&lt;/em&gt;
block, and for this we need to add masking to attention.&lt;/p&gt;
&lt;p&gt;Conceptually, masking is very simple. Consider the sentence:&lt;/p&gt;
&lt;blockquote&gt;
People like watching funny cat videos&lt;/blockquote&gt;
&lt;p&gt;When our attention code generates the &lt;tt class="docutils literal"&gt;att&lt;/tt&gt; matrix, it's a square (N, N)
matrix with attention weights from each token to each other token in the
sequence:&lt;/p&gt;
&lt;img alt="attention masking" class="align-center" src="https://eli.thegreenplace.net/images/2025/attention-masking.png" /&gt;
&lt;p&gt;What we want is for all the gray cells in this matrix to be zero, to ensure
that a token doesn't attend to future tokens. The blue cells in the matrix
add up to 1 in each row, after the softmax operation.&lt;/p&gt;
&lt;p&gt;Now take a look at the previous code sample and see what happens when
&lt;tt class="docutils literal"&gt;do_mask=True&lt;/tt&gt;:&lt;/p&gt;
&lt;ol class="arabic simple"&gt;
&lt;li&gt;First, a (N, N) lower-triangular array is prepared with zeros above the
diagonal and ones on the diagonal and below.&lt;/li&gt;
&lt;li&gt;Then, before we pass the scaled &lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/fa80a943cc46f6c1154320719b40219b80c9e5e4.svg" style="height: 19px;" type="image/svg+xml"&gt;QK^T&lt;/object&gt; to softmax, we set its values
to &lt;object class="valign-0" data="https://eli.thegreenplace.net/images/math/18787d835dea1ca698e365c252f82b506cecfce7.svg" style="height: 8px;" type="image/svg+xml"&gt;-\infty&lt;/object&gt; wherever the mask matrix is 0. This ensures that the
softmax function will assign zeros to outputs at these indices, while still
producing the proper values in the rest of the row.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;Another name for masked self-attention is &lt;em&gt;causal&lt;/em&gt; self-attention. This is a
very good name that comes from &lt;a class="reference external" href="https://en.wikipedia.org/wiki/Causal_system"&gt;causal systems&lt;/a&gt;
in control theory.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="intuition-what-attention-does"&gt;
&lt;h2&gt;Intuition - what attention does&lt;/h2&gt;
&lt;p&gt;What does the attention block try to accomplish? To think about it intuitively,
let's focus on a single token in the input (ignoring batch) - &lt;tt class="docutils literal"&gt;x[i]&lt;/tt&gt;. For
this token, the attention block produces an output token &lt;tt class="docutils literal"&gt;out[i]&lt;/tt&gt; that
blends &lt;tt class="docutils literal"&gt;x[i]&lt;/tt&gt;'s embedding (multi-dimensional dense vector representation)
with contextual information from all the tokens preceding it in
the sequence, i.e. &lt;tt class="docutils literal"&gt;&lt;span class="pre"&gt;x[:i]&lt;/span&gt;&lt;/tt&gt;.&lt;/p&gt;
&lt;p&gt;The way this is done is first calculating the &lt;em&gt;query&lt;/em&gt; vector &lt;tt class="docutils literal"&gt;q&lt;/tt&gt; for &lt;tt class="docutils literal"&gt;x[i]&lt;/tt&gt; (using
&lt;tt class="docutils literal"&gt;Wq&lt;/tt&gt;). This query can be thought of as &amp;quot;what attributes does this token care
about in its context tokens&amp;quot;.&lt;/p&gt;
&lt;p&gt;Then, for each of the context tokens (including &lt;tt class="docutils literal"&gt;x[i]&lt;/tt&gt; itself) we calculate:&lt;/p&gt;
&lt;ul class="simple"&gt;
&lt;li&gt;&lt;em&gt;Key&lt;/em&gt; (using &lt;tt class="docutils literal"&gt;Wk&lt;/tt&gt;): these are the attributes of the token that queries may
refer to.&lt;/li&gt;
&lt;li&gt;&lt;em&gt;Value&lt;/em&gt; (using &lt;tt class="docutils literal"&gt;Wv&lt;/tt&gt;): these are the associated values tokens carry.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;When attention calculates &lt;tt class="docutils literal"&gt;q &amp;#64; K.T&lt;/tt&gt; for each token, the result is - for each
context token - the weights to use for mixing in the token's value. Then, when
this is multiplied by &lt;tt class="docutils literal"&gt;V&lt;/tt&gt;, the values are properly weighted.&lt;/p&gt;
&lt;p&gt;So this is a very general approach for the model to learn what kind of
information each token &amp;quot;cares&amp;quot; about in its context tokens, and how to blend
the token's embedding with those of the preceding context tokens, to properly
encode the context the token is encountered in.&lt;/p&gt;
&lt;p&gt;Our implementation, starting with the basic scaled self-attention, implements
this for all tokens in the input sequence simultaneously; hence, we don't just
take a single &lt;tt class="docutils literal"&gt;x[i]&lt;/tt&gt;, calculate its &lt;tt class="docutils literal"&gt;q&lt;/tt&gt; and then multiply that by &lt;tt class="docutils literal"&gt;K.T&lt;/tt&gt;.
Rather, we calculate &lt;tt class="docutils literal"&gt;Q&lt;/tt&gt; from all &lt;tt class="docutils literal"&gt;x&lt;/tt&gt;, and continue using matrix
multiplications to vectorize these calculations across the entire sequence.&lt;/p&gt;
&lt;p&gt;It's important to keep in mind that this intuitive explanation suffers from
anthropomorphism. We try to explain what the model does intuitively, but in
reality this is only a very abstract approximation of what's happening (consider
that attention has multiple heads, and also that that LLMs typically have dozens
of repeating transformer layers with self-attention blocks, applying the same
mechanism over and over again).&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="cross-attention"&gt;
&lt;h2&gt;Cross-attention&lt;/h2&gt;
&lt;p&gt;So far we've been working with self-attention blocks, where the &lt;em&gt;self&lt;/em&gt; suggests
that elements in the input sequence attend to other elements in the same input
sequence.&lt;/p&gt;
&lt;p&gt;Another variant of attention is &lt;em&gt;cross-attention&lt;/em&gt;, where elements of one
sequence attend to elements in another sequence. This variant exists in the
decoder block of the AIAYN paper. This is a single head of
cross-attention:&lt;/p&gt;
&lt;img alt="cross-attention with different Nq, Nv" class="align-center" src="https://eli.thegreenplace.net/images/2025/cross-attention-head.png" /&gt;
&lt;p&gt;Here we have two sequences with potentially different lengths: &lt;tt class="docutils literal"&gt;xq&lt;/tt&gt; and
&lt;tt class="docutils literal"&gt;xv&lt;/tt&gt;. &lt;tt class="docutils literal"&gt;xq&lt;/tt&gt; is used for the query part of attention, while &lt;tt class="docutils literal"&gt;xv&lt;/tt&gt; is used for
the key and value parts. The rest of the dimensions remain as before. The output
of such a block is shaped (Nq, HS).&lt;/p&gt;
&lt;p&gt;This is an implementation of multi-head cross-attention; it doesn't include
masking, since masking is not typically necessary in cross attention - it's OK
for elements of &lt;tt class="docutils literal"&gt;xq&lt;/tt&gt; to attend to all elements of &lt;tt class="docutils literal"&gt;xv&lt;/tt&gt; &lt;a class="footnote-reference" href="#footnote-3" id="footnote-reference-3"&gt;[3]&lt;/a&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# Cross attention between two input sequences that can have different lengths.&lt;/span&gt;
&lt;span class="c1"&gt;# xq has shape (B, Nq, D)&lt;/span&gt;
&lt;span class="c1"&gt;# xv has shape (B, Nv, D)&lt;/span&gt;
&lt;span class="c1"&gt;# In what follows:&lt;/span&gt;
&lt;span class="c1"&gt;#   NH = number of heads&lt;/span&gt;
&lt;span class="c1"&gt;#   HS = head size&lt;/span&gt;
&lt;span class="c1"&gt;# Each W*s is a list of NH weight matrices of shape (D, HS).&lt;/span&gt;
&lt;span class="c1"&gt;# Wp is a weight matrix for the final linear projection, of shape (NH * HS, D)&lt;/span&gt;
&lt;span class="c1"&gt;# The result is (B, Nq, D)&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;multihead_cross_attention_list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wqs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wvs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wp&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Check shapes.&lt;/span&gt;
    &lt;span class="n"&gt;NH&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;HS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wqs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wvs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;Wqs&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;Wks&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;Wvs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;HS&lt;/span&gt;
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;Wp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;NH&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;HS&lt;/span&gt;

    &lt;span class="c1"&gt;# List of head outputs&lt;/span&gt;
    &lt;span class="n"&gt;head_outs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;Wk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wv&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Wks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wqs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wvs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;xq&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wq&lt;/span&gt;  &lt;span class="c1"&gt;# (B, Nq, HS)&lt;/span&gt;
        &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;xv&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wk&lt;/span&gt;  &lt;span class="c1"&gt;# (B, Nv, HS)&lt;/span&gt;
        &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;xv&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wv&lt;/span&gt;  &lt;span class="c1"&gt;# (B, Nv, HS)&lt;/span&gt;

        &lt;span class="n"&gt;kq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;swapaxes&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# (B, Nq, Nv)&lt;/span&gt;

        &lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, Nq, Nv)&lt;/span&gt;
        &lt;span class="n"&gt;head_outs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, Nq, HS)&lt;/span&gt;

    &lt;span class="c1"&gt;# Concatenate the head outputs and apply the final linear projection&lt;/span&gt;
    &lt;span class="n"&gt;all_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concatenate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;head_outs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, Nq, NH * HS)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;all_heads&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wp&lt;/span&gt;  &lt;span class="c1"&gt;# (B, Nq, D)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;div class="section" id="vectorizing-across-the-heads-dimension"&gt;
&lt;h2&gt;Vectorizing across the heads dimension&lt;/h2&gt;
&lt;p&gt;The &lt;tt class="docutils literal"&gt;multihead_attention_list&lt;/tt&gt; implementation shown above uses lists of weight
matrices as input. While this makes the code clearer, it's not a particularly
friendly format for an optimized implementation - especially on accelerators
like GPUs and TPUs. We can vectorize it further by creating a new dimension for
attention heads.&lt;/p&gt;
&lt;p&gt;To understand the trick being used, consider a basic matmul of (8, 6) by
(6, 2):&lt;/p&gt;
&lt;img alt="basic matrix multiplication" class="align-center" src="https://eli.thegreenplace.net/images/2025/matmul-28.png" /&gt;
&lt;p&gt;Now suppose we want to multiply our LHS by &lt;em&gt;another&lt;/em&gt; (6, 2) matrix. We can do
it all in the same operation by concatenating the two RHS matrices along
columns:&lt;/p&gt;
&lt;img alt="concatenated basic matrix multiplication" class="align-center" src="https://eli.thegreenplace.net/images/2025/matmul-concat-two.png" /&gt;
&lt;p&gt;If the yellow RHS block in both diagrams is identical, the green block of the
result will be as well. And the violet block is just the matmul of the LHS by
the red block of the RHS. This stems from the semantics of matrix
multiplication, and is easy to verify on paper.&lt;/p&gt;
&lt;p&gt;Now back to our multi-head attention. Note that we multiply the input &lt;em&gt;x&lt;/em&gt; by
a whole list of weight matrices - in fact, by &lt;em&gt;three&lt;/em&gt; lists (one list for Q,
one for K, and another for V). We can use the same vectorization technique by
concatenating all these weight matrices into a single one. Assuming that
&lt;tt class="docutils literal"&gt;NH * HS = D&lt;/tt&gt;, the shape of the combined matrix is (D, 3 * D). Here's
the vectorized implementation:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# x has shape (B, N, D)&lt;/span&gt;
&lt;span class="c1"&gt;# In what follows:&lt;/span&gt;
&lt;span class="c1"&gt;#   NH = number of heads&lt;/span&gt;
&lt;span class="c1"&gt;#   HS = head size&lt;/span&gt;
&lt;span class="c1"&gt;#   NH * HS = D&lt;/span&gt;
&lt;span class="c1"&gt;# W is expected to have shape (D, 3 * D), with all the weight matrices for&lt;/span&gt;
&lt;span class="c1"&gt;# Qs, Ks, and Vs concatenated along the last dimension, in this order.&lt;/span&gt;
&lt;span class="c1"&gt;# Wp is a weight matrix for the final linear projection, of shape (D, D).&lt;/span&gt;
&lt;span class="c1"&gt;# The result is (B, N, D).&lt;/span&gt;
&lt;span class="c1"&gt;# If do_mask is True, each attention head is masked from attending to future&lt;/span&gt;
&lt;span class="c1"&gt;# tokens.&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;multihead_attention_vec&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;NH&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Wp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;do_mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;D&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;D&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;D&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;qkv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, 3 * D)&lt;/span&gt;
    &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;qkv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, D) each&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;do_mask&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="c1"&gt;# mask is a lower-triangular (N, N) matrix, with zeros above&lt;/span&gt;
        &lt;span class="c1"&gt;# the diagonal and ones on the diagonal and below.&lt;/span&gt;
        &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tril&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;

    &lt;span class="n"&gt;HS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;D&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;NH&lt;/span&gt;
    &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;NH&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;HS&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, NH, N, HS)&lt;/span&gt;
    &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;NH&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;HS&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, NH, N, HS)&lt;/span&gt;
    &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;NH&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;HS&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, NH, N, HS)&lt;/span&gt;

    &lt;span class="n"&gt;kq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;swapaxes&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# (B, NH, N, N)&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;do_mask&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="c1"&gt;# Set the masked positions to -inf, to ensure that a token isn&amp;#39;t&lt;/span&gt;
        &lt;span class="c1"&gt;# affected by tokens that come after it in the softmax.&lt;/span&gt;
        &lt;span class="n"&gt;kq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;inf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_lastdim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (B, NH, N, N)&lt;/span&gt;
    &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;att&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;  &lt;span class="c1"&gt;# (B, NH, N, HS)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;D&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Wp&lt;/span&gt;  &lt;span class="c1"&gt;# (B, N, D)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;This code computes Q, K and V in a single matmul, and then splits them
into separate arrays (note that on accelerators these splits and later
transposes may be very cheap or even free as they represent a different access
pattern into the same data).&lt;/p&gt;
&lt;p&gt;Each of Q, K and V is initially (B, N, D), so they are reshaped into a more
convenient shape by first splitting the D into (NH, HS), and finally
changing the order of dimensions to get (B, NH, N, HS). In this format, both
B and NH are considered batch dimensions that are fully parallelizable.
The &lt;object class="valign-m4" data="https://eli.thegreenplace.net/images/math/fa80a943cc46f6c1154320719b40219b80c9e5e4.svg" style="height: 19px;" type="image/svg+xml"&gt;QK^T&lt;/object&gt; computation can then proceed as before, and Numpy will
automatically perform the matmul over all the batch dimensions.&lt;/p&gt;
&lt;p&gt;Sometimes you'll see an alternative notation used in papers for these matrix
multiplications: &lt;tt class="docutils literal"&gt;numpy.einsum&lt;/tt&gt;. For
example, in our last code sample the computation of &lt;tt class="docutils literal"&gt;kq&lt;/tt&gt; could also be
written as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="n"&gt;kq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;bhqd,bhkd-&amp;gt;bhqk&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;See &lt;a class="reference external" href="https://eli.thegreenplace.net/2025/understanding-numpys-einsum/"&gt;this post for my detailed notes on this notation&lt;/a&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div class="section" id="code"&gt;
&lt;h2&gt;Code&lt;/h2&gt;
&lt;p&gt;The full code for these samples, with tests, is available
&lt;a class="reference external" href="https://github.com/eliben/deep-learning-samples/tree/main/transformer-attention"&gt;in this repository&lt;/a&gt;.&lt;/p&gt;
&lt;hr class="docutils" /&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-1" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-1"&gt;[1]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;In LLM papers, D is often called &lt;object class="valign-m3" data="https://eli.thegreenplace.net/images/math/8594e2a5169d08eec15a946ef8fadc74c00423cd.svg" style="height: 15px;" type="image/svg+xml"&gt;d_{model}&lt;/object&gt;.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-2" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-2"&gt;[2]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;In the GPT-3 paper, this is also true for all model variants. For example,
the largest 175B model has NH=96, HS=128 and D=12288.&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;table class="docutils footnote" frame="void" id="footnote-3" rules="none"&gt;
&lt;colgroup&gt;&lt;col class="label" /&gt;&lt;col /&gt;&lt;/colgroup&gt;
&lt;tbody valign="top"&gt;
&lt;tr&gt;&lt;td class="label"&gt;&lt;a class="fn-backref" href="#footnote-reference-3"&gt;[3]&lt;/a&gt;&lt;/td&gt;&lt;td&gt;It's also not as easy to define mathematically: how do we make a
non-square matrix triangular? And what does it mean when the lengths
of the two inputs are different?&lt;/td&gt;&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
</content><category term="misc"></category><category term="Math"></category><category term="Machine Learning"></category><category term="Python"></category></entry></feed>