Eli Bendersky's website - Mathhttps://eli.thegreenplace.net/2019-10-21T05:37:00-07:00Diffie-Hellman Key Exchange2019-10-21T05:37:00-07:002019-10-21T05:37:00-07:00Eli Benderskytag:eli.thegreenplace.net,2019-10-21:/2019/diffie-hellman-key-exchange/<p>This post presents the Diffie-Hellman Key Exchange (DHKE) - an important part of today's practical cryptography. Whenever you're accessing an HTTPS website, it's very likely that your browser and the server negotiated a shared secret key using the DHKE under the hood.</p> <div class="section" id="mathematical-prerequisites"> <h2>Mathematical prerequisites</h2> <p>The understand the math behind DHKE, you …</p></div><p>This post presents the Diffie-Hellman Key Exchange (DHKE) - an important part of today's practical cryptography. Whenever you're accessing an HTTPS website, it's very likely that your browser and the server negotiated a shared secret key using the DHKE under the hood.</p> <div class="section" id="mathematical-prerequisites"> <h2>Mathematical prerequisites</h2> <p>The understand the math behind DHKE, you should be familiar with basic <em>group theory</em>. A group is a set with a binary operation, such that any two items in the set combined with the operation produce another item in the set, the operation is associative, the set has an identity element w.r.t the operation and each set element has an inverse.</p> <p>The group we're most interested in for the sake of understanding Diffie Hellman is <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/a81e9fa41d32a290c7b1cbd52a99fc08c82f2f7d.svg" style="height: 20px;" type="image/svg+xml">\mathbb{Z}_{p}^{*}</object> - the positive integers that are relatively prime to <em>p</em>, with the &quot;multiplication modulo <em>p</em>&quot; operation (another common notation for this group is <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/b920f08f0a72a94bef16c469929c229b6c28e0dc.svg" style="height: 19px;" type="image/svg+xml">(\mathbb{Z}/p\mathbb{Z})^*</object>). This is a finite group. By definition, its cardinality is <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3cc2f030099018851b9b711164207baa2252eda4.svg" style="height: 18px;" type="image/svg+xml">\phi(p)</object>, where <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/411e715f9ab9075b0a30b4117d209921f0bc2389.svg" style="height: 16px;" type="image/svg+xml">\phi</object> is <a class="reference external" href="https://en.wikipedia.org/wiki/Euler%27s_totient_function">Euler's totient function</a>.</p> <p>As an example, <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/3a731309e0a13f0d6fbef6f970c497bdd912fce0.svg" style="height: 18px;" type="image/svg+xml">\mathbb{Z}_{9}^{*}=\{1,2,4,5,7,8\}</object>. The cardinality of this group is <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/881fe91696b6af703fd166b2d6d78340b7a8bd1b.svg" style="height: 18px;" type="image/svg+xml">\phi(9)=6</object>. We can multiply members of the group modulo 9 to get other elements of the group: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/813ae271ff0b69949e40bc3ce03e564c1be0ba99.svg" style="height: 18px;" type="image/svg+xml">2*5\equiv 1\pmod 9</object>, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/876cfeae5153e11446596a97418dc9a18a4d942c.svg" style="height: 18px;" type="image/svg+xml">8*4\equiv 5\pmod 9</object> etc.</p> <p>For a prime <em>p</em>, the group contains all the integers from 1 to <em>p-1</em> and its cardinality is <em>p-1</em>.</p> <div class="section" id="cyclic-groups"> <h3>Cyclic groups</h3> <p>Given a group <em>G</em> with the operator <object class="valign-m2" data="https://eli.thegreenplace.net/images/math/c8e2d1a0bf50a27d43ade30cfb048d99feb31ad1.svg" style="height: 13px;" type="image/svg+xml">\odot</object>, we define the <strong>order</strong> of an element <em>g</em> in the group - <em>ord(g)</em> - as the smallest positive integer <em>k</em> such that:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/341f8d88e088dd35dda6183228c753ad3fc18e76.svg" style="height: 44px;" type="image/svg+xml"> $g^k=\underbrace{g\odot g\odot\cdots\odot g}_{k \ times}=1$</object> <p>Where 1 is the identity element of <em>G</em>. Note that we use the exponent notation <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/babd5875bef8d7f6c6ca6b17c626406191ddffdc.svg" style="height: 19px;" type="image/svg+xml">g^k</object> for convenience, even though <object class="valign-m2" data="https://eli.thegreenplace.net/images/math/c8e2d1a0bf50a27d43ade30cfb048d99feb31ad1.svg" style="height: 13px;" type="image/svg+xml">\odot</object> is not necessarily a multiplication - this would work for any group. For example, in the group <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/010274226559b48ad047330d3647fbb26e0775ff.svg" style="height: 18px;" type="image/svg+xml">\mathbb{Z}_{9}^{*}</object> shown above, <em>ord(8)</em> is 2, and <em>ord(2)</em> is 6.</p> <p>A group <em>G</em> which contains an element <em>a</em> with the maximal order <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/12fcff68777f4c028a8f53980840e3cc807823f8.svg" style="height: 18px;" type="image/svg+xml">ord(a)=\left|G\right|</object> is called a <strong>cyclic group</strong>. Elements in a cyclic group that have maximal orders are called <em>generators</em> or <em>primitive elements</em>.</p> <p>These elements can generate all the other elements of the group by repeated application of the group operation. In other words, given a generator <em>g</em>, every <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/4cbe37e25ff6e34b50a2ef01190bc26af1cc355e.svg" style="height: 13px;" type="image/svg+xml">a\in G</object> can be expressed as <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/babd5875bef8d7f6c6ca6b17c626406191ddffdc.svg" style="height: 19px;" type="image/svg+xml">g^k</object> for some <em>k</em>.</p> <p>For example, <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/010274226559b48ad047330d3647fbb26e0775ff.svg" style="height: 18px;" type="image/svg+xml">\mathbb{Z}_{9}^{*}</object> is cyclic and its primitive elements are 2, 5 and 8.</p> <p>It can be shown that for a prime <em>p</em>, the group <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/a81e9fa41d32a290c7b1cbd52a99fc08c82f2f7d.svg" style="height: 20px;" type="image/svg+xml">\mathbb{Z}_{p}^{*}</object> is always cyclic and has <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f126122f73357701354be89d15c15b25b1b7138b.svg" style="height: 18px;" type="image/svg+xml">\phi(p-1)</object> primitive elements, though there's no easy way to find them - we just have to test them one by one. The proof of this theorem is quite technical, so I'll leave it for another time.</p> </div> </div> <div class="section" id="the-discrete-logarithm-problem"> <h2>The Discrete Logarithm Problem</h2> <p>The mathematical problem at the heart of the DHKE is the Discrete Logarithm Problem (DLP). In this discussion I'm going to focus on the DLP in the multiplicative group of integers modulo a prime - <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/995f98c769b056e41fda04bafc1efc23710e5494.svg" style="height: 20px;" type="image/svg+xml">\mathbb{Z}^{*}_{p}</object>, and will mention the general DLP later on.</p> <p>Given a finite cyclic group <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/995f98c769b056e41fda04bafc1efc23710e5494.svg" style="height: 20px;" type="image/svg+xml">\mathbb{Z}^{*}_{p}</object> with a prime <em>p</em>, a primitive element <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/0b593ff95f677d3978c68181cc89fa85ea8a335f.svg" style="height: 20px;" type="image/svg+xml">g \in \mathbb{Z}^{*}_{p}</object> and another element <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/f84aeb9463cdde34dc0820f13d960c491a9580b2.svg" style="height: 20px;" type="image/svg+xml">b \in \mathbb{Z}^{*}_{p}</object>, the DLP problem is finding an integer <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8e8de33869b2c4646c4854a54f69ad6252ff2ce5.svg" style="height: 16px;" type="image/svg+xml">1\le x\le p-1</object> such that:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4f0b4b76b5a484549776dcba485bc0eb22fd7df6.svg" style="height: 18px;" type="image/svg+xml"> $g^x\equiv b\pmod{p}$</object> <p>We've seen earlier that such an integer must exist because <em>g</em> is a primitive element of the group.</p> <p>The DLP is hard - no one knows how to solve it efficiently. This doesn't mean that such a solution doesn't exist - it wasn't proven to not exist. In this, DLP is similar to factoring, which is essential for the security of <a class="reference external" href="httpp://eli.thegreenplace.net/2019/rsa-theory-and-implementation/">RSA</a>.</p> </div> <div class="section" id="diffie-hellman-key-exchange-dhke"> <h2>Diffie-Hellman Key Exchange (DHKE)</h2> <p>The protocol starts with a <em>setup stage</em>, where the two parties agree on the parameters <em>p</em> and <em>g</em> to be used in the rest of the protocol. These parameters can be entirely public, and are specified in RFCs, such as <a class="reference external" href="https://tools.ietf.org/html/rfc7919">RFC 7919</a>.</p> <p>For the main key exchange protocol, let's assume that Alice and Bob want to compute a shared secret they could later use to send encrypted messages to one another. They know <em>p</em> and <em>g</em> already.</p> <p><strong>Stage 1</strong></p> <p>Alice does:</p> <ul class="simple"> <li>Choose a random <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/880a56359b587e5fddfd05454524bbf400890014.svg" style="height: 18px;" type="image/svg+xml">b_{alice}\in\{{2,\dots,p-2}\}</object></li> <li>compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/65855958b897b94d96926cccf6952ccb10fba4d5.svg" style="height: 19px;" type="image/svg+xml">B_{alice}\equiv g^{b_{alice}} \mod p</object></li> </ul> <p>Bob does:</p> <ul class="simple"> <li>Choose a random <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/36b27e8cd01cfb0b07d3c99ce54b97243b4efe64.svg" style="height: 18px;" type="image/svg+xml">b_{bob}\in\{{2,\dots,p-2}\}</object></li> <li>compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f4abf9a31bcb3552e85963ca5e1a337b84021e4d.svg" style="height: 19px;" type="image/svg+xml">B_{bob}\equiv g^{b_{bob}} \mod p</object></li> </ul> <p>These <em>B</em>s are Alice's and Bob's public keys, while <em>b</em>s are their private keys. Note that due to the DLP, it's hard to compute <em>b</em> from <em>B</em>.</p> <p><strong>Stage 2</strong></p> <p>Alice sends <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/c9c2ee1b6300cddd985e5f9035d40a2f631e436d.svg" style="height: 15px;" type="image/svg+xml">B_{alice}</object> to Bob, while Bob sends <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/246a2e8c5c2d22c9dcbe80604458c2d1b5bcce67.svg" style="height: 15px;" type="image/svg+xml">B_{bob}</object> to Alice.</p> <p><strong>Stage 3</strong> Now, Bob can compute <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/e3c06d2a162b0932932d387c54d467c973e20176.svg" style="height: 23px;" type="image/svg+xml">B_{alice}^{b_{bob}}\equiv (g^{b_{alice}})^{b_{bob}}\equiv g^{b_{alice}b_{bob}}\mod p</object>.</p> <p>Alice can compute <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/f87a735ee6b10b0ebbe5d09f09d388d3c0b854bf.svg" style="height: 23px;" type="image/svg+xml">B_{bob}^{b_{alice}}\equiv (g^{b_{bob}})^{b_{alice}}\equiv g^{b_{bob}b_{alice}}\mod p</object>.</p> <p>These are equal, and serve as a shared key between Alice and Bob. They can now use it to encrypt a strong symmetric cipher key (say, AES-256) and use that to communicate in complete privacy.</p> </div> <div class="section" id="authenticated-dhke"> <h2>Authenticated DHKE</h2> <p>The basic DHKE protocol, as descibed above, is easily vulnerable to a man-in-the-middle (MITM) attack. When Alice and Bob exchange their public keys in stage 2, nothing guarantees to Alice that the key she received comes from Bob. Eve could place herself between Alice and Bob and set up an exchange with each one of them separately, while making them beleive they are talking to each other. Then she could read all the traffic, while Alice and Bob suspect nothing.</p> <p>The solution to this problem is to use <em>authenticated DHKE</em> instead. The core protocol remains the same, but when Alice and Bob exchange messages, these are signed with a strong signature algorithm. For example, Alice and Bob can use their RSA private keys to sign these messages. Then the MITM attack is impossible because Eve can't send a message to Bob pretending she's Alice, without access to Alice's private RSA key.</p> </div> <div class="section" id="forward-secrecy"> <h2>Forward secrecy</h2> <p>In the <a class="reference external" href="http://eli.thegreenplace.net/2019/rsa-theory-and-implementation/">RSA post</a> we've seen how the RSA algorithm can be used to create a shared secret between two parties and thus for secret communication. RSA has a serious flaw when used like that, though. There's a lot of traffic using a single key, which may help breaking it. Once broken, this key can be used to read <em>all past</em> communications that used the same key.</p> <p>DHKE, on the other hand, has <em>forward secrecy</em>. A new DHKE shared secret is generated for every session. Breaking this key will expose the secrets of this session, but won't enable the attacker to read all past correspondence. Such keys are called <em>ephemeral</em>.</p> <p>You may ask - can't RSA be made ephemeral? Can't we use a &quot;master&quot; RSA key to authenticate the key exchange, and generate a fresh public/private key pair for each communication? Yes, that's absolutely possible, but DHKE is still preferred because it's more efficient. While generating an RSA key pair requires finding two large primes with certain characteristics, generating a new DHKE public key is simply choosing a random integer and computing a single modular exponent - this is much faster.</p> </div> <div class="section" id="choosing-safe-primes"> <h2>Choosing &quot;safe&quot; primes</h2> <p>We've seen before that the <em>p</em> and <em>g</em> parameters for DHKE are public. How are these chosen? Can we choose any <em>p</em> and <em>g</em> and have strong security?</p> <p>Turns out that the answer is no, due to some interesting math. Algorithms like <a class="reference external" href="https://en.wikipedia.org/wiki/Index_calculus_algorithm">Index Calculus</a> can be used to crack the DLP in sub-exponential time. It's so powerful that we'll need primes of 1024 bits to have 80-bit security (meaning the equivalent of brute-forcing a 80-bit symmetric key).</p> <p>When coupled with the <a class="reference external" href="https://en.wikipedia.org/wiki/Pohlig–Hellman_algorithm">Pohlig-Hellman</a> attack, we may get in trouble. This attack uses the <a class="reference external" href="http://eli.thegreenplace.net/2019/the-chinese-remainder-theorem/">CRT</a> to break the DLP in time proportional to the <em>factors</em> of <em>|G|</em> <a class="footnote-reference" href="#id3" id="id1"></a>. Note that when <em>p</em> is a prime, <em>p-1</em> is composite, so it will end up having some factors. Which factors? Hard to say, but we want to maximize them. The best way to maximize them is to pick primes of the form <em>2q+1</em>, where <em>q</em> is a prime. Then <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/4ba6046eb4a34a7a7821084f4fb4a78d0fefe875.svg" style="height: 18px;" type="image/svg+xml">|G|=p-1=2q</object>, and its factors are 2 and <em>q</em>. <em>g</em> is chosen such that it generates a sub-group of size <em>q</em>, which ensures we have a large prime <em>|G|</em>.</p> <p>Primes of the form <em>2q+1</em> are called <em>safe primes</em>.</p> <p>For example, <a class="reference external" href="https://tools.ietf.org/html/rfc7919">RFC 7919</a> recommends several parameters, presenting them thus:</p> <div class="highlight"><pre><span></span>The hexadecimal representation of p is: FFFFFFFF FFFFFFFF ADF85458 A2BB4A9A AFDC5620 273D3CF1 D8B9C583 CE2D3695 A9E13641 146433FB CC939DCE 249B3EF9 7D2FE363 630C75D8 F681B202 AEC4617A D3DF1ED5 D5FD6561 2433F51F 5F066ED0 85636555 3DED1AF3 B557135E 7F57C935 984F0C70 E0E68B77 E2A689DA F3EFE872 1DF158A1 36ADE735 30ACCA4F 483A797A BC0AB182 B324FB61 D108A94B B2C8E3FB B96ADAB7 60D7F468 1D4F42A3 DE394DF4 AE56EDE7 6372BB19 0B07A7C8 EE0A6D70 9E02FCE1 CDF7E2EC C03404CD 28342F61 9172FE9C E98583FF 8E4F1232 EEF28183 C3FE3B1B 4C6FAD73 3BB5FCBC 2EC22005 C58EF183 7D1683B2 C6F34A26 C1B2EFFA 886B4238 611FCFDC DE355B3B 6519035B BC34F4DE F99C0238 61B46FC9 D6E6C907 7AD91D26 91F7F7EE 598CB0FA C186D91C AEFE1309 85139270 B4130C93 BC437944 F4FD4452 E2D74DD3 64F2E21E 71F54BFF 5CAE82AB 9C9DF69E E86D2BC5 22363A0D ABC52197 9B0DEADA 1DBF9A42 D5C4484E 0ABCD06B FA53DDEF 3C1B20EE 3FD59D7C 25E41D2B 66C62E37 FFFFFFFF FFFFFFFF The generator is: g = 2 The group size is: q = (p-1)/2 </pre></div> <p>The parameters in this RFC are the only ones approved for the newest TLS standard - version 1.3, which also removes the support for custom groups.</p> <p>The safety of the primes used for DHKE is not a purely theoretical concern! Real attacks have been (and are probably still being) mounted against unsafe choices. See <a class="reference external" href="https://nvd.nist.gov/vuln/detail/CVE-2016-0701">CVE-2016-0701</a> for example, and the paper <a class="reference external" href="https://jhalderm.com/pub/papers/subgroup-ndss16.pdf">Measuring small subgroup attacks against Diffie-Hellman</a> for more technical details.</p> </div> <div class="section" id="a-word-on-elliptic-curves"> <h2>A word on elliptic curves</h2> <p>Elliptic curves are all the rage in cryptography in the <a class="reference external" href="https://tools.ietf.org/html/rfc4492">past couple of decades</a>, and for a good reason. They provide similar security to the &quot;classical&quot; multiplicative modular groups with much smaller keys. If you're using TLS 1.3, the key exchange protocol will most likely be ECDHE - Elliptic Curve Diffie-Hellman Exchange.</p> <p>Explaining elliptic curves is a huge topic of its own, so I'll just briefly mention them w.r.t. the material presented in this post.</p> <p>The beauty of abstract algebra is that you can develop mathematics that will apply in the same way to very different groups. We've seen the DLP defined for multiplicative modular groups, but it can also be defined for different groups.</p> <p>Elliptic curves are pairs of points which fullfill certain polynomial equations <a class="footnote-reference" href="#id4" id="id2"></a>, and when set up properly these points can form cyclic groups under certain operations. A DLP can be defined for these groups, and it's as hard to solve as the classical DLP. Much of the math remains the same - generators, subgroups, and so on. DHKE looks the same as well - Alice and Bob both pick a random group member, and compute an &quot;exponent&quot; (repeated application of the group operation), sending it on the wire. They combine their exponents to get a shared secret key, while Eve cannot reconstruct their private exponents from the transmitted information because of the infeasibility of the DLP.</p> <p>Elliptic curve groups are great because - compared to classical multiplicative modular groups - they are less susceptible to sub-exponential attacks. Therefore, to gain ~128 bits of security (i.e. make attacks equivalent to brute-forcing 128-bit values) we can use a key of size 256 bits (as opposed to 3072 bits for classical DH). This makes cryptographic protocols much faster.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id3" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>Specifically, it's proportional to the sizes of the <em>subgroup</em> which the generator generates. The size of subgroups are related to the factors of <em>|G|</em>, per <a class="reference external" href="https://en.wikipedia.org/wiki/Lagrange%27s_theorem_(group_theory)">Lagrange's Theorem</a>.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/47b3fd1d9acbd8a3feb58d52f6c73c7dba87ffec.svg" style="height: 19px;" type="image/svg+xml">y^2=x^3+ax+b</object>, which should look familiar from analytic geometry in middle school.</td></tr> </tbody> </table> </div> RSA - theory and implementation2019-09-03T05:24:00-07:002019-09-03T05:24:00-07:00Eli Benderskytag:eli.thegreenplace.net,2019-09-03:/2019/rsa-theory-and-implementation/<p>RSA has been a staple of public key cryptography for over 40 years, and is still being used today for some tasks in the newest TLS 1.3 standard. This post describes the theory behind RSA - the math that makes it work, as well as some practical considerations; it also …</p><p>RSA has been a staple of public key cryptography for over 40 years, and is still being used today for some tasks in the newest TLS 1.3 standard. This post describes the theory behind RSA - the math that makes it work, as well as some practical considerations; it also presents a complete implementation of RSA key generation, encryption and decryption in Go.</p> <div class="section" id="the-rsa-algorithm"> <h2>The RSA algorithm</h2> <p>The beauty of the RSA algorithm is its simplicity. You don't need much more than some familiarity with elementary number theory to understand it, and the prerequisites can be grokked in a few hours.</p> <p>In this presentation <em>M</em> is the message we want to encrypt, resulting in the ciphertext <em>C</em>. Both <em>M</em> and <em>C</em> are large integers. Refer to the Practical Considerations section for representing arbitrary data with such integers.</p> <p>The RSA algorithm consists of three main phases: key generation, encryption and decryption.</p> <div class="section" id="key-generation"> <h3>Key generation</h3> <p>The first phase in using RSA is generating the public/private keys. This is accomplished in several steps.</p> <p><strong>Step 1</strong>: find two random, very large prime numbers <em>p</em> and <em>q</em> and calculate <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2df650fff78b85cfb0330f2a2e65e4ac0e1e1ca1.svg" style="height: 12px;" type="image/svg+xml">n=pq</object>. How large should these primes be? The current recommendation is for <em>n</em> to be at least 2048 bits, or over 600 decimal digits. We'll assume that the message <em>M</em> - represented as a number - is smaller than <em>n</em> (see Practical Considerations for details on what to do if it's not).</p> <p><strong>Step 2</strong>: select a small odd integer <em>e</em> that is relatively prime to <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1c7f9bc7f04407dd7fee51ec2ec4df99f20355ee.svg" style="height: 18px;" type="image/svg+xml">\phi(n)</object>, which is <a class="reference external" href="https://en.wikipedia.org/wiki/Euler%27s_totient_function">Euler's totient function</a>. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1c7f9bc7f04407dd7fee51ec2ec4df99f20355ee.svg" style="height: 18px;" type="image/svg+xml">\phi(n)</object> is calculated directly from Euler's formula (its proof is on Wikipedia):</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/660f0ef1ba862cad10df79d9274e30ed265331c0.svg" style="height: 51px;" type="image/svg+xml"> $\phi(n) =n \prod_{p\mid n} \left(1-\frac{1}{p}\right)$</object> <p>For <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2df650fff78b85cfb0330f2a2e65e4ac0e1e1ca1.svg" style="height: 12px;" type="image/svg+xml">n=pq</object> where <em>p</em> and <em>q</em> are primes, we get</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/c4e921f87628b962ed3f77e50dfd51d92a924041.svg" style="height: 40px;" type="image/svg+xml"> $\phi(n)=n\frac{p-1}{p}\frac{q-1}{q}=(p-1)(q-1)$</object> <p>In practice, it's recommended to pick <em>e</em> as one of a set of known prime values, most notably <a class="reference external" href="https://tools.ietf.org/html/rfc2313">65537</a>. Picking this known number does not diminish the security of RSA, and has some advantages such as efficiency <a class="footnote-reference" href="#id7" id="id2"></a>.</p> <p><strong>Step 3</strong>: compute <em>d</em> as the multiplicative inverse of <em>e</em> modulo <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1c7f9bc7f04407dd7fee51ec2ec4df99f20355ee.svg" style="height: 18px;" type="image/svg+xml">\phi(n)</object>. Lemma 3 in <a class="reference external" href="http://eli.thegreenplace.net/2019/the-chinese-remainder-theorem/">this post</a> guarantees that <em>d</em> exists and is unique (and also explains what a modular multiplicative inverse is).</p> <p>At this point we have all we need for the public/private keys. The public key is the pair <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/e97a2ea99cfffbb197c3a2ea0c0e8d6962422e84.svg" style="height: 18px;" type="image/svg+xml">[e,n]</object> and the private key is the pair <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/30c8e363b6a1070055dd59a89f457dd42dbad6a5.svg" style="height: 18px;" type="image/svg+xml">[d,n]</object>. In practice, when doing decryption we have access to <em>n</em> already (from the public key), so <em>d</em> is really the only unknown.</p> </div> <div class="section" id="encryption-and-decryption"> <h3>Encryption and decryption</h3> <p>Encryption and decryption are both accomplished with the same <a class="reference external" href="http://eli.thegreenplace.net/2009/03/28/efficient-modular-exponentiation-algorithms">modular exponentiation</a> formula, substituting different values for <em>x</em> and <em>y</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7ba9e4575b2f901ac6ab1301c9260a0ebb8c4ddb.svg" style="height: 18px;" type="image/svg+xml"> $f(x)=x^y\pmod{n}$</object> <p>For encryption, the input is <em>M</em> and the exponent is <em>e</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/16100b92251780f65ec193d9e8f0fd7b3df7f55e.svg" style="height: 18px;" type="image/svg+xml"> $Enc(M)=M^e\pmod{n}$</object> <p>For decryption, the input is the ciphertext <em>C</em> and the exponent is <em>d</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/e501f753509274a8d9c1792563b70c7afb04b7cb.svg" style="height: 21px;" type="image/svg+xml"> $Dec(C)=C^d\pmod{n}$</object> </div> </div> <div class="section" id="why-does-it-work"> <h2>Why does it work?</h2> <p>Given <em>M</em>, we encrypt it by raising to the power of <em>e</em> modulo <em>n</em>. Apparently, this process is reversible by raising the result to the power of <em>d</em> modulo <em>n</em>, getting <em>M</em> back. Why does this work?</p> <p><strong>Proof</strong>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/b1865666d04fc18858aee6e6bb0e79b861822cc8.svg" style="height: 21px;" type="image/svg+xml"> $Dec(Enc(M))=M^{ed}\pmod{n}$</object> <p>Recall that <em>e</em> and <em>d</em> are multiplicative inverses modulo <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1c7f9bc7f04407dd7fee51ec2ec4df99f20355ee.svg" style="height: 18px;" type="image/svg+xml">\phi(n)</object>. That is, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ff52bee7e9ab7e6ba6c4eaec88d621a058253f8b.svg" style="height: 18px;" type="image/svg+xml">ed\equiv 1\pmod{\phi(n)}</object>. This means that for some integer <em>k</em> we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d9312685b11b6605d92b1cb3f528e78bfdae9ce0.svg" style="height: 18px;" type="image/svg+xml">ed=1+k\phi(n)</object> or <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/068d54df52635d19705da7af64959f05f5415dc0.svg" style="height: 18px;" type="image/svg+xml">ed=1+k(p-1)(q-1)</object>.</p> <p>Let's see what <object class="valign-0" data="https://eli.thegreenplace.net/images/math/0851d104a1204f3680dc479111e1c56b15d50924.svg" style="height: 15px;" type="image/svg+xml">M^{ed}</object> is modulo <em>p</em>. Substituting in the formula for <em>ed</em> we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2a143c253f8d87a633a3d784919995f4849e2820.svg" style="height: 23px;" type="image/svg+xml"> $M^{ed}\equiv M(M^{p-1})^{k(q-1)}\pmod{p}$</object> <p>Now we can use <a class="reference external" href="https://en.wikipedia.org/wiki/Fermat%27s_little_theorem">Fermat's little theorem</a>, which states that if <em>M</em> is not divisible by <em>p</em>, we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e59079c61e78d1fa10e39d2394416b925e961e50.svg" style="height: 19px;" type="image/svg+xml">M^{p-1}\equiv 1\pmod{p}</object>. This theorem is a special case of Euler's theorem, the proof of which <a class="reference external" href="http://eli.thegreenplace.net/2009/08/01/a-group-theoretic-proof-of-eulers-theorem">I wrote about here</a>.</p> <p>So we can substitute 1 for <object class="valign-0" data="https://eli.thegreenplace.net/images/math/ce56881e232caecbb33c9e0c42f73da4568bc43e.svg" style="height: 15px;" type="image/svg+xml">M^{p-1}</object> in the latest equation, and raising 1 to any power is still 1:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/992600d7cffd6118684c9fb7bd2884eddbd28c1b.svg" style="height: 21px;" type="image/svg+xml"> $M^{ed}\equiv M\pmod{p}$</object> <p>Note that Fermat's little theorem requires that <em>M</em> is not divisible by <em>p</em>. We can safely assume that, because if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/edc02af17cdf5fefdcee2d00c213bfc9deed163b.svg" style="height: 18px;" type="image/svg+xml">M\equiv 0\pmod{p}</object>, then trivially <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e3faa78a6c339536793bb1521021b33a8d0e7a01.svg" style="height: 19px;" type="image/svg+xml">M^{ed}\equiv 0\pmod{p}</object> and again <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/4bdc1b0fbf669d81ffa7e4f726380a7090fda112.svg" style="height: 19px;" type="image/svg+xml">M^{ed}\equiv M\pmod{p}</object>.</p> <p>We can similarly show that:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/c924efc575ead500eb25859e78fcd2aa4b166166.svg" style="height: 21px;" type="image/svg+xml"> $M^{ed}\equiv M\pmod{q}$</object> <p>So we have <object class="valign-0" data="https://eli.thegreenplace.net/images/math/01e9660f252f2e39af7563cd3464c24f770bc7db.svg" style="height: 15px;" type="image/svg+xml">M^{ed}\equiv M</object> for the prime factors of <em>n</em>. Using a <a class="reference external" href="http://eli.thegreenplace.net/2019/the-chinese-remainder-theorem/">corollary to the Chinese Remainder Theorem</a>, they are then equivalent modulo <em>n</em> itself:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/59989cc9b589764c6339babf180807ad78c02721.svg" style="height: 21px;" type="image/svg+xml"> $M^{ed}\equiv M\pmod{n}$</object> <p>Since we've defined <em>M</em> to be smaller than <em>n</em>, we've shown that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/933aa016970d0adaae6a5832eafe9f4f73750317.svg" style="height: 18px;" type="image/svg+xml">Dec(Enc(M))=M</object> ∎</p> </div> <div class="section" id="why-is-it-secure"> <h2>Why is it secure?</h2> <p>Without the private key in hand, attackers only have the result of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/00149eb8468e1ff6a5afe1ac4edc10a3426e6a18.svg" style="height: 18px;" type="image/svg+xml">M^e\pmod {n}</object>, as well as <em>n</em> and <em>e</em> (as they're part of the public key). Could they infer <em>M</em> from these numbers?</p> <p>There is no <em>known</em> general way of doing this without factoring <em>n</em> (see the <a class="reference external" href="http://people.csail.mit.edu/rivest/Rsapaper.pdf">original RSA paper</a>, section IX), and factoring is known to be a difficult problem. Specifically, here we assume that <em>M</em> and <em>e</em> are sufficiently large that <object class="valign-0" data="https://eli.thegreenplace.net/images/math/08c3c067bdffe6aa41c60dada94a96fa79a030b9.svg" style="height: 12px;" type="image/svg+xml">M^e&gt;n</object> (otherwise decrypting would be trivial).</p> <p>If factoring was easy, we could factor <em>n</em> into <em>p</em> and <em>q</em>, then compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1c7f9bc7f04407dd7fee51ec2ec4df99f20355ee.svg" style="height: 18px;" type="image/svg+xml">\phi(n)</object> and then finally find <em>d</em> from <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ff52bee7e9ab7e6ba6c4eaec88d621a058253f8b.svg" style="height: 18px;" type="image/svg+xml">ed\equiv 1\pmod{\phi(n)}</object> using the extended Euclidean algorithm.</p> </div> <div class="section" id="practical-considerations"> <h2>Practical considerations</h2> <p>The algorithm described so far is sometimes called <em>textbook RSA</em> (or <em>schoolbook RSA</em>). That's because it deals entirely in numbers, ignoring all kinds of practical matters. In fact, textbook RSA is susceptible to <a class="reference external" href="https://crypto.stackexchange.com/questions/20085/which-attacks-are-possible-against-raw-textbook-rsa">several clever attacks</a> and has to be enhanced with random padding schemes for practical use.</p> <p>A simple padding scheme called PKCS #1 v1.5 has been used for many years and is defined in <a class="reference external" href="https://tools.ietf.org/html/rfc2313">RFC 2313</a>. These days more advanced schemes like <a class="reference external" href="https://tools.ietf.org/html/rfc2437">OAEP</a> are recommended instead, but PKCS #1 v1.5 is very easy to explain and therefore I'll use it for didactic purposes.</p> <p>Suppose we have some binary data <em>D</em> to encrypt. The approach works for data of any size, but we will focus on just encrypting small pieces of data. In practice this is sufficient because RSA is commonly used to only encrypt a symmetric encryption key, which is much smaller than the RSA key size <a class="footnote-reference" href="#id8" id="id3"></a>. The scheme can work well enough for arbitrary sized messages though - we'll just split it to multiple blocks with some pre-determined block size.</p> <p>From <em>D</em> we create a block for encryption - the block has the same length as our RSA key:</p> <img alt="PKCS #1 v1.5 encryption padding scheme" class="align-center" src="https://eli.thegreenplace.net/images/2019/pkcs-15-rsa.png" /> <p>Here <em>PS</em> is the padding, which should occupy all the bytes not taken by the header and <em>D</em> in the block, and should be at least 8 bytes long (if it's shorter, the data may be broken into two separate blocks). It's a sequence of random non-zero bytes generated separately for each encryption. Once we have this full block of data, we convert it to a number treating the bytes as a big-endian encoding <a class="footnote-reference" href="#id9" id="id4"></a>. We end up with a large number <em>x</em>, which we then perform the RSA encryption step on with <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/76a0d009913eb990fab8299e3574b743b0bed303.svg" style="height: 18px;" type="image/svg+xml">Enc(x)=x^e\pmod{n}</object>. The result is then encoded in binary and sent over the wire.</p> <p>Decryption is done in reverse. We turn the received byte stream into a number, perform <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/67a657cde6a7d17787a986e10bf64e55a83c65ae.svg" style="height: 19px;" type="image/svg+xml">Dec(C)=C^d\pmod{n}</object>, then strip off the padding (note that the padding has no 0 bytes and is terminated with a 0, so this is easy) and get our original message back.</p> <p>The random padding here makes attacks on textbook RSA impractical, but the scheme as a whole may still be vulnerable to <a class="reference external" href="https://crypto.stackexchange.com/questions/12688/can-you-explain-bleichenbachers-cca-attack-on-pkcs1-v1-5">more sophisticated attacks</a> in some cases. Therefore, more modern schemes like OAEP should be used in practice.</p> </div> <div class="section" id="implementing-rsa-in-go"> <h2>Implementing RSA in Go</h2> <p>I've implemented a simple variant of RSA encryption and decryption as described in this post, in Go. Go makes it particularly easy to implement cryptographic algorithms because of its great support for arbitrary-precision integers with the stdlib <tt class="docutils literal">big</tt> package. Not only does this package support basics of manipulating numbers, it also supports several primitives specifically for cryptography - for example the <tt class="docutils literal">Exp</tt> method supports efficient modular exponentiation, and the <tt class="docutils literal">ModInverse</tt> method supports finding modular multiplicative modular inverses. In addition, the <tt class="docutils literal">crypto/rand</tt> contains randomness primitives specifically designed for cryptographic uses.</p> <p>Go has a production-grade crypto implementation in the standard library. RSA is in <tt class="docutils literal">crypto/rsa</tt>, so for anything real <em>please</em> use that <a class="footnote-reference" href="#id10" id="id5"></a>. The code shown and linked here is just for educational purposes.</p> <p>The full code, with some tests, is <a class="reference external" href="https://github.com/eliben/code-for-blog/tree/master/2019/rsa">available on GitHub</a>. We'll start by defining the types to hold public and private keys:</p> <div class="highlight"><pre><span></span><span class="kd">type</span> <span class="nx">PublicKey</span> <span class="kd">struct</span> <span class="p">{</span> <span class="nx">N</span> <span class="o">*</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span> <span class="nx">E</span> <span class="o">*</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span> <span class="p">}</span> <span class="kd">type</span> <span class="nx">PrivateKey</span> <span class="kd">struct</span> <span class="p">{</span> <span class="nx">N</span> <span class="o">*</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span> <span class="nx">D</span> <span class="o">*</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span> <span class="p">}</span> </pre></div> <p>The code also contains a <tt class="docutils literal">GenerateKeys</tt> function that will randomly generate these keys with an appropriate bit length. Given a public key, textbook encryption is simply:</p> <div class="highlight"><pre><span></span><span class="kd">func</span> <span class="nx">encrypt</span><span class="p">(</span><span class="nx">pub</span> <span class="o">*</span><span class="nx">PublicKey</span><span class="p">,</span> <span class="nx">m</span> <span class="o">*</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span><span class="p">)</span> <span class="o">*</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span> <span class="p">{</span> <span class="nx">c</span> <span class="o">:=</span> <span class="nb">new</span><span class="p">(</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span><span class="p">)</span> <span class="nx">c</span><span class="p">.</span><span class="nx">Exp</span><span class="p">(</span><span class="nx">m</span><span class="p">,</span> <span class="nx">pub</span><span class="p">.</span><span class="nx">E</span><span class="p">,</span> <span class="nx">pub</span><span class="p">.</span><span class="nx">N</span><span class="p">)</span> <span class="k">return</span> <span class="nx">c</span> <span class="p">}</span> </pre></div> <p>And decryption is:</p> <div class="highlight"><pre><span></span><span class="kd">func</span> <span class="nx">decrypt</span><span class="p">(</span><span class="nx">priv</span> <span class="o">*</span><span class="nx">PrivateKey</span><span class="p">,</span> <span class="nx">c</span> <span class="o">*</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span><span class="p">)</span> <span class="o">*</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span> <span class="p">{</span> <span class="nx">m</span> <span class="o">:=</span> <span class="nb">new</span><span class="p">(</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span><span class="p">)</span> <span class="nx">m</span><span class="p">.</span><span class="nx">Exp</span><span class="p">(</span><span class="nx">c</span><span class="p">,</span> <span class="nx">priv</span><span class="p">.</span><span class="nx">D</span><span class="p">,</span> <span class="nx">priv</span><span class="p">.</span><span class="nx">N</span><span class="p">)</span> <span class="k">return</span> <span class="nx">m</span> <span class="p">}</span> </pre></div> <p>You'll notice that the bodies of these two functions are pretty much the same, except for which exponent they use. Indeed, they are just typed wrappers around the <tt class="docutils literal">Exp</tt> method.</p> <p>Finally, here's the full PKCS #1 v1.5 encryption procedure, as described above:</p> <div class="highlight"><pre><span></span><span class="c1">// EncryptRSA encrypts the message m using public key pub and returns the</span> <span class="c1">// encrypted bytes. The length of m must be &lt;= size_in_bytes(pub.N) - 11,</span> <span class="c1">// otherwise an error is returned. The encryption block format is based on</span> <span class="c1">// PKCS #1 v1.5 (RFC 2313).</span> <span class="kd">func</span> <span class="nx">EncryptRSA</span><span class="p">(</span><span class="nx">pub</span> <span class="o">*</span><span class="nx">PublicKey</span><span class="p">,</span> <span class="nx">m</span> <span class="p">[]</span><span class="kt">byte</span><span class="p">)</span> <span class="p">([]</span><span class="kt">byte</span><span class="p">,</span> <span class="kt">error</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// Compute length of key in bytes, rounding up.</span> <span class="nx">keyLen</span> <span class="o">:=</span> <span class="p">(</span><span class="nx">pub</span><span class="p">.</span><span class="nx">N</span><span class="p">.</span><span class="nx">BitLen</span><span class="p">()</span> <span class="o">+</span> <span class="mi">7</span><span class="p">)</span> <span class="o">/</span> <span class="mi">8</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="nx">m</span><span class="p">)</span> <span class="p">&gt;</span> <span class="nx">keyLen</span><span class="o">-</span><span class="mi">11</span> <span class="p">{</span> <span class="k">return</span> <span class="kc">nil</span><span class="p">,</span> <span class="nx">fmt</span><span class="p">.</span><span class="nx">Errorf</span><span class="p">(</span><span class="s">&quot;len(m)=%v, too long&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="nx">m</span><span class="p">))</span> <span class="p">}</span> <span class="c1">// Following RFC 2313, using block type 02 as recommended for encryption:</span> <span class="c1">// EB = 00 || 02 || PS || 00 || D</span> <span class="nx">psLen</span> <span class="o">:=</span> <span class="nx">keyLen</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="nx">m</span><span class="p">)</span> <span class="o">-</span> <span class="mi">3</span> <span class="nx">eb</span> <span class="o">:=</span> <span class="nb">make</span><span class="p">([]</span><span class="kt">byte</span><span class="p">,</span> <span class="nx">keyLen</span><span class="p">)</span> <span class="nx">eb</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="p">=</span> <span class="mh">0x00</span> <span class="nx">eb</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="p">=</span> <span class="mh">0x02</span> <span class="c1">// Fill PS with random non-zero bytes.</span> <span class="k">for</span> <span class="nx">i</span> <span class="o">:=</span> <span class="mi">2</span><span class="p">;</span> <span class="nx">i</span> <span class="p">&lt;</span> <span class="mi">2</span><span class="o">+</span><span class="nx">psLen</span><span class="p">;</span> <span class="p">{</span> <span class="nx">_</span><span class="p">,</span> <span class="nx">err</span> <span class="o">:=</span> <span class="nx">rand</span><span class="p">.</span><span class="nx">Read</span><span class="p">(</span><span class="nx">eb</span><span class="p">[</span><span class="nx">i</span> <span class="p">:</span> <span class="nx">i</span><span class="o">+</span><span class="mi">1</span><span class="p">])</span> <span class="k">if</span> <span class="nx">err</span> <span class="o">!=</span> <span class="kc">nil</span> <span class="p">{</span> <span class="k">return</span> <span class="kc">nil</span><span class="p">,</span> <span class="nx">err</span> <span class="p">}</span> <span class="k">if</span> <span class="nx">eb</span><span class="p">[</span><span class="nx">i</span><span class="p">]</span> <span class="o">!=</span> <span class="mh">0x00</span> <span class="p">{</span> <span class="nx">i</span><span class="o">++</span> <span class="p">}</span> <span class="p">}</span> <span class="nx">eb</span><span class="p">[</span><span class="mi">2</span><span class="o">+</span><span class="nx">psLen</span><span class="p">]</span> <span class="p">=</span> <span class="mh">0x00</span> <span class="c1">// Copy the message m into the rest of the encryption block.</span> <span class="nb">copy</span><span class="p">(</span><span class="nx">eb</span><span class="p">[</span><span class="mi">3</span><span class="o">+</span><span class="nx">psLen</span><span class="p">:],</span> <span class="nx">m</span><span class="p">)</span> <span class="c1">// Now the encryption block is complete; we take it as a m-byte big.Int and</span> <span class="c1">// RSA-encrypt it with the public key.</span> <span class="nx">mnum</span> <span class="o">:=</span> <span class="nb">new</span><span class="p">(</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span><span class="p">).</span><span class="nx">SetBytes</span><span class="p">(</span><span class="nx">eb</span><span class="p">)</span> <span class="nx">c</span> <span class="o">:=</span> <span class="nx">encrypt</span><span class="p">(</span><span class="nx">pub</span><span class="p">,</span> <span class="nx">mnum</span><span class="p">)</span> <span class="c1">// The result is a big.Int, which we want to convert to a byte slice of</span> <span class="c1">// length keyLen. It&#39;s highly likely that the size of c in bytes is keyLen,</span> <span class="c1">// but in rare cases we may need to pad it on the left with zeros (this only</span> <span class="c1">// happens if the whole MSB of c is zeros, meaning that it&#39;s more than 256</span> <span class="c1">// times smaller than the modulus).</span> <span class="nx">padLen</span> <span class="o">:=</span> <span class="nx">keyLen</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="nx">c</span><span class="p">.</span><span class="nx">Bytes</span><span class="p">())</span> <span class="k">for</span> <span class="nx">i</span> <span class="o">:=</span> <span class="mi">0</span><span class="p">;</span> <span class="nx">i</span> <span class="p">&lt;</span> <span class="nx">padLen</span><span class="p">;</span> <span class="nx">i</span><span class="o">++</span> <span class="p">{</span> <span class="nx">eb</span><span class="p">[</span><span class="nx">i</span><span class="p">]</span> <span class="p">=</span> <span class="mh">0x00</span> <span class="p">}</span> <span class="nb">copy</span><span class="p">(</span><span class="nx">eb</span><span class="p">[</span><span class="nx">padLen</span><span class="p">:],</span> <span class="nx">c</span><span class="p">.</span><span class="nx">Bytes</span><span class="p">())</span> <span class="k">return</span> <span class="nx">eb</span><span class="p">,</span> <span class="kc">nil</span> <span class="p">}</span> </pre></div> <p>There's also <tt class="docutils literal">DecryptRSA</tt>, which unwraps this:</p> <div class="highlight"><pre><span></span><span class="c1">// DecryptRSA decrypts the message c using private key priv and returns the</span> <span class="c1">// decrypted bytes, based on block 02 from PKCS #1 v1.5 (RCS 2313).</span> <span class="c1">// It expects the length in bytes of the private key modulo to be len(eb).</span> <span class="c1">// Important: this is a simple implementation not designed to be resilient to</span> <span class="c1">// timing attacks.</span> <span class="kd">func</span> <span class="nx">DecryptRSA</span><span class="p">(</span><span class="nx">priv</span> <span class="o">*</span><span class="nx">PrivateKey</span><span class="p">,</span> <span class="nx">c</span> <span class="p">[]</span><span class="kt">byte</span><span class="p">)</span> <span class="p">([]</span><span class="kt">byte</span><span class="p">,</span> <span class="kt">error</span><span class="p">)</span> <span class="p">{</span> <span class="nx">keyLen</span> <span class="o">:=</span> <span class="p">(</span><span class="nx">priv</span><span class="p">.</span><span class="nx">N</span><span class="p">.</span><span class="nx">BitLen</span><span class="p">()</span> <span class="o">+</span> <span class="mi">7</span><span class="p">)</span> <span class="o">/</span> <span class="mi">8</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="nx">c</span><span class="p">)</span> <span class="o">!=</span> <span class="nx">keyLen</span> <span class="p">{</span> <span class="k">return</span> <span class="kc">nil</span><span class="p">,</span> <span class="nx">fmt</span><span class="p">.</span><span class="nx">Errorf</span><span class="p">(</span><span class="s">&quot;len(c)=%v, want keyLen=%v&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="nx">c</span><span class="p">),</span> <span class="nx">keyLen</span><span class="p">)</span> <span class="p">}</span> <span class="c1">// Convert c into a bit.Int and decrypt it using the private key.</span> <span class="nx">cnum</span> <span class="o">:=</span> <span class="nb">new</span><span class="p">(</span><span class="nx">big</span><span class="p">.</span><span class="nx">Int</span><span class="p">).</span><span class="nx">SetBytes</span><span class="p">(</span><span class="nx">c</span><span class="p">)</span> <span class="nx">mnum</span> <span class="o">:=</span> <span class="nx">decrypt</span><span class="p">(</span><span class="nx">priv</span><span class="p">,</span> <span class="nx">cnum</span><span class="p">)</span> <span class="c1">// Write the bytes of mnum into m, left-padding if needed.</span> <span class="nx">m</span> <span class="o">:=</span> <span class="nb">make</span><span class="p">([]</span><span class="kt">byte</span><span class="p">,</span> <span class="nx">keyLen</span><span class="p">)</span> <span class="nb">copy</span><span class="p">(</span><span class="nx">m</span><span class="p">[</span><span class="nx">keyLen</span><span class="o">-</span><span class="nb">len</span><span class="p">(</span><span class="nx">mnum</span><span class="p">.</span><span class="nx">Bytes</span><span class="p">()):],</span> <span class="nx">mnum</span><span class="p">.</span><span class="nx">Bytes</span><span class="p">())</span> <span class="c1">// Expect proper block 02 beginning.</span> <span class="k">if</span> <span class="nx">m</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="mh">0x00</span> <span class="p">{</span> <span class="k">return</span> <span class="kc">nil</span><span class="p">,</span> <span class="nx">fmt</span><span class="p">.</span><span class="nx">Errorf</span><span class="p">(</span><span class="s">&quot;m=%v, want 0x00&quot;</span><span class="p">,</span> <span class="nx">m</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="p">}</span> <span class="k">if</span> <span class="nx">m</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="mh">0x02</span> <span class="p">{</span> <span class="k">return</span> <span class="kc">nil</span><span class="p">,</span> <span class="nx">fmt</span><span class="p">.</span><span class="nx">Errorf</span><span class="p">(</span><span class="s">&quot;m=%v, want 0x02&quot;</span><span class="p">,</span> <span class="nx">m</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="p">}</span> <span class="c1">// Skip over random padding until a 0x00 byte is reached. +2 adjusts the index</span> <span class="c1">// back to the full slice.</span> <span class="nx">endPad</span> <span class="o">:=</span> <span class="nx">bytes</span><span class="p">.</span><span class="nx">IndexByte</span><span class="p">(</span><span class="nx">m</span><span class="p">[</span><span class="mi">2</span><span class="p">:],</span> <span class="mh">0x00</span><span class="p">)</span> <span class="o">+</span> <span class="mi">2</span> <span class="k">if</span> <span class="nx">endPad</span> <span class="p">&lt;</span> <span class="mi">2</span> <span class="p">{</span> <span class="k">return</span> <span class="kc">nil</span><span class="p">,</span> <span class="nx">fmt</span><span class="p">.</span><span class="nx">Errorf</span><span class="p">(</span><span class="s">&quot;end of padding not found&quot;</span><span class="p">)</span> <span class="p">}</span> <span class="k">return</span> <span class="nx">m</span><span class="p">[</span><span class="nx">endPad</span><span class="o">+</span><span class="mi">1</span><span class="p">:],</span> <span class="kc">nil</span> <span class="p">}</span> </pre></div> </div> <div class="section" id="digital-signatures-with-rsa"> <h2>Digital signatures with RSA</h2> <p>RSA can be also used to perform <em>digital signatures</em>. Here's how it works:</p> <ol class="arabic simple"> <li>Key generation and distribution remains the same. Alice has a public key and a private key. She publishes her public key online.</li> <li>When Alice wants to send Bob a message and have Bob be sure that only she could have sent it, she will <em>encrypt</em> the message with her <em>private</em> key, that is <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ad1c0c30bf900657c2a36a6361873f2e8801873f.svg" style="height: 19px;" type="image/svg+xml">S=Sign(M)=M^d\pmod{n}</object>. The signature is attached to the message.</li> <li>When Bob receives a message, he can <em>decrypt</em> the signature with Alice's public key: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/15a4fcf51b7c6984cc437be976d1e1d52e5f749c.svg" style="height: 18px;" type="image/svg+xml">Check(S)=S^e\pmod{n}</object> and if he gets the original message back, the signature was correct.</li> </ol> <p>The correctness proof would be exactly the same as for encryption. No one else could have signed the message, because proper signing would require having the private key of Alice, which only she possesses.</p> <p>This is the textbook signature algorithm. One difference between the practical implementation of signing and encryption is in the padding protocol used. While OAEP is recommended for encryption, <a class="reference external" href="https://en.wikipedia.org/wiki/Probabilistic_signature_scheme">PSS</a> is recommended for signing <a class="footnote-reference" href="#id11" id="id6"></a>. I'm not going to implement signing for this post, but the Go standard library has great code for this - for example <tt class="docutils literal">rsa.SignPKCS1v15</tt> and <tt class="docutils literal">rsa.SignPSS</tt>.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id7" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>For two reasons: one is that we don't have to randomly find another large number - this operation takes time; another is that 65537 has only two bits &quot;on&quot; in its binary representation, which makes <a class="reference external" href="http://eli.thegreenplace.net/2009/03/28/efficient-modular-exponentiation-algorithms">modular exponentiation algorithms faster</a>.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>A strong AES key is 256 bits, while RSA is commonly 2048 or more. The reason RSA encrypts a symmetric key is efficiency - RSA encryption is much slower than block ciphers, to the extent that it's often impractical to encrypt large streams of data with it. A hybrid scheme - wherein a strong AES key is first encrypted with RSA, and then AES is used to encrypt large data - is very common. This is the general idea behind what TLS and similar secure protocols use.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id9" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>Note that the first 8 bits of the data block are 0, which makes it easy to ensure that the number we encrypt is smaller than <em>n</em>.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id5"></a></td><td>The stdlib implementation is resilient to common kinds of side-channel attacks, such as using algorithms whose run time is independent of certain characteristics of the input, which makes timing attacks less feasible.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id11" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id6"></a></td><td>The reason for a different protocol is that the attacks on encrypted messages and on signatures tend to be different. For example, while for encrypted messages it's unthinkable to let attackers know any characteristics of the original message (the <em>base</em> in the exponentiation), in signing it's usually plainly available.</td></tr> </tbody> </table> </div> The Chinese Remainder Theorem2019-08-28T06:00:00-07:002019-08-28T06:00:00-07:00Eli Benderskytag:eli.thegreenplace.net,2019-08-28:/2019/the-chinese-remainder-theorem/<p>The Chinese Remainder Theorem (CRT) is very useful in cryptography and other domains. According <a class="reference external" href="https://en.wikipedia.org/wiki/Chinese_remainder_theorem">to Wikipedia</a>, its origin and name come from this riddle in a 3rd century book by a Chinese mathematician:</p> <blockquote> There are certain things whose number is unknown. If we count them by threes, we have two …</blockquote><p>The Chinese Remainder Theorem (CRT) is very useful in cryptography and other domains. According <a class="reference external" href="https://en.wikipedia.org/wiki/Chinese_remainder_theorem">to Wikipedia</a>, its origin and name come from this riddle in a 3rd century book by a Chinese mathematician:</p> <blockquote> There are certain things whose number is unknown. If we count them by threes, we have two left over; by fives, we have three left over; and by sevens, two are left over. How many things are there?</blockquote> <p>Mathematically, this is a system of linear congruences. In this post we'll go through a simple proof of the <em>existence</em> of a solution. It also demonstrates how to find such a solution, though check the Wikipedia link for a discussion of different methods and their relative efficiency.</p> <p>We'll start with a few prerequisite lemmas needed to prove the CRT. You may want to skip them on first reading and refer back when going through the CRT proof.</p> <div class="section" id="prerequisites"> <h2>Prerequisites</h2> <p><strong>Lemma 1</strong>: if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1a4cf51ff825f4ff0afd531c6a8c9860d6d51896.svg" style="height: 18px;" type="image/svg+xml">d|ab</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/05f83d8097b6d8ae83319bf25a53212cc97d48c2.svg" style="height: 18px;" type="image/svg+xml">(d,a)=1</object>, then <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7e30fae9eb28807eec8a567db43b8396a79d881a.svg" style="height: 18px;" type="image/svg+xml">d|b</object>.</p> <p><strong>Proof</strong>: Since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/05f83d8097b6d8ae83319bf25a53212cc97d48c2.svg" style="height: 18px;" type="image/svg+xml">(d,a)=1</object> we know from <a class="reference external" href="http://eli.thegreenplace.net/2009/07/10/the-gcd-and-linear-combinations">Bézout's identity</a> that there exist integers <em>x</em> and <em>y</em> such that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1f8e45ed33065085f62b878d9fdf9151d0f757e0.svg" style="height: 17px;" type="image/svg+xml">dx+ay=1</object>. Multiplying both sides by <em>b</em>, we get: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8141e2352af9dea6ef77b76a12fde7161a8e070e.svg" style="height: 17px;" type="image/svg+xml">bdx+bay=b</object>. <em>bdx</em> is divisible by <em>d</em>, and so is <em>bay</em> because <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1a4cf51ff825f4ff0afd531c6a8c9860d6d51896.svg" style="height: 18px;" type="image/svg+xml">d|ab</object>. Therefore <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7e30fae9eb28807eec8a567db43b8396a79d881a.svg" style="height: 18px;" type="image/svg+xml">d|b</object> ∎</p> <p><strong>Lemma 2</strong>: if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/146c2b5df7fcf337ad748394da7a872aac087af3.svg" style="height: 18px;" type="image/svg+xml">ac\equiv bc \pmod{m}</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8f4b6819405398557257b9fb77c9ad496616c65b.svg" style="height: 18px;" type="image/svg+xml">(c,m)=1</object>, then <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2daa6d2924d8b164c04f2f9f0723cf966dfab7f8.svg" style="height: 18px;" type="image/svg+xml">a\equiv b \pmod{m}</object>.</p> <p><strong>Proof</strong>: Since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/146c2b5df7fcf337ad748394da7a872aac087af3.svg" style="height: 18px;" type="image/svg+xml">ac\equiv bc \pmod{m}</object>, we know that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/68c33cb136d7bb3306ffb221c197c2094a4cec62.svg" style="height: 18px;" type="image/svg+xml">m|(ac-bc)</object>, or <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/95d7155cdbe323b74a4aeeb21c788b04d1039388.svg" style="height: 18px;" type="image/svg+xml">m|c(a-b)</object>. Since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1fcb1e450165a208f3b6af15614e23408bb752de.svg" style="height: 18px;" type="image/svg+xml">(m,c)=1</object> we can use Lemma 1 to conclude that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b3a7cab922796acb6514c2ed260af9948d515ca1.svg" style="height: 18px;" type="image/svg+xml">m|(a-b)</object>. In other words, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2daa6d2924d8b164c04f2f9f0723cf966dfab7f8.svg" style="height: 18px;" type="image/svg+xml">a\equiv b \pmod{m}</object> ∎</p> <div class="section" id="modular-multiplicative-inverse"> <h3>Modular multiplicative inverse</h3> <p>A <em>modular multiplicative inverse</em> of an integer <em>a</em> w.r.t. the modulus <em>m</em> is the solution of the linear congruence:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/c7d32701fe73c767ddf59237fe114f1e078e5340.svg" style="height: 18px;" type="image/svg+xml"> $ax\equiv1 \pmod{m}$</object> <p><strong>Lemma 3</strong>: if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/99dabb93e5e7794513a1a19f669b0084a89ac5d1.svg" style="height: 18px;" type="image/svg+xml">(a,m)=1</object> then <em>a</em> has a unique modular multiplicative inverse modulo <em>m</em>.</p> <p><strong>Proof</strong>: Once again using Bézout's identity, we know from <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/99dabb93e5e7794513a1a19f669b0084a89ac5d1.svg" style="height: 18px;" type="image/svg+xml">(a,m)=1</object> that there exist integers <em>r</em> and <em>s</em> such that <object class="valign-m2" data="https://eli.thegreenplace.net/images/math/f311ad65574a8e08313b2fc884d8d8e196cb3f7e.svg" style="height: 14px;" type="image/svg+xml">ar+ms=1</object>. Therefore <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/93cb307eb473ec8df78477dd1a6221646ed3015d.svg" style="height: 13px;" type="image/svg+xml">ar-1</object> is a multiple of <em>m</em>, or <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/eb337aabcef266f520fd462aecb901dfb0ebd7fb.svg" style="height: 18px;" type="image/svg+xml">ar\equiv 1\pmod{m}</object>. So <em>r</em> is a multiplicative inverse of <em>a</em>.</p> <p>Now let's see why this inverse is unique. Let's assume there are two inverses, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/83b3fdda5b127e3a4f9bcb7b45d2fa7ef3659493.svg" style="height: 12px;" type="image/svg+xml">r_1</object> and <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/1f7e755308eb8efb09a75b7dbdc677c0b60074bd.svg" style="height: 11px;" type="image/svg+xml">r_2</object>, so <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/275749cbe3134f270d053cdddab253d7a64c940a.svg" style="height: 18px;" type="image/svg+xml">ar_1\equiv 1\pmod{m}</object> and also <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c7e4092c117d6ae06f011e7361eb69f02e41c8b8.svg" style="height: 18px;" type="image/svg+xml">ar_2\equiv 1\pmod{m}</object>, which means that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/35145e09c1438e9b639b99a751186ed4cc9f4bbb.svg" style="height: 18px;" type="image/svg+xml">ar_1\equiv ar_2\pmod{m}</object>.</p> <p>Since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/99dabb93e5e7794513a1a19f669b0084a89ac5d1.svg" style="height: 18px;" type="image/svg+xml">(a,m)=1</object> we can apply Lemma 2 to conclude that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/5d5f4c8f44d461ef1b62559ddd71fb5061c3e8d1.svg" style="height: 18px;" type="image/svg+xml">r_1\equiv r_2\pmod{m}</object> ∎</p> </div> <div class="section" id="factorization-and-multiplying-moduli"> <h3>Factorization and multiplying moduli</h3> <p><strong>Lemma 4</strong>: if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/5ad940bf1ab4dd63281bb98110d931c7f009f95f.svg" style="height: 18px;" type="image/svg+xml">a|n</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6ab771b7f5fd62895faddd69134e55b5c58ac511.svg" style="height: 18px;" type="image/svg+xml">b|n</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3c4a3e293dc3b869630ed4dc0a5c7e5acba8d35e.svg" style="height: 18px;" type="image/svg+xml">(a,b)=1</object> then also <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6d405fb50602a9bb0c6960c894ca148ab210f222.svg" style="height: 18px;" type="image/svg+xml">ab|n</object>.</p> <p><strong>Proof</strong>: Consider the prime factorization of <em>n</em>. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/5ad940bf1ab4dd63281bb98110d931c7f009f95f.svg" style="height: 18px;" type="image/svg+xml">a|n</object> so <em>a</em> is a multiple of some subset of the these prime factors. The same can be said about <em>b</em>. But <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3c4a3e293dc3b869630ed4dc0a5c7e5acba8d35e.svg" style="height: 18px;" type="image/svg+xml">(a,b)=1</object>, so <em>a</em> and <em>b</em> don't have any prime factors in common. Therefore <object class="valign-0" data="https://eli.thegreenplace.net/images/math/da23614e02469a0d7c7bd1bdab5c9c474b1904dc.svg" style="height: 13px;" type="image/svg+xml">ab</object> is also a subset of the prime factors of <em>n</em>, and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6d405fb50602a9bb0c6960c894ca148ab210f222.svg" style="height: 18px;" type="image/svg+xml">ab|n</object> ∎</p> </div> </div> <div class="section" id="id1"> <h2>The Chinese Remainder Theorem</h2> <p>Assume <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3095109bb55e0b34ecac71d33040fa004bfdfc7d.svg" style="height: 12px;" type="image/svg+xml">n_1,\dots,n_k</object> are positive integers, pairwise coprime; that is, for any <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/eebecf421c4d33eeab4a0c4da6c20ed8d49e6c6c.svg" style="height: 17px;" type="image/svg+xml">i\neq j</object>, <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/b2820249e9aae164709f509d84fa260d142ee148.svg" style="height: 20px;" type="image/svg+xml">(n_i,n_j)=1</object>. Let <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1399d8e50dc6eadcb3e40b623a13734e492a60f7.svg" style="height: 12px;" type="image/svg+xml">a_1,\dots,a_k</object> be arbitrary integers. The system of congruences with an unknown <em>x</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/93757b2beff989fc02e09cd045a9b4ff5959c8c6.svg" style="height: 82px;" type="image/svg+xml"> \begin{align*} x &amp;\equiv a_1 \pmod{n_1} \\ &amp;\vdots \\ x &amp;\equiv a_k \pmod{n_k} \end{align*}</object> <p>has a single solution modulo the product <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9ee0288589a80853a3cb1ede6482968e0d126e93.svg" style="height: 16px;" type="image/svg+xml">N=n_1\times n_2\times \cdots \times n_k</object>.</p> <p><strong>Proof</strong>: Let <object class="valign-m8" data="https://eli.thegreenplace.net/images/math/962577ced41b97a773cca5462d4d68b22375bd8d.svg" style="height: 24px;" type="image/svg+xml">N_k=\frac{N}{n_k}</object>. Then <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8878d4a84555010f6794370163c9f5c2a1865a93.svg" style="height: 18px;" type="image/svg+xml">(N_k,n_k)=1</object>, so each <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/8f94afd90555960e1ac40d2908475e16922594bc.svg" style="height: 15px;" type="image/svg+xml">N_k</object> has a unique multiplicative inverse modulo <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/b1d70855b10553d5c5a4d03b4018211bcf0114c8.svg" style="height: 11px;" type="image/svg+xml">n_k</object> per Lemma 3 above; let's call this inverse <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/bf042bd7a5b6ab321af6ac1dbba45dd3cba86d40.svg" style="height: 19px;" type="image/svg+xml">N&#x27;_k</object>. Now consider:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/97aca0d663e04fe8ebfadcd87053758dad9b08af.svg" style="height: 21px;" type="image/svg+xml"> $x=a_1 N_1 N&#x27;_1+a_2 N_2 N&#x27;_2+\cdots +a_k N_k N&#x27;_k$</object> <p><object class="valign-m3" data="https://eli.thegreenplace.net/images/math/8f94afd90555960e1ac40d2908475e16922594bc.svg" style="height: 15px;" type="image/svg+xml">N_k</object> is a multiple of every <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b05dd3722f57cd7ac250228f9a1aaf3af86311d.svg" style="height: 11px;" type="image/svg+xml">n_i</object> except for <object class="valign-0" data="https://eli.thegreenplace.net/images/math/f4b7e42a4b8c52f40eb9458e68e81c74d70c1c61.svg" style="height: 13px;" type="image/svg+xml">i=k</object>. In other words for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b778002ef6ea962f1ebd964f044ff0bb2f7b5503.svg" style="height: 17px;" type="image/svg+xml">i\neq k</object> we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b31ed2abd6c79316a9ed56460062d39322ab3f20.svg" style="height: 18px;" type="image/svg+xml">N_i\equiv 0\pmod{n_i}</object>. On the other hand, for <object class="valign-0" data="https://eli.thegreenplace.net/images/math/f4b7e42a4b8c52f40eb9458e68e81c74d70c1c61.svg" style="height: 13px;" type="image/svg+xml">i=k</object> we have, by construction, <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/71448798bc2f36eadf4c3d0a7c123d77fff9c828.svg" style="height: 19px;" type="image/svg+xml">N_i N&#x27;_i\equiv 1\pmod{n_i}</object>. So for each <em>k</em> we have:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/be7613a4bea5400e35bbe2ea728c447b42f0a8b5.svg" style="height: 20px;" type="image/svg+xml"> $x\equiv a_k N_k N&#x27;_k \equiv a_k \pmod{n_k}$</object> <p>because all the other terms in the sum are zero. Hence <em>x</em> satisfies every congruence in the system.</p> <p>To prove that <em>x</em> is unique modulo <em>N</em>, let's assume there are two solutions: <em>x</em> and <em>y</em>. Both solutions to the CRT should satisfy <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/47cfde2f44ca022ba766fcc25301922bbdadd91b.svg" style="height: 18px;" type="image/svg+xml">x\equiv y\equiv a_k\pmod{n_k}</object>. Therefore <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/aace079d3f57b531e8cd699df5595629fbd6cd72.svg" style="height: 12px;" type="image/svg+xml">x-y</object> is a multiple of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/b1d70855b10553d5c5a4d03b4018211bcf0114c8.svg" style="height: 11px;" type="image/svg+xml">n_k</object> for each <em>k</em>. Since these <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/b1d70855b10553d5c5a4d03b4018211bcf0114c8.svg" style="height: 11px;" type="image/svg+xml">n_k</object> are pairwise coprime, from Lemma 4 we know that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/aace079d3f57b531e8cd699df5595629fbd6cd72.svg" style="height: 12px;" type="image/svg+xml">x-y</object> is also a multiple of N, or <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/beac92d520cda137c1af24245b956a17792abdc0.svg" style="height: 18px;" type="image/svg+xml">x\equiv y\pmod{N}</object> ∎</p> <div class="section" id="corollary"> <h3>Corollary</h3> <p>If <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3095109bb55e0b34ecac71d33040fa004bfdfc7d.svg" style="height: 12px;" type="image/svg+xml">n_1,\dots,n_k</object> are pairwise coprime and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9ee0288589a80853a3cb1ede6482968e0d126e93.svg" style="height: 16px;" type="image/svg+xml">N=n_1\times n_2\times \cdots \times n_k</object>, then for all integers <em>x</em> and <em>a</em>, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/824b662cf71dd9ea955a952ddc6d7ed9131d12b3.svg" style="height: 18px;" type="image/svg+xml">x\equiv a\pmod{n_i}</object> for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b1fc066bab1cdf156a4603ce180645c09bc992f5.svg" style="height: 17px;" type="image/svg+xml">i=1,2,\dots,k</object> if and only if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/85363e9928d3b9de1d3d4a7a182a3b899ffe60fa.svg" style="height: 18px;" type="image/svg+xml">x\equiv a\pmod{N}</object>.</p> <p><strong>Proof</strong>: we'll start with the <em>if</em> direction. If <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/85363e9928d3b9de1d3d4a7a182a3b899ffe60fa.svg" style="height: 18px;" type="image/svg+xml">x\equiv a\pmod{N}</object> this means <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c68b089bbf91d70e490912848f13893de8b23a59.svg" style="height: 18px;" type="image/svg+xml">N|(x-a)</object>. But that immediately means that for each <em>i</em>, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3cdb745a255300e442dcef60a12cf40caa411571.svg" style="height: 18px;" type="image/svg+xml">n_i|(x-a)</object> as well, or <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/824b662cf71dd9ea955a952ddc6d7ed9131d12b3.svg" style="height: 18px;" type="image/svg+xml">x\equiv a\pmod{n_i}</object>.</p> <p>Now the <em>only if</em> direction. Given <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/824b662cf71dd9ea955a952ddc6d7ed9131d12b3.svg" style="height: 18px;" type="image/svg+xml">x\equiv a\pmod{n_i}</object> for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b1fc066bab1cdf156a4603ce180645c09bc992f5.svg" style="height: 17px;" type="image/svg+xml">i=1,2,\dots,k</object>, we can invoke the CRT using <em>a</em> in all congruences. The CRT tells us this system has a single solution modulo <object class="valign-0" data="https://eli.thegreenplace.net/images/math/b51a60734da64be0e618bacbea2865a8a7dcd669.svg" style="height: 12px;" type="image/svg+xml">N</object>. But we know that <em>a</em> is a solution, so it has to be the only one ∎</p> </div> </div> Unification2018-11-12T05:49:00-08:002018-11-12T05:49:00-08:00Eli Benderskytag:eli.thegreenplace.net,2018-11-12:/2018/unification/<p>In logic and computer science, unification is a process of automatically solving equations between symbolic terms. Unification has several interesting applications, notably in logic programming and <a class="reference external" href="https://eli.thegreenplace.net/2018/type-inference/">type inference</a>. In this post I want to present the basic unification algorithm with a complete implementation.</p> <p>Let's start with some terminology. We'll be …</p><p>In logic and computer science, unification is a process of automatically solving equations between symbolic terms. Unification has several interesting applications, notably in logic programming and <a class="reference external" href="https://eli.thegreenplace.net/2018/type-inference/">type inference</a>. In this post I want to present the basic unification algorithm with a complete implementation.</p> <p>Let's start with some terminology. We'll be using <em>terms</em> built from constants, variables and function applications:</p> <ul class="simple"> <li>A lowercase letter represents a constant (could be any kind of constant, like an integer or a string)</li> <li>An uppercase letter represents a variable</li> <li><tt class="docutils literal"><span class="pre">f(...)</span></tt> is an application of function <tt class="docutils literal">f</tt> to some parameters, which are <em>terms</em> themselves</li> </ul> <p>This representation is borrowed from <a class="reference external" href="https://en.wikipedia.org/wiki/First-order_logic">first-order logic</a> and is also used in the Prolog programming language. Some examples:</p> <ul class="simple"> <li><tt class="docutils literal">V</tt>: a single variable term</li> <li><tt class="docutils literal">foo(V, k)</tt>: function <tt class="docutils literal">foo</tt> applied to variable V and constant k</li> <li><tt class="docutils literal">foo(bar(k), baz(V))</tt>: a nested function application</li> </ul> <div class="section" id="pattern-matching"> <h2>Pattern matching</h2> <p>Unification can be seen as a generalization of <em>pattern matching</em>, so let's start with that first.</p> <p>We're given a constant term and a pattern term. The pattern term has variables. Pattern matching is the problem of finding a variable assignment that will make the two terms match. For example:</p> <ul class="simple"> <li>Constant term: <tt class="docutils literal">f(a, b, bar(t))</tt></li> <li>Pattern term: <tt class="docutils literal">f(a, V, X)</tt></li> </ul> <p>Trivially, the assignment <tt class="docutils literal">V=b</tt> and <tt class="docutils literal">X=bar(t)</tt> works here. Another name to call such an assignment is a <em>substitution</em>, which maps variables to their assigned values. In a less trivial case, variables can appear multiple times in a pattern:</p> <ul class="simple"> <li>Constant term: <tt class="docutils literal">f(top(a), a, <span class="pre">g(top(a)),</span> t)</tt></li> <li>Pattern term: <tt class="docutils literal">f(V, a, g(V), t)</tt></li> </ul> <p>Here the right substitution is <tt class="docutils literal">V=top(a)</tt>.</p> <p>Sometimes, no valid substitutions exist. If we change the constant term in the latest example to <tt class="docutils literal">f(top(b), a, <span class="pre">g(top(a)),</span> t)</tt>, then there is no valid substitution becase V would have to match <tt class="docutils literal">top(b)</tt> and <tt class="docutils literal">top(a)</tt> simultaneously, which is not possible.</p> </div> <div class="section" id="id1"> <h2>Unification</h2> <p>Unification is just like pattern matching, except that both terms can contain variables. So we can no longer say one is the pattern term and the other the constant term. For example:</p> <ul class="simple"> <li>First term: <tt class="docutils literal">f(a, V, bar(D))</tt></li> <li>Second term <tt class="docutils literal">f(D, k, bar(a))</tt></li> </ul> <p>Given two such terms, finding a variable substitution that will make them equivalent is called <em>unification</em>. In this case the substitution is <tt class="docutils literal">{D=a, V=k}</tt>.</p> <p>Note that there is an infinite number of possible unifiers for some solvable unification problem. For example, given:</p> <ul class="simple"> <li>First term: <tt class="docutils literal">f(X, Y)</tt></li> <li>Second term: <tt class="docutils literal">f(Z, g(X))</tt></li> </ul> <p>We have the substitution <tt class="docutils literal">{X=Z, Y=g(X)}</tt> but also something like <tt class="docutils literal">{X=K, Z=K, Y=g(K)}</tt> and <tt class="docutils literal">{X=j(K), Z=j(K), <span class="pre">Y=g(j(K))}</span></tt> and so on. The first substitution is the simplest one, and also the most general. It's called the <em>most general unifier</em> or <em>mgu</em>. Intuitively, the <em>mgu</em> can be turned into any other unifier by performing another substitution. For example <tt class="docutils literal">{X=Z, Y=g(X)}</tt> can be turned into <tt class="docutils literal">{X=j(K), Z=j(K), <span class="pre">Y=g(j(K))}</span></tt> by applying the substitution <tt class="docutils literal">{Z=j(K)}</tt> to it. Note that the reverse doesn't work, as we can't turn the second into the first by using a substitution. So we say that <tt class="docutils literal">{X=Z, Y=g(X)}</tt> is the most general unifier for the two given terms, and it's the <em>mgu</em> we want to find.</p> </div> <div class="section" id="an-algorithm-for-unification"> <h2>An algorithm for unification</h2> <p>Solving unification problems may seem simple, but there are a number of subtle corner cases to be aware of. In his 1991 paper <a class="reference external" href="https://www.semanticscholar.org/paper/Correcting-a-Widespread-Error-in-Unification-Norvig/95af3dc93c2e69b2c739a9098c3428a49e54e1b6">Correcting a Widespread Error in Unification Algorithms</a>, Peter Norvig noted a common error that exists in many books presenting the algorithm, including SICP.</p> <p>The correct algorithm is based on J.A. Robinson's 1965 paper &quot;A machine-oriented logic based on the resolution principle&quot;. More efficient algorithms have been developed over time since it was first published, but our focus here will be on correctness and simplicity rather than performance.</p> <p>The following implementation is based on Norvig's, and the full code (with tests) is <a class="reference external" href="https://github.com/eliben/code-for-blog/blob/master/2018/unif/unifier.py">available on Github</a>. This implementation uses Python 3, while Norvig's original is in Common Lisp. There's a slight difference in representations too, as Norvig uses the Lisp-y <tt class="docutils literal">(f X Y)</tt> syntax to denote an application of function <tt class="docutils literal">f</tt>. The two representations are isomorphic, and I'm picking the more classical one which is used in most papers on the subject. In any case, if you're interested in the more Lisp-y version, I have some Clojure <a class="reference external" href="https://github.com/eliben/paip-in-clojure/tree/master/src/paip/11_logic">code online</a> that ports Norvig's implementation more directly.</p> <p>We'll start by defining the data structure for terms:</p> <div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Term</span><span class="p">:</span> <span class="k">pass</span> <span class="k">class</span> <span class="nc">App</span><span class="p">(</span><span class="n">Term</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fname</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">()):</span> <span class="bp">self</span><span class="o">.</span><span class="n">fname</span> <span class="o">=</span> <span class="n">fname</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span> <span class="o">=</span> <span class="n">args</span> <span class="c1"># Not shown here: __str__ and __eq__, see full code for the details...</span> <span class="k">class</span> <span class="nc">Var</span><span class="p">(</span><span class="n">Term</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span> <span class="k">class</span> <span class="nc">Const</span><span class="p">(</span><span class="n">Term</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">value</span> </pre></div> <p>An <tt class="docutils literal">App</tt> represents the application of function <tt class="docutils literal">fname</tt> to a sequence of arguments.</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">unify</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">subst</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Unifies term x and y with initial subst.</span> <span class="sd"> Returns a subst (map of name-&gt;term) that unifies x and y, or None if</span> <span class="sd"> they can&#39;t be unified. Pass subst={} if no subst are initially</span> <span class="sd"> known. Note that {} means valid (but empty) subst.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="k">if</span> <span class="n">subst</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span> <span class="k">return</span> <span class="bp">None</span> <span class="k">elif</span> <span class="n">x</span> <span class="o">==</span> <span class="n">y</span><span class="p">:</span> <span class="k">return</span> <span class="n">subst</span> <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">Var</span><span class="p">):</span> <span class="k">return</span> <span class="n">unify_variable</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">subst</span><span class="p">)</span> <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">Var</span><span class="p">):</span> <span class="k">return</span> <span class="n">unify_variable</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">subst</span><span class="p">)</span> <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">App</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">App</span><span class="p">):</span> <span class="k">if</span> <span class="n">x</span><span class="o">.</span><span class="n">fname</span> <span class="o">!=</span> <span class="n">y</span><span class="o">.</span><span class="n">fname</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">args</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">args</span><span class="p">):</span> <span class="k">return</span> <span class="bp">None</span> <span class="k">else</span><span class="p">:</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">args</span><span class="p">)):</span> <span class="n">subst</span> <span class="o">=</span> <span class="n">unify</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">args</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">y</span><span class="o">.</span><span class="n">args</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">subst</span><span class="p">)</span> <span class="k">return</span> <span class="n">subst</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="bp">None</span> </pre></div> <p><tt class="docutils literal">unify</tt> is the main function driving the algorithm. It looks for a <em>substitution</em>, which is a Python dict mapping variable names to terms. When either side is a variable, it calls <tt class="docutils literal">unify_variable</tt> which is shown next. Otherwise, if both sides are function applications, it ensures they apply the same function (otherwise there's no match) and then unifies their arguments one by one, carefully carrying the updated substitution throughout the process.</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">unify_variable</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">subst</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Unifies variable v with term x, using subst.</span> <span class="sd"> Returns updated subst or None on failure.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Var</span><span class="p">)</span> <span class="k">if</span> <span class="n">v</span><span class="o">.</span><span class="n">name</span> <span class="ow">in</span> <span class="n">subst</span><span class="p">:</span> <span class="k">return</span> <span class="n">unify</span><span class="p">(</span><span class="n">subst</span><span class="p">[</span><span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">],</span> <span class="n">x</span><span class="p">,</span> <span class="n">subst</span><span class="p">)</span> <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">Var</span><span class="p">)</span> <span class="ow">and</span> <span class="n">x</span><span class="o">.</span><span class="n">name</span> <span class="ow">in</span> <span class="n">subst</span><span class="p">:</span> <span class="k">return</span> <span class="n">unify</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">subst</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">name</span><span class="p">],</span> <span class="n">subst</span><span class="p">)</span> <span class="k">elif</span> <span class="n">occurs_check</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">subst</span><span class="p">):</span> <span class="k">return</span> <span class="bp">None</span> <span class="k">else</span><span class="p">:</span> <span class="c1"># v is not yet in subst and can&#39;t simplify x. Extend subst.</span> <span class="k">return</span> <span class="p">{</span><span class="o">**</span><span class="n">subst</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">x</span><span class="p">}</span> </pre></div> <p>The key idea here is recursive unification. If <tt class="docutils literal">v</tt> is bound in the substitution, we try to unify its definition with <tt class="docutils literal">x</tt> to guarantee consistency throughout the unification process (and vice versa when <tt class="docutils literal">x</tt> is a variable). There's another function being used here - <tt class="docutils literal">occurs_check</tt>; I'm retaining its classical name from early presentations of unification. Its goal is to guarantee that we don't have self-referential variable bindings like <tt class="docutils literal">X=f(X)</tt> that would lead to potentially infinite unifiers.</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">occurs_check</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">term</span><span class="p">,</span> <span class="n">subst</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Does the variable v occur anywhere inside term?</span> <span class="sd"> Variables in term are looked up in subst and the check is applied</span> <span class="sd"> recursively.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Var</span><span class="p">)</span> <span class="k">if</span> <span class="n">v</span> <span class="o">==</span> <span class="n">term</span><span class="p">:</span> <span class="k">return</span> <span class="bp">True</span> <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">term</span><span class="p">,</span> <span class="n">Var</span><span class="p">)</span> <span class="ow">and</span> <span class="n">term</span><span class="o">.</span><span class="n">name</span> <span class="ow">in</span> <span class="n">subst</span><span class="p">:</span> <span class="k">return</span> <span class="n">occurs_check</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">subst</span><span class="p">[</span><span class="n">term</span><span class="o">.</span><span class="n">name</span><span class="p">],</span> <span class="n">subst</span><span class="p">)</span> <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">term</span><span class="p">,</span> <span class="n">App</span><span class="p">):</span> <span class="k">return</span> <span class="nb">any</span><span class="p">(</span><span class="n">occurs_check</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">arg</span><span class="p">,</span> <span class="n">subst</span><span class="p">)</span> <span class="k">for</span> <span class="n">arg</span> <span class="ow">in</span> <span class="n">term</span><span class="o">.</span><span class="n">args</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="bp">False</span> </pre></div> <p>Let's see how this code handles some of the unification examples discussed earlier in the post. Starting with the pattern matching example, where variables are just one one side:</p> <div class="highlight"><pre><span></span>&gt;&gt;&gt; unify(parse_term(&#39;f(a, b, bar(t))&#39;), parse_term(&#39;f(a, V, X)&#39;), {}) {&#39;V&#39;: b, &#39;X&#39;: bar(t)} </pre></div> <p>Now the examples from the <em>Unification</em> section:</p> <div class="highlight"><pre><span></span>&gt;&gt;&gt; unify(parse_term(&#39;f(a, V, bar(D))&#39;), parse_term(&#39;f(D, k, bar(a))&#39;), {}) {&#39;D&#39;: a, &#39;V&#39;: k} &gt;&gt;&gt; unify(parse_term(&#39;f(X, Y)&#39;), parse_term(&#39;f(Z, g(X))&#39;), {}) {&#39;X&#39;: Z, &#39;Y&#39;: g(X)} </pre></div> <p>Finally, let's try one where unification will fail due to two conflicting definitions of variable X.</p> <div class="highlight"><pre><span></span>&gt;&gt;&gt; unify(parse_term(&#39;f(X, Y, X)&#39;), parse_term(&#39;f(r, g(X), p)&#39;), {}) None </pre></div> <p>Lastly, it's instructive to trace through the execution of the algorithm for a non-trivial unification to see how it works. Let's unify the terms <tt class="docutils literal"><span class="pre">f(X,h(X),Y,g(Y))</span></tt> and <tt class="docutils literal"><span class="pre">f(g(Z),W,Z,X)</span></tt>:</p> <ul class="simple"> <li><tt class="docutils literal">unify</tt> is called, sees the root is an <tt class="docutils literal">App</tt> of function <tt class="docutils literal">f</tt> and loops over the arguments.<ul> <li><tt class="docutils literal">unify(X, g(Z))</tt> invokes <tt class="docutils literal">unify_variable</tt> because <tt class="docutils literal">X</tt> is a variable, and the result is augmenting subst with <tt class="docutils literal">X=g(Z)</tt></li> <li><tt class="docutils literal">unify(h(X), W)</tt> invokes <tt class="docutils literal">unify_variable</tt> because <tt class="docutils literal">W</tt> is a variable, so the subst grows to <tt class="docutils literal">{X=g(Z), W=h(X)}</tt></li> <li><tt class="docutils literal">unify(Y, Z)</tt> invokes <tt class="docutils literal">unify_variable</tt>; since neither <tt class="docutils literal">Y</tt> nor <tt class="docutils literal">Z</tt> are in subst yet, the subst grows to <tt class="docutils literal">{X=g(Z), W=h(X), Y=Z}</tt> (note that the binding between two variables is arbitrary; <tt class="docutils literal">Z=Y</tt> would be equivalent)</li> <li><tt class="docutils literal">unify(g(Y), X)</tt> invokes <tt class="docutils literal">unify_variable</tt>; here things get more interesting, because <tt class="docutils literal">X</tt> is already in the subst, so now we call <tt class="docutils literal">unify</tt> on <tt class="docutils literal">g(Y)</tt> and <tt class="docutils literal">g(Z)</tt> (what <tt class="docutils literal">X</tt> is bound to)<ul> <li>The functions match for both terms (<tt class="docutils literal">g</tt>), so there's another loop over arguments, this time only for unifying <tt class="docutils literal">Y</tt> and <tt class="docutils literal">Z</tt></li> <li><tt class="docutils literal">unify_variable</tt> for <tt class="docutils literal">Y</tt> and <tt class="docutils literal">Z</tt> leads to lookup of <tt class="docutils literal">Y</tt> in the subst and then <tt class="docutils literal">unify(Z, Z)</tt>, which returns the unmodified subst; the result is that nothing new is added to the subst, but the unification of <tt class="docutils literal">g(Y)</tt> and <tt class="docutils literal">g(Z)</tt> succeeds, because it agrees with the existing bindings in subst</li> </ul> </li> </ul> </li> <li>The final result is <tt class="docutils literal">{X=g(Z), W=h(X), Y=Z}</tt></li> </ul> </div> <div class="section" id="efficiency"> <h2>Efficiency</h2> <p>The algorithm presented here is not particularly efficient, and when dealing with large unification problems it's wise to consider more advanced options. It does too much copying around of subst, and also too much work is repeated because we don't try to cache terms that have already been unified.</p> <p>For a good overview of the efficiency of unification algorithms, I recommend checking out two papers:</p> <ul class="simple"> <li>&quot;An Efficient Unificaiton algorithm&quot; by Martelli and Montanari</li> <li>&quot;Unification: A Multidisciplinary survey&quot; by Kevin Knight</li> </ul> </div> Partial and Total Orders2018-10-01T06:01:00-07:002018-10-01T06:01:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-10-01:/2018/partial-and-total-orders/<p>Imagine a set of 2D rectangles of different sizes; let's assume for the sake of simplicity that no two rectangles in this set have <em>exactly</em> the same size. Here is a sample set:</p> <img alt="Five boxes of different sizes" class="align-center" src="https://eli.thegreenplace.net/images/2018/boxes-order.png" /> <p>We'll say that box X <strong>fits</strong> inside box Y if we could physically enclose X inside Y …</p><p>Imagine a set of 2D rectangles of different sizes; let's assume for the sake of simplicity that no two rectangles in this set have <em>exactly</em> the same size. Here is a sample set:</p> <img alt="Five boxes of different sizes" class="align-center" src="https://eli.thegreenplace.net/images/2018/boxes-order.png" /> <p>We'll say that box X <strong>fits</strong> inside box Y if we could physically enclose X inside Y; in other words, if Y's dimensions are larger than X's. In this example:</p> <ul class="simple"> <li>Box A can fit inside box B, but not the other way around</li> <li>E can fit inside all other boxes, but no other box can fit inside it</li> <li>A, B, D, E can fit inside C, which itself cannot fit in any of the other boxes</li> <li>D cannot fit inside A or B; neither can A or B fit inside D</li> </ul> <p>As we're going to see soon, in this case &quot;fits&quot; is a <em>partial order</em> on a set of 2D rectangular boxes, because even though we can order some of the boxes relative to each other, some other pairs of boxes have no relative order among themselves (for example A and D).</p> <p>If all pairs of boxes in this set had relative ordering - for example, consider the set without box D - we could define a <em>total order</em> on the set. Another example for this is a set of 2D <em>squares</em> (rather than rectangles); as long as all the squares in the set have unique sizes <a class="footnote-reference" href="#id3" id="id1"></a>, we can always define a total order on them because for any pair of squares either the first can fit in the second, or vice versa.</p> <div class="section" id="mathematical-definition-of-relations"> <h2>Mathematical definition of relations</h2> <p>To develop a mathematically sound approach to ordering, we'll have to dip our feet into set theory and <em>relations</em>. We'll only be talking about binary relations here.</p> <p>Given a set A, a <em>relation on A</em> is a set of pairs with elements taken from A. A bit more rigorously, given that <object class="valign-0" data="https://eli.thegreenplace.net/images/math/bc659bc638626217264a2aa7a0cca55c0cc40ddc.svg" style="height: 12px;" type="image/svg+xml">A\times A</object> is the set containing all possible ordered pairs taken from A (a.k.a. the <em>Cartesian</em> product of A), then R is a relation on A if it's a subset of <object class="valign-0" data="https://eli.thegreenplace.net/images/math/bc659bc638626217264a2aa7a0cca55c0cc40ddc.svg" style="height: 12px;" type="image/svg+xml">A\times A</object>, or <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/6e8b00b1044916a1bd2a6a15ca276c60b4687b15.svg" style="height: 15px;" type="image/svg+xml">R\subseteq A\times A</object>.</p> <p>For example, given the set <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/45bcabcc09a2d7523e72b58250c14b1d1038c22a.svg" style="height: 18px;" type="image/svg+xml">A=\{1,2,3\}</object>, then:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/212e13c924f24ce8e2115145e3169ac4ffbd0a4a.svg" style="height: 19px;" type="image/svg+xml"> $A\times A=\{\left(1,1\right),\left(1,2\right),\left(1,3\right),\left(2,1\right),\left(2,2\right),\left(2,3\right),\left(3,1\right),\left(3,2\right),\left(3,3\right)\}$</object> <p>Note that we explicitly defined the pairs to be <em>ordered</em>, meaning that (1,2) and (2,1) are two distinct elements in this set.</p> <p>By definition, any subset of <object class="valign-0" data="https://eli.thegreenplace.net/images/math/bc659bc638626217264a2aa7a0cca55c0cc40ddc.svg" style="height: 12px;" type="image/svg+xml">A\times A</object> is a relation on A. For example <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/fe939ac9c1402f39b62f0176b30bbda8cb269aa9.svg" style="height: 19px;" type="image/svg+xml">R=\{\left(1,1\right),\left(2,2\right),\left(3,3\right)\}</object>. In programming, we often use the term <em>predicate</em> to express a similar idea. A predicate is a function with a binary outcome, and the correspondence to relations is trivial - we just say that all pairs belonging to the relation satisfy the predicate, and vice versa. If we defined a predicate <tt class="docutils literal">R(x,y)</tt> to be true if and only if <tt class="docutils literal"><span class="pre">x==y</span></tt>, we'd get the relation above.</p> <p>A shortcut notation that will make definitions cleaner: we say <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2a3189528858697b5fc631eaf27ea7cc5a0a0c00.svg" style="height: 16px;" type="image/svg+xml">xRy</object> when <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3e97fb0a8391d7acb982ec45e84661d02fcb09dd.svg" style="height: 18px;" type="image/svg+xml">\left(x,y\right)\in R</object>. In our example set 1R1, 2R2 and 3R3. This notation is a bit awkward, but it's the accepted standard in math; therefore I'm using it for consistency with other sources.</p> <p>Besides, it becomes nicer when R is an operator. If we redefine R as <tt class="docutils literal">==</tt>, it becomes more natural: <tt class="docutils literal"><span class="pre">1==1</span></tt>, <tt class="docutils literal"><span class="pre">2==2</span></tt>, <tt class="docutils literal"><span class="pre">3==3</span></tt>. The equality relation is a perfectly valid relation on a set - its elements are all the pairs where both members are the same value.</p> </div> <div class="section" id="properties-of-relations"> <h2>Properties of relations</h2> <p>There are a number of useful properties relations could have. Here are just a few that we'll need for the rest of the article; for a longer list, see the <a class="reference external" href="https://en.wikipedia.org/wiki/Binary_relation">Wikipedia page</a>.</p> <p><strong>Reflexive</strong>: every element in the set is related to itself, or <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/962581a9498b2d84cd3587621102a61e32e70f77.svg" style="height: 17px;" type="image/svg+xml">\forall x\in A, xRx</object>. The <tt class="docutils literal">==</tt> relation shown above is reflexive.</p> <p><strong>Irreflexive</strong>: no element in the set is related to itself, or <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/796bac36bef12c06c76254d8eec1a27251a1cb47.svg" style="height: 17px;" type="image/svg+xml">\neg\exists x\in A, xRx</object>. For example if we define the <tt class="docutils literal">&lt;</tt> less than relation on numbers, it's irreflexive since no number is less than itself. In our boxes example, the &quot;fits in&quot; relation is irreflexive because no box can fit inside itself.</p> <p><strong>Transitive</strong>: intuitively, &quot;if x fits inside y, and y fits inside z, then x fits inside z&quot;. Mathematically <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/50b9110f993774dd55916f5caf85f15c01fa100e.svg" style="height: 18px;" type="image/svg+xml">\forall x,y,z \in A, \left(xRy \wedge yRz \right )\rightarrow xRz</object>. The <tt class="docutils literal">&lt;</tt> relation on numbers is obviously transitive.</p> <p><strong>Symmetric</strong>: if x is related to y, then y is related to x. This might sound obvious with the colloquial meaning of &quot;related&quot;, but not in the mathematical sense. Most relations we deal with aren't symmetric. The definition is <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/210765ef171f3c2c0765b6ac8c0f9321776afc0b.svg" style="height: 17px;" type="image/svg+xml">\forall x,y \in A, xRy \rightarrow yRx</object>. For example, the relation <tt class="docutils literal">==</tt> is symmetric, but <tt class="docutils literal">&lt;</tt> is not symmetric.</p> <p><strong>Antisymmetric</strong>: if x is related to y, then y is <em>not</em> related to x unless x and y are the same element; mathematically <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/128db3f8c70b2d33ff35ced497a688f9d9db8706.svg" style="height: 18px;" type="image/svg+xml">\forall x,y \in A, \left(xRy \wedge yRx \right ) \rightarrow x=y</object>. For example, the relation <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/60fd4c42f3956e697cf94397160a51086fbb6f5b.svg" style="height: 15px;" type="image/svg+xml">\le</object> (less than or equal) is antisymmetric; if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9671bd6c8271173673b6deb89be8ab5c4fb98511.svg" style="height: 16px;" type="image/svg+xml">x \le y</object> and also <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1f6f56272ca54db1a574f2402e651dae340f18a5.svg" style="height: 16px;" type="image/svg+xml">y \le x</object> then it must be that x and y are the same number. The relation <tt class="docutils literal">&lt;</tt> is also antisymmetric, though in the empty sense because we won't be able to find any pair x and y to satisfy the left side of the definition; in logic this is called <em>vacuously</em>.</p> </div> <div class="section" id="partial-order"> <h2>Partial order</h2> <p>There are two kinds of partial orders we can define - <em>weak</em> and <em>strong</em>. The <em>weak</em> partial order is the more common one, so let's start with that. Whenever I'm saying just &quot;partial order&quot;, I'll mean a weak partial order.</p> <p>A <em>weak partial order</em> (a.k.a. <em>non-strict</em>) is a relation on a set A that is reflexive, transitive and antisymmetric. The <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/60fd4c42f3956e697cf94397160a51086fbb6f5b.svg" style="height: 15px;" type="image/svg+xml">\le</object> relation on numbers is a classical example:</p> <ul class="simple"> <li>It is reflexive because for any number x we have <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/3f43bb5a3905f424fae578127d51f13208cd264a.svg" style="height: 15px;" type="image/svg+xml">x\le x</object></li> <li>It is transitive because given <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a4e6f960762caa31e09625ed52234681f6abad1e.svg" style="height: 16px;" type="image/svg+xml">x\le y</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/bdee2aefcb32e8811b5950ad3fb6410888d2e955.svg" style="height: 16px;" type="image/svg+xml">y\le z</object>, we know that <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/8a8f22b672c4ddf78e9a95b4bad6927d0037e980.svg" style="height: 15px;" type="image/svg+xml">x\le z</object></li> <li>It is antisymmetric because given <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a4e6f960762caa31e09625ed52234681f6abad1e.svg" style="height: 16px;" type="image/svg+xml">x\le y</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1f6f56272ca54db1a574f2402e651dae340f18a5.svg" style="height: 16px;" type="image/svg+xml">y \le x</object>, we know that x and y are the same number</li> </ul> <p>A <em>strong partial order</em> (a.k.a. <em>strict</em>) is a relation on a set A that is irreflexive, transitive and antisymmetric. The difference between weak and strong partial orders is reflexivity. In weak partial orders, every element is related to itself; in strong partial orders, no element is related to itself. The operator &lt; on numbers is an example of strict partial order, since it satisfies all the properties; while <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/60fd4c42f3956e697cf94397160a51086fbb6f5b.svg" style="height: 15px;" type="image/svg+xml">\le</object> is reflexive, &lt; is irreflexive.</p> <p>Our rectantular boxes with the &quot;fits&quot; relation is a good example to distinguish between the two. We can only define a <em>strong</em> partial order on them, because a box cannot fit inside itself.</p> <p>Another good example is a morning dressing routine. The set of clothes to wear is {underwear, pants, jacket, shirt, left sock, right sock, left shoe, right shoe}, and the relation is &quot;has to be worn before&quot;. The following drawing encodes the relation:</p> <img alt="Partial ordering of dressing different clothes; what comes before what" class="align-center" src="https://eli.thegreenplace.net/images/2018/dressing-partial-order.png" /> <p>This kind of drawing is called a <a class="reference external" href="https://en.wikipedia.org/wiki/Hasse_diagram">Hasse diagram</a>, which is useful to graphically represent partially ordered sets <a class="footnote-reference" href="#id4" id="id2"></a>; the arrow represents the relation. For example, the arrow from &quot;pants&quot; to &quot;left shoe&quot; encodes that pants have to be worn before the left shoe.</p> <p>Note that this relation is irreflextive, because it's meaningless to say that &quot;pants have to be worn before wearing pants&quot;. Therefore, the relation defines a <em>strong</em> partial order on the set.</p> <p>Similarly to the rectangular boxes example, the partial order here lets us order only some of the elements in the set w.r.t. each other. Some elements like socks and a shirt don't have an order defined.</p> </div> <div class="section" id="total-order"> <h2>Total order</h2> <p>A total order is a partial order that has one additional property - any two elements in the set should be related. Mathematically:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/33471be0a98a586cd93417bb3c47c3c8e210c01a.svg" style="height: 18px;" type="image/svg+xml"> $\forall x\in A\forall y\in A, \left(xRy \vee yRx \right )$</object> <p>While a partial order lets us order <em>some</em> elements in a set w.r.t. each other, total order requires us to be able to order <em>all</em> elements in a set. In the boxes example, we can't define a total order for rectangular boxes (there is not &quot;fits in&quot; relation between boxes A and D, no matter which way we try). We <em>can</em> define a total order between square boxes, however, as long as their sizes are unique.</p> <p>Neither can we define a total order for the dressing diagram shown above, because we can't say either &quot;left socks have to be worn before shirts&quot; or &quot;shirts have to be worn before left socks&quot;.</p> </div> <div class="section" id="examples-from-programming"> <h2>Examples from programming</h2> <p>Partial and total orders frequently come up in programming, especially when thinking about sorts. Sorting an array usually implies finding some <em>total order</em> on its elements. Tie breaking is important, but not always possible. If there is no way to tell two elements apart, we cannot mathematically come up with a total order, but we can still sort (and we do have a weak partial order). This is where the distinction between regular and <a class="reference external" href="https://en.cppreference.com/w/cpp/algorithm/stable_sort">stable sorts</a> comes in.</p> <p>Sometimes we're sorting non-linear structures, like dependency graphs in the dressing example from above. In these cases a total order is impossible, but we do have a partial order which can be useful to find a &quot;valid&quot; dressing order - a linear sequence of dressing steps that wouldn't violate any constraints. This can be done with <a class="reference external" href="http://eli.thegreenplace.net/2015/directed-graph-traversal-orderings-and-applications-to-data-flow-analysis/">topological sorting</a> which finds a valid &quot;linearization&quot; of the dependency graph.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id3" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>You may notice that saying &quot;unique&quot; when talking about sets can sound superfluous; after all, sets are defined to have distinct elements. That said, it's not clear what &quot;distinct&quot; means. In our case, distinct can refer to the complete identities of the boxes; for example, two boxes can have the exact same dimensions but different colors - so they are not the same as far as the set is concerned. Moreover, in programming identity is further moot and can be defined for specific types in specific ways. For these reasons I'm going to call out uniqueness explicitly to avoid confusion.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>A <em>partially ordered set with R</em> (or <em>poset with R</em>) is a set with a relation R that is a partial order on it.</td></tr> </tbody> </table> </div> Minimal character-based LSTM implementation2018-06-07T05:34:00-07:002018-06-07T05:34:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-06-07:/2018/minimal-character-based-lstm-implementation/<p>Following up on <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model/">the earlier post</a> deciphering a minimal vanilla RNN implementation, here I'd like to extend the example to a simple LSTM model.</p> <p>Once again, the idea is to combine a well-commented code sample (<a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/min-char-lstm.py">available here</a>) with some high-level diagrams and math to enable someone to fully understand the …</p><p>Following up on <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model/">the earlier post</a> deciphering a minimal vanilla RNN implementation, here I'd like to extend the example to a simple LSTM model.</p> <p>Once again, the idea is to combine a well-commented code sample (<a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/min-char-lstm.py">available here</a>) with some high-level diagrams and math to enable someone to fully understand the code. The LSTM architecture presented herein is the standard one originating from Hochreiter's and Schmidthuber's <a class="reference external" href="https://www.google.com/search?q=lstm+hochreiter">1997 paper</a>. It's described pretty much everywhere; <a class="reference external" href="http://colah.github.io/posts/2015-08-Understanding-LSTMs/">Chris Olah's post</a> has particularly nice diagrams and is worth reading.</p> <div class="section" id="lstm-cell-structure"> <h2>LSTM cell structure</h2> <p>From 30,000 feet, LSTMs look just like regular RNNs; there's a &quot;cell&quot; that has a recurrent connection (output tied to input), and when trained this cell is usually unrolled to some fixed length.</p> <p>So we can take the basic RNN structure from the <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model">previous post</a>:</p> <img alt="Basic RNN diagram" class="align-center" src="https://eli.thegreenplace.net/images/2018/rnnbasic.png" /> <p>LSTMs are a bit trickier because there are two recurrent connections; these can be &quot;packed&quot; into a single vector <em>h</em>, so the above diagram still applies. Here's how an LSTM cell looks inside:</p> <img alt="LSTM cell" class="align-center" src="https://eli.thegreenplace.net/images/2018/lstm-cell.png" /> <p><em>x</em> is the input; <em>p</em> is the probabilities computed from the output <em>y</em> (these symbols are named consistently with my earlier RNN post) and exit the cell at the bottom purely due to topological convenience. The two memory vectors are <em>h</em> and <em>c</em> - as mentioned earlier, they could be combined into a single vector, but are shown here separately for clarity.</p> <p>The main idea of LSTMs is to enable training of longer sequences by providing a &quot;fast-path&quot; to back-propagate information farther down in memory. Hence the <em>c</em> vector is not multiplied by any matrices on its path. The circle-in-circle block means element-wise multiplication of two vectors; plus-in-square is element-wise addition. The funny greek letter is the Sigmoid non-linearity:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/8b0db8368e8a617143fa6566f42c1e47cd833c9c.svg" style="height: 38px;" type="image/svg+xml"> $\sigma(x) =\frac{1}{1+e^{-x}}$</object> <p>The only other block we haven't seen in the vanilla RNN diagram is the colon-in-square in the bottom-left corner; this is simply the concatenation of <em>h</em> and <em>x</em> into a single column vector. In addition, I've combined the &quot;multiply by matrix <em>W</em>, then add bias <em>b</em>&quot; operation into a single rectantular box to save on precious diagram space.</p> <p>Here are the equations computed by a cell:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/c2cc966ba7ce8075317b87885bc9c432aafe2dba.svg" style="height: 249px;" type="image/svg+xml"> \begin{align*} xh&amp;=x^{[t]}:h^{[t-1]}\\ f&amp;=\sigma(W_f\cdot xh+b_f)\\ i&amp;=\sigma(W_i\cdot xh+b_i)\\ o&amp;=\sigma(W_o\cdot xh+b_o)\\ cc&amp;=tanh(W_{cc}\cdot xh+b_{cc})\\ c^{[t]}&amp;=c^{[t-1]}\odot f +cc\odot i\\ h^{[t]}&amp;=tanh(c^{[t]})\odot o\\ y^{[t]}&amp;=W_{y}\cdot h^{[t]}+b_y\\ p^{[t]}&amp;=softmax(y^{[t]})\\ \end{align*}</object> </div> <div class="section" id="backpropagating-through-an-lstm-cell"> <h2>Backpropagating through an LSTM cell</h2> <p>This works <em>exactly</em> like backprop through a vanilla RNN; we have to carefully compute how the gradient flows through every node and make sure we properly combine gradients at fork points. Most of the elements in the LSTM diagram are familiar from the <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model">previous post</a>. Let's briefly work through the new ones.</p> <p>First, the Sigmoid function; it's an elementwise function, and computing its derivative is very similar to the <em>tanh</em> function discussed in the previous post. As usual, given <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e9ef6bd037537d5fe08743736acadccc09e70b06.svg" style="height: 18px;" type="image/svg+xml">f=\sigma(k)</object>, from the chain rule we have the following derivative w.r.t. some weight <em>w</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/57e3f2cab3c9b46a03d763a2f73b83963a1cd500.svg" style="height: 39px;" type="image/svg+xml"> $\frac{\partial f}{\partial w}=\frac{\partial \sigma(k)}{\partial k}\frac{\partial k}{\partial w}$</object> <p>To compute the derivative <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/8aa59f2f536b727cf97239b345ddcc98e41c2c91.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial \sigma(k)}{\partiak k}</object>, we'll use the ratio-derivative formula:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/9e006cf5e9f1f8ccac82ba1f2bcdabd710731756.svg" style="height: 42px;" type="image/svg+xml"> $(\frac{f}{g})&#x27;=\frac{f&#x27;g-g&#x27;f}{g^2}$</object> <p>So:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/e3f7af782f52215e8326b389271709a440993984.svg" style="height: 44px;" type="image/svg+xml"> $\sigma &#x27;(k)=\frac{e^{-k}}{(1+e^{-k})^2}$</object> <p>A clever way to express this is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/eb1953be928287ff01ae23dfb4ff1cb2290854c9.svg" style="height: 20px;" type="image/svg+xml"> $\sigma &#x27;(k)=\sigma(k)(1-\sigma(k))$</object> <p>Going back to the chain rule with <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e9ef6bd037537d5fe08743736acadccc09e70b06.svg" style="height: 18px;" type="image/svg+xml">f=\sigma(k)</object>, we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/885829ecab969c96daed7f0df6e5864339ad9d8b.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial f}{\partial w}=f(1-f)\frac{\partial k}{\partial w}$</object> <p>The other new operation we'll have to find the derivative of is element-wise multiplication. Let's say we have the column vectors <em>x</em>, <em>y</em> and <em>z</em>, each with <em>m</em> rows, and we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/660b1e0dacc15aa3737b8170c3ecfdcbc6e77db4.svg" style="height: 18px;" type="image/svg+xml">z(x)=x\odot y</object>. Since <em>z</em> as a function of <em>x</em> has <em>m</em> inputs and <em>m</em> outputs, its Jacobian has dimensions [m,m].</p> <p><object class="valign-m6" data="https://eli.thegreenplace.net/images/math/0ab96cb4e5d8c6ba3ac8038fda07d518bbe1f388.svg" style="height: 18px;" type="image/svg+xml">D_{j}z_{i}</object> is the derivative of the i-th element of <em>z</em> w.r.t. the j-th element of <em>x</em>. For <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/660b1e0dacc15aa3737b8170c3ecfdcbc6e77db4.svg" style="height: 18px;" type="image/svg+xml">z(x)=x\odot y</object> this is non-zero only when <em>i</em> and <em>j</em> are equal, and in that case the derivative is <img alt="y_i" class="valign-m4" src="https://eli.thegreenplace.net/images/math/35c2ac2f82d0ff8f9011b596ed7e54bfcc55f471.png" style="height: 12px;" />.</p> <p>Therefore, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e6631f3b13f877a8bb7b3b6a0c0d2ca110ecce23.svg" style="height: 18px;" type="image/svg+xml">Dz(x)</object> is a square matrix with the elements of <em>y</em> on the diagonal and zeros elsewhere:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2450b2e2a827054f5d292822ff292eaa63c77d1b.svg" style="height: 97px;" type="image/svg+xml"> $Dz=\begin{bmatrix} y_1 &amp; 0 &amp; \cdots &amp; 0 \\ 0 &amp; y_2 &amp; \cdots &amp; 0 \\ \vdots &amp; \ddots &amp; \ddots &amp; \vdots \\ 0 &amp; 0 &amp; \cdots &amp; y_m \\ \end{bmatrix}$</object> <p>If we want to backprop some loss <em>L</em> through this function, we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/48b17da284ae52bc4b9fdeb7b98b73f398bd4458.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial L}{\partial x}=\frac{\partial L}{\partial z}Dz$</object> <p>As <em>x</em> has <em>m</em> elements, the right-hand side of this equation multiplies a [1,m] vector by a [m,m] matrix which is diagonal, resulting in element-wise multiplication with the matrix's diagonal elements. In other words:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/e2a6c0742fb006e35e3001d3b3d33f78316fb1e8.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial L}{\partial x}=\frac{\partial L}{\partial z}\odot y$</object> <p>In code, it looks like this:</p> <div class="highlight"><pre><span></span><span class="c1"># Assuming dz is the gradient of loss w.r.t. z; dz, y and dx are all</span> <span class="c1"># column vectors.</span> <span class="n">dx</span> <span class="o">=</span> <span class="n">dz</span> <span class="o">*</span> <span class="n">y</span> </pre></div> </div> <div class="section" id="model-quality"> <h2>Model quality</h2> <p>In the <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model/">post about min-char-rnn</a>, we've seen that the vanilla RNN generates fairly low quality text:</p> <blockquote> one, my dred, roriny. qued bamp gond hilves non froange saws, to mold his a work, you shirs larcs anverver strepule thunboler muste, thum and cormed sightourd so was rewa her besee pilman</blockquote> <p>The LSTM's generated text quality is somewhat better when trained with roughtly the same hyper-parameters:</p> <blockquote> the she, over is was besiving the fact to seramed for i said over he will round, such when a where, &quot;i went of where stood it at eye heardul rrawed only coside the showed had off with the refaurtoned</blockquote> <p>I'm fairly sure that it can be made to perform even better with larger memory vectors and more training data. That said, an even more advanced architecture can be helpful here. Moreover, since this is a <em>character</em>-based model, to really capture effects between words a few words apart we'll need a much deeper LSTM (I'm unrolling to 16 characters we can only capture 2-3 words), and hence much more training data and time.</p> <p>Once again, the goal here is not to develop a state-of-the-art language model, but to show a simple, comprehensible example of how and LSTM is implemented end-to-end in Python code. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/min-char-lstm.py">The full code is here</a> - please let me know if you find any issues with it or something still remains unclear.</p> </div> Understanding how to implement a character-based RNN language model2018-05-25T05:20:00-07:002018-05-25T05:20:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-05-25:/2018/understanding-how-to-implement-a-character-based-rnn-language-model/<p>In <a class="reference external" href="https://gist.github.com/karpathy/d4dee566867f8291f086">a single gist</a>, <a class="reference external" href="https://cs.stanford.edu/people/karpathy/">Andrej Karpathy</a> did something truly impressive. In a little over 100 lines of Python - without relying on any heavy-weight machine learning frameworks - he presents a fairly complete implementation of training a character-based recurrent neural network (RNN) language model; this includes the full backpropagation learning with Adagrad …</p><p>In <a class="reference external" href="https://gist.github.com/karpathy/d4dee566867f8291f086">a single gist</a>, <a class="reference external" href="https://cs.stanford.edu/people/karpathy/">Andrej Karpathy</a> did something truly impressive. In a little over 100 lines of Python - without relying on any heavy-weight machine learning frameworks - he presents a fairly complete implementation of training a character-based recurrent neural network (RNN) language model; this includes the full backpropagation learning with Adagrad optimization.</p> <p>I love such minimal examples because they allow me to understand some topic in full depth, connecting the math to the code and having a complete picture of how everything works. In this post I want to present a companion explanation to Karpathy's gist, showing the diagrams and math that hide in its Python code.</p> <p>My own fork of the code <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/min-char-rnn.py">is here</a>; it's semantically equivalent to Karpathy's gist, but includes many more comments and some debugging options. I won't reproduce the whole program here; instead, the idea is that you'd go through the code while reading this article. The diagrams, formulae and explanations here are complementary to the code comments.</p> <div class="section" id="what-rnns-do"> <h2>What RNNs do</h2> <p>I expect readers to have a basic idea of what RNN do and why they work well for some problems. RNN are well-suited for problem domains where the input (and/or output) is some sort of a sequence - time-series financial data, words or sentences in natural language, speech, etc.</p> <p>There is <em>a lot</em> of material about this online, and the basics are easy to understand for anyone with even a bit of machine learning background. However, there is not enough coherent material online about how RNNs are implemented and trained - this is the goal of this post.</p> </div> <div class="section" id="character-based-rnn-language-model"> <h2>Character-based RNN language model</h2> <p>The basic structure of <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> is represented by this recurrent diagram, where <em>x</em> is the input vector (at time step <em>t</em>), <em>y</em> is the output vector and <em>h</em> is the <em>state vector</em> kept inside the model.</p> <img alt="Basic RNN diagram" class="align-center" src="https://eli.thegreenplace.net/images/2018/rnnbasic.png" /> <p>The line leaving and returning to the cell represents that the state is retained between invocations of the network. When a new time step arrives, some things are still the same (the weights inherent to the network, as we shall soon see) but some things are different - <em>h</em> may have changed. Therefore, unlike stateless NNs, <em>y</em> is not simply a function of <em>x</em>; in RNNs, identical <em>x</em>s can produce different <em>y</em>s, because <em>y</em> is a function of <em>x</em> and <em>h</em>, and <em>h</em> can change between steps.</p> <p>The <em>character-based</em> part of the model's name means that every input vector represents a single character (as opposed to, say, a word or part of an image). <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> uses one-hot vectors to represent different characters.</p> <p>A <em>language model</em> is a particular kind of machine learning algorithm that learns the statistical structure of language by &quot;reading&quot; a large corpus of text. This model can then reproduce authentic language segments - by predicting the next character (or word, for word-based models) based on past characters.</p> </div> <div class="section" id="internal-structure-of-the-rnn-cell"> <h2>Internal-structure of the RNN cell</h2> <p>Let's proceed by looking into the internal structure of the RNN cell in <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt>:</p> <img alt="RNN cell for min-char-rnn" class="align-center" src="https://eli.thegreenplace.net/images/2018/min-char-rnn-cell.png" /> <ul class="simple"> <li>Bold-faced symbols in reddish color are the model's parameters, weights for matrix multiplication and biases.</li> <li>The state vector <em>h</em> is shown twice - once for its past value, and once for its currently computed value. Whenever the RNN cell is invoked in sequence, the last computed state <em>h</em> is passed in from the left.</li> <li>In this diagram <em>y</em> is not the final answer of the cell - we compute a softmax function on it to obtain <em>p</em> - the probabilities for output characters <a class="footnote-reference" href="#id7" id="id1"></a>. I'm using these symbols for consistency with the code of <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt>, though it would probably be more readable to flip the uses of <em>p</em> and <em>y</em> (making <em>y</em> the actual output of the cell).</li> </ul> <p>Mathematically, this cell computes:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/886e94d526c2e538f1ba4414696ae9bf6618f0ff.svg" style="height: 82px;" type="image/svg+xml"> \begin{align*} h^{[t]}&amp;=tanh(W_{hh}\cdot h^{[t-1]}+W_{xh}\cdot x^{[t]}+b_h)\\ y^{[t]}&amp;=W_{hy}\cdot h^{[t]}+b_y\\ p^{[t]}&amp;=softmax(y^{[t]}) \end{align*}</object> </div> <div class="section" id="learning-model-parameters-with-backpropagation"> <h2>Learning model parameters with backpropagation</h2> <p>This section will examine how we can <em>learn</em> the parameters <em>W</em> and <em>b</em> for this model. Mostly it's standard neural-network fare; we'll compute the derivatives of all the steps involved and will then employ backpropagation to find a parameter update based on some computed loss.</p> <p>There's one serious issue we'll have to address first. Backpropagation is usually defined on <em>acyclic</em> graphs, so it's not entirely clear how to apply it to our RNN. Is <em>h</em> an input? An output? Both? In the original high-level diagram of the RNN cell, <em>h</em> is both an input and an output - how can we compute the gradient for it when we don't know its value yet? <a class="footnote-reference" href="#id8" id="id2"></a></p> <p>The way out of this conundrum is to <em>unroll</em> the RNN for a few steps. Note that we're already doing this in the detailed diagram by distinguishing between <object class="valign-0" data="https://eli.thegreenplace.net/images/math/057276c060e575533321773afb483e778e6a03f1.svg" style="height: 16px;" type="image/svg+xml">h^{[t]}</object> and <object class="valign-0" data="https://eli.thegreenplace.net/images/math/e4bc0503e20a8e6b82d9c86e10eb2c8e1dfe3471.svg" style="height: 16px;" type="image/svg+xml">h^{[t-1]}</object>. This makes every RNN cell <em>locally acyclic</em>, which makes it possible to use backpropagation on it. This approach has a cool-sounding name - <em>Backpropagation Through Time</em> (BPTT) - although it's really the same as regular backpropagation.</p> <p>Note that the architecture used here is called &quot;synced many-to-many&quot; in Karpathy's <a class="reference external" href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">Unreasonable Effectiveness of RNNs post</a>, and it's useful for training a simple char-based language model - we immediately observe the output sequence produced by the model while reading the input. Similar unrolling can be applied to other architectures, like encoder-decoder.</p> <p>Here's our RNN again, unrolled for 3 steps:</p> <img alt="Unrolled RNN diagram" class="align-center" src="https://eli.thegreenplace.net/images/2018/rnnunroll.png" /> <p>Now the same diagram, with the gradient flows depicted with orange-ish arrows:</p> <img alt="Unrolled RNN diagram with gradient flow arrows shown" class="align-center" src="https://eli.thegreenplace.net/images/2018/rnnunrollgrad.png" /> <p>With this unrolling, we have everything we need to compute the actual weight updates during learning, because when we want to compute the gradients through step 2, we already have the incoming gradient <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/41bad72882e3d266373df060e8ab3ce36a819679.svg" style="height: 18px;" type="image/svg+xml">\Delta h</object>, and so on.</p> <p>Do you now wonder what is <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/0bdf36986644d54bc1bccf1410d2b9f0f86cf697.svg" style="height: 18px;" type="image/svg+xml">\Delta h[t]</object> for the final step at time <em>t</em>?</p> <p>In some models, sequence lengths are fairly limited. For example, when we translate a single sentence, the sequence length is rarely over a couple dozen words; for such models we can fully unroll the RNN. The <em>h</em> state output of the final step doesn't really &quot;go anywhere&quot;, and we assume its gradient is zero. Similarly, the incoming state <em>h</em> for the first step is zero.</p> <p>Other models work on potentially infinite sequence lengths, or sequences much too long for unrolling. The language model in <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> is a good example, because it can theoretically ingest and emit text of any length. For these models we'll perform <em>truncated</em> BPTT, by just assuming that the influence of the current state extends only <em>N</em> steps into the future. We'll then unroll the model <em>N</em> times and assume that <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/9851a3637afe3f6d70466ac3a1d1c104935647fd.svg" style="height: 18px;" type="image/svg+xml">\Delta h[N]</object> is zero. Although it really isn't, for a large enough <em>N</em> this is a fairly safe assumption. RNNs are hard to train on very long sequences for other reasons, anyway (we'll touch upon this point again towards the end of the post).</p> <p>Finally, it's important to remember that although we unroll the RNN cells, all parameters (weights, biases) are <em>shared</em>. This plays an important part in ensuring <em>translation invariance</em> for the models - patterns learned in one place apply to another place <a class="footnote-reference" href="#id9" id="id3"></a>. It leaves the question of how to update the weights, since we compute gradients for them separately in each step. The answer is very simple - just add them up. This is similar to other cases where the output of a cell branches off in two directions - when gradients are computed, their values are added up along the branches - this is just the basic chain rule in action.</p> <p>We now have all the necessary background to understand how an RNN learns. What remains before looking at the code is figuring out how the gradients propagate <em>inside</em> the cell; in other words, the derivatives of each operation comprising the cell.</p> </div> <div class="section" id="flowing-the-gradient-inside-an-rnn-cell"> <h2>Flowing the gradient inside an RNN cell</h2> <p>As we saw above, the formulae for computing the cell's output from its inputs are:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/886e94d526c2e538f1ba4414696ae9bf6618f0ff.svg" style="height: 82px;" type="image/svg+xml"> \begin{align*} h^{[t]}&amp;=tanh(W_{hh}\cdot h^{[t-1]}+W_{xh}\cdot x^{[t]}+b_h)\\ y^{[t]}&amp;=W_{hy}\cdot h^{[t]}+b_y\\ p^{[t]}&amp;=softmax(y^{[t]}) \end{align*}</object> <p>To be able to learn weights, we have to find the derivatives of the cell's output w.r.t. the weights. The full backpropagation process was explained <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus/">in this post</a>, so here is only a brief refresher.</p> <p>Recall that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f75a9c33c546d725557a4d452769bfd8fbb6cc22.svg" style="height: 20px;" type="image/svg+xml">p^{[t]}</object> is the predicted output; we compare it with the &quot;real&quot; output (<object class="valign-0" data="https://eli.thegreenplace.net/images/math/e44181afdf5e5f0f8ad4379f7d5f3ff924379c82.svg" style="height: 16px;" type="image/svg+xml">r^{[t]}</object>) during learning, to find the loss (error):</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/788a210d4bab7831a28e0ae7713ff9c1cd5aef12.svg" style="height: 22px;" type="image/svg+xml"> $L=L(p^{[t]}, r^{[t]})$</object> <p>To perform a gradient descent update, we'll need to find <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/c9e2c4ffca9564929c45a5244c7fb064465ab005.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial L}{\partial w}</object>, for every weight value <em>w</em>. To do this, we'll have to:</p> <ol class="arabic simple"> <li>Find the &quot;local&quot; gradients for every mathematical operation leading from <em>w</em> to <em>L</em>.</li> <li>Use the chain rule to propagate the error backwards through these local gradients until we find <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/c9e2c4ffca9564929c45a5244c7fb064465ab005.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial L}{\partial w}</object>.</li> </ol> <p>We start by formulating the chain rule to compute <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/c9e2c4ffca9564929c45a5244c7fb064465ab005.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial L}{\partial w}</object>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/45ad1052f6c6b78265143f4d41f2f12f1714ebfb.svg" style="height: 45px;" type="image/svg+xml"> $\frac{\partial L}{\partial w}=\frac{\partial L}{\partial p^{[t]}}\frac{\partial p^{[t]}}{\partial w}$</object> <p>Next comes:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/f7a68f105f4e7483f2781a7bebeaad0ce659bf06.svg" style="height: 45px;" type="image/svg+xml"> $\frac{\partial p^{[t]}}{\partial w}=\frac{\partial softmax}{\partial y^{[t]}}\frac{\partial y^{[t]}}{\partial w}$</object> <p>Let's say the weight <em>w</em> we're interested in is part of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object>, so we have to propagate some more:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/53bcf70971d45064242463ddfad70e3ba6fb0ec9.svg" style="height: 42px;" type="image/svg+xml"> $\frac{\partial y^{[t]}}{\partial w}=\frac{\partial y^{[t]}}{\partial h^{[t]}}\frac{\partial h^{[t]}}{\partial w}$</object> <p>We'll then proceed to propagate through the <em>tanh</em> function, bias addition and finally the multiplication by <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object>, for which the derivative by <em>w</em> is computed directly without further chaining.</p> <p>Let's now see how to compute all the relevant local gradients.</p> </div> <div class="section" id="cross-entropy-loss-gradient"> <h2>Cross-entropy loss gradient</h2> <p>We'll start with the derivative of the loss function, which is cross-entropy in the <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> model. I went through a detailed derivation of the gradient of softmax followed by cross-entropy in <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">this post</a>; here is only a brief recap:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/b26f68a12667ba254facf9815252f52ebf2238d9.svg" style="height: 38px;" type="image/svg+xml"> $xent(p,q)=-\sum_{k}p(k)log(q(k))$</object> <p>Re-formulating this for our specific case, the loss is a function of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f75a9c33c546d725557a4d452769bfd8fbb6cc22.svg" style="height: 20px;" type="image/svg+xml">p^{[t]}</object>, assuming the &quot;real&quot; class <em>r</em> is constant for every training example:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/9ff2ef0e3dbe188129b93dddeb12759fdf909bcb.svg" style="height: 39px;" type="image/svg+xml"> $L(p^{[t]})=-\sum_{k}r(k)log(p^{[t]}(k))$</object> <p>Since inputs and outputs to the cell are 1-hot encoded, let's just use <em>r</em> to denote the index where <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/20aea9dc9718c5f2b3e11b3ebec11518202f0af1.svg" style="height: 18px;" type="image/svg+xml">r(k)</object> is non-zero. Then the Jacobian of <em>L</em> is only non-zero at index <em>r</em> and its value there is <object class="valign-m11" data="https://eli.thegreenplace.net/images/math/c4efb22a708d798abd641a16679976b8829f500d.svg" style="height: 27px;" type="image/svg+xml">-\frac{1}{p^{[t]}}(r)</object>.</p> </div> <div class="section" id="softmax-gradient"> <h2>Softmax gradient</h2> <p>A detailed computation of the gradient for the softmax function was also presented in <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">this post</a>. For <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a197fbc6c2f0e9e1d6b4c51c6fca2756927a3055.svg" style="height: 18px;" type="image/svg+xml">S(y)</object> being the softmax of <em>y</em>, the Jacobian is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/87fbe94e6b409d31b512cb7a4581c24907d4dd4a.svg" style="height: 42px;" type="image/svg+xml"> $D_{j}S_{i}=\frac{\partial S_i}{\partial y_j}=S_{i}(\delta_{ij}-S_j)$</object> </div> <div class="section" id="fully-connected-layer-gradient"> <h2>Fully-connected layer gradient</h2> <p>Next on our path backwards is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/ead3bdf11cf41b04164a83008db4a7dd0db5a074.svg" style="height: 24px;" type="image/svg+xml"> $y^{[t]}&amp;=W_{hy}\cdot h^{[t]}+b_y$</object> <p>From my earlier <a class="reference external" href="http://eli.thegreenplace.net/2018/backpropagation-through-a-fully-connected-layer/">post on backpropagating through a fully-connected layer</a>, we know that <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/413d530fbd3e019cc3f49aec6e8f7cb7a8f0c622.svg" style="height: 29px;" type="image/svg+xml">\frac{\partial y^{[t]}}{\partial h^{[t]}}=W_{hy}</object>. But that's not all; note that on the forward pass <object class="valign-0" data="https://eli.thegreenplace.net/images/math/057276c060e575533321773afb483e778e6a03f1.svg" style="height: 16px;" type="image/svg+xml">h^{[t]}</object> splits in two - one edge goes into the fully-connected layer, another goes to the next RNN cell as the state. When we backpropagate the loss gradient to <object class="valign-0" data="https://eli.thegreenplace.net/images/math/057276c060e575533321773afb483e778e6a03f1.svg" style="height: 16px;" type="image/svg+xml">h^{[t]}</object>, we have to take both edges into account; more specifically, we have to <em>add</em> the gradients along the two edges. This leads to the following backpropagation equation:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/e7d5afe050e8e2f3f4b867ae4d9eb510fbe2e583.svg" style="height: 45px;" type="image/svg+xml"> $\frac{\partial L}{\partial h^{[t]}} = \frac{\partial y^{[t]}}{\partial h^{[t]}}\frac{\partial L}{\partial y^{[t]}}+\frac{\partial L}{\partial h^{[t+1]}}\frac{\partial h^{[t+1]}}{\partial h^{[t]}} =W_{hy}\cdot \frac{\partial L}{\partial y^{[t]}}+\frac{\partial L}{\partial h^{[t+1]}}\frac{\partial h^{[t+1]}}{\partial h^{[t]}}$</object> <p>In addition, note that this layer already has model parameters that need to be learned - <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/20c1b19b6e71072b92080b2eb00b5b99123cf057.svg" style="height: 18px;" type="image/svg+xml">W_{hy}</object> and <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/9bd872acdafb9ea752b3ba10b2670499cb65469f.svg" style="height: 19px;" type="image/svg+xml">b_y</object> - a &quot;final&quot; destination for backpropagation. Please refer to my fully-connected layer backpropagation post to see how the gradients for these are computed.</p> </div> <div class="section" id="gradient-of-tanh"> <h2>Gradient of tanh</h2> <p>The vector <object class="valign-0" data="https://eli.thegreenplace.net/images/math/057276c060e575533321773afb483e778e6a03f1.svg" style="height: 16px;" type="image/svg+xml">h^{[t]}</object> is produced by applying a hyperbolic tangent nonlinearity to another fully connected layer.</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4ce55619f9cd4d96083ec3dadf303cc83a426543.svg" style="height: 22px;" type="image/svg+xml"> $h^{[t]}&amp;=tanh(W_{hh}\cdot h^{[t-1]}+W_{xh}\cdot x^{[t]}+b_h)$</object> <p>To get to the model parameters <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object>, <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/4ee22236c608ad6f49adc4807465b6e6896092ec.svg" style="height: 15px;" type="image/svg+xml">W_{xh}</object> and <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/14e5d8599f43750d0cf9dda2d90b085c69079049.svg" style="height: 16px;" type="image/svg+xml">b_h</object>, we have to first backpropagate the loss gradient through <em>tanh</em>. <em>tanh</em> is a scalar function; when it's applied to a vector we apply it in <em>element-wise</em> fashion to every element in the vector independently, and collect the results in a similarly-shaped result vector.</p> <p>Its mathematical definition is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/326a49518bfe326c6be2de37838971407fa5175d.svg" style="height: 39px;" type="image/svg+xml"> $tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}$</object> <p>To find the derivative of this function, we'll use the formula for deriving a ratio:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/9e006cf5e9f1f8ccac82ba1f2bcdabd710731756.svg" style="height: 42px;" type="image/svg+xml"> $(\frac{f}{g})&#x27;=\frac{f&#x27;g-g&#x27;f}{g^2}$</object> <p>So:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/dc73b540394a92b823ec3eaabe4d02a7735f146f.svg" style="height: 43px;" type="image/svg+xml"> $tanh&#x27;(x)=\frac{(e^x+e^{-x})(e^x+e^{-x})-(e^x-e^{-x})(e^x-e^{-x})}{(e^x+e^{-x})^2}=1-(tanh(x))^2$</object> <p>Just like for softmax, it turns out that there's a convenient way to express the derivative of <em>tanh</em> in terms of <em>tanh</em> itself. When we apply the chain rule to derivatives of <em>tanh</em>, for example: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7f748d5017b1817f6d3912d339e85871b81d93b4.svg" style="height: 18px;" type="image/svg+xml">h=tanh(k)</object> where <em>k</em> is a function of <em>w</em>. We get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/9137818273fdbac7d5dc0e05df4fcf3f8cb7ea9d.svg" style="height: 39px;" type="image/svg+xml"> $\frac{\partial h}{\partial w}=\frac{\partial tanh(k)}{\partial k}\frac{\partial k}{\partial w}=(1-h^2)\frac{\partial k}{\partial w}$</object> <p>In our case <em>k(w)</em> is a fully-connected layer; to find its derivatives w.r.t. the weight matrices and bias, please refer to the <a class="reference external" href="http://eli.thegreenplace.net/2018/backpropagation-through-a-fully-connected-layer/">backpropagation through a fully-connected layer post</a>.</p> </div> <div class="section" id="learning-model-parameters-with-adagrad"> <h2>Learning model parameters with Adagrad</h2> <p>We've just went through all the major parts of the RNN cell and computed local gradients. Armed with these formulae and the chain rule, it should be possible to understand how the <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> code flows the loss gradient backwards. But that's not the end of the story; once we have the loss derivatives w.r.t. to some model parameter, how do we update this parameter?</p> <p>The most straightforward way to do this would be using the gradient descent algorithm, with some constant learning rate. <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent/">I've written about gradient descent</a> in the past - please take a look for a refresher.</p> <p>Most real-world learning is done with more advanced algorithms these days, however. One such algorithm is called Adagrad, <a class="reference external" href="http://jmlr.org/papers/v12/duchi11a.html">proposed in 2011</a> by some experts in mathematical optimization. <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> happens to use Adagrad, so here is a simplified explanation of how it works.</p> <p>The main idea is to adjust the learning rate separately per parameter, because in practice some parameters change much more often than others. This could be due to rare examples in the training data set that affect a parameter that's not often affected; we'd like to amplify these changes because they are rare, and dampen changes to parameters that change often.</p> <p>Therefore the Adagrad algorithm works as follows:</p> <div class="highlight"><pre><span></span><span class="c1"># Same shape as the parameter array x</span> <span class="n">memory</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">while</span> <span class="bp">True</span><span class="p">:</span> <span class="n">dx</span> <span class="o">=</span> <span class="n">compute_grad</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Elementwise: each memory element gets the corresponding dx^2 added to it.</span> <span class="n">memory</span> <span class="o">+=</span> <span class="n">dx</span> <span class="o">*</span> <span class="n">dx</span> <span class="c1"># The actual parameter update for this step. Note how the learning rate is</span> <span class="c1"># modified by the memory. epsilon is some very small number to avoid dividing</span> <span class="c1"># by 0.</span> <span class="n">x</span> <span class="o">-=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">dx</span> <span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">memory</span><span class="p">)</span> <span class="o">+</span> <span class="n">epsilon</span><span class="p">)</span> </pre></div> <p>If a given element in <tt class="docutils literal">dx</tt> was updated significantly in the past, its corresponding <tt class="docutils literal">memory</tt> element will grow and thus the learning rate is effectively decreased.</p> </div> <div class="section" id="gradient-clipping"> <h2>Gradient clipping</h2> <p>If we unroll the RNN cell 10 times, the gradient will be multiplied by <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object> ten times on its way from the last cell to the first. For some structures of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object>, this may lead to an &quot;exploding gradient&quot; effect where the value keeps growing <a class="footnote-reference" href="#id10" id="id5"></a>.</p> <p>To mitigate this, <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> uses the <em>gradient clipping</em> trick. Whenever the gradients are updated, they are &quot;clipped&quot; to some reasonable range (like -5 to 5) so they will never get out of this range. This method is crude, but it works reasonably well for training RNNs.</p> <p>The flip side problem of <em>vanishing gradient</em> (wherein the gradients keep getting smaller with each step) is much harder to solve, and usually requires more advanced recurrent NN architectures.</p> </div> <div class="section" id="min-char-rnn-model-quality"> <h2>min-char-rnn model quality</h2> <p>While <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> is a complete RNN implementation that manages to learn, it's not really good enough for learning a reasonable model for the English language. The model is too simple for this, and suffers seriously from the vanishing gradient problem.</p> <p>For example, when training a 16-step unrolled model on a corpus of Sherlock Holmes books, it produces the following text after 60,000 iterations (learning on about a MiB of text):</p> <blockquote> one, my dred, roriny. qued bamp gond hilves non froange saws, to mold his a work, you shirs larcs anverver strepule thunboler muste, thum and cormed sightourd so was rewa her besee pilman</blockquote> <p>It's not complete gibberish, but not really English either. Just for fun, I wrote a simple <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/markov-model.py">Markov chain generator</a> and trained it on the same text with a 4-character state. Here's a sample of its output:</p> <blockquote> though throughted with to taken as when it diabolice, and intered the stairhead, the stood initions of indeed, as burst, his mr. holmes' room, and now i fellows. the stable. he retails arm</blockquote> <p>Which, you'll admit, is quite a bit better than our &quot;fancy&quot; deep learning approach! And it was much faster to train too...</p> <p>To have a better chance of learning a good model, we'll need a more advanced architecture like LSTM. LSTMs employ a bunch of tricks to preserve long-term dependencies through the cells and can learn much better language models. For example, Andrej Karpathy's char-rnn model from the <a class="reference external" href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">Unreasonable Effectiveness of RNNs post</a> is a multi-layer LSTM, and it can learn fairly nice models for a varied set of domains, ranging from Shakespeare sonnets to C code snippets in the Linux kernel.</p> </div> <div class="section" id="conclusion"> <h2>Conclusion</h2> <p>The goal of this post wasn't to develop a very good RNN model; rather, it was to explain in detail the math behind training a simple RNN. More advanced RNN architerctures like LSTM are somewhat more complicated, but all the core ideas are very similar and this post should be helpful in nailing the basics.</p> <p><em>Update:</em> <a class="reference external" href="https://eli.thegreenplace.net/2018/minimal-character-based-lstm-implementation/">An extension of this post to LSTMs</a>.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id7" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td><p class="first">Computing a softmax makes sense because <em>x</em> is encoded with one-hot over a vocabulary-sized vector, meaning there's a 1 in the position of the letter it represents with 0s in all other positions. For example, is we only care about the 26 lower-case alphabet letters, <em>x</em> could be a 26-element vector. To represent 'a' it would have 1 in position 0 and zeros elsewhere; to represent 'd' it would have 1 in position 3 and zeros elsewhere.</p> <p class="last">The output <em>p</em> here models what the RNN cell thinks the next generated character should be. Using softmax, it would have probabilities for each character in the corresponding position, all of them properly summing up to 1.</p> </td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td><p class="first">A slightly more technical explanation: to compute the gradient for the error w.r.t. weights in the typical backpropagation flow, we'll need input gradients for <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/0ede500c5edc819b5f962923f98724936ef9d593.svg" style="height: 18px;" type="image/svg+xml">p[t]</object> and <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/897ad2ab624c79d6dcb687ad28f7a3767a76712c.svg" style="height: 18px;" type="image/svg+xml">h[t]</object>. Then, when learning happens we use the measured error and propagate it backwards. But what is the measured error for <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/897ad2ab624c79d6dcb687ad28f7a3767a76712c.svg" style="height: 18px;" type="image/svg+xml">h[t]</object>? We don't know it before we compute the error of the next iteration, and so on - a bit of a chicken-egg problem.</p> <p class="last">Unrolling/BPTT helps approximate a solution for this issue. An alternative solution is to use <em>forward-mode</em> gradient propagation instead, with an algorithm called RTRL (Real Time Recurrent Learning). This algorithm works well but has a high computational cost compared to BPTT. I'd love to explore this topic in more depth, as it ties into the difference between forward-mode and reverse-mode auto differentiation. But that would be a topic for another post.</p> </td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id9" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>This is similar to convolutional networks, where the convolution filter weights are reused many times when processing a much larger input. In such models the invariance is <em>spatial</em>; in sequence models the invariance is <em>temporal</em>. In fact, space vs. time in models is just a matter of convention, and it turns out that 1D convolutional models perform very well on some sequence tasks!</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id5"></a></td><td><p class="first">An easy way to think about it is to imagine some initial value <em>v</em>, multiplied by another value <em>c</em> many times. We get <object class="valign-0" data="https://eli.thegreenplace.net/images/math/f24f59611d0e1f3043785fe772138687cfd6da97.svg" style="height: 15px;" type="image/svg+xml">vc^N</object> for <em>N</em> multiplications. If <em>c</em> is larger than 1, it means the result will keep growing with each multiplication. How quickly will depend on the actual value of <em>c</em>, but this is basically an exponential runoff. We actually care about the absolute value of <em>c</em>, of course, since runoff is equally bad in the positive or negative direction.</p> <p class="last">Similarly with the absolute value of <em>c</em> smaller than 1, we'll get a &quot;vanishing&quot; effect since the result will keep getting smaller with each iteration.</p> </td></tr> </tbody> </table> </div> Backpropagation through a fully-connected layer2018-05-22T05:47:00-07:002018-05-22T05:47:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-05-22:/2018/backpropagation-through-a-fully-connected-layer/<p>The goal of this post is to show the math of backpropagating a derivative for a fully-connected (FC) neural network layer consisting of matrix multiplication and bias addition. I have briefly mentioned this in an <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">earlier post dedicated to Softmax</a>, but here I want to give some more attention to …</p><p>The goal of this post is to show the math of backpropagating a derivative for a fully-connected (FC) neural network layer consisting of matrix multiplication and bias addition. I have briefly mentioned this in an <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">earlier post dedicated to Softmax</a>, but here I want to give some more attention to FC layers specifically.</p> <p>Here is a fully-connected layer for input vectors with <em>N</em> elements, producing output vectors with <em>T</em> elements:</p> <img alt="Diagram of a fully connected layer" class="align-center" src="https://eli.thegreenplace.net/images/2018/fclayer.png" /> <p>As a formula, we can write:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/0f980ab4c97ad86b4d0a15ede6e9c05901323702.svg" style="height: 17px;" type="image/svg+xml"> $y=Wx+b$</object> <p>Presumably, this layer is part of a network that ends up computing some loss <em>L</em>. We'll assume we already have the derivative of the loss w.r.t. the output of the layer <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/9a5154f5e8d64cc77db745d8d3baa723bc6df829.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial{L}}{\partial{y}}</object>.</p> <p>We'll be interested in two other derivatives: <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/33d2709b664fdd69317758b433b61b13c1cdc62f.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{W}}</object> and <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/5f12a50803653cf2ee02135944343ec70506d31c.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{b}}</object>.</p> <div class="section" id="jacobians-and-the-chain-rule"> <h2>Jacobians and the chain rule</h2> <p>As a reminder from <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus">The Chain Rule of Calculus</a>, we're dealing with functions that map from <em>n</em> dimensions to <em>m</em> dimensions: <img alt="f:\mathbb{R}^{n} \to \mathbb{R}^{m}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/13f219789047343729036279bb11630db317d98d.png" style="height: 16px;" />. We'll consider the outputs of <em>f</em> to be numbered from 1 to <em>m</em> as <img alt="f_1,f_2 \dots f_m" class="valign-m4" src="https://eli.thegreenplace.net/images/math/93b446c5209263534d09d617bbede21101d6536e.png" style="height: 16px;" />. For each such <img alt="f_i" class="valign-m4" src="https://eli.thegreenplace.net/images/math/68bd0dc647944d362ec8df628a22967b91d82c80.png" style="height: 16px;" /> we can compute its partial derivative by any of the <em>n</em> inputs as:</p> <img alt="$D_j f_i(a)=\frac{\partial f_i}{\partial a_j}(a)$" class="align-center" src="https://eli.thegreenplace.net/images/math/30881b5a92e45259714ba01c7a12fbf8f6c56109.png" style="height: 42px;" /> <p>Where <em>j</em> goes from 1 to <em>n</em> and <em>a</em> is a vector with <em>n</em> components. If <em>f</em> is differentiable at <em>a</em> then the derivative of <em>f</em> at <em>a</em> is the <em>Jacobian matrix</em>:</p> <img alt="$Df(a)=\begin{bmatrix} D_1 f_1(a) &amp;amp; \cdots &amp;amp; D_n f_1(a) \\ \vdots &amp;amp; &amp;amp; \vdots \\ D_1 f_m(a) &amp;amp; \cdots &amp;amp; D_n f_m(a) \\ \end{bmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/ab09367d48e9ef4d8bc2314a60313dec700193af.png" style="height: 76px;" /> <p>The multivariate chain rule states: given <img alt="g:\mathbb{R}^n \to \mathbb{R}^m" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b4b7d25491897b053abf7e48688fada4a85368bd.png" style="height: 16px;" /> and <img alt="f:\mathbb{R}^m \to \mathbb{R}^p" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ac8a6cea4e02e885538fc3ef969c5733e84712f9.png" style="height: 16px;" /> and a point <img alt="a \in \mathbb{R}^n" class="valign-m1" src="https://eli.thegreenplace.net/images/math/43a85f2c59f396fe5c4e2c403a0453c463fcfb0d.png" style="height: 13px;" />, if <em>g</em> is differentiable at <em>a</em> and <em>f</em> is differentiable at <img alt="g(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e7373233d49e18a0882e0dce41d9d6aa26964d6b.png" style="height: 18px;" /> then the composition <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" /> is differentiable at <em>a</em> and its derivative is:</p> <img alt="$D(f \circ g)(a)=Df(g(a)) \cdot Dg(a)$" class="align-center" src="https://eli.thegreenplace.net/images/math/00bdefa904bd34df2dfb50cc385e6497c4e5096e.png" style="height: 18px;" /> <p>Which is the matrix multiplication of <img alt="Df(g(a))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e567730c48bb2f95c258b630b4d6e997043e09ab.png" style="height: 18px;" /> and <img alt="Dg(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2575fc98e794a733a7aa6237fe67246a41e6c8c5.png" style="height: 18px;" />.</p> </div> <div class="section" id="back-to-the-fully-connected-layer"> <h2>Back to the fully-connected layer</h2> <p>Circling back to our fully-connected layer, we have the loss <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/abf7408ae6d9fb4683480735dc1ebc8555b8fef8.svg" style="height: 18px;" type="image/svg+xml">L(y)</object> - a scalar function <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/ddef8b9ca23fb246b2a984c719d812f37a41a406.svg" style="height: 16px;" type="image/svg+xml">L:\mathbb{R}^{T} \to \mathbb{R}</object>. We also have the function <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f09c295439296549a068b64ffe69a48dd77d1078.svg" style="height: 17px;" type="image/svg+xml">y=Wx+b</object>. If we're interested in the derivative w.r.t the weights, what are the dimensions of this function? Our &quot;variable part&quot; is then <em>W</em>, which has <em>NT</em> elements overall, and the output has <em>T</em> elements, so <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/06178f1d07375b8286afcd48f02bcd34d71537f0.svg" style="height: 19px;" type="image/svg+xml">y:\mathbb{R}^{NT} \to \mathbb{R}^{T}</object> <a class="footnote-reference" href="#id3" id="id1"></a>.</p> <p>The chain rule tells us how to compute the derivative of <em>L</em> w.r.t. <em>W</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/ee6bc25a34980031f93f0c7eefccc40663b05c76.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial{L}}{\partial{W}}=D(L \circ y)(W)=DL(y(W)) \cdot Dy(W)$</object> <p>Since we're backpropagating, we already know <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/dcb2eda345045dac22c425a1ee19113e047126cf.svg" style="height: 18px;" type="image/svg+xml">DL(y(W))</object>; because of the dimensionality of the <em>L</em> function, the dimensions of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/dcb2eda345045dac22c425a1ee19113e047126cf.svg" style="height: 18px;" type="image/svg+xml">DL(y(W))</object> are [1,T] (one row, <em>T</em> columns). <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d0064a180ddb231bb6868ce25c68ef3ec1c2a464.svg" style="height: 18px;" type="image/svg+xml">y(W)</object> has <em>NT</em> inputs and <em>T</em> outputs, so the dimensions of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b22fe7345e02ae50c68605696f3a447435cd1f9d.svg" style="height: 18px;" type="image/svg+xml">Dy(W)</object> are [T,NT]. Overall, the dimensions of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2f6b6eded4ba20b3eeb59b2b687f84de1e91c04c.svg" style="height: 18px;" type="image/svg+xml">D(L \circ y)(W)</object> are then [1,NT]. This makes sense if you think about it, because as a function of <em>W</em>, the loss has <em>NT</em> inputs and a single scalar output.</p> <p>What remains is to compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b22fe7345e02ae50c68605696f3a447435cd1f9d.svg" style="height: 18px;" type="image/svg+xml">Dy(W)</object>, the Jacobian of <em>y</em> w.r.t. <em>W</em>. As mentioned above, it has <em>T</em> rows - one for each output element of <em>y</em>, and <em>NT</em> columns - one for each element in the weight matrix <em>W</em>. Computing such a large Jacobian may seem daunting, but we'll soon see that it's very easy to generalize from a simple example. Let's start with <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6a53741f2a8810da3cae4efadde63c8e7ee2662f.svg" style="height: 12px;" type="image/svg+xml">y_1</object>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7190e002ac69968b674aecacfd5a8531ad9cd208.svg" style="height: 55px;" type="image/svg+xml"> $y_1=\sum_{j=1}^{N}W_{1,j}x_{j}+b_1$</object> <p>What's the derivative of this result element w.r.t. each element in <em>W</em>? When the element is in row 1, the derivative is <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/73058e43db0f4edc791b10f27f913cbc5d361ab6.svg" style="height: 14px;" type="image/svg+xml">x_j</object> (<em>j</em> being the column of <em>W</em>); when the element is in any other row, the derivative is 0.</p> <p>Similarly for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b9f59182e34baa532fa4e27471acc714f3105d16.svg" style="height: 12px;" type="image/svg+xml">y_2</object>, we'll have non-zero derivatives only for the second row of <em>W</em> (with the same result of <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/73058e43db0f4edc791b10f27f913cbc5d361ab6.svg" style="height: 14px;" type="image/svg+xml">x_j</object> being the derivative for the <em>j</em>-th column), and so on.</p> <p>Generalizing from the example, if we split the index of <em>W</em> to <em>i</em> and <em>j</em>, we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/e28e0d3b44645eb299cceae8dde2319244e86373.svg" style="height: 50px;" type="image/svg+xml"> \begin{align} D_{ij}y_t&amp;=\frac{\partial(\sum_{j=1}^{N}W_{t,j}x_{j}+b_t)}{\partial W_{ij}} &amp;= \left\{\begin{matrix} x_j &amp; i = t\\ 0 &amp; i \ne t \end{matrix}\right. \end{align*}</object> <p>This goes into row <em>t</em>, column <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef7b2d987af3c0ceb75381d096c35e8c19085642.svg" style="height: 18px;" type="image/svg+xml">(i-1)N+j</object> in the Jacobian matrix. Overall, we get the following Jacobian matrix with shape [T,NT]:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/8a59a6251d12196f12eaadb6537289e3a6368d53.svg" style="height: 76px;" type="image/svg+xml"> $Dy=\begin{bmatrix} x_1 &amp; x_2 &amp; \cdots &amp; x_N &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ \vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\ 0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_1 &amp; x_2 &amp; \cdots &amp; x_N \end{bmatrix}$</object> <p>Now we're ready to finally multiply the Jacobians together to complete the chain rule:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2e9823350972d4874d201a0f8232d89fea710c6f.svg" style="height: 18px;" type="image/svg+xml"> $D(L \circ y)(W)=DL(y(W)) \cdot Dy(W)$</object> <p>The left-hand side is this row vector:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/af6d7af820f16b7493d378e8a40daa87031591f4.svg" style="height: 41px;" type="image/svg+xml"> $DL(y)=(\frac{\partial L}{\partial y_1}, \frac{\partial L}{\partial y_2},\cdots,\frac{\partial L}{\partial y_T})$</object> <p>And we're multiplying it by the matrix <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/baf7d8b700759b28ece347bd62793400ef52a8e0.svg" style="height: 16px;" type="image/svg+xml">Dy</object> shown above. Each item in the result vector will be a dot product between <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/992673c682a388ea8231ebbd8ea28c9cecae874d.svg" style="height: 18px;" type="image/svg+xml">DL(y)</object> and the corresponding column in the matrix <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/baf7d8b700759b28ece347bd62793400ef52a8e0.svg" style="height: 16px;" type="image/svg+xml">Dy</object>. Since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/baf7d8b700759b28ece347bd62793400ef52a8e0.svg" style="height: 16px;" type="image/svg+xml">Dy</object> has a single non-zero element in each column, the result is fairly trivial. The first <em>N</em> entries are:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/d0e87e4d5cd9feb9d93d153733a182426d175d7e.svg" style="height: 41px;" type="image/svg+xml"> $\frac{\partial L}{\partial y_1}x_1, \frac{\partial L}{\partial y_1}x_2, \cdots, \frac{\partial L}{\partial y_1}x_N$</object> <p>The next <em>N</em> entries are:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/5bd4d8b6b4eb817b071dbf4ddda71680f4bf0392.svg" style="height: 41px;" type="image/svg+xml"> $\frac{\partial L}{\partial y_2}x_1, \frac{\partial L}{\partial y_2}x_2, \cdots, \frac{\partial L}{\partial y_2}x_N$</object> <p>And so on, until the last (<em>T</em>-th) set of <em>N</em> entries is all <em>x</em>-es multiplied by <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/b44681f2ca721dae2b24a49d88f01463e3a88e50.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial L}{\partial y_T}</object>.</p> <p>To better see how to apply each derivative to a corresponding element in <em>W</em>, we can &quot;re-roll&quot; this result back into a matrix of shape [T,N]:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7cfccbaaa844f8ae994f8e012f12557919927e31.svg" style="height: 129px;" type="image/svg+xml"> $\frac{\partial{L}}{\partial{W}}=D(L\circ y)(W)=\begin{bmatrix} \frac{\partial L}{\partial y_1}x_1 &amp; \frac{\partial L}{\partial y_1}x_2 &amp; \cdots &amp; \frac{\partial L}{\partial y_1}x_N \\ \\ \frac{\partial L}{\partial y_2}x_1 &amp; \frac{\partial L}{\partial y_2}x_2 &amp; \cdots &amp; \frac{\partial L}{\partial y_2}x_N \\ \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\ \frac{\partial L}{\partial y_T}x_1 &amp; \frac{\partial L}{\partial y_T}x_2 &amp; \cdots &amp; \frac{\partial L}{\partial y_T}x_N \end{bmatrix}$</object> </div> <div class="section" id="computational-cost-and-shortcut"> <h2>Computational cost and shortcut</h2> <p>While the derivation shown above is complete and mathematically correct, it can also be computationally intensive; in realistic scenarios, the full Jacobian matrix can be <em>really</em> large. For example, let's say our input is a (modestly sized) 128x128 image, so <em>N=16,384</em>. Let's also say that <em>T=100</em>. The weight matrix then has <em>NT=1,638,400</em> elements; respectably big, but nothing out of the ordinary.</p> <p>Now consider the size of the full Jacobian matrix: it's <em>T</em> by <em>NT</em>, or over 160 million elements. At 4 bytes per element that's more than half a GiB!</p> <p>Moreover, to compute every backpropagation we'd be forced to multiply this full Jacobian matrix by a 100-dimensional vector, performing 160 million multiply-and-add operations for the dot products. That's a lot of compute.</p> <p>But the final result <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/89da23923a43fcd95b185bebb6fd362b6d1ac695.svg" style="height: 18px;" type="image/svg+xml">D(L\circ y)(W)</object> is the size of <em>W</em> - 1.6 million elements. Do we really need 160 million computations to get to it? No, because the Jacobian is very <em>sparse</em> - most of it is zeros. And in fact, when we look at the <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/89da23923a43fcd95b185bebb6fd362b6d1ac695.svg" style="height: 18px;" type="image/svg+xml">D(L\circ y)(W)</object> found above - it's fairly straightforward to compute using a single multiplication per element.</p> <p>Moreover, if we stare at the <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/33d2709b664fdd69317758b433b61b13c1cdc62f.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{W}}</object> matrix a bit, we'll notice it has a familiar pattern: this is just the <a class="reference external" href="https://en.wikipedia.org/wiki/Outer_product">outer product</a> between the vectors <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/9a5154f5e8d64cc77db745d8d3baa723bc6df829.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial{L}}{\partial{y}}</object> and <em>x</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/89bdcdf27feb489a5e3cb1bb8adc7faffcf0207d.svg" style="height: 41px;" type="image/svg+xml"> $\frac{\partial L}{\partial W}=\frac{\partial L}{\partial y}\otimes x$</object> <p>If we have to compute this backpropagation in Python/Numpy, we'll likely write code similar to:</p> <div class="highlight"><pre><span></span><span class="c1"># Assuming dy (gradient of loss w.r.t. y) and x are column vectors, by</span> <span class="c1"># performing a dot product between dy (column) and x.T (row) we get the</span> <span class="c1"># outer product.</span> <span class="n">dW</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">T</span><span class="p">)</span> </pre></div> </div> <div class="section" id="bias-gradient"> <h2>Bias gradient</h2> <p>We've just seen how to compute weight gradients for a fully-connected layer. Computing the gradients for the bias vector is very similar, and a bit simpler.</p> <p>This is the chain rule equation applied to the bias vector:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/3aee48692aeafe9ffab07037ad374f4c803787a7.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial{L}}{\partial{b}}=D(L \circ y)(b)=DL(y(b)) \cdot Dy(b)$</object> <p>The shapes involved here are: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef9baa4141fed9b40c4f1b0ebf189e4d8d28badc.svg" style="height: 18px;" type="image/svg+xml">DL(y(b))</object> is still [1,T], because the number of elements in <em>y</em> remains <em>T</em>. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/966d0c5b07f027b02b0ca9eb418ed9ac12f63386.svg" style="height: 18px;" type="image/svg+xml">Dy(b)</object> has <em>T</em> inputs (bias elements) and <em>T</em> outputs (<em>y</em> elements), so its shape is [T,T]. Therefore, the shape of the gradient <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7640378fb78362268ffe48bf5d68a266211673e4.svg" style="height: 18px;" type="image/svg+xml">D(L \circ y)(b)</object> is [1,T].</p> <p>To see how we'd fill the Jacobian matrix <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/966d0c5b07f027b02b0ca9eb418ed9ac12f63386.svg" style="height: 18px;" type="image/svg+xml">Dy(b)</object>, let's go back to the formula for <em>y</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7190e002ac69968b674aecacfd5a8531ad9cd208.svg" style="height: 55px;" type="image/svg+xml"> $y_1=\sum_{j=1}^{N}W_{1,j}x_{j}+b_1$</object> <p>When derived by anything other than <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c7cd24d955e66b8fe5ce45ded69fd98da5c68ba8.svg" style="height: 17px;" type="image/svg+xml">b_1</object>, this would be 0; when derived by <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c7cd24d955e66b8fe5ce45ded69fd98da5c68ba8.svg" style="height: 17px;" type="image/svg+xml">b_1</object> the result is 1. The same applies to every other element of <em>y</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/0511363d44324e04aee704c9cc3094a4e8c8c108.svg" style="height: 44px;" type="image/svg+xml"> $\frac{\partial y_i}{\partial b_j}=\left\{\begin{matrix} 1 &amp; i=j \\ 0 &amp; i\neq j \end{matrix}\right$</object> <p>In matrix form, this is just an identity matrix with dimensions [T,T]. Therefore:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/a64bd506727621a3e78444f7e158769dae30f93b.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial{L}}{\partial{b}}=D(L \circ y)(b)=DL(y(b)) \cdot I =DL(y(b))$</object> <p>For a given element of <em>b</em>, its gradient is just the corresponding element in <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/f004c6bbe71887354e0aad67dd7cbe6650eb58e9.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial L}{\partial y}</object>.</p> </div> <div class="section" id="fully-connected-layer-for-a-batch-of-inputs"> <h2>Fully-connected layer for a batch of inputs</h2> <p>The derivation shown above applies to a FC layer with a single input vector <em>x</em> and a single output vector <em>y</em>. When we train models, we almost always try to do so in <em>batches</em> (or <em>mini-batches</em>) to better leverage the parallelism of modern hardware. So a more typical layer computation would be:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/c26a36b850dd7b2e4288e475a590f343ec3a18a3.svg" style="height: 15px;" type="image/svg+xml"> $Y=WX+b$</object> <p>Where the shape of <em>X</em> is [N,B]; <em>B</em> is the batch size, typically a not-too-large power of 2, like 32. <em>W</em> and <em>b</em> still have the same shapes, so the shape of <em>Y</em> is [T,B]. Each column in <em>X</em> is a new input vector (for a total of <em>B</em> vectors in a batch); a corresponding column in <em>Y</em> is the output.</p> <p>As before, given <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/d5de3c7d9e0e1bcb4f6c00ea06b4ad808d2ea998.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{Y}}</object>, our goal is to find <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/33d2709b664fdd69317758b433b61b13c1cdc62f.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{W}}</object> and <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/5f12a50803653cf2ee02135944343ec70506d31c.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{b}}</object>. While the end results are fairly simple and pretty much what you'd expect, I still want to go through the full Jacobian computation to show how to find the gradiends in a rigorous way.</p> <p>Starting with the weigths, the chain rule is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/0ffed6ae9645ea6bd0d02932e2f0ca20fb8e7bc6.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial{L}}{\partial{W}}=D(L \circ Y)(W)=DL(Y(W)) \cdot DY(W)$</object> <p>The dimensions are:</p> <ul class="simple"> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86485acc2c4461f7817626204bf6c9148dad9d87.svg" style="height: 18px;" type="image/svg+xml">DL(Y(W))</object>: [1,TB] because <em>Y</em> has <em>T</em> outputs for each input vector in the batch.</li> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/573b889d69d85759886840570c6970345209b332.svg" style="height: 18px;" type="image/svg+xml">DY(W)</object>: [TB,TN] since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/fe5551953a6c071c738578f2ebc316864078cc81.svg" style="height: 18px;" type="image/svg+xml">Y(W)</object> has <em>TB</em> outputs and <em>TN</em> inputs overall.</li> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b214435777e236879c609900ba7a118e9f0da022.svg" style="height: 18px;" type="image/svg+xml">D(L\circ Y)(W)</object>: [1,TN] same as in the batch-1 case, because the same weight matrix is used for all inputs in the batch.</li> </ul> <p>Also, we'll use the notation <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/5f40e2ad50a0eb5c2f5019c48563f9c6605f84b6.svg" style="height: 24px;" type="image/svg+xml">x_{i}^{[b]}</object> to talk about the <em>i</em>-th element in the <em>b</em>-th input vector <em>x</em> (out of a total of <em>B</em> such input vectors).</p> <p>With this in hand, let's see how the Jacobians look; starting with <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86485acc2c4461f7817626204bf6c9148dad9d87.svg" style="height: 18px;" type="image/svg+xml">DL(Y(W))</object>, it's the same as before except that we have to take the batch into account. Each batch element is independent of the others in loss computations, so we'll have:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/943d68b7dbda5009cbfd597b4e0fcc46748204a5.svg" style="height: 47px;" type="image/svg+xml"> $\frac{\partial L}{\partial y_{i}^{[b]}}$</object> <p>As the Jacobian element; how do we arrange them in a 1-dimensional vector with shape [1,TB]? We'll just have to agree on a linearization here - same as we did with <em>W</em> before. We'll go for row-major again, so in 1-D the array <em>Y</em> would be:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/13122f18953df14ebb5ff74f59441194d3adb445.svg" style="height: 26px;" type="image/svg+xml"> $Y=(y_{1}^{},y_{1}^{},\cdots,y_{1}^{[B]}, y_{2}^{},y_{2}^{},\cdots,y_{2}^{[B]},\cdots)$</object> <p>And so on for <em>T</em> elements. Therefore, the Jacobian of <em>L</em> w.r.t <em>Y</em> is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4e79e866dbfbe0d6b6de5f6617762bce00d5f61f.svg" style="height: 48px;" type="image/svg+xml"> $\frac{\partial L}{\partial Y}=( \frac{\partial L}{\partial y_{1}^{}}, \frac{\partial L}{\partial y_{1}^{}},\cdots, \frac{\partial L}{\partial y_{1}^{[B]}}, \frac{\partial L}{\partial y_{2}^{}}, \frac{\partial L}{\partial y_{2}^{}},\cdots, \frac{\partial L}{\partial y_{2}^{[B]}},\cdots)$</object> <p>To find <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/573b889d69d85759886840570c6970345209b332.svg" style="height: 18px;" type="image/svg+xml">DY(W)</object>, let's first see how to compute <em>Y</em>. The <em>i</em>-th element of <em>Y</em> for batch <em>b</em> is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2b9c88f44b9cbec2343ce49418ca3e17dd2e0946.svg" style="height: 55px;" type="image/svg+xml"> $y_{i}^{[b]}=\sum_{j=1}^{N}W_{i,j}x_{j}^{[b]}+b_i$</object> <p>Recall that the Jacobian <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/573b889d69d85759886840570c6970345209b332.svg" style="height: 18px;" type="image/svg+xml">DY(W)</object> now has shape [TB,TN]. Previously we had to unroll the [T,N] of the weight matrix into the rows. Now we'll also have to unrill the [T,B] of the output into the columns. As before, first all <em>b</em>-s for <em>t=1</em>, then all <em>b</em>-s for <em>t=2</em>, etc. If we carefully compute the derivative, we'll see that the Jacobian matrix has similar structure to the single-batch case, just with each line repeated <em>B</em> times for each of the batch elements:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4fe86285c11962226758ecfad2839b2ce6520d2d.svg" style="height: 291px;" type="image/svg+xml"> $DY(W)=\begin{bmatrix} x_{1}^{} &amp; x_{2}^{} &amp; \cdots &amp; x_{N}^{} &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ \\ x_{1}^{} &amp; x_{2}^{} &amp; \cdots &amp; x_{N}^{} &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ \vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\ x_{1}^{[B]} &amp; x_{2}^{[B]} &amp; \cdots &amp; x_{N}^{[B]} &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ \vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\ 0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_{1}^{} &amp; x_{2}^{} &amp; \cdots &amp; x_{N}^{} \\ \\ 0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_{1}^{} &amp; x_{2}^{} &amp; \cdots &amp; x_{N}^{} \\ \vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\ 0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_{1}^{[B]} &amp; x_{2}^{[B]} &amp; \cdots &amp; x_{N}^{[B]} \\ \end{bmatrix}$</object> <p>Multiplying the two Jacobians together we get the full gradient of <em>L</em> w.r.t. each element in the weight matrix. Where previously (in the non-batch case) we had:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/1c762a65e0003e82f7dc5108f23126989b64112b.svg" style="height: 42px;" type="image/svg+xml"> $\frac{\partial L}{\partial W_{ij}}=\frac{\partial L}{\partial y_i}x_j$</object> <p>Now, instead, we'll have:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/14fb5d857d1e43a9f692ad436c448a43c0ea041f.svg" style="height: 54px;" type="image/svg+xml"> $\frac{\partial L}{\partial W_{ij}}=\sum_{b=1}^{B}\frac{\partial L}{\partial y_{i}^{[b]}}x_{j}^{[b]}$</object> <p>Which makes total sense, since it's simply taking the loss gradient computed from each batch separately and adds them up. This aligns with our intuition of how gradient for a whole batch is computed - compute the gradient for each batch element separately and add up all the gradients <a class="footnote-reference" href="#id4" id="id2"></a>.</p> <p>As before, there's a clever way to express the final gradient using matrix operations. Note the sum across all batch elements when computing <object class="valign-m10" data="https://eli.thegreenplace.net/images/math/2d41c4c820515c93e916d32532b9bdc7012e8121.svg" style="height: 27px;" type="image/svg+xml">\frac{\partial L}{\partial W_{ij}}</object>. We can express this as the matrix multiplication:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7f6c176f28de451b2d67fcf7ebf238122de9a970.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial L}{\partial W}=\frac{\partial L}{\partial Y}\cdot X^T$</object> <p>This is a good place to recall the computation cost again. Previously we've seen that for a single-input case, the Jacobian can be extremely large ([T,NT] having about 160 million elements). In the batch case, the Jacobian would be even larger since its shape is [TB,NT]; with a reasonable batch of 32, it's something like 5-billion elements strong. It's good that we don't actually have to hold the full Jacobian in memory and have a shortcut way of computing the gradient.</p> </div> <div class="section" id="bias-gradient-for-a-batch"> <h2>Bias gradient for a batch</h2> <p>For the bias, we have:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/1228909398e9618223e551b8b1d394ac20d697f1.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial{L}}{\partial{b}}=D(L \circ Y)(b)=DL(Y(b)) \cdot DY(b)$</object> <p><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/40699cf4e67bde5205359e04102f7b0011dac800.svg" style="height: 18px;" type="image/svg+xml">DL(Y(b))</object> here has the shape [1,TB]; <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f260ada7edc55af13f145e5786803198a3452f1e.svg" style="height: 18px;" type="image/svg+xml">DY(b)</object> has the shape [TB,T]. Therefore, the shape of <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/5f12a50803653cf2ee02135944343ec70506d31c.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{b}}</object> is [1,T], as before.</p> <p>From the formula for computing <em>Y</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7190e002ac69968b674aecacfd5a8531ad9cd208.svg" style="height: 55px;" type="image/svg+xml"> $y_1=\sum_{j=1}^{N}W_{1,j}x_{j}+b_1$</object> <p>We get, for any batch <em>b</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7ad81febc17ec33d21c8fba2a2e6956a8b43e1ad.svg" style="height: 49px;" type="image/svg+xml"> $\frac{\partial y_{i}^{[b]}}{\partial b_j}=\left\{\begin{matrix} 1 &amp; i=j \\ 0 &amp; i\neq j \end{matrix}\right$</object> <p>So, whereas <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f260ada7edc55af13f145e5786803198a3452f1e.svg" style="height: 18px;" type="image/svg+xml">DY(b)</object> was an identity matrix in the no-batch case, here it looks like this:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4b8148fb1343ab283fa0f8b0cdb6f3723201df15.svg" style="height: 267px;" type="image/svg+xml"> $DY(b)=\begin{bmatrix} 1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ 1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ \vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\ 1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ 0 &amp; 1 &amp; 0 &amp; \cdots &amp; 0 \\ 0 &amp; 1 &amp; 0 &amp; \cdots &amp; 0 \\ \vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\ 0 &amp; 0 &amp; 0 &amp; \cdots &amp; 1 \\ 0 &amp; 0 &amp; 0 &amp; \cdots &amp; 1 \\ \vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\ 0 &amp; 0 &amp; 0 &amp; \cdots &amp; 1 \\ \end{bmatrix}$</object> <p>With <em>B</em> identical rows at a time, for a total of <em>TB</em> rows. Since <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/c7d9499ae5d7e1fc81bc540909deac668210911d.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial L}{\partial Y}</object> is the same as before, their matrix multiplication result has this in column <em>j</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/910c6d77b9356303c0a01896a09d62fe2963d8ac.svg" style="height: 57px;" type="image/svg+xml"> $\frac{\partial{L}}{\partial{b_j}}=\sum_{b=1}^{B}\frac{\partial L}{\partial y_{j}^{[b]}}$</object> <p>Which just means adding up the gradient effects from every batch element independently.</p> </div> <div class="section" id="addendum-gradient-w-r-t-x"> <h2>Addendum - gradient w.r.t. x</h2> <p>This post started by explaining that the parameters of a fully-connected layer we're usually looking to optimize are the weight matrix and bias. In most cases this is true; however, in some other cases we're actually interested in propagating a gradient through <em>x</em> - often when there are more layers before the fully-connected layer in question.</p> <p>Let's find the derivative <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/54869ab2743febebc22269d12572c77e057c816e.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{x}}</object>. The chain rule here is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/a52a21efbf7f0b992db127d495abddd618677709.svg" style="height: 38px;" type="image/svg+xml"> $\frac{\partial{L}}{\partial{x}}=D(L \circ y)(x)=DL(y(x)) \cdot Dy(x)$</object> <p>Dimensions: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2e9597ce1ffd09d94be733216c3f1c1b2ab5f33c.svg" style="height: 18px;" type="image/svg+xml">DL(y(x))</object> is [1, T] as before; <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a2bef37f23427154c47e53945043549039e36bcf.svg" style="height: 18px;" type="image/svg+xml">Dy(x)</object> has T outputs (elements of <em>y</em>) and N inputs (elements of <em>x</em>), so its dimensions are [T, N]. Therefore, the dimensions of <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/54869ab2743febebc22269d12572c77e057c816e.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{x}}</object> are [1, N].</p> <p>From:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7190e002ac69968b674aecacfd5a8531ad9cd208.svg" style="height: 55px;" type="image/svg+xml"> $y_1=\sum_{j=1}^{N}W_{1,j}x_{j}+b_1$</object> <p>We know that <object class="valign-m10" data="https://eli.thegreenplace.net/images/math/c209b0f19299fee08359f73898212bb0d0df8c30.svg" style="height: 28px;" type="image/svg+xml">\frac{\partial y_1}{\partial x_j}=W_{1,j}</object>. Generalizing this, we get <object class="valign-m10" data="https://eli.thegreenplace.net/images/math/6abadcb3365f09141a9cac088fdbb17418e75171.svg" style="height: 28px;" type="image/svg+xml">\frac{\partial y_i}{\partial x_j}=W_{i,j}</object>; in other words, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a2bef37f23427154c47e53945043549039e36bcf.svg" style="height: 18px;" type="image/svg+xml">Dy(x)</object> is just the weight matrix <em>W</em>. So <object class="valign-m8" data="https://eli.thegreenplace.net/images/math/e6c2d66d989f1abdb9e8b492e45f00be1ab2a21b.svg" style="height: 25px;" type="image/svg+xml">\frac{\partial{L}}{\partial{x_i}}</object> is the dot product of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2e9597ce1ffd09d94be733216c3f1c1b2ab5f33c.svg" style="height: 18px;" type="image/svg+xml">DL(y(x))</object> with the <em>i</em>-th column of <em>W</em>.</p> <p>Computationally, we can express this as follows:</p> <div class="highlight"><pre><span></span><span class="n">dx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">dy</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">W</span><span class="p">)</span><span class="o">.</span><span class="n">T</span> </pre></div> <p>Again, recall that our vectors are <em>column</em> vectors. Therefore, to multiply <em>dy</em> from the left by <em>W</em> we have to transpose it to a row vector first. The result of this matrix multiplication is a [1, N] row-vector, so we transpose it again to get a column.</p> <p>An alternative method to compute this would transpose <em>W</em> rather than <em>dy</em> and then swap the order:</p> <div class="highlight"><pre><span></span><span class="n">dx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">W</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">dy</span><span class="p">)</span> </pre></div> <p>These two methods produce exactly the same <em>dx</em>; it's important to be familiar with these tricks, because otherwise it may be confusing to see a transposed <em>W</em> when we expect the actual <em>W</em> from gradient computations.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id3" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td><p class="first">As explained in the <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">softmax post</a>, we <em>linearize</em> the 2D matrix <em>W</em> into a single vector with <em>NT</em> elements using some approach like row-major, where the <em>N</em> elements of the first row go first, then the <em>N</em> elements of the second row, and so on until we have <em>NT</em> elements for all the rows.</p> <p class="last">This is a fully general approach as we can linearize any-dimensional arrays. To work with Jacobians, we're interested in <em>K</em> inputs, no matter where they came from - they could be a linearization of a 4D array. As long as we remember which element out of the <em>K</em> corresponds to which original element, we'll be fine.</p> </td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>In some cases you may hear about <em>averaging</em> the gradients across the batch. Averaging just means dividing the sum by <em>B</em>; it's a constant factor that can be consolidated into the learning rate.</td></tr> </tbody> </table> </div> Depthwise separable convolutions for machine learning2018-04-04T06:21:00-07:002018-04-04T06:21:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-04-04:/2018/depthwise-separable-convolutions-for-machine-learning/<p>Convolutions are an important tool in modern deep neural networks (DNNs). This post is going to discuss some common types of convolutions, specifically regular and depthwise separable convolutions. My focus will be on the implementation of these operation, showing from-scratch Numpy-based code to compute them and diagrams that explain how …</p><p>Convolutions are an important tool in modern deep neural networks (DNNs). This post is going to discuss some common types of convolutions, specifically regular and depthwise separable convolutions. My focus will be on the implementation of these operation, showing from-scratch Numpy-based code to compute them and diagrams that explain how things work.</p> <p>Note that my main goal here is to explain how depthwise separable convolutions differ from regular ones; if you're completely new to convolutions I suggest reading some more introductory resources first.</p> <p>The code here is compatible with TensorFlow's definition of convolutions in the <a class="reference external" href="https://www.tensorflow.org/api_docs/python/tf/nn">tf.nn</a> module. After reading this post, the documentation of TensorFlow's convolution ops should be easy to decipher.</p> <div class="section" id="basic-2d-convolution"> <h2>Basic 2D convolution</h2> <p>The basic idea behind a 2D convolution is sliding a small window (usually called a &quot;filter&quot;) over a larger 2D array, and performing a dot product between the filter elements and the corresponding input array elements at every position.</p> <p>Here's a diagram demonstrating the application of a 3x3 convolution filter to a 6x6 array, in 3 different positions. <tt class="docutils literal">W</tt> is the filter, and the yellow-ish array on the right is the result; the red square shows which element in the result array is being computed.</p> <object class="align-center" data="https://eli.thegreenplace.net/images/2018/conv2d-single-block.svg" style="width: 400px;" type="image/svg+xml"> Single-channel 2D convolution</object> <p>The topmost diagram shows the important concept of <em>padding</em>: what should we do when the window goes &quot;out of bounds&quot; on the input array. There are several options, with the following two being most common in DNNs:</p> <ul class="simple"> <li><em>Valid</em> padding: in which only valid, in-bounds windows are considered. This also makes the output smaller than the input, because border elements can't be in the center of a filter (unless the filter is 1x1).</li> <li><em>Same</em> padding: in which we assume there's some constant value outside the bounds of the input (usually 0) and the filter is applied to every element. In this case the output array has the same size as the input array. The diagrams above depict same padding, which I'll keep using throughout the post.</li> </ul> <p>There are other options for the basic 2D convolution case. For example, the filter can be moving over the input in jumps of more than 1, thus not centering on all elements. This is called <em>stride</em>, and in this post I'm always using stride of 1. Convolutions can also be dilated (or <em>atrous</em>), wherein the filter is expanded with gaps between every element. In this post I'm not going to discuss dilated convolutions and other options - there are plenty of resources on these topics online.</p> </div> <div class="section" id="implementing-the-2d-convolution"> <h2>Implementing the 2D convolution</h2> <p>Here is a full Python implementation of the simple 2D convolution. It's called &quot;single channel&quot; to distinguish it from the more general case in which the input has more than two dimensions; we'll get to that shortly.</p> <p>This implementation is fully self-contained, and only needs Numpy to work. All the loops are fully explicit - I specifically avoided vectorizing them for efficiency to maintain clarity:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">conv2d_single_channel</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Two-dimensional convolution of a single channel.</span> <span class="sd"> Uses SAME padding with 0s, a stride of 1 and no dilation.</span> <span class="sd"> input: input array with shape (height, width)</span> <span class="sd"> w: filter array with shape (fd, fd) with odd fd.</span> <span class="sd"> Returns a result with the same shape as input.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="k">assert</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">and</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span> <span class="c1"># SAME padding with zeros: creating a new padded array to simplify index</span> <span class="c1"># calculations and to avoid checking boundary conditions in the inner loop.</span> <span class="c1"># padded_input is like input, but padded on all sides with</span> <span class="c1"># half-the-filter-width of zeros.</span> <span class="n">padded_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">pad_width</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;constant&#39;</span><span class="p">,</span> <span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> <span class="c1"># This inner double loop computes every output element, by</span> <span class="c1"># multiplying the corresponding window into the input with the</span> <span class="c1"># filter.</span> <span class="k">for</span> <span class="n">fi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="k">for</span> <span class="n">fj</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">padded_input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">fi</span><span class="p">,</span> <span class="n">j</span> <span class="o">+</span> <span class="n">fj</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span><span class="p">[</span><span class="n">fi</span><span class="p">,</span> <span class="n">fj</span><span class="p">]</span> <span class="k">return</span> <span class="n">output</span> </pre></div> </div> <div class="section" id="convolutions-in-3-and-4-dimensions"> <h2>Convolutions in 3 and 4 dimensions</h2> <p>The convolution computed above works in two dimensions; yet, most convolutions used in DNNs are 4-dimensional. For example, TensorFlow's <tt class="docutils literal">tf.nn.conv2d</tt> op takes a 4D input tensor and a 4D filter tensor. How come?</p> <p>The two additional dimensions in the input tensor are <em>channel</em> and <em>batch</em>. A canonical example of channels is color images in RGB format. Each pixel has a value for red, green and blue - three channels overall. So instead of seeing it as a matrix of triples, we can see it as a 3D tensor where one dimension is height, another width and another channel (also called the <em>depth</em> dimension).</p> <p>Batch is somewhat different. ML training - with stochastic gradient descent - is often done in batches for performance; we train the model not on a single sample at a time, but a &quot;batch&quot; of samples, usually some power of two. Performing all the operations in tandem on a batch of data makes it easier to leverage the SIMD capabilities of modern processors. So it doesn't have any mathematical significance here - it can be seen as an outer loop over all operations, performing them for a set of inputs and producing a corresponding set of outputs.</p> <p>For filters, the 4 dimensions are height, width, input channel and output channel. Input channel is the same as the input tensor's; output channel collects multiple filters, each of which can be different.</p> <p>This can be slightly difficult to grasp from text, so here's a diagram:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/2018/conv2d-3d.svg" style="width: 300px;" type="image/svg+xml"> Multi-channel 2D convolution</object> <p>In the diagram and the implementation I'm going to ignore the batch dimension, since it's not really mathematically interesting. So the input image has three dimensions - in this diagram height and width are 8 and depth is 3. The filter is 3x3 with depth 3. In each step, the filter is slid over the input <em>in two dimensions</em>, and all of its elements are multiplied with the corresponding elements in the input. That's 3x3x3=27 multiplications added into the output element.</p> <p>Note that this is different from a 3D convolution, where a filter is moved across the input in all 3 dimensions; true 3D convolutions are not widely used in DNNs at this time.</p> <p>So, to reitarate, to compute the multi-channel convolution as shown in the diagram above, we compute each of the 64 output elements by a dot-product of the filter with the relevant parts of the input tensor. This produces a single output channel. To produce additional output channels, we perform the convolution with additional filters. So if our filter has dimensions (3, 3, 3, 4) this means 4 different 3x3x3 filters. The output will thus have dimensions 8x8 for the spatials and 4 for depth.</p> <p>Here's the Numpy implementation of this algorithm:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">conv2d_multi_channel</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Two-dimensional convolution with multiple channels.</span> <span class="sd"> Uses SAME padding with 0s, a stride of 1 and no dilation.</span> <span class="sd"> input: input array with shape (height, width, in_depth)</span> <span class="sd"> w: filter array with shape (fd, fd, in_depth, out_depth) with odd fd.</span> <span class="sd"> in_depth is the number of input channels, and has the be the same as</span> <span class="sd"> input&#39;s in_depth; out_depth is the number of output channels.</span> <span class="sd"> Returns a result with shape (height, width, out_depth).</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="k">assert</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">and</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span> <span class="n">padw</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span> <span class="n">padded_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">pad_width</span><span class="o">=</span><span class="p">((</span><span class="n">padw</span><span class="p">,</span> <span class="n">padw</span><span class="p">),</span> <span class="p">(</span><span class="n">padw</span><span class="p">,</span> <span class="n">padw</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)),</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;constant&#39;</span><span class="p">,</span> <span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">in_depth</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span> <span class="k">assert</span> <span class="n">in_depth</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="n">out_depth</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">out_depth</span><span class="p">))</span> <span class="k">for</span> <span class="n">out_c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_depth</span><span class="p">):</span> <span class="c1"># For each output channel, perform 2d convolution summed across all</span> <span class="c1"># input channels.</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">height</span><span class="p">):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">width</span><span class="p">):</span> <span class="c1"># Now the inner loop also works across all input channels.</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">in_depth</span><span class="p">):</span> <span class="k">for</span> <span class="n">fi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="k">for</span> <span class="n">fj</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> <span class="n">w_element</span> <span class="o">=</span> <span class="n">w</span><span class="p">[</span><span class="n">fi</span><span class="p">,</span> <span class="n">fj</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">out_c</span><span class="p">]</span> <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">out_c</span><span class="p">]</span> <span class="o">+=</span> <span class="p">(</span> <span class="n">padded_input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">fi</span><span class="p">,</span> <span class="n">j</span> <span class="o">+</span> <span class="n">fj</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="o">*</span> <span class="n">w_element</span><span class="p">)</span> <span class="k">return</span> <span class="n">output</span> </pre></div> <p>An interesting point to note here w.r.t. TensorFlow's <tt class="docutils literal">tf.nn.conv2d</tt> op. If you read its semantics you'll see discussion of <em>layout</em> or <em>data format</em>, which is <tt class="docutils literal">NHWC</tt> by default. NHWC simply means the order of dimensions in a 4D tensor is:</p> <ul class="simple"> <li><strong>N</strong>: batch</li> <li><strong>H</strong>: height (spatial dimension)</li> <li><strong>W</strong>: width (spatial dimension)</li> <li><strong>C</strong>: channel (depth)</li> </ul> <p><tt class="docutils literal">NHWC</tt> is the default layout for TensorFlow; another commonly used layout is <tt class="docutils literal">NCHW</tt>, because it's the format preferred by NVIDIA's DNN libraries. The code samples here follow the default.</p> </div> <div class="section" id="depthwise-convolution"> <h2>Depthwise convolution</h2> <p>Depthwise convolutions are a variation on the operation discussed so far. In the regular 2D convolution performed over multiple input channels, the filter is as deep as the input and lets us freely mix channels to generate each element in the output. Depthwise convolutions don't do that - each channel is kept separate - hence the name <em>depthwise</em>. Here's a diagram to help explain how that works:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/2018/conv2d-depthwise.svg" style="width: 500px;" type="image/svg+xml"> Depthwise 2D convolution</object> <p>There are three conceptual stages here:</p> <ol class="arabic simple"> <li>Split the input into channels, and split the filter into channels (the number of channels between input and filter must match).</li> <li>For each of the channels, convolve the input with the corresponding filter, producing an output tensor (2D).</li> <li>Stack the output tensors back together.</li> </ol> <p>Here's the code implementing it:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">depthwise_conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Two-dimensional depthwise convolution.</span> <span class="sd"> Uses SAME padding with 0s, a stride of 1 and no dilation. A single output</span> <span class="sd"> channel is used per input channel (channel_multiplier=1).</span> <span class="sd"> input: input array with shape (height, width, in_depth)</span> <span class="sd"> w: filter array with shape (fd, fd, in_depth)</span> <span class="sd"> Returns a result with shape (height, width, in_depth).</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="k">assert</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">and</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span> <span class="n">padw</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span> <span class="n">padded_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">pad_width</span><span class="o">=</span><span class="p">((</span><span class="n">padw</span><span class="p">,</span> <span class="n">padw</span><span class="p">),</span> <span class="p">(</span><span class="n">padw</span><span class="p">,</span> <span class="n">padw</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)),</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;constant&#39;</span><span class="p">,</span> <span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">in_depth</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span> <span class="k">assert</span> <span class="n">in_depth</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">in_depth</span><span class="p">))</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">in_depth</span><span class="p">):</span> <span class="c1"># For each input channel separately, apply its corresponsing filter</span> <span class="c1"># to the input.</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">height</span><span class="p">):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">width</span><span class="p">):</span> <span class="k">for</span> <span class="n">fi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="k">for</span> <span class="n">fj</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> <span class="n">w_element</span> <span class="o">=</span> <span class="n">w</span><span class="p">[</span><span class="n">fi</span><span class="p">,</span> <span class="n">fj</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="o">+=</span> <span class="p">(</span> <span class="n">padded_input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">fi</span><span class="p">,</span> <span class="n">j</span> <span class="o">+</span> <span class="n">fj</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="o">*</span> <span class="n">w_element</span><span class="p">)</span> <span class="k">return</span> <span class="n">output</span> </pre></div> <p>In TensorFlow, the corresponding op is <tt class="docutils literal">tf.nn.depthwise_conv2d</tt>; this op has the notion of <em>channel multiplier</em> which lets us compute multiple outputs for each input channel (somewhat like the number of output channels concept in <tt class="docutils literal">conv2d</tt>).</p> </div> <div class="section" id="depthwise-separable-convolution"> <h2>Depthwise separable convolution</h2> <p>The depthwise convolution shown above is more commonly used in combination with an additional step to mix in the channels - <em>depthwise separable convolution</em> <a class="footnote-reference" href="#id2" id="id1"></a>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/2018/conv2d-depthwise-separable.svg" style="width: 500px;" type="image/svg+xml"> Depthwise separable convolution</object> <p>After completing the depthwise convolution, and additional step is performed: a 1x1 convolution across channels. This is exactly the same operation as the &quot;convolution in 3 dimensions discussed earlier&quot; - just with a 1x1 spatial filter. This step can be repeated multiple times for different output channels. The output channels all take the output of the depthwise step and mix it up with different 1x1 convolutions. Here's the implementation:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">separable_conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w_depth</span><span class="p">,</span> <span class="n">w_pointwise</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Depthwise separable convolution.</span> <span class="sd"> Performs 2d depthwise convolution with w_depth, and then applies a pointwise</span> <span class="sd"> 1x1 convolution with w_pointwise on the result.</span> <span class="sd"> Uses SAME padding with 0s, a stride of 1 and no dilation. A single output</span> <span class="sd"> channel is used per input channel (channel_multiplier=1) in w_depth.</span> <span class="sd"> input: input array with shape (height, width, in_depth)</span> <span class="sd"> w_depth: depthwise filter array with shape (fd, fd, in_depth)</span> <span class="sd"> w_pointwise: pointwise filter array with shape (in_depth, out_depth)</span> <span class="sd"> Returns a result with shape (height, width, out_depth).</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="c1"># First run the depthwise convolution. Its result has the same shape as</span> <span class="c1"># input.</span> <span class="n">depthwise_result</span> <span class="o">=</span> <span class="n">depthwise_conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w_depth</span><span class="p">)</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">in_depth</span> <span class="o">=</span> <span class="n">depthwise_result</span><span class="o">.</span><span class="n">shape</span> <span class="k">assert</span> <span class="n">in_depth</span> <span class="o">==</span> <span class="n">w_pointwise</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">out_depth</span> <span class="o">=</span> <span class="n">w_pointwise</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">out_depth</span><span class="p">))</span> <span class="k">for</span> <span class="n">out_c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_depth</span><span class="p">):</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">height</span><span class="p">):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">width</span><span class="p">):</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">in_depth</span><span class="p">):</span> <span class="n">w_element</span> <span class="o">=</span> <span class="n">w_pointwise</span><span class="p">[</span><span class="n">c</span><span class="p">,</span> <span class="n">out_c</span><span class="p">]</span> <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">out_c</span><span class="p">]</span> <span class="o">+=</span> <span class="n">depthwise_result</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="o">*</span> <span class="n">w_element</span> <span class="k">return</span> <span class="n">output</span> </pre></div> <p>In TensorFlow, this op is called <tt class="docutils literal">tf.nn.separable_conv2d</tt>. Similarly to our implementation it takes two different filter parameters: <tt class="docutils literal">depthwise_filter</tt> for the depthwise step and <tt class="docutils literal">pointwise_filter</tt> for the mixing step.</p> <p>Depthwise separable convolutions have become popular in DNN models recently, for two reasons:</p> <ol class="arabic simple"> <li>They have fewer parameters than &quot;regular&quot; convolutional layers, and thus are less prone to overfitting.</li> <li>With fewer parameters, they also require less operations to compute, and thus are cheaper and faster.</li> </ol> <p>Let's examine the difference between the number of parameters first. We'll start with some definitions:</p> <ul class="simple"> <li><tt class="docutils literal">S</tt>: spatial dimension - width and height, assuming square inputs.</li> <li><tt class="docutils literal">F</tt>: filter width and height, assuming square filter.</li> <li><tt class="docutils literal">inC</tt>: number of input channels.</li> <li><tt class="docutils literal">outC</tt>: number of output channels.</li> </ul> <p>We also assume <tt class="docutils literal">SAME</tt> padding as discussed above, so that the spatial size of the output matches the input.</p> <p>In a regular convolution there are <tt class="docutils literal">F*F*inC*outC</tt> parameters, because every filter is 3D and there's one such filter per output channel.</p> <p>In depthwise separable convolutions there are <tt class="docutils literal">F*F*inC</tt> parameters for the depthwise part, and then <tt class="docutils literal">inC*outC</tt> parameters for the mixing part. It should be obvious that for a non-trivial <tt class="docutils literal">outC</tt>, the sum of these two is significanly smaller than <tt class="docutils literal">F*F*inC*outC</tt>.</p> <p>Now on to computational cost. For a regular convolution, we perform <tt class="docutils literal">F*F*inC</tt> operations at each position of the input (to compute the 2D convolution over 3 dimensions). For the whole input, the number of computations is thus <tt class="docutils literal">F*F*inC*S*S</tt> and taking all the output channels we get <tt class="docutils literal">F*F*inC*S*S*outC</tt>.</p> <p>For depthwise separable convolutions we need <tt class="docutils literal">F*F*inC*S*S*</tt> operations for the depthwise part; then we need <tt class="docutils literal">S*S*inC*outC</tt> operations for the mixing part. Let's use some real numbers to get a feel for the difference:</p> <p>We'll assume <tt class="docutils literal">S=128</tt>, <tt class="docutils literal">F=3</tt>, <tt class="docutils literal">inC=3</tt>, <tt class="docutils literal">outC=16</tt>. For regular convolution:</p> <ul class="simple"> <li>Parameters: <tt class="docutils literal">3*3*3*16 = 432</tt></li> <li>Computation cost: <tt class="docutils literal">3*3*3*128*128*16 = ~7e6</tt></li> </ul> <p>For depthwise separable convolution:</p> <ul class="simple"> <li>Parameters: <tt class="docutils literal">3*3*3+3*16 = 75</tt></li> <li>Computation cost: <tt class="docutils literal">3*3*3*128*128+128*128*3*16 = ~1.2e6</tt></li> </ul> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id2" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>The term <em>separable</em> comes from image processing, where <em>spatially separable convolutions</em> are sometimes used to save on computation resources. A spatial convolution is separable when the 2D convolution filter can be expressed as an outer product of two vectors. This lets us compute some 2D convolutions more cheaply. In the case of DNNs, the spatial filter is not necessarily separable but the channel dimension is separable from the spatial dimensions.</td></tr> </tbody> </table> </div> The Confusion Matrix in statistical tests2018-03-26T05:47:00-07:002018-03-26T05:47:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-03-26:/2018/the-confusion-matrix-in-statistical-tests/<p>This winter was one of the worst flu seasons in recent years, so I found myself curious to learn more about the diagnostic flu tests available to doctors in addition to the usual &quot;looks like bad cold but no signs of bacteria&quot; strategy. There's a wide array of RIDTs (Rapid …</p><p>This winter was one of the worst flu seasons in recent years, so I found myself curious to learn more about the diagnostic flu tests available to doctors in addition to the usual &quot;looks like bad cold but no signs of bacteria&quot; strategy. There's a wide array of RIDTs (Rapid Influenza Dignostic Tests) available to doctors today <a class="footnote-reference" href="#id3" id="id1"></a>, and reading through literature quickly gets you to decipher statements like:</p> <blockquote> Overall, RIDTs had a modest <em>sensitivity</em> of 62.3% and a high <em>specificity</em> of 98.2%, corresponding to a <em>positive likelihood ratio</em> of 34.5 and a <em>negative likelihood ratio of 0.38</em>. For the clinician, this means that although <em>false-negatives</em> are frequent (occurring in nearly four out of ten negative RIDTs), a positive test is unlikely to be a <em>false-positive</em> result. A diagnosis of influenza can thus confidently be made in the presence of a positive RIDT. However, a negative RIDT result is unreliable and should be confirmed by traditional diagnostic tests if the result is likely to affect patient management.</blockquote> <p>While I heard about statistical test quality measures like <em>sensitivity</em> before, there are too many terms here to remember for someone not dealing with these things routinely; this post is my attempt at documenting this understanding for future use.</p> <div class="section" id="a-table-of-test-outcomes"> <h2>A table of test outcomes</h2> <p>Let's say there is a condition with a binary outcome (&quot;yes&quot; vs. &quot;no&quot;, 1 vs 0, or whatever you want to call it). Suppose we conduct a test that is designed to detect this condition; the test also has a binary outcome. The totality of outcomes can thus be represented with a 2-by-2 table, which is also called the <a class="reference external" href="https://en.wikipedia.org/wiki/Confusion_matrix">Confusion Matrix</a>.</p> <p>Suppose 10000 patients get tested for flu; out of them, 9000 are actually healthy and 1000 are actually sick. For the sick people, a test was positive for 620 and negative for 380. For the healthy people, the same test was positive for 180 and negative for 8820. Let's summarize these results in a table:</p> <img alt="Confusion matrix with numbers only" class="align-center" src="https://eli.thegreenplace.net/images/2018/confusionmatrix.png" /> <p>Now comes our first batch of definitions.</p> <ul class="simple"> <li><strong>True Positive (TP)</strong>: positive test result matches reality - person is actually sick and tested positive.</li> <li><strong>False Positive (FP)</strong>: positive test result doesn't match reality - test is positive but the person is not actually sick.</li> <li><strong>True Negative (TN)</strong>: negative test result matches reality - person is not sick and tested negative.</li> <li><strong>False Negative (FN)</strong>: negative test result doesn't match reality - test is negative but the person is actually sick.</li> </ul> <p>Folks get confused with these often, so here's a useful heuristic: positive vs. negative reflects the test outcome; true vs. false reflects whether the test got it right or got it wrong.</p> <p>Since the rest of the definitions build upon these, here's the confusion matrix again now with them embedded:</p> <img alt="Confusion matrix with TP, FP, TN, FN marked" class="align-center" src="https://eli.thegreenplace.net/images/2018/confusionmatrix-tptnfpfn.png" /> </div> <div class="section" id="definition-soup"> <h2>Definition soup</h2> <p>Armed with these and <strong>N</strong> for the <em>total population</em> (10000 in our case), we are now ready to tackle the multitude of definitions statisticians have produced over the years to describe the performance of tests:</p> <ul class="simple"> <li><strong>Prevalence</strong>: how common is the actual disease in the population<ul> <li>(FN+TP)/N</li> <li>In the example: (380+620)/10000=0.1</li> </ul> </li> <li><strong>Accuracy</strong>: how often is the test correct<ul> <li>(TP+TN)/N</li> <li>In the example: (620+8820)/10000=0.944</li> </ul> </li> <li><strong>Misclassification rate</strong>: how often the test is wrong<ul> <li>1 - Accuracy = (FP+FN)/N</li> <li>In the example: (180+380)/10000=0.056</li> </ul> </li> <li><strong>Sensitivity</strong> or <strong>True Positive Rate (TPR)</strong> or <strong>Recall</strong>: when the patient is sick, how often does the test actually predict it correctly<ul> <li>TP/(TP+FN)</li> <li>In the example: 620/(620+380)=0.62</li> </ul> </li> <li><strong>Specificity</strong> or <strong>True Negative Rate (TNR)</strong>: when the patient is not sick, how often does the test actually predict it correctly<ul> <li>TN/(TN+FP)</li> <li>In the example: 8820/(8820+180)=0.98</li> </ul> </li> <li><strong>False Positive Rate (FPR)</strong>: probability of false alarm<ul> <li>1 - Specificity = FP/(TN+FP)</li> <li>In the example: 180/(8820+180)=0.02</li> </ul> </li> <li><strong>False Negative Rage (FNR)</strong>: miss rate, probability of missing a sickness with a test<ul> <li>1 - Sensitivity = FN/(TP+FN)</li> <li>In the example: 380/(620+380)=0.38</li> </ul> </li> <li><strong>Precision</strong> or <strong>Positive Predictive Value (PPV)</strong>: when the prediction is positive, how often is it correct<ul> <li>TP/(TP+FP)</li> <li>In the example: 620/(620+180)=0.775</li> </ul> </li> <li><strong>Negative Predictive Value (NPV)</strong>: when the prediction is negative, how often is it correct<ul> <li>TN/(TN+FN)</li> <li>In the example: 8820/(8820+380)=0.959</li> </ul> </li> <li><strong>Positive Likelihood Ratio</strong>: odds of a positive prediction given that the person is sick (used with odds formulations of probability)<ul> <li>TPR/FPR</li> <li>In the example: 0.62/0.02=31</li> </ul> </li> <li><strong>Negative Likelihood Ratio</strong>: odds of a positive prediction given that the person is not sick<ul> <li>FNR/TNR</li> <li>In the example: 0.38/0.98=0.388</li> </ul> </li> </ul> <p><a class="reference external" href="https://en.wikipedia.org/wiki/Confusion_matrix">The wikipedia page</a> has even more.</p> </div> <div class="section" id="deciphering-our-example"> <h2>Deciphering our example</h2> <p>Now back to the flu test example this post began with. RIDTs are said to have sensitivity of 62.3%; this is just a clever way of saying that for a person with flu, the test will be positive 62.3% of the time. For people who do not have the flu, the test is more accurate since its specificity is 98.2% - only 1.8% of healthy people will be flagged positive.</p> <p>The positive likelihood ratio is said to be 34.5; let's see how it was computed:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/85bd9dc7996bdc9198da6226a73cfb1e900734ea.svg" style="height: 41px;" type="image/svg+xml"> $PLR=\frac{TPR}{FPR}=\frac{Sensitivity}{1-Specificity}=\frac{0.623}{1-0.982}=35$</object> <p>This is to say - if the person is sick, odds are 35-to-1 that the test will be positive.</p> <p>And the negative likelihood ratio is said to be 0.38:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/895921ee70c59a531ceaff02c94404e5d5316697.svg" style="height: 41px;" type="image/svg+xml"> $NLR=\frac{FNR}{TNR}=\frac{1-Sensitivity}{Specificity}=\frac{1-0.623}{0.982}=0.38$</object> <p>This is to say - if the person is not sick, odds are 1-to-3 that the test will be positive.</p> <p>In other words, these flu tests are pretty good when a person is actually sick, but not great when the person is not sick. Which is exactly what the quoted paragraph at the top of the post ends up saying.</p> </div> <div class="section" id="back-to-bayes"> <h2>Back to Bayes</h2> <p>An astute reader will notice that the previous sections talk about the probability of test outcomes given sickness, when we're usually interested in the opposite - given a positive test, how likely is it that the person is actually sick.</p> <p><a class="reference external" href="https://eli.thegreenplace.net/2018/conditional-probability-and-bayes-theorem/">My previous post on the Bayes theorem</a> covered this issue in depth <a class="footnote-reference" href="#id4" id="id2"></a>. Let's recap, using the actual numbers from our example. The events are:</p> <ul class="simple"> <li><img alt="T" class="valign-0" src="https://eli.thegreenplace.net/images/math/c2c53d66948214258a26ca9ca845d7ac0c17f8e7.png" style="height: 12px;" />: test is positive</li> <li><object class="valign-0" data="https://eli.thegreenplace.net/images/math/0e4c77261e251cb98e8cedc2b74772ae6f14318d.svg" style="height: 15px;" type="image/svg+xml">T^C</object>: test is negative</li> <li><object class="valign-0" data="https://eli.thegreenplace.net/images/math/e69f20e9f683920d3fb4329abd951e878b1f9372.svg" style="height: 12px;" type="image/svg+xml">F</object>: person actually sick with flu</li> <li><object class="valign-0" data="https://eli.thegreenplace.net/images/math/9f37c126895eff99088c5545433d3e33692aa267.svg" style="height: 15px;" type="image/svg+xml">F^C</object>: person doesn't have flu</li> </ul> <p>Sensitivity of 0.623 means <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/27fd26004577f906dcaaefbf0553b7232c432d0a.svg" style="height: 18px;" type="image/svg+xml">P(T|F)=0.623</object>; similarly, specificity is <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/72e4f07dc70098feeb50032960ebe0478451e087.svg" style="height: 19px;" type="image/svg+xml">P(T^C|F^C)=0.982</object>. We're interested in finding <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b8a3711b3dc7ec3b1a73a82447c88ce067999fed.svg" style="height: 18px;" type="image/svg+xml">P(F|T)</object>, and we can use the Bayes theorem for that:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2d2d2b1a53f1b7a76155a9abe9b7b5d35198e669.svg" style="height: 42px;" type="image/svg+xml"> $P(F|T)=\frac{P(T|F)P(F)}{P(T)}$</object> <p>Recall that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ce2db3a1be8bbfefc6a7bdee94dcaca4f5426799.svg" style="height: 18px;" type="image/svg+xml">P(F)</object> is the <em>prevalence</em> of flu in the general population; for the sake of this example let's assume it's 0.1; we'll then compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9d3d6e3b10a97d19adafbea8cc72b8e3619a1d27.svg" style="height: 18px;" type="image/svg+xml">P(T)</object> by using the law of total probability as follows:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7b37bf70ff6ae37f99ca38df902a48adbf14a542.svg" style="height: 21px;" type="image/svg+xml"> $P(T)=P(T|F)P(F)+P(T|F^C)P(F^C)$</object> <p>Obviously, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a7828371e81491815c4498bc5dcb7407b63beeed.svg" style="height: 19px;" type="image/svg+xml">P(T|F^C)=1-P(T^C|F^C)=0.018</object>, so:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/9f16b641c00488170b04f5bb409a5d0d21c958c8.svg" style="height: 18px;" type="image/svg+xml"> $P(T)=0.623\ast0.1 + 0.018\ast0.9=0.0785$</object> <p>And then:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/41b1c0bec6fa21e5322dd66a756d5e9d9e7253f3.svg" style="height: 36px;" type="image/svg+xml"> $P(F|T)=\frac{0.623\ast 0.1}{0.0785}=0.79$</object> <p>So the probability of having flu given a positive test and a 10% flu prevalence is 79%. The prevalence strongly affects the outcome! Let's plot <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b8a3711b3dc7ec3b1a73a82447c88ce067999fed.svg" style="height: 18px;" type="image/svg+xml">P(F|T)</object> as a function of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ce2db3a1be8bbfefc6a7bdee94dcaca4f5426799.svg" style="height: 18px;" type="image/svg+xml">P(F)</object> for some reasonable range of values:</p> <img alt="P(F|T) as function of prevalence" class="align-center" src="https://eli.thegreenplace.net/images/2018/pft-prevalence-plot.png" /> <p>Note how low the value of the test becomes with low disease prevalence - we've also observed this phenomenon in <a class="reference external" href="https://eli.thegreenplace.net/2018/conditional-probability-and-bayes-theorem/">the previous post</a>; there's a &quot;tug of war&quot; between the prevalence and the test's sensitivity and specificity. In fact, <a class="reference external" href="https://www.cdc.gov/flu/professionals/diagnosis/rapidlab.htm">the official CDC guidelines page</a> for interpreting RIDT results discusses this:</p> <blockquote> When influenza prevalence is relatively low, the positive predictive value (PPV) is low and false-positive test results are more likely. By contrast, when influenza prevalence is low, the negative predictive value (NPV) is high, and negative results are more likely to be true.</blockquote> <p>And then goes on to present a handy table for estimating PPV based on prevalence and specificity.</p> <p>Naturally, the rapid test is not the only tool in the doctor's toolbox. Flu has other symptoms, and by observing them on the patient the doctor can increase their confidence in the diagnosis. For example, if the probability <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b8a3711b3dc7ec3b1a73a82447c88ce067999fed.svg" style="height: 18px;" type="image/svg+xml">P(F|T)</object> given 10% prevalence is 0.79 (as computed above), the doctor may be significantly less sure of the results if flu symptoms like cough and fever are not demonstrated, etc. The CDC discusses this in more detail with an <a class="reference external" href="https://www.cdc.gov/flu/professionals/diagnosis/algorithm-results-not-circulating.htm">algorithm for interepreting flu results</a>.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id3" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>Slower tests like full viral cultures are also available, and they are very accurate. The problem is that these tests take a long time to complete - days - so they're usually not very useful in treating the disease. Anti-viral medication is only useful in the first 48 hours after disease onset. RIDTs provide results within hours, or even minutes.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>In that post we didn't distinguish between sensitivity and specificity, but assumed they're equal at 90%. It's much more common for these measures to be different, but it doesn't actually complicate the computations.</td></tr> </tbody> </table> </div> Conditional probability and Bayes' theorem2018-03-13T05:32:00-07:002018-03-13T05:32:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-03-13:/2018/conditional-probability-and-bayes-theorem/<p>One morning, while seeing a mention of a disease on Hacker News, Bob decides on a whim to get tested for it; there are no other symptoms, he's just curious. He convinces his doctor to order a blood test, which is known to be 90% accurate. For 9 out of …</p><p>One morning, while seeing a mention of a disease on Hacker News, Bob decides on a whim to get tested for it; there are no other symptoms, he's just curious. He convinces his doctor to order a blood test, which is known to be 90% accurate. For 9 out of 10 sick people it will detect the disease (but for 1 out of 10 it won't); similarly, for 9 out of 10 healthy people it will report no disease (but for 1 out of 10 it will).</p> <p>Unfortunatly for Bob, his test is positive; what's the probability that Bob actually has the disease?</p> <p>You might be tempted to say 90%, but this is wrong. One of the most common fallacies made in probability and statistics is mixing up conditional probabilities. Given event D - &quot;Bob has disease&quot; and event T - &quot;test was positive&quot;, we want to know what is <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/28f810c7c292ba7faa5b47a0b4e0d470f79b19d8.svg" style="height: 18px;" type="image/svg+xml">P(D|T)</object> - the conditional probability of D given T. But the test result is actually giving us <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7455e3a5dfb3c45f4a3b110a7906341362a50e53.svg" style="height: 18px;" type="image/svg+xml">P(T|D)</object> - which is distinct from <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/28f810c7c292ba7faa5b47a0b4e0d470f79b19d8.svg" style="height: 18px;" type="image/svg+xml">P(D|T)</object>.</p> <p>In fact, the problem doesn't provide enough details to answer the question. An important detail that's missing is the <em>prevalence</em> of the disease in population; that is, the value of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f85da048bad405b6f81778501db99b331b689d46.svg" style="height: 18px;" type="image/svg+xml">P(D)</object> without being conditioned on anything. Let's say that it's a moderately common disease with 2% prevalence.</p> <p>To solve this without any clever probability formulae, we can resort to the basic technique of counting by cases. Let's assume there is a sample of 10,000 people <a class="footnote-reference" href="#id4" id="id1"></a>; test aside, how many of them have the disease? 2%, so 200.</p> <img alt="Bayes counting disease calculation prevalence" class="align-center" src="https://eli.thegreenplace.net/images/2018/bayes-count-disease-1.png" /> <p>Of the people who have the disease, 90% will test positive and 10% will test negative. Similarly, of the people with no disease, 90% will test negative and 10% will test positive. Graphically:</p> <img alt="Bayes counting disease calculation prevalence and test" class="align-center" src="https://eli.thegreenplace.net/images/2018/bayes-count-disease-2.png" /> <p>Now we just have to count. There are 980 + 180 = 1160 people who tested positive in the sample population. Of these people, 180 have the disease. In other words, given that Bob is in the &quot;tested positive&quot; population, his chance of having the disease is 180/1160 = 15.5%. This is <em>far</em> lower than the 90% test accuracy; conditional probability often produces surprising results. To motivate this, consider that the number of <em>true positives</em> (people with the disease that tested positive) is 180, while the number of <em>false positives</em> (people w/o the disease that tested positive) is 980. So the chance of being in the second group is larger.</p> <div class="section" id="conditional-probability"> <h2>Conditional probability</h2> <p>As the examples shown above demonstrate, conditional probabilities involve questions like &quot;what's the chance of A happening, given that B happened&quot;, and they are far from being intuitive. Luckily, the mathematical theory of probability gives us the precise and rigorous tools necessary to reason about such problems with relative elegance.</p> <p>The conditional probability <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c8937022ac4e55a642f3bd850e6e9b17dd8fc8d3.svg" style="height: 18px;" type="image/svg+xml">P(A|B)</object> means &quot;what is the probability of event A given that we know event B occurred&quot;. Its mathematical definition is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/990efadb78ac3b3145842995f70010710da00dc2.svg" style="height: 42px;" type="image/svg+xml"> $P(A|B)=\frac{P(A\cap B)}{P(B)}$</object> <p>Notes:</p> <ul class="simple"> <li>Obviously, this is only defined when <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/89a84f08607a9c7fe798b67b1d6f7778c6b2e366.svg" style="height: 18px;" type="image/svg+xml">P(B)&gt;0</object>.</li> <li>Here <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/766349d42cbd0bbefc0311ba2e67a3c1da93625f.svg" style="height: 18px;" type="image/svg+xml">P(A\cap B)</object> is the probability that both A and B occurred.</li> </ul> <p>The first time you look at it, the definition of conditional probability looks somewhat unintuitive. Why is the connection made this way? Here's a visualization that I found useful:</p> <img alt="Sample space dots visualization for conditional probability" class="align-center" src="https://eli.thegreenplace.net/images/2018/samplespace.png" /> <p>The dots in the black square represent the &quot;universe&quot;, our whole sampling space (let's call it S, and then <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2db3db98100616af986bd71ff1b3b779df968b9f.svg" style="height: 18px;" type="image/svg+xml">P(S)=1</object>). A and B are events. Here <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/3a0b4d21054323cff53c2226dfc210377dbf4588.svg" style="height: 22px;" type="image/svg+xml">P(A)=\frac{30}{64}</object> and <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/424554d20e0a324069ca970290cb1f937eacb2f5.svg" style="height: 22px;" type="image/svg+xml">P(B)=\frac{18}{64}</object>. But what is <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c8937022ac4e55a642f3bd850e6e9b17dd8fc8d3.svg" style="height: 18px;" type="image/svg+xml">P(A|B)</object>? Let's figure it out graphically. We know that the outcome is one of the dots encircled in red. What is the chance we got a dot also encircled in blue? It's the number of dots that are both red and blue, divided by the total number of dots in red. Probabilities are calculated as these counts normalized by the size of the whole sample space; all the numbers are divided by 64, so these denominators cancel out; we'll have:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4d783b6592e32a7119193d0bad39843877f398e9.svg" style="height: 42px;" type="image/svg+xml"> $P(A|B)=\frac{P(A\cap B)}{P(B)} = \frac{9}{18} = \frac{1}{2}$</object> <p>In words - the probability that A happened, given that B happened, is 1/2, which makes sense when you eyeball the diagram, and assuming events are uniformly distributed (that is, no dot is inherently more likely to be the outcome than any other dot).</p> <p>Another explanation that always made sense to me was to multiply both sides of the definition of conditional probability by the denominator, to get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/942d1e523c4b088a86138644ceef8512fd97e877.svg" style="height: 18px;" type="image/svg+xml"> $P(A|B)P(B)=P(A\cap B)$</object> <p>In words: we know the chance that A happens given B; if we multiply this by the chance that B happens, we get the chance both A and B happened.</p> <p>Finally, since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/526645ebc23e442dd422c3ee7956b3503578d512.svg" style="height: 18px;" type="image/svg+xml">P(A\cap B)=P(B\cap A)</object>, we can freely exchange A and B in these definitions (they're arbitrary labels, after all), to get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/d748581a273f112cf0af1add7fc7320283f0b226.svg" style="height: 18px;" type="image/svg+xml"> $\begin{equation} P(A\cap B)=P(A|B)P(B)=P(B|A)P(A) \tag{1} \end{equation}$</object> <p>This is an important equation we'll use later on.</p> </div> <div class="section" id="independence-of-events"> <h2>Independence of events</h2> <p>By definition, two events A and B are <em>independent</em> if:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/39d4d03d68cb696d28d659e5ba0d3c7c1b474a44.svg" style="height: 18px;" type="image/svg+xml"> $P(A\cap B)=P(A)P(B)$</object> <p>Using conditional probability, we can provide a slightly different definition. Since:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/990efadb78ac3b3145842995f70010710da00dc2.svg" style="height: 42px;" type="image/svg+xml"> $P(A|B)=\frac{P(A\cap B)}{P(B)}$</object> <p>And <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/79ab089f069fc76a74136c5331d4655c256b1129.svg" style="height: 18px;" type="image/svg+xml">P(A\cap B)=P(A)P(B)</object>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/14ce0bcf839194d2662ca7b52f59521a39c8cabc.svg" style="height: 42px;" type="image/svg+xml"> $P(A|B)=\frac{P(A)P(B)}{P(B)}=P(A)$</object> <p>As long as <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/89a84f08607a9c7fe798b67b1d6f7778c6b2e366.svg" style="height: 18px;" type="image/svg+xml">P(B)&gt;0</object>, for independent A and B we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c399868e8ef09ea7c4e694fce59aad863775bdb0.svg" style="height: 18px;" type="image/svg+xml">P(A|B)=P(A)</object>; in words - B doesn't affect the probability of A in any way. Similarly we can show that for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ee817c4528261210d55c260023c288b451672b8c.svg" style="height: 18px;" type="image/svg+xml">P(A)&gt;0</object> we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e86a7ac5a9f6584ffbc2ede7f26664b7bc9aeead.svg" style="height: 18px;" type="image/svg+xml">P(B|A)=P(B)</object>.</p> <p>Independence also extends to the complements of events. Recall that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d90bd1cf9cdf80fe34935388ba98fc83fced8881.svg" style="height: 19px;" type="image/svg+xml">P(B^C)</object> is the probability that B <em>did not</em> occur, or <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ace42eb9c0d264d434a445f8efa82a392665ed7d.svg" style="height: 18px;" type="image/svg+xml">1-P(B)</object>; since conditional probabilities obey the usual probability axioms, we have: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/58060444dee52056dd38f0178bc90d973cb4de6e.svg" style="height: 19px;" type="image/svg+xml">P(B^C|A)=1-P(B|A)</object>. Then, if A and B are independent:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2fd0e5ba3a9f72858b5289c505ee4ffa1a433047.svg" style="height: 21px;" type="image/svg+xml"> $P(B^C|A)=1-P(B)=P(B^C)$</object> <p>Therefore, <object class="valign-0" data="https://eli.thegreenplace.net/images/math/5ff671c7bd1273544cca53c173582f98ff8a099d.svg" style="height: 15px;" type="image/svg+xml">B^C</object> is independent of A. Similarly the complement of A is independent of B, and the two complements are independent of each other.</p> </div> <div class="section" id="bayes-theorem"> <h2>Bayes' theorem</h2> <p>Starting with equation (1) from above:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4682360529fab726af4caf0827dbb6edd5f9d902.svg" style="height: 18px;" type="image/svg+xml"> $P(A\cap B)=P(A|B)P(B)=P(B|A)P(A)$</object> <p>And taking the right-hand-side equality and dividing it by <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d3ecb6c1c04b6ea74af4cacdcb3f1e1bead3b66e.svg" style="height: 18px;" type="image/svg+xml">P(B)</object> (which is positive, per definition), we get Bayes's theorem:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2e91b38c9f5e26500ede19152eb669f561d9c870.svg" style="height: 42px;" type="image/svg+xml"> $P(A|B)=\frac{P(B|A)P(A)}{P(B)}$</object> <p>This is an extremely useful result, because it links <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9aa002063e1297678f0283b9e7339a73d8a7f6f6.svg" style="height: 18px;" type="image/svg+xml">P(B|A)</object> with <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c8937022ac4e55a642f3bd850e6e9b17dd8fc8d3.svg" style="height: 18px;" type="image/svg+xml">P(A|B)</object>. Recall the disease test example, where we're looking for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/28f810c7c292ba7faa5b47a0b4e0d470f79b19d8.svg" style="height: 18px;" type="image/svg+xml">P(D|T)</object>. We can use Bayes theorem:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7821df244a23990c03bab3a5ad8266797d2b7c4a.svg" style="height: 42px;" type="image/svg+xml"> $P(D|T)=\frac{P(T|D)P(D)}{P(T)}$</object> <p>We know <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7455e3a5dfb3c45f4a3b110a7906341362a50e53.svg" style="height: 18px;" type="image/svg+xml">P(T|D)</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f85da048bad405b6f81778501db99b331b689d46.svg" style="height: 18px;" type="image/svg+xml">P(D)</object>, but what is <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9d3d6e3b10a97d19adafbea8cc72b8e3619a1d27.svg" style="height: 18px;" type="image/svg+xml">P(T)</object>? You may be tempted to say it's 1 because &quot;well, <em>we know the test is positive</em>&quot; but that would be a mistake. To understand why, we have to dig a bit deeper into the meanings of conditional vs. unconditional probabilities.</p> </div> <div class="section" id="prior-and-posterior-probabilities"> <h2>Prior and posterior probabilities</h2> <p>Fundamentally, conditional probability helps us address the following question:</p> <blockquote> How do we update our beliefs in light of new data?</blockquote> <p><em>Prior</em> probability is our beliefs (probabilities assigned to events) before we see the new data. <em>Posterior</em> probability is our beliefs after we see the new data. In the Bayes equation, prior probabilities are simply the un-conditioned ones, while posterior probabilities are conditional. This leads to a key distinction:</p> <ul class="simple"> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7455e3a5dfb3c45f4a3b110a7906341362a50e53.svg" style="height: 18px;" type="image/svg+xml">P(T|D)</object>: posterior probability of the test being positive when we have new data about the person - they have the disease.</li> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9d3d6e3b10a97d19adafbea8cc72b8e3619a1d27.svg" style="height: 18px;" type="image/svg+xml">P(T)</object>: prior probability of the test being positive before we know anything about the person.</li> </ul> <p>This should make it clearer why we can't just assign <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/00c3bafd4576764728524129480194f205402208.svg" style="height: 18px;" type="image/svg+xml">P(T)=1</object>. Instead, recall the &quot;counting by cases&quot; exercise we did in the first example, where we produced a tree of all possibilities; let's formalize it.</p> </div> <div class="section" id="law-of-total-probability"> <h2>Law of Total Probability</h2> <p>Suppose we have the sample space S and some event B. Sometimes it's easier to find the probability of B by first partitioning the space into disjoint pieces:</p> <img alt="Sample space dots visualization for conditional probability" class="align-center" src="https://eli.thegreenplace.net/images/2018/spacepartition.png" /> <p>Then, because the probabilities of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5aa3f2ac5ea9b6b96e13e2bd945ab77b2cce164a.svg" style="height: 15px;" type="image/svg+xml">A_n</object> are disjoint, we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/059517fb7fbb92cfe88968642a0ef5cebcf57f70.svg" style="height: 18px;" type="image/svg+xml"> $P(B)=P(B\cap A_1)+P(B\cap A_2)+P(B\cap A_3)+P(B\cap A_4)$</object> <p>Or, using equation (1):</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/615b1c4ed0b1b0e18f48953ed54e1026aaf8337a.svg" style="height: 18px;" type="image/svg+xml"> $P(B)=P(B|A_1)P(A_1)+P(B|A_2)P(A_2)+P(B|A_3)P(A_3)+P(B|A_4)P(A_4)$</object> </div> <div class="section" id="bayesian-solution-to-the-disease-test-example"> <h2>Bayesian solution to the disease test example</h2> <p>Now we have everything we need to provide a Bayesian solution to the disease test example. Recall that we already know:</p> <ul class="simple"> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c772c4b7636ec5638c2c0058c93526d105f8b659.svg" style="height: 18px;" type="image/svg+xml">P(T|D)=0.9</object>: test accuracy</li> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6445317e3d0d9a55e4669fd041aa16e8d203ec32.svg" style="height: 18px;" type="image/svg+xml">P(D)=0.02</object>: disease prevalance in the population</li> </ul> <p>Now we want to compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9d3d6e3b10a97d19adafbea8cc72b8e3619a1d27.svg" style="height: 18px;" type="image/svg+xml">P(T)</object>. We'll use the law of total probability, with the space partitioning of &quot;has disease&quot; / &quot;does not have disease&quot;:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/5474419ab67a489dc7192dce2940d5f09f25352e.svg" style="height: 21px;" type="image/svg+xml"> $P(T)=P(T|D)P(D)+P(T|D^C)P(D^C)=0.9\ast 0.02+0.1\ast 0.98=0.116$</object> <p>Finally, plugging everything into Bayes theorem:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/05d2f7a8e457de18576450dd426ee0e3d54e2e28.svg" style="height: 206px;" type="image/svg+xml"> P(D|T)=\frac{P(T|D)P(D)}{P(T)} \begin{align*} P(D|T)&amp;=\frac{P(T|D)P(D)}{P(T)}\\ &amp;=\frac{P(T|D)P(D)}{0.116}\\ &amp;=\frac{0.9\ast 0.02}{0.116}=0.155 \end{align*}</object> <p>Which is the same result we got while working through possibilities in the example.</p> </div> <div class="section" id="conditioning-on-multiple-events"> <h2>Conditioning on multiple events</h2> <p>We've just computed <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/28f810c7c292ba7faa5b47a0b4e0d470f79b19d8.svg" style="height: 18px;" type="image/svg+xml">P(D|T)</object> - the conditional probability of event D (patient has disease) on event T (patient tested positive). An important extension of this technique is being able to reason about multiple tests, and how they affect the conditional probability. We'll want to compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6caceb5e79d179c9bc5ee1f30bc94a2cc5a43f09.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1\cap T_2)</object> where <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2885fa41d340ab94bb0451308cf01996f1916011.svg" style="height: 16px;" type="image/svg+xml">T_1</object> and <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/f725afdfa00dd57660feb233ef8547c9985c924e.svg" style="height: 15px;" type="image/svg+xml">T_2</object> are two events for different tests.</p> <p>Let's assume <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2885fa41d340ab94bb0451308cf01996f1916011.svg" style="height: 16px;" type="image/svg+xml">T_1</object> is our original test. <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/f725afdfa00dd57660feb233ef8547c9985c924e.svg" style="height: 15px;" type="image/svg+xml">T_2</object> is a slightly different test that's only 80% accurate. Importantly, the tests are <em>independent</em> (they test completely different things) <a class="footnote-reference" href="#id5" id="id2"></a>.</p> <p>We'll start with a naive approach that seems reasonable. For <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2885fa41d340ab94bb0451308cf01996f1916011.svg" style="height: 16px;" type="image/svg+xml">T_1</object>, we already know that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/bf10cd55e38ac2b90c6a58763b5c7207e21112ac.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1)=0.155</object>. For <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/f725afdfa00dd57660feb233ef8547c9985c924e.svg" style="height: 15px;" type="image/svg+xml">T_2</object>, it's similarly simple to compute:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/0adb7aa21363b77105139ba87a7a6a3b1bb18c1f.svg" style="height: 42px;" type="image/svg+xml"> $P(D|T_2)=\frac{P(T_2|D)P(D)}{P(T_2)}$</object> <p>The disease prevalence is still 2%, and using the law of total probability we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/499e223575d2ca5776c3d7493ec72a7034b68ef8.svg" style="height: 21px;" type="image/svg+xml"> $P(T_2)=P(T_2|D)P(D)+P(T_2|D^C)P(D^C)=0.8\ast 0.02+0.2\ast 0.98=0.212$</object> <p>Therefore:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/5d76ab3755acc699caa4dd080dad975b020aeab9.svg" style="height: 42px;" type="image/svg+xml"> $P(D|T_2)=\frac{P(T_2|D)P(D)}{P(T_2)}=\frac{0.8\ast 0.02}{0.212}=0.075$</object> <p>In other words, if a person tests positive with the second test, the chance of being sick is only 7.5%. But what if they tested positive for both tests?</p> <p>Well, since the tests are independent we can do the usual probability trick of combining the complements. We'll compute the probability the person is <em>not</em> sick given positive tests, and then compute the complement of that. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/34f3f1a1d7cf4f9bb761f5d30e30c73ae74a0d88.svg" style="height: 19px;" type="image/svg+xml">P(D^C|T_1)=1-0.155=0.845</object>, and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/bc939034f01f15365f691b44910bc5671eb66f17.svg" style="height: 19px;" type="image/svg+xml">P(D^C|T_2)=1-0.075=0.925</object>. Therefore:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4f6a0d5b400338982d6fccd353263ad6bf1129c6.svg" style="height: 21px;" type="image/svg+xml"> $P(D^C|T_1\cap T_2)=P(D^C|T_1)P(D^C|T_2)=0.845\ast 0.925=0.782$</object> <p>And complementing again, we get <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/fc78afd8f0ce0dc8a278906a52710401645f0f55.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1\cap T_2)=1-0.782=0.218</object>. The chance of being sick, having tested positive both times is 21.8%.</p> <p>Unfortunately, this computation is wrong, <em>very</em> wrong. Can you spot why before reading on?</p> <p>We've committed a fairly common blunder in conditional probabilities. Given the independence of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/078820d4711ac1ab1b075b3b3e452a97424174c8.svg" style="height: 18px;" type="image/svg+xml">P(T_1|D)</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8934e9ba84564c2e86c8223fa4edf8ecf142349a.svg" style="height: 18px;" type="image/svg+xml">P(T_2|D)</object>, we've assumed the independence of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9e7f4ee884ceb05aa1e7b2dd2453698971ca6689.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1)</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7e3df62da0b4a6aaf0fb78afd834a8d9842e463a.svg" style="height: 18px;" type="image/svg+xml">P(D|T_2)</object>, but this is wrong! It's even easy to see why, given our concrete example. Both of them have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f85da048bad405b6f81778501db99b331b689d46.svg" style="height: 18px;" type="image/svg+xml">P(D)</object> - the disease prevalence - in the numerator. Changing the prevalence will change both <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9e7f4ee884ceb05aa1e7b2dd2453698971ca6689.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1)</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7e3df62da0b4a6aaf0fb78afd834a8d9842e463a.svg" style="height: 18px;" type="image/svg+xml">P(D|T_2)</object> in exactly the same proportion; say, increasing the prevalence 2x will increase both probabilities 2x. They're pretty strongly dependent!</p> <p>The right way of finding <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6caceb5e79d179c9bc5ee1f30bc94a2cc5a43f09.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1\cap T_2)</object> is working from first principles. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7bf9b64bece7c786508608f42ccb59b54c058a2c.svg" style="height: 16px;" type="image/svg+xml">T_1\cap T_2</object> is just another event, so treating it as such and using Bayes theorem we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/490d067f11ad71b9824cac46864577eaeaed6e22.svg" style="height: 42px;" type="image/svg+xml"> $P(D|T_1\cap T_2)=\frac{P(T_1\cap T_2|D)P(D)}{P(T_1\cap T_2)}$</object> <p>Here <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f85da048bad405b6f81778501db99b331b689d46.svg" style="height: 18px;" type="image/svg+xml">P(D)</object> is still 0.02; <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/234e83d1410ef9137bc8dc052b7ea2e038a81337.svg" style="height: 18px;" type="image/svg+xml">P(T_1\cap T_2|D)=0.9\ast0.8=0.72</object>. To compute the denominator we'll use the law of total probability again:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/97bb82b17c9ba327a72f472114e7ffe0aa8ed6a3.svg" style="height: 21px;" type="image/svg+xml"> $P(T_1\cap T_2)=P(T_1\cap T_2|D)P(D)+P(T_1\cap T_2|D^C)P(D^C)=0.72\ast 0.02+0.1\ast 0.2\ast 0.98=0.034$</object> <p>Combining them all together we'll get <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f109539a4f087faa32c6c9bf4006ce9bf5318979.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1\cap T_2)=0.42</object>; the chance of being sick, given two positive tests, is 42%, which is twice higher than our erroneous estimate <a class="footnote-reference" href="#id6" id="id3"></a>.</p> </div> <div class="section" id="bayes-theorem-with-conditioning"> <h2>Bayes theorem with conditioning</h2> <p>Since conditional probabilities satistfy all probability axioms, many theorems remain true when adding a condition. Here's Bayes theorem with extra conditioning on event C:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/66add8a0c5c81776826cb9a58662e30e118c71b9.svg" style="height: 42px;" type="image/svg+xml"> $P(A|B\cap C)=\frac{P(B|A\cap C)P(A|C)}{P(B|C)}$</object> <p>In other words, the connection between <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c8937022ac4e55a642f3bd850e6e9b17dd8fc8d3.svg" style="height: 18px;" type="image/svg+xml">P(A|B)</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9aa002063e1297678f0283b9e7339a73d8a7f6f6.svg" style="height: 18px;" type="image/svg+xml">P(B|A)</object> is true even when everything is conditioned on some event C. To prove it, we can take both sides and expand the definitions of conditional probability until we reach something trivially true:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/28d63e34558ac9a513785738a4eb2c0ce82668c2.svg" style="height: 138px;" type="image/svg+xml"> \begin{align*} P(A|B\cap C)&amp;=\frac{P(B|A\cap C)P(A|C)}{P(B|C)}\\ \frac{P(A\cap B\cap C)}{P(B\cap C)}&amp;=\frac{P(A\cap B\cap C)P(A|C)}{P(A\cap C)P(B|C)}\\ \frac{P(A\cap B\cap C)}{P(B\cap C)}&amp;=\frac{P(A\cap B\cap C)P(A\cap C)}{P(A\cap C)P(B|C)P(C)}\\ \end{align*}</object> <p>Assuming that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7784945e846c26d7c35f5a323dfd18c8d359a001.svg" style="height: 18px;" type="image/svg+xml">P(A\cap C)&gt;0</object>, it cancels out (similarly for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/fd0e822b920e5f9312922a7136ac99a01db7d44e.svg" style="height: 18px;" type="image/svg+xml">P(C)&gt;0</object> in a later step):</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4d007d49999b1286d02810151a6ca8cf8e945bdf.svg" style="height: 138px;" type="image/svg+xml"> \begin{align*} \frac{P(A\cap B\cap C)}{P(B\cap C)}&amp;=\frac{P(A\cap B\cap C)}{P(B|C)P(C)}\\ \frac{P(A\cap B\cap C)}{P(B\cap C)}&amp;=\frac{P(A\cap B\cap C)P(C)}{P(B\cap C)P(C)}\\ \frac{P(A\cap B\cap C)}{P(B\cap C)}&amp;=\frac{P(A\cap B\cap C)}{P(B\cap C)} \end{align*}</object> <p><em>Q.E.D.</em></p> <p>Using this new result, we can compute our two-test disease exercise in another way. Let's say that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2885fa41d340ab94bb0451308cf01996f1916011.svg" style="height: 16px;" type="image/svg+xml">T_1</object> happens first, and we've already computed <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9e7f4ee884ceb05aa1e7b2dd2453698971ca6689.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1)</object>. We can now treat this as the new <em>prior</em> data, and find <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6caceb5e79d179c9bc5ee1f30bc94a2cc5a43f09.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1\cap T_2)</object> based on the new evidence that <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/f725afdfa00dd57660feb233ef8547c9985c924e.svg" style="height: 15px;" type="image/svg+xml">T_2</object> happened. We'll use the conditioned Bayes formulation with <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2885fa41d340ab94bb0451308cf01996f1916011.svg" style="height: 16px;" type="image/svg+xml">T_1</object> being C.</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/a06ff1feb6ddc4f774eeab3db36f24bc6f8ebfb1.svg" style="height: 42px;" type="image/svg+xml"> $P(D|T_2\cap T_1)=\frac{P(T_2|D\cap T_1)P(D|T_1)}{P(T_2|T_1)}$</object> <p>We already know that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9e7f4ee884ceb05aa1e7b2dd2453698971ca6689.svg" style="height: 18px;" type="image/svg+xml">P(D|T_1)</object> is 0.155; What about <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/55700383bdf6e67d7fc2a7e98d9615982e36e546.svg" style="height: 18px;" type="image/svg+xml">P(T_2|D\cap T_1)</object>? Since the tests are independent, this is actually equivalent to <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8934e9ba84564c2e86c8223fa4edf8ecf142349a.svg" style="height: 18px;" type="image/svg+xml">P(T_2|D)</object>, which is 0.8. The denominator requires a bit more careful computation:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/34ba42765c1ab9f7ffd49867707bd9cc3091f1cd.svg" style="height: 42px;" type="image/svg+xml"> $P(T_2|T_1)=\frac{P(T_1\cap T_2)}{P(T_1)}$</object> <p>We've already found <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a0674ed54736ae0ea920485197320e0e23abd85a.svg" style="height: 18px;" type="image/svg+xml">P(T_1)=0.116</object> previously, using the law of total probability. Using the same law:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/8347d5e40753c64fbc86668b7b53321fa00bab65.svg" style="height: 21px;" type="image/svg+xml"> $P(T_1\cap T_2)=P(T_1\cap T_2|D)P(D)+P(T_2\cap T_2|D^C)P(D^C)=0.9\ast 0.9\ast 0.02+0.1\ast 0.2\ast 0.98=0.034$</object> <p>Therefore, <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/8853f1766ae90cf8da81a56ab3fcc173baa21878.svg" style="height: 23px;" type="image/svg+xml">P(T_2|T_1)=\frac{0.034}{0.116}=0.293</object> and we now have all the ingredients:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7bd0854b518fca8ea6c42551abd0b032ce5e1708.svg" style="height: 37px;" type="image/svg+xml"> $P(D|T_2\cap T_1)=\frac{0.8\ast 0.155}{0.293}=0.42$</object> <p>We've reached the same result using two different approaches, which is reassuring. Computing with both tests taken together is a bit quicker, but taking one test at a time is also useful because it lets us <em>update our beliefs</em> over time, given new data.</p> <p>Computing conditional probabilities w.r.t. multiple parameters is very useful in machine learning - this would be a good topic for a separate article.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>This actual number of people is arbitrary, and it could be anything else; in formulae it cancels out anyway. I picked 10,000 because it's a nice number ending with a bunch of zeros and won't produce fractional people for this particular example.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id5" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td><p class="first">You may be suspicious of this assumption - how can two tests for the same disease be independent? Being suspicious about probability independence assumptions is a good idea in general, but here the assumption is reasonable.</p> <p class="last">Note that we assume independence given D; in other words, that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/078820d4711ac1ab1b075b3b3e452a97424174c8.svg" style="height: 18px;" type="image/svg+xml">P(T_1|D)</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8934e9ba84564c2e86c8223fa4edf8ecf142349a.svg" style="height: 18px;" type="image/svg+xml">P(T_2|D)</object> are independent. We know the person is sick, and we know that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2885fa41d340ab94bb0451308cf01996f1916011.svg" style="height: 16px;" type="image/svg+xml">T_1</object> turned positive - does this affect <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/f725afdfa00dd57660feb233ef8547c9985c924e.svg" style="height: 15px;" type="image/svg+xml">T_2</object>? Depends on the test; some tests definitely test related things, but some may test unrelated things (say the first looks for a particular by-product of sick cells while the second looks for a gene that is known to be correlated with disease prevalence). It's possible to find plausible connections between almost anything though, so all independence assumptions are &quot;best-effort&quot;.</p> </td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>My intuition for understanding why it's higher is that there's a tug of war between the test accuracy and the prevalence (the lower the prevalence, the higher the test accuracy has to be to produce reasonable predivtive value). But when we recompute with two tests, we still use prevalence just once in the formula, so the two tests combine forces against it.</td></tr> </tbody> </table> </div> Computing remainders by doubling2018-02-12T05:40:00-08:002018-02-12T05:40:00-08:00Eli Benderskytag:eli.thegreenplace.net,2018-02-12:/2018/computing-remainders-by-doubling/<p>I'm going through Stepanov and Rose's <em>From Mathematics to Generic Programming</em>, and on page 48 they present a fast algorithm for computing remainders without using either division or multiplication. Unfortunately, there's not much in terms of proof <a class="footnote-reference" href="#id3" id="id1"></a>, so this post is to document my understanding of the algorithm.</p> <p>The …</p><p>I'm going through Stepanov and Rose's <em>From Mathematics to Generic Programming</em>, and on page 48 they present a fast algorithm for computing remainders without using either division or multiplication. Unfortunately, there's not much in terms of proof <a class="footnote-reference" href="#id3" id="id1"></a>, so this post is to document my understanding of the algorithm.</p> <p>The algorithm relies on the following lemma: For <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/359f68d4e18136986adfdd925d71beba6906d6f2.svg" style="height: 17px;" type="image/svg+xml">a,b\in\mathbb{N}</object>, given <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/cea4102e00c379c659bd1ac0ec828b107ef8e191.svg" style="height: 18px;" type="image/svg+xml">t=remainder(a,2b)</object>, we have:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/cb7a84e3fdaea31591f54385a69365a074dbc16e.svg" style="height: 43px;" type="image/svg+xml"> $remainder(a,b)=\left\{\begin{matrix} t &amp; t &lt; b\\ t-b &amp; t \geq b \end{matrix}\right.$</object> <p>To prove this, consider the standard quotient and remainder representation of a's and b's relation: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e467ce1c3e2bed6e3c59d96fb444ec0f18012b2b.svg" style="height: 17px;" type="image/svg+xml">a=qb+r</object>, with <em>q</em> the quotient and <em>r</em> the remainder. <em>q</em> can be either even or odd. If it's even, we can say that there exists <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/fcd7fa67ec4934ffd698ba002f505cf4cb93cb4f.svg" style="height: 14px;" type="image/svg+xml">k\in\mathbb{N}</object> such that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/70e471804d82f9d0c2b34ebae8386daaf4b7c163.svg" style="height: 17px;" type="image/svg+xml">q=2k</object>, so:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7b35830ec6cf002f68272e4c031182d2b69d98e9.svg" style="height: 15px;" type="image/svg+xml"> $a=2kb+r$</object> <p>In this case, the remainder of <em>a</em> divided by <em>2b</em> is trivially <em>r</em> (the same as the remainder divided by <em>b</em>). If <em>q</em> is odd, we can say that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b832c21a1f7af5db87bdd07c9881f58078a6ca77.svg" style="height: 17px;" type="image/svg+xml">q=2k+1</object>, so:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/d4d371e1aa19719ea652108e2443ce6b4684aa2a.svg" style="height: 18px;" type="image/svg+xml"> $a=(2k+1)b+r=2kb+b+r$</object> <p>In this case, the remainder of <em>a</em> divided by <em>2b</em> is <em>b+r</em>. Now it's obvious why the lemma is true. Without explicitly distinguishing <em>q</em> as even or odd, it just examines the remainder of <em>a</em> divided by <em>2b</em>. If this remainder is smaller than <em>b</em>, then that's also the remainder of dividing by <em>b</em> because <em>q</em> must be even. On the other hand, if the remainder is larger than <em>b</em>, <em>q</em> must be odd and we have <em>b+r</em> as the remainder, in which case we subtract <em>b</em> to get to <em>r</em>.</p> <p>Now, the algorithm itself, as Python code <a class="footnote-reference" href="#id4" id="id2"></a>:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fast_remainder</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span> <span class="k">if</span> <span class="n">a</span> <span class="o">&lt;</span> <span class="n">b</span><span class="p">:</span> <span class="k">return</span> <span class="n">a</span> <span class="k">if</span> <span class="n">a</span> <span class="o">-</span> <span class="n">b</span> <span class="o">&lt;</span> <span class="n">b</span><span class="p">:</span> <span class="k">return</span> <span class="n">a</span> <span class="o">-</span> <span class="n">b</span> <span class="n">r</span> <span class="o">=</span> <span class="n">fast_remainder</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">+</span> <span class="n">b</span><span class="p">)</span> <span class="k">if</span> <span class="n">r</span> <span class="o">&lt;</span> <span class="n">b</span><span class="p">:</span> <span class="k">return</span> <span class="n">r</span> <span class="k">return</span> <span class="n">r</span> <span class="o">-</span> <span class="n">b</span> </pre></div> <p>It starts by covering base cases of <em>a</em> being up to <em>2b</em>. Then it recurses to find the remainder of <em>a</em> divided by <em>2b</em>. This is a curious recursive pattern, as the parameters grow rather than shrink! Therefore, it's important to prove that this recursion terminates (if it does, its correctness stems from the lemma).</p> <p>We keep doubling <em>b</em> in every recursive invocation, and the base cases break the recursive cycle once <em>b</em> outgrows <em>a</em>. It will take at most <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/a65d0750d3247b90609ed8fdb790dcf3ac93a463.svg" style="height: 19px;" type="image/svg+xml">\left \lceil log_{2}a\right \rceil</object> steps to reach that point. Therefore, the recursion terminates.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id3" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>Which is a bit disappointing for a book that was written to show the beauty of math to programmers and is full of proofs for other stuff. For this algorithm the authors just mention &quot;It's not obvious where the work is done, but it works&quot; and then provide a single extended example.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>This is a slightly adapted version of the algorithm, which also works when <em>a</em> is a multiple of <em>b</em>, such that the remainder is 0.</td></tr> </tbody> </table> Affine transformations2018-01-09T05:15:00-08:002018-01-09T05:15:00-08:00Eli Benderskytag:eli.thegreenplace.net,2018-01-09:/2018/affine-transformations/<p>This is a brief article on affine mappings and their relation to linear mappings, with some applications.</p> <div class="section" id="linear-vs-affine"> <h2>Linear vs. Affine</h2> <p>To start discussing affine mappings, we have to first address a common confusion around what it means for a function to be linear.</p> <p>According to <a class="reference external" href="https://en.wikipedia.org/wiki/Linear_function">Wikipedia</a> the term <em>linear function …</em></p></div><p>This is a brief article on affine mappings and their relation to linear mappings, with some applications.</p> <div class="section" id="linear-vs-affine"> <h2>Linear vs. Affine</h2> <p>To start discussing affine mappings, we have to first address a common confusion around what it means for a function to be linear.</p> <p>According to <a class="reference external" href="https://en.wikipedia.org/wiki/Linear_function">Wikipedia</a> the term <em>linear function</em> can refer to two distinct concepts, based on the context:</p> <ol class="arabic simple"> <li>In Calculus, a linear function is a polynomial function of degree zero or one; in other words, a function of the form <img alt="f(x)=ax+b" class="valign-m4" src="https://eli.thegreenplace.net/images/math/a85393d5068f5c4bc36ff7efed535a8f1a686848.png" style="height: 18px;" /> for some constants <tt class="docutils literal">a</tt> and <tt class="docutils literal">b</tt>.</li> <li>In Linear Algebra, a linear function is a linear mapping, or linear <em>transformation</em>.</li> </ol> <p>In this article we're going to be using (2) as the definition of <em>linear</em>, and it will soon become obvious why (1) is confusing when talking about transformations. To avoid some of the jumble going forward, I'm goine to be using the term <em>mapping</em> instead of <em>function</em>, but in linear algebra the two are interchangeable (<em>transformation</em> is another synonym, which I'm going to be making less effort to avoid since it's not as overloaded <a class="footnote-reference" href="#id8" id="id1"></a>).</p> </div> <div class="section" id="linear-transformations"> <h2>Linear transformations</h2> <p>Since we're talking about linear algebra, let's use the domain of vector spaces for the definitions. A transformation (or mapping) <tt class="docutils literal">f</tt> is linear when for any two vectors <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> and <object class="valign-0" data="https://eli.thegreenplace.net/images/math/d45128696127d3ae74860c6f8b14ce6ca20d15e7.svg" style="height: 13px;" type="image/svg+xml">\vec{w}</object> (assuming the vectors are in the same vector space, say <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" />):</p> <ul class="simple"> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c77fa5b7073e6b81e5b431b6e383a7414858cea0.svg" style="height: 18px;" type="image/svg+xml">f(\vec{v}+\vec{w})=f(\vec{v})+f(\vec{w})</object></li> <li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d48c4c3abf0c65851d92030c7f40d799156f5871.svg" style="height: 18px;" type="image/svg+xml">f(k\vec{v})=kf(\vec{v})</object> for some scalar <tt class="docutils literal">k</tt></li> </ul> <p>For example, the mapping <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/6ebc8ee559ec27b734f8f10214bd0a5fd6fc6c54.svg" style="height: 19px;" type="image/svg+xml">f(\vec{v})=\langle 3v_1-4v_2,v_2 \rangle</object> - where <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9b12bbf79036cb3e904f971fd86838db1dade1aa.svg" style="height: 12px;" type="image/svg+xml">v_1</object> and <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/2e84f52c0f54659a1f533b25591adb924f2a4131.svg" style="height: 11px;" type="image/svg+xml">v_2</object> are the components of <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> - is linear. The mapping <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/f86fed5746a1646abc0377fbbf9002231177b0fa.svg" style="height: 19px;" type="image/svg+xml">g(\vec{v})=\langle v_2,2v_{1}v_{2} \rangle</object> is <em>not</em> linear.</p> <p>In fact, it can be shown that for the kind of vector spaces we're mostly interested in <a class="footnote-reference" href="#id9" id="id2"></a>, any linear mapping can be represented by a matrix that is multiplied by the input vector. This is because we can represent any vector in terms of the standard basis vectors: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/dce46a4dda3d1b14590f131161880969b7998cce.svg" style="height: 17px;" type="image/svg+xml">\vec{v}=v_1\vec{e}_1+...+v_n\vec{e}_n</object>. Then, since <tt class="docutils literal">f</tt> is linear:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/cac8ee862d974540e4c17cb4f2c4309db0f00193.svg" style="height: 50px;" type="image/svg+xml"> $f(\vec{v})=f(\sum_{i=1}^{n}v_i\vec{e}_i)=\sum_{i=1}^{n}v_if(\vec{e}_i)$</object> <p>If we think of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c37ade9231729dad728ad612e88916fc118f8f24.svg" style="height: 18px;" type="image/svg+xml">f(\vec{e}_i)</object> as column vectors, this is precisely the multiplication of a matrix by <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/cd0deb0c6714d83a6fca594b4c755d789d34d8a9.svg" style="height: 86px;" type="image/svg+xml"> $f(\vec{v}) = \begin{pmatrix} \mid &amp; \mid &amp; &amp; \mid \\ f(\vec{e}_1) &amp; f(\vec{e}_2) &amp; \cdots &amp; f(\vec{e}_n) \\ \mid &amp; \mid &amp; &amp; \mid \\ \end{pmatrix}\begin{pmatrix} v_1 \\ v_2 \\ ... \\ v_n \end{pmatrix}$</object> <p>This multiplication by a matrix can also be seen as a <em>change of basis</em> for <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> from the standard base to a base defined by <tt class="docutils literal">f</tt>. If you want a refresher on how changes of basis work, take a look at my <a class="reference external" href="http://eli.thegreenplace.net/2015/change-of-basis-in-linear-algebra/">older post on this topic</a>.</p> <p>Let's get back to our earlier example of the mapping <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/6ebc8ee559ec27b734f8f10214bd0a5fd6fc6c54.svg" style="height: 19px;" type="image/svg+xml">f(\vec{v})=\langle 3v_1-4v_2,v_2 \rangle</object>. We can represent this mapping with the following matrix:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/f54bea9f647bcffea53eaa9de5831b086cea8987.svg" style="height: 43px;" type="image/svg+xml"> $\begin{pmatrix} 3 &amp; -4 \\ 0 &amp; 1 \end{pmatrix}$</object> <p>Meaning that:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/d5f35dc6342d25f9c4bf13e84986f90a870f3728.svg" style="height: 43px;" type="image/svg+xml"> $f(\vec{v})=\begin{pmatrix} 3 &amp; -4 \\ 0 &amp; 1 \end{pmatrix}\begin{pmatrix} v_1 \\ v_2 \end{pmatrix}$</object> <p>Representing linear mappings this way gives us a number of interesting tools for working with them. For example, the associativity of matrix multiplication means that we can represent compositions of mappings by simply multiplying the mapping matrices together.</p> <p>Consider the following mapping:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/1677a0280bf92fc8725c0a71e5a2705eaabebde8.svg" style="height: 43px;" type="image/svg+xml"> $S=\begin{pmatrix} 2 &amp; 0\\ 0 &amp; 2 \end{pmatrix}$</object> <p>In equational form: <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/3f53317f06530c0cd6af66868f708b09cc719eaa.svg" style="height: 19px;" type="image/svg+xml">S(\vec{v})=\langle 2v_1,2v_2 \rangle</object>. This mapping <em>stretches</em> the input vector 2x in both dimensions. To visualize a mapping, it's useful to examine its effects on some standard vectors. Let's use the vectors <tt class="docutils literal">(0,0)</tt>, <tt class="docutils literal">(0,1)</tt>, <tt class="docutils literal">(1,0)</tt>, <tt class="docutils literal">(1,1)</tt> (the &quot;unit square&quot;). In <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> they represent four points that can be connected together as follows <a class="footnote-reference" href="#id10" id="id3"></a>:</p> <img alt="Unit vectors as points on the plane" class="align-center" src="https://eli.thegreenplace.net/images/2018/points-unit-vectors.png" /> <p>It's easy to see that when transformed with <object class="valign-0" data="https://eli.thegreenplace.net/images/math/02aa629c8b16cd17a44f3a0efec2feed43937642.svg" style="height: 12px;" type="image/svg+xml">S</object>, we'll get:</p> <img alt="Unit vectors trasformed with 2x stretch" class="align-center" src="https://eli.thegreenplace.net/images/2018/points-stretch.png" /> <p>It's also well known that rotation (relative to the origin) can be modeled with the following mapping with <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> in radians:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/5783a1d0f95db66e5dc6e8499cbb5853d43a2a60.svg" style="height: 43px;" type="image/svg+xml"> $R=\begin{pmatrix} cos\theta &amp; sin\theta \\ -sin\theta &amp; cos\theta \end{pmatrix}$</object> <p>Transforming our unit square with this matrix we get:</p> <img alt="Unit vectors trasformed with rotation by one radian" class="align-center" src="https://eli.thegreenplace.net/images/2018/points-rotate.png" /> <p>Finally, let's say we want to combine these transformations. To stretch and then rotate a vector, we would do: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/40412ae6b5ed1bafeb85baada5ab732975419037.svg" style="height: 18px;" type="image/svg+xml">f(\vec{v})=R(Sv)</object>. Since matrix multiplication is associative, this can also be rewritten as: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/fd0a61053256cb3481d5ce29828291712b9b5a8e.svg" style="height: 18px;" type="image/svg+xml">f(\vec{v})=(RS)v</object>. In other words, we can find a matrix <object class="valign-0" data="https://eli.thegreenplace.net/images/math/7b0ecef9a260b7e055cb6c5ab4d53ca3b236a621.svg" style="height: 12px;" type="image/svg+xml">A=RS</object> which represents the combined transformation, and we &quot;find&quot; it by simply multiplying <tt class="docutils literal">R</tt> and <tt class="docutils literal">S</tt> together <a class="footnote-reference" href="#id11" id="id4"></a>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7c7b6d5faacf808b4a6551e3bfbee40bdf0dd158.svg" style="height: 43px;" type="image/svg+xml"> $A=\begin{pmatrix} cos\theta &amp; sin\theta \\ -sin\theta &amp; cos\theta \end{pmatrix}\begin{pmatrix} 2 &amp; 0 \\ 0 &amp; 2 \end{pmatrix}=\begin{pmatrix} 2cos\theta &amp; 2sin\theta \\ -2sin\theta &amp; 2cos\theta \end{pmatrix}$</object> <p>And when we multiply our unit by this matrix we get:</p> <img alt="Unit vectors transformed with rotation and stretch" class="align-center" src="https://eli.thegreenplace.net/images/2018/points-rotate-and-stretch.png" /> </div> <div class="section" id="id5"> <h2>Affine transformations</h2> <p>Now that we have some good context on linear transformations, it's time to get to the main topic of this post - affine transformations.</p> <p>For an affine space (we'll talk about what this is exactly in a later section), every affine transformation is of the form <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9de6170fb616cd269a8c53f567e94d08b9b813a0.svg" style="height: 18px;" type="image/svg+xml">g(\vec{v})=Av+b</object> where <img alt="A" class="valign-0" src="https://eli.thegreenplace.net/images/math/6dcd4ce23d88e2ee9568ba546c007c63d9131c1b.png" style="height: 12px;" /> is a matrix representing a linear transformation and <object class="valign-0" data="https://eli.thegreenplace.net/images/math/e9d71f5ee7c92d6dc9e92ffdad17b8bd49418f98.svg" style="height: 13px;" type="image/svg+xml">b</object> is a vector. In other words, an affine transformation combines a linear transformation with a <em>translation</em>.</p> <p>Quite obviously, every linear transformation is affine (just set <object class="valign-0" data="https://eli.thegreenplace.net/images/math/e9d71f5ee7c92d6dc9e92ffdad17b8bd49418f98.svg" style="height: 13px;" type="image/svg+xml">b</object> to the zero vector). However, not every affine transformation is linear. For a non-zero <object class="valign-0" data="https://eli.thegreenplace.net/images/math/e9d71f5ee7c92d6dc9e92ffdad17b8bd49418f98.svg" style="height: 13px;" type="image/svg+xml">b</object>, the linearity rules don't check out. Let's say that:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/8aa781f9271cdf38440d17a01899275112fe143a.svg" style="height: 52px;" type="image/svg+xml"> \begin{align*} f(\vec{v})&amp;=A\vec{v}+\vec{b} \\ f(\vec{w})&amp;=A\vec{w}+\vec{b} \end{align*}</object> <p>Then if we try to add these together, we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7550749852d8b7f5ad72f26dd850cc94e16135dd.svg" style="height: 22px;" type="image/svg+xml"> $f(\vec{v}+\vec{w})=A(\vec{v}+\vec{w})+\vec{b}$</object> <p>Whereas:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/594f1da3f413dafcad4090111f4e7886df1b1afe.svg" style="height: 18px;" type="image/svg+xml"> $f(\vec{v})+f(\vec{w})=A\vec{v}+b+A\vec{w}+b=A(\vec{v}+\vec{w})+2b$</object> <p>The violation of the scalar multiplication rule can be checked similarly.</p> <p>Let's examine the affine transformation that stretches a vector by a factor of two (similarly to the <tt class="docutils literal">S</tt> transformation we've discussed before) and translates it by 0.5 for both dimensions:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/b73e9c882baeb6d742f05acc82bf6130520a634c.svg" style="height: 43px;" type="image/svg+xml"> $f(\vec{v})=\begin{pmatrix} 2 &amp; 0 \\ 0 &amp; 2 \end{pmatrix}\vec{v}+\begin{pmatrix} 0.5 \\ 0.5\end{pmatrix}$</object> <p>Here is this transformation visualized:</p> <img alt="Unit vectors translated and stretched" class="align-center" src="https://eli.thegreenplace.net/images/2018/points-translate.png" /> <p>With some clever augmentation, we can represent affine transformations as a multiplication by a single matrix, if we add another dimension to the vectors <a class="footnote-reference" href="#id12" id="id6"></a>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/5dda35f4a28bb65c45da251170ae5cd17138b8ec.svg" style="height: 65px;" type="image/svg+xml"> $f(\vec{v})=T\vec{v}=\begin{pmatrix} 2 &amp; 0 &amp; 0.5 \\ 0 &amp; 2 &amp; 0.5 \\ 0 &amp; 0 &amp; 1 \end{pmatrix} \begin{pmatrix} v_1 \\ v_2 \\ 1 \end{pmatrix}$</object> <p>The translation vector is tacked on the right-hand side of the transform matrix, with a 1 for the extra dimension (the matrix gets 0s in that dimension). The result will always have a 1 in the final dimension, which we can ignore.</p> <p>Affine transforms can be composed similarly to linear transforms, using matrix multiplication. This also makes them associative. As an example, let's compose the scaling+translation transform discussed most recently with the rotation transform mentioned earlier. This is the augmented matrix for the rotation:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/dc7c578ad6e6d7d50aa9ce8ebd8024829d072695.svg" style="height: 65px;" type="image/svg+xml"> $R=\begin{pmatrix} cos\theta &amp; sin\theta &amp; 0 \\ -sin\theta &amp; cos\theta &amp; 0 \\ 0 &amp; 0 &amp; 1 \end{pmatrix}$</object> <p>The composed transform will be <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/0b6f50c9b5f2a9431f500810ae71a40ec939d943.svg" style="height: 18px;" type="image/svg+xml">f(\vec{v})=T(R(\vec{v}))=(TR)\vec{v}</object>. Its matrix is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/1ad583a63663ce317112e10234e7ef065ef97252.svg" style="height: 65px;" type="image/svg+xml"> $TR=\begin{pmatrix} 2 &amp; 0 &amp; 0.5 \\ 0 &amp; 2 &amp; 0.5 \\ 0 &amp; 0 &amp; 1 \end{pmatrix}\begin{pmatrix} cos\theta &amp; sin\theta &amp; 0 \\ -sin\theta &amp; cos\theta &amp; 0 \\ 0 &amp; 0 &amp; 1 \end{pmatrix}=\begin{pmatrix} 2cos\theta &amp; 2sin\theta &amp; 0.5 \\ -2sin\theta &amp; 2cos\theta &amp; 0.5 \\ 0 &amp; 0 &amp; 1 \end{pmatrix}$</object> <p>The visualization is:</p> <img alt="Translation after rotation" class="align-center" src="https://eli.thegreenplace.net/images/2018/points-rotate-translate.png" /> </div> <div class="section" id="affine-subspaces"> <h2>Affine subspaces</h2> <p>The previous section defined affine transformation w.r.t. the concept of <em>affine space</em>, and now it's time to pay the rigor debt. According <a class="reference external" href="https://en.wikipedia.org/wiki/Affine_space">to Wikipedia</a>, an affine space:</p> <blockquote> ... is a geometric structure that generalizes the properties of Euclidean spaces in such a way that these are independent of the concepts of distance and measure of angles, keeping only the properties related to parallelism and ratio of lengths for parallel line segments.</blockquote> <p>Since we've been using vectors and vector spaces so far in the article, let's see the relation between vector spaces and affine spaces. The best explanation I found online is the following.</p> <p>Consider the vector space <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" />, with two lines:</p> <img alt="Lines for subspace and affine space of R2" class="align-center" src="https://eli.thegreenplace.net/images/2018/subspace-lines.png" /> <p>The blue line can be seen as a vector subspace (also known as <em>linear subspace</em>) of <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" />. On the other hand, the green line is not a vector subspace because it doesn't contain the zero vector. The green line is an <em>affine subspace</em>. This leads us to a definition:</p> <blockquote> A subset <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/deec434246ee4364d506b710d495a68faae6cb99.svg" style="height: 13px;" type="image/svg+xml">U \subset V</object> of a vector space <img alt="V" class="valign-0" src="https://eli.thegreenplace.net/images/math/c9ee5681d3c59f7541c27a38b67edf46259e187b.png" style="height: 12px;" /> is an affine space if there exists a <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/66d9cae10caefdd28dcb23fed51b0bb194c40cff.svg" style="height: 13px;" type="image/svg+xml">u \in U</object> such that <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/93f362965ba8f75b9f3cc491918201ef91811888.svg" style="height: 19px;" type="image/svg+xml">U - u = \{x-u \mid x \in U\}</object> is a vector subspace of <img alt="V" class="valign-0" src="https://eli.thegreenplace.net/images/math/c9ee5681d3c59f7541c27a38b67edf46259e187b.png" style="height: 12px;" />.</blockquote> <p>If you recall the definition of affine transformations from earlier on, this should seem familiar - linear and affine subspaces are related by using a translation vector. It can also be said that an affine space is a generalization of a linear space, in that it doesn't require a specific origin point. From Wikipedia, again:</p> <blockquote> Any vector space may be considered as an affine space, and this amounts to forgetting the special role played by the zero vector. In this case, the elements of the vector space may be viewed either as points of the affine space or as displacement vectors or translations. When considered as a point, the zero vector is called the origin. Adding a fixed vector to the elements of a linear subspace of a vector space produces an affine subspace. One commonly says that this affine subspace has been obtained by translating (away from the origin) the linear subspace by the translation vector.</blockquote> <p>When mathematicians define new algebraic structures, they don't do it just for fun (well, sometimes they do) but because such structures have some properties which can lead to useful generalizations. Affine spaces and transformations also have interesting properties, which make them useful. For example, an affine transformation always maps a line to a line (and not to, say, a parabola). Any two triangles can be converted one to the other using an affine transform, and so on. This leads to interesting applications in computational geometry and 3D graphics.</p> </div> <div class="section" id="affine-functions-in-linear-regression-and-neural-networks"> <h2>Affine functions in linear regression and neural networks</h2> <p>Here I want to touch upon the linear vs. affine confusion again, in the context of machine learning. Recall that <a class="reference external" href="http://eli.thegreenplace.net/2016/linear-regression/">Linear Regression</a> attempts to fit a line onto data in an optimal way, the line being defined as the function:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/0e60e25963ba73aa9e55f1ebb41a3bf2460b7f28.svg" style="height: 18px;" type="image/svg+xml"> $y(x) = mx + b$</object> <p>But as this article explained, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/370b21bb4fe6d65ddec7d4c585f09a5e49b55652.svg" style="height: 18px;" type="image/svg+xml">y(x)</object> is not actually a linear function; it's an affine function (because of the constant factor <object class="valign-0" data="https://eli.thegreenplace.net/images/math/e9d71f5ee7c92d6dc9e92ffdad17b8bd49418f98.svg" style="height: 13px;" type="image/svg+xml">b</object>). Should linear regression be renamed to <em>affine regression</em>? It's probably too late for that :-), but it's good to get the terminology right.</p> <p>Similarly, a single fully connected layer in a neural network is often expressed mathematically as:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/9f626c0ce605723e39bd0dae81451b0cddee09b0.svg" style="height: 22px;" type="image/svg+xml"> $y(\vec{x})=W\vec{x}+\vec{b}$</object> <p>Where <object class="valign-0" data="https://eli.thegreenplace.net/images/math/f8914399eadbd8be3c3196100658870e03c61fee.svg" style="height: 13px;" type="image/svg+xml">\vec{x}</object> is the input vector, <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> is the weight matrix and <object class="valign-0" data="https://eli.thegreenplace.net/images/math/71fa108edb785ca9f729fa3cd5ad18556dd682e4.svg" style="height: 18px;" type="image/svg+xml">\vec{b}</object> is the bias vector. This function is also usually referred to as <em>linear</em> although it's actually <em>affine</em>.</p> </div> <div class="section" id="affine-expressions-and-array-accesses"> <h2>Affine expressions and array accesses</h2> <p>Pivoting from algebra to programming, affine functions have a use when discussing one of the most fundamental building blocks of computer science: accessing arrays.</p> <p>Let's start by defining an <em>affine expression</em>:</p> <blockquote> An expression is affine w.r.t. variables <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3bca803fa0f8dd4ba421a15cbf1a2547ae0285b7.svg" style="height: 12px;" type="image/svg+xml">v_1,v_2,...,v_n</object> if it can be expressed as <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/258cc23dfefcbb9c4cf7ffbe169028181113b5a2.svg" style="height: 15px;" type="image/svg+xml">c_0+c_{1}v_1+...+c_{n}v_n</object> where <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9fa86460c3375a0934ab62697483f4692cdfb0a2.svg" style="height: 12px;" type="image/svg+xml">c_0,c_1,...,c_n</object> are constants.</blockquote> <p>Affine expressions are interesting because they are often used to index arrays in loops. Consider the following loop in C that copies all elements in an MxN matrix &quot;one to the left&quot;:</p> <div class="highlight"><pre><span></span><span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span> <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">1</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="o">++</span><span class="n">j</span><span class="p">)</span> <span class="p">{</span> <span class="n">arr</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">arr</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">];</span> <span class="p">}</span> <span class="p">}</span> </pre></div> <p>Since C's memory layout for multi-dimensional arrays is <a class="reference external" href="http://eli.thegreenplace.net/2015/memory-layout-of-multi-dimensional-arrays">row-major</a>, the statement in the loop assigns a value to <tt class="docutils literal">arr[i*N + j - 1]</tt> at every iteration. <tt class="docutils literal">i*N + j - 1</tt> is an <em>affine expression</em> w.r.t. variables <tt class="docutils literal">i</tt> and <tt class="docutils literal">j</tt> <a class="footnote-reference" href="#id13" id="id7"></a>.</p> <p>When all expressions in a loop are affine, the loop is amenable to some advanced analyses and optimizations, but this is a topic for another post.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>Though it's also not entirely precise. Generally speaking, transformations are more limited than functions. A transformation is defined on a set as a binjection of the set to itself, whereas functions are more general (they can map between different sets, for example).</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id9" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>Finite-dimensional vector spaces with a defined basis.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>Tossing a bit of rigor aside, we can imagine points and vectors to be isomophic since both are represented by pairs of numbers on the <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> plane. Some resources will mention the <em>Euclidean plane</em> - <object class="valign-0" data="https://eli.thegreenplace.net/images/math/49853b597499c984c2d89848a19153d282da202c.svg" style="height: 15px;" type="image/svg+xml">\mathbb{E}^2</object> when talking about points and lines, but the Euclidean plane can be modeled by a same-dimensional real plane so I'll just be using <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" />.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id11" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>I'll admit this result looks fairly obvious. But longer chains of transforms work in exactly the same way, and the fact that we can represent such chains with a single matrix is very useful.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id12" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id6"></a></td><td>This trick has a geometrical explanation: translation in 2D can be modeled as adding a dimension and performing a 3D <em>shear</em> operation, then projecting the resulting object onto a 2D plane again. The object will appear shifted.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id13" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id7"></a></td><td>It's actually only affine if <tt class="docutils literal">N</tt> is a compile-time constant or can be proven to be constant throughout the loop.</td></tr> </tbody> </table> </div> Logistic regression2016-11-02T05:45:00-07:002016-11-02T05:45:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-11-02:/2016/logistic-regression/<p>This article covers logistic regression - arguably the simplest classification model in machine learning; it starts with basic binary classification, and ends up with some techniques for multinomial classification (selecting between multiple possibilities). The final examples using the softmax function can also be viewed as an example of a single-layer fully …</p><p>This article covers logistic regression - arguably the simplest classification model in machine learning; it starts with basic binary classification, and ends up with some techniques for multinomial classification (selecting between multiple possibilities). The final examples using the softmax function can also be viewed as an example of a single-layer fully connected neural network.</p> <p>This article is the theoretical part; in addition, there's quite a bit of accompanying code <a class="reference external" href="https://github.com/eliben/deep-learning-samples/tree/master/logistic-regression">here</a>. All the models discussed in the article are implemented from scratch in Python using only Numpy.</p> <div class="section" id="linear-model-for-binary-classification"> <h2>Linear model for binary classification</h2> <p>Using a linear model for binary classification is very similar to <a class="reference external" href="http://eli.thegreenplace.net/2016/linear-regression/">linear regression</a>, except that we expect a binary (yes/no) answer rather than a numeric answer.</p> <p>We want to come up with a parameter vector <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, such that for every data vector <strong>x</strong> we can compute <a class="footnote-reference" href="#id10" id="id1"></a>:</p> <img alt="$\hat{y}(x) = \theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n$" class="align-center" src="https://eli.thegreenplace.net/images/math/ae682f9fda97c28c8e100c87aecad635c7c1d96c.png" style="height: 18px;" /> <p>And then make a binary decision based on the value of <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" />. A simple way to make a decision is to say &quot;yes&quot; if <img alt="\hat{y}(x)\geq 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c30aad52f5af131a89f1a8805e25aa8e354795dc.png" style="height: 18px;" /> and &quot;no&quot; otherwise. Note that this is arbitrary, as we could flip the condition for &quot;yes&quot; and for &quot;no&quot;. We could also compare <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> to some value other than zero, and the model would learn equally well <a class="footnote-reference" href="#id12" id="id2"></a>.</p> <p>Let's make this more concrete, also assigning numeric values to &quot;yes&quot; and &quot;no&quot;, which will make some computations simpler later on. For &quot;yes&quot; we'll (again, arbitrarily) select +1, and for &quot;no&quot; we'll go with -1. So, a linear model for binary classification is parameterized by some <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, such that:</p> <img alt="$\hat{y}(x) = \theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n$" class="align-center" src="https://eli.thegreenplace.net/images/math/ae682f9fda97c28c8e100c87aecad635c7c1d96c.png" style="height: 18px;" /> <p>And:</p> <img alt="$class(x)=\left\{\begin{matrix} +1 &amp;amp; \operatorname{if}\ \hat{y}(x)\geq 0\\ -1 &amp;amp; \operatorname{if}\ \hat{y}(x)&amp;lt; 0 \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/092debeba72a26bd76603bd3ce140fc798e5f692.png" style="height: 43px;" /> <p>It helps seeing a graphical example of how this looks in practice. As usual, we'll have to stick to low dimensionality if we want to visualize things, so let's use 2D data points.</p> <p>Since our data is in 2D, we need a 3D <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> (<img alt="\theta_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/ba6201ddbe2fd0bb66e0704ad8b3c6bdb36f37aa.png" style="height: 15px;" /> for the bias). Let's pick <img alt="\theta=(4,-0.5, -1)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6cb259a86870d3bd0a5ad2f839d0515bfc70f0d7.png" style="height: 18px;" />. Plotting <img alt="\hat{y}(x)=\theta \cdot x" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e0a45fd444b0526e19a0f22fb3c264b026fb3bcf.png" style="height: 18px;" /> will give us a plane in 3D, but what we're really interested in is just to know whether <img alt="\hat{y}(x) \geq 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/39ad82f3252b80454caa343952948440827f2961.png" style="height: 18px;" />. So we can draw this plane's intersection with the x/y axis:</p> <img alt="Line for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/binary-classification-line.png" /> <p>We can play with some sample points to see that everything &quot;to the right&quot; of the line gives us <img alt="\hat{y}(x) &amp;gt; 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d686dc49d4c08e21f67c22cbb42aab2a1f3d3875.png" style="height: 18px;" />, and everything &quot;to the left&quot; of it gives us <img alt="\hat{y}(x) &amp;lt; 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d8a7e77c45cecd8e4ba7c8f7d1f02944e9b55ecf.png" style="height: 18px;" /> <a class="footnote-reference" href="#id13" id="id3"></a>.</p> </div> <div class="section" id="loss-functions-for-binary-classification"> <h2>Loss functions for binary classification</h2> <p>How do we find the right <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> for a classification problem? Similarly to linear regression, we're going to define a &quot;loss function&quot; and then train a classifier by minimizing this loss with gradient descent. However, here picking a good loss function is not as simple - it turns out square loss doesn't work very well, as we'll see soon.</p> <p>Let's start by considering the most logical loss function to use for classification - the number of misclassified data samples. This is called the 0/1 loss, and it's the true measure of how well a classifier works. Say we have 1000 samples, our classifier placed 960 of them in the right category, and got the wrong answer for the other 40 samples. So the loss would be 40. A better classifier may get it wrong only 35 times, so its loss would be smaller.</p> <p>It will be helpful to plot loss functions, so let's add another definition we're going to be using a lot here: the <em>margin</em>. For a given sample <strong>x</strong>, and its correct classification <em>y</em>, the margin of classification is <img alt="m=\hat{y}(x)y" class="valign-m4" src="https://eli.thegreenplace.net/images/math/fc8c312b137c8aafaaebd881836e4332cc14e61f.png" style="height: 18px;" />. Recall that <em>y</em> is either +1 or -1, so the margin is either <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> or its negation, depending on the correct answer. Note that the margin is positive when our guess is correct (both <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> and y have the same sign) and negative when our guess is wrong. With this in hand, we define 0/1 loss as:</p> <img alt="$L_{01}(m) = \mathbb{I}(m \leq 0)$" class="align-center" src="https://eli.thegreenplace.net/images/math/e9731883ade0db9b166741b2ff53a8167a8e3ffd.png" style="height: 18px;" /> <p>Where <img alt="\mathbb{I}" class="valign-0" src="https://eli.thegreenplace.net/images/math/3dcdffb11a6b55b62a0c9e29d85dd9120f5945f4.png" style="height: 12px;" /> is an <em>indicator function</em> taking the value 1 when its condition is true and the value 0 otherwise. Here is the plot of <img alt="L_{01}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3ed6799c7063de4663bdeab8fa126196f41bcd0f.png" style="height: 16px;" /> as a function of margin:</p> <img alt="0/1 loss for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/binary-01-loss.png" /> <p>Unfortunately, the 0/1 loss is fairly hostile to gradient descent optimization, since it's not convex. This is easy to see intuitively. Suppose we have some <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> that gives us a margin of -1.5. The 0/1 loss for this margin is 1, but how can we improve it? Small nudges to <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> will still give us a margin very close to -1.5, which results in exactly the same loss. We don't know which way to nudge <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> since either way we get the same outcome. In other words, there's no slope to follow here.</p> <p>That's not to say all is lost. Some work is being done with optimizing 0/1 losses for classification, but this is a bit outside the mainstream of machine learning. Here's an <a class="reference external" href="http://jmlr.org/proceedings/papers/v28/nguyen13a.pdf">interesting paper</a> that discusses some approaches. It's fascinating for computer science geeks since it uses combinatorial search techniques. The rest of this post, however, will use 0/1 loss only as an idealized limit, trying other kinds of loss we can actually run gradient descent with.</p> <p>The first such loss that comes to mind is square loss, the same one we use in linear regression. We'll define the square loss as a function of margin:</p> <img alt="$L_2(m) = (m - 1)^2$" class="align-center" src="https://eli.thegreenplace.net/images/math/ea06356db44999485977e3a7e6ff5e97e617b1bb.png" style="height: 21px;" /> <p>The reason we do this is to get two desired outcomes at important points: at <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" /> we want the loss to be 0, since this is actually the correct classification: we only get <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" /> when either both <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/45f0241f56d9823eb2d24a228d7ffe62c5fdcdc2.svg" style="height: 16px;" type="image/svg+xml">y=1</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c5f34fb4e66b84bde15d596cf76efd468983c4d5.svg" style="height: 17px;" type="image/svg+xml">\hat{y}=1</object> or when both <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ad8ddd3de86ba8af8476af79d20b151a251ec117.svg" style="height: 16px;" type="image/svg+xml">y=-1</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/4ae2d248963bac1702c9e5e1f1d0769126f0c479.svg" style="height: 17px;" type="image/svg+xml">\hat{y}=-1</object>.</p> <p>Furthermore, to approximate the 0/1 loss, we want our loss at <img alt="m=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/5e49227d625a223efeaa8d7bc48bb0b87f878bff.png" style="height: 12px;" /> to be 1. Here's a plot of the square loss together with 0/1 loss:</p> <img alt="0/1 loss and square loss for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/binary-01-with-square-loss.png" /> <p>A couple of problems are immediately apparent with the square loss:</p> <ol class="arabic simple"> <li>It penalizes correct classification as well, in case the margin is very positive. This is not something we want! Ideally, we want the loss to be 0 starting with <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" /> and for all subsequent values of <em>m</em>.</li> <li>It very strongly penalizes outliers. One sample that we misclassified badly can shift the training too much.</li> </ol> <p>We could try to fix these problems by using clamping of some sort, but there is another loss function which serves as a much better approximation to 0/1 loss. It's called &quot;hinge loss&quot;:</p> <img alt="$L_h(m) = max(0, 1-m)$" class="align-center" src="https://eli.thegreenplace.net/images/math/dd883f12c7f609fe9256e0e6bb4cfdf319d07844.png" style="height: 18px;" /> <p>And its plot, along with the previously shown losses:</p> <img alt="0/1 loss, square loss and hinge loss for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/binary-01-with-square-and-hinge-loss.png" /> <p>Note that the hinge loss also matches 0/1 loss on the two important points: <img alt="m=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/5e49227d625a223efeaa8d7bc48bb0b87f878bff.png" style="height: 12px;" /> and <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" />. It also has some nice properties:</p> <ol class="arabic simple"> <li>It doesn't penalize correct classification after <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" />.</li> <li>It penalizes incorrect classifications, but not as much as square loss.</li> <li>It's convex (at least where it matters - where the loss is nonzero)! If we get <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/5abbd129a48c53a04b0caa6eef4d760329f02149.svg" style="height: 14px;" type="image/svg+xml">m=-1.5</object> we can actually examine the loss in its very close vicinity and find a slope we can use to improve the loss. So, unlike 0/1 loss, it's amenable to gradient descent optimization.</li> </ol> <p>There are other loss functions used to train binary classifiers, such as log loss, but I will leave them out of this post.</p> <p>This is a good place to mention that hinge loss leads naturally to <a class="reference external" href="https://en.wikipedia.org/wiki/Support_vector_machine#SVM_and_the_hinge_loss">SVMs</a> (support vector machines), an interesting technique I'll leave for some other time.</p> </div> <div class="section" id="finding-a-classifier-with-gradient-descent"> <h2>Finding a classifier with gradient descent</h2> <p>With a loss function in hand, we can use <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent/">gradient descent</a> to find a good classifier for some data. The procedure is very similar to what we've been doing for linear regression:</p> <p>Given a loss function, we compute the loss gradient with respect to each <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> and update <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> for the next step:</p> <img alt="$\theta_{j}=\theta_{j}-\eta\frac{\partial L}{\partial \theta_{j}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/561a940034503fe1bb00e86c90ac130cb351d73b.png" style="height: 42px;" /> <p>Where <img alt="\eta" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2899aeb886ad0fa72652bffd5511e452aaf084ab.png" style="height: 12px;" /> is the learning rate.</p> </div> <div class="section" id="computing-gradients-for-our-loss-functions-with-regularization"> <h2>Computing gradients for our loss functions, with regularization</h2> <p>The only missing part remaining is computing the gradients for the square and loss hinge functions we've defined. In addition, I'm going to add &quot;<img alt="L_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0d2398f5890edff3f40f1686fc3b51528209bf9b.png" style="height: 15px;" /> regularization&quot; to the loss as a means to prevent overfitting for the training data. <a class="reference external" href="https://en.wikipedia.org/wiki/Regularization_(mathematics)">Regularization</a> is an important component of the learning algorithm. <img alt="L_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0d2398f5890edff3f40f1686fc3b51528209bf9b.png" style="height: 15px;" /> regularization adds the sum of the squares of all parameters to the loss, and thus &quot;tries&quot; to keep parameters low. This way, we don't end up over-emphasizing one or a group of parameters over the others.</p> <p>Here is square loss with regularization <a class="footnote-reference" href="#id14" id="id4"></a>:</p> <img alt="$L_2=\frac{1}{k}\sum_{i=1}^{k}(m^{(i)}-1)^2+\frac{\beta}{2}\sum_{j=0}^{n}\theta_{j}^2$" class="align-center" src="https://eli.thegreenplace.net/images/math/a9735ff6606b3ad3454c3dfefc541c21b926d541.png" style="height: 56px;" /> <p>This is assuming we have <em>k</em> data points (<em>n+1</em> dimensional) and <em>n+1</em> parameters (including the special 0th parameter representing the bias). The total loss is the square loss averaged over all data points, plus the regularization loss. <img alt="\beta" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6499d503bfc00cadae1440b191c52a8632e2f8c4.png" style="height: 16px;" /> is the regularization &quot;strength&quot; (another hyper-parameter in the learning algorithm).</p> <p>Let's start by computing the derivative of the margin. Using superscripts for indexing data items, recall that:</p> <img alt="$m^{(i)}=\hat{y}^{(i)}y^{(i)}=(\theta_0 x_0^{(i)}+\cdots + \theta_n x_n^{(i)})y^{(i)}$" class="align-center" src="https://eli.thegreenplace.net/images/math/bce48f26ac61cbfd37c8bfbaad0004e5c30ccbbc.png" style="height: 26px;" /> <p>Therefore:</p> <img alt="$\frac{\partial m^{(i)}}{\partial \theta_j}=x_j^{(i)}y^{(i)}$" class="align-center" src="https://eli.thegreenplace.net/images/math/fd79e2321a3ee607dbf3840535d1a8a2327e2117.png" style="height: 47px;" /> <p>With this in hand, it's easy to compute the gradient of <img alt="L_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0d2398f5890edff3f40f1686fc3b51528209bf9b.png" style="height: 15px;" /> loss.</p> <img alt="$\frac{\partial L_2}{\partial \theta_j}=\frac{2}{k}\sum_{i=1}^{k}(m^{(i)}-1)x_{j}^{(i)}y^{(i)}+\beta\theta_j$" class="align-center" src="https://eli.thegreenplace.net/images/math/2340ff828a85ab17aa5067b4985cf9da4fd5fae7.png" style="height: 54px;" /> <p>Now let's turn to hinge loss. The total loss for the data set with regularization is:</p> <img alt="$L_h=\frac{1}{k}\sum_{i=1}^{k}max(0, 1-m^{(i)})+\frac{\beta}{2}\sum_{j=0}^{n}\theta_{j}^2$" class="align-center" src="https://eli.thegreenplace.net/images/math/2ce4a6debf2650ea4c8a1ff24ce8e42f3d370a6e.png" style="height: 56px;" /> <p>The tricky part here is finding the derivative of the <img alt="max" class="valign-0" src="https://eli.thegreenplace.net/images/math/0706025b2bbcec1ed8d64822f4eccd96314938d0.png" style="height: 8px;" /> function with respect to <img alt="\theta_j" class="valign-m6" src="https://eli.thegreenplace.net/images/math/56adcea6f10a3cd4a439536412c7fb690f803bc9.png" style="height: 18px;" />. I find it easier to reason about functions like <img alt="max" class="valign-0" src="https://eli.thegreenplace.net/images/math/0706025b2bbcec1ed8d64822f4eccd96314938d0.png" style="height: 8px;" /> when the different cases are cleanly separated:</p> <img alt="$max(0,1-m^{(i)})=\left\{\begin{matrix} 1-m^{(i)} &amp;amp; \operatorname{if}\ m^{(i)}&amp;lt; 1\\ 0 &amp;amp; \operatorname{if}\ m^{(i)}\geq 1 \end{matrix}\right.$" class="align-center" src="https://eli.thegreenplace.net/images/math/884d533e1ff8dd51ae43a229bc2f86bc72e82c2a.png" style="height: 46px;" /> <p>We already know the derivative of <img alt="m^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/0971cbdfca7ab3d5c094d8a8e75c77ccf66e4715.png" style="height: 17px;" /> with respect to <img alt="\theta_j" class="valign-m6" src="https://eli.thegreenplace.net/images/math/56adcea6f10a3cd4a439536412c7fb690f803bc9.png" style="height: 18px;" />. So it's easy to derive this expression case-by-case:</p> <img alt="$\frac{\partial max(0,1-m^{(i)})}{\partial \theta_j}=\left\{\begin{matrix} -x_j^{(i)}y^{(i)} &amp;amp; \operatorname{if}\ m^{(i)}&amp;lt; 1\\ 0 &amp;amp; \operatorname{if}\ m^{(i)}\geq 1 \end{matrix}\right.$" class="align-center" src="https://eli.thegreenplace.net/images/math/4feb3f18ab008352c513de8508c4e8f877510167.png" style="height: 54px;" /> <p>And the overall gradient of the hinge loss is:</p> <img alt="$\frac{\partial L_h}{\partial \theta_j}=\frac{1}{k}\sum_{i=1}^{k}\frac{\partial max(0,1-m^{(i)})}{\partial \theta_j}+\beta\theta_j$" class="align-center" src="https://eli.thegreenplace.net/images/math/d3113e543be93630457f9501379fe0b6956d9342.png" style="height: 54px;" /> </div> <div class="section" id="experiments-with-synthetic-data"> <h2>Experiments with synthetic data</h2> <p>Let's see an example of learning binary classifier in action. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/simple_binary_classifier.py">This code sample</a> generates some synthetic data in two dimensions and then uses the approach described so far in the post to train a binary classifier. Here's a sample data set:</p> <img alt="Synthetic data for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/synthetic-data.png" /> <p>The data points for which the correct answer is positive (<em>y=1</em>) are the green crosses; the ones for which the correct answer is negative (<em>y=-1</em>) are the red dots. Note that I include a small number of negative outliers (red dots where we'd expect only green crosses to be) to test the classifier on realistic, imperfect data.</p> <p>The sample code can use combinatorial search to find a &quot;best&quot; set of parameters that results in the lowest 0/1 loss - the lowest number of misclassified data items. Note that misclassifying some items in this data set is inevitable (with a linear classifier), because of the outliers. Here is the contour line showing how the classification decision is made with parameters found by doing the combinatorial search:</p> <img alt="Synthetic data for binary classification with only 0/1 loss" class="align-center" src="https://eli.thegreenplace.net/images/2016/synthetic-data-only-01-loss.png" /> <p>The 0/1 loss - number of misclassified data items - for this set of parameters is 20 out of 400 data items (95% correct prediction rate).</p> <p>Next, the code trains a classifier using square loss, and another using hinge loss. I'm not using regularization for this data set, since with only 3 parameters there can't be too much selective bias between them; in other words, <img alt="\beta=0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3bb1ac87ba8d8d0c95fd43b91640c0b96f8e72d9.png" style="height: 16px;" />.</p> <p>A classifier trained with square loss misclassifies 32 items (92% success rate). A classifier trained with hinge loss misclassifies 26 items (93.5% success rate, much closer to the &quot;perfect&quot; rate). This is to be expected from the earlier discussion - square loss very strongly penalizes outliers, which makes it more skewed on this data <a class="footnote-reference" href="#id15" id="id5"></a>. Here are the contour plots for all losses that demonstrate this graphically:</p> <img alt="Synthetic data for binary classification with all losses" class="align-center" src="https://eli.thegreenplace.net/images/2016/synthetic-data-all-losses.png" /> </div> <div class="section" id="binary-classification-of-mnist-digits"> <h2>Binary classification of MNIST digits</h2> <p>The <a class="reference external" href="https://en.wikipedia.org/wiki/MNIST_database">MNIST dataset</a> is the &quot;hello world&quot; of machine learning these days. It's a database of grayscale images representing handwritten digits, with a correct label for each of these images.</p> <p>MNIST is usually employed for the more general multinomial classification problem - classifying a given data item into one of multiple classes (0 to 9 in the case of MNIST). We'll address this in a later section.</p> <p>Here, however, we can experiment with training a binary classifier on MNIST. The idea is to train a classifier that recognizes some single label. For example, a classifier answering the question &quot;is this an image of the digit 4&quot;. This is a binary classification problem, since there are only two answers - &quot;yes&quot; and &quot;no&quot;.</p> <p><a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/mnist_binary_classifier.py">Here's a code sample</a> that trains such a classifier, using the hinge loss function (since we've already determined it gives better results than square loss for classification problems).</p> <p>It starts by converting the correct labels of MNIST from the numeric range 0-9 to +1 or -1 based on whether the label is 4:</p> <div class="highlight"><pre><span></span> 0 -1 1 -1 4 1 9 -1 y = 3 ==&gt; -1 8 -1 5 -1 ... 4 1 </pre></div> <p>Then all we have is a binary classification problem, albeit one that is 785-dimensional (784 dimensions for each of the 28x28 pixels in the input images, plus one for bias). Visualizing the separating contours would be quite challenging here, but we can now trust the math to know what's going on. Other than this, the code for gradient descent is <em>exactly the same</em> as for the simple 2D synthetic data shown earlier.</p> <p>My goal here is not to design a state-of-the-art machine learning architecture, but to explain how the main parts work. So I didn't tune the model too much, but it's possible to get 98% accuracy on this binary formulation of MNIST by tuning the code a bit. While 98% sounds great, recall that we could get 90% just by saying &quot;no&quot; to every digit :-) Feel free to play with the code to see if you can get even higher numbers; I don't really expect record-beating numbers from this model, though, since it's so simple.</p> </div> <div class="section" id="logistic-regression-predicting-probabilities"> <h2>Logistic regression - predicting probabilities</h2> <p>So far the predictors we've been looking at were trained to return a binary yes/no response; a more useful model would also tell us how sure it is. For example &quot;what is the chance of rain tomorrow&quot;, rather than &quot;will there be rain, yes or no&quot;? The probability gives additional information. &quot;90% chance of rain&quot; vs. &quot;56% chance of rain&quot; gives us additional information over the binary &quot;yes&quot; for both cases (assuming a 50% cutoff).</p> <p>Moreover, note that the linear model we've trained actually provides more information already, giving a numerical answer. We choose to cut it off at 0, saying yes for positive and no for negative numbers. But some numbers are more positive (or negative) than others!</p> <p>Quick thought experiment: can we somehow interpret the response before cutoff as probability? The main problem here is that probabilities must be in the range [0, 1], while the linear model gives us an arbitrary real number. We may end up with negative probabilities or probabilities over 1, neither of which makes much sense. So we'll want to find some mathematical way to &quot;squish&quot; the result into the valid [0, 1] range. A common way to do this is to use the logistic function:</p> <img alt="$S(z) = \frac{1}{1 + e^{-z}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/62429be191903e2433ba80f92aaf1044568b831d.png" style="height: 38px;" /> <p>It's also known as the &quot;sigmoid&quot; function because of its S-like shape:</p> <img alt="Sigmoid function" class="align-center" src="https://eli.thegreenplace.net/images/2016/sigmoid.png" /> <p>We're going to assign <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> into the <em>z</em> variable of the sigmoid, to get the function:</p> <img alt="$S(x) = \frac{1}{1 + e^{-(\theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n)}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/2b9f6770ff23ed08c38a9ab5c3b5972f5d002ddb.png" style="height: 39px;" /> <p>And now, the answer we get can be interpreted as a probability between 0 and 1 (without actually touching either asymptote) <a class="footnote-reference" href="#id16" id="id6"></a>. We can train a model to get as close to 1 as possible for training samples where the true answer is &quot;yes&quot; and as close to 0 as possible for training samples where the true answer is &quot;no&quot;. This is called &quot;logistic regression&quot; due to the use of the logistic function.</p> </div> <div class="section" id="training-logistic-regression-with-the-cross-entropy-loss"> <h2>Training logistic regression with the cross-entropy loss</h2> <p>Earlier in this post, we've seen how a number of loss functions fare for the binary classifier problem. It turns out that for logistic regression, a very natural loss function exists that's called <a class="reference external" href="https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression">cross-entropy</a> (also sometimes &quot;logistic loss&quot; or &quot;log loss&quot;). This loss function is derived from probability and information theory, and its derivation is outside the scope of this post (check out <a class="reference external" href="http://neuralnetworksanddeeplearning.com/chap3.html">Chapter 3 of Michael Nielsen's online book</a> for a nice intuitive explanation for why this loss function makes sense).</p> <p>The formulation of cross-entropy we're going to use here starts from the most general:</p> <img alt="$C(x^{(i)})=-\sum_{t} p^{(i)}_t log(p(y^{(i)}=t|\theta))$" class="align-center" src="https://eli.thegreenplace.net/images/math/a689c6537836933fae93c80a71cd52ff88703a78.png" style="height: 41px;" /> <p>Let's unravel this definition, step by step. The parenthesized superscript <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" /> denotes, as usual, the <em>ith</em> input sample. <em>t</em> runs over all the possible outcomes; <img alt="p_t" class="valign-m4" src="https://eli.thegreenplace.net/images/math/aaf082725869f54161f39f7d9c39fff25c52ac94.png" style="height: 12px;" /> is the actual probability of outcome <em>t</em> and inside the <em>log</em> we have the conditional probability of this outcome given the regression parameters - in other words, this is the model's prediction <a class="footnote-reference" href="#id17" id="id7"></a>.</p> <p>To make this more concrete, in our case we have two possible outcomes in the training data: either <img alt="y^{(i)}=+1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e3495884df85359610f062a6a6428fba7891bb8.png" style="height: 21px;" /> or <img alt="y^{(i)}=-1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8465f16030efd8eab0982f1e60b8ff292317cdbe.png" style="height: 21px;" />. Given any such outcome, its &quot;actual&quot; probability is either 1 (when we get this outcome in the training data) or 0 (when we don't). So for any given sample, one of the two possible values of <em>t</em> has <img alt="p^{(i)}_t=0" class="valign-m5" src="https://eli.thegreenplace.net/images/math/eedcbf364060646a9b6abfccb8e9dda67a645ff0.png" style="height: 25px;" /> and the other has <img alt="p^{(i)}=1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e44b2858aeb1c845d09a851cbea5fdc9c465199e.png" style="height: 21px;" />. Therefore, we get <a class="footnote-reference" href="#id18" id="id8"></a>:</p> <img alt="$C(x^{(i)})=\left\{ \begin{matrix} -log(S(x^{(i)}) &amp;amp; \operatorname{if}\ y^{(i)}=+1 \\ -log(1-S(x^{(i)})) &amp;amp; \operatorname{if}\ y^{(i)}=-1 \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/97e3fd44d870673c7a74047b82e30c993a9bec59.png" style="height: 46px;" /> <p>The second possibility has <img alt="-log(1-S(x^{(i)}))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0b34c17378147a8a82db655998c07649ca71ed39.png" style="height: 21px;" /> because we define <img alt="S(z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/61bc9efb9d2c99669df519617ee7daee7670e156.png" style="height: 18px;" /> to predict the probability of the answer being +1; therefore, the probability of the answer being -1 is <img alt="1-S(z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d006e787dd01f802c9c5cb570e39a44cb133b2ce.png" style="height: 18px;" />.</p> <p>This is the cross-entropy loss for a single sample <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" />. To get the total loss over a data set, we take the average sample loss, as usual:</p> <img alt="$C = \frac{1}{k}\sum_{i=1}^{k} C(x^{(i)})$" class="align-center" src="https://eli.thegreenplace.net/images/math/642351dc03ee1f11eca503f558971282d5c700e7.png" style="height: 54px;" /> <p>Now let's compute the gradient of this loss function, so we can use it to train a model. Starting with the +1 case, we have:</p> <img alt="$C_{+1} = -log(S(x^{(i)}))$" class="align-center" src="https://eli.thegreenplace.net/images/math/986418ed0bf4c05742c9a412a0918ed00108d93d.png" style="height: 23px;" /> <p>Then:</p> <img alt="$\frac{\partial C_{+1}}{\partial \theta_j} = \frac{-1}{S(x^{(i)})}\frac{\partial S(x^{(i)})}{\partial \theta_j}$" class="align-center" src="https://eli.thegreenplace.net/images/math/c91fbcf0bb4630112d1efa5adbb8756c25512c68.png" style="height: 47px;" /> <p>Here it will be helpful to use the following identity, which can be easily verified by going through the math <a class="footnote-reference" href="#id19" id="id9"></a>:</p> <img alt="$S&amp;#x27;(z)=S(z)(1-S(z))$" class="align-center" src="https://eli.thegreenplace.net/images/math/3d880e07d60096518b916e877cd6a8496c39bc37.png" style="height: 20px;" /> <p>Since in our case <img alt="S(x^{(i)})" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8a85ab5b49ac41fe751ac8b29e2f2e76f34650bb.png" style="height: 21px;" /> is actually <img alt="S(\hat{y}(x^{(i})))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/915144b3b0a3b41ff5d71f88e798c702386cfea8.png" style="height: 21px;" /> where <img alt="\hat{y}(x) = \theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7ad144258d3d91e1ada8fd7f94a7d0b0538faa2d.png" style="height: 18px;" />, we can apply the chain rule:</p> <img alt="$\frac{\partial S(x^{(i)})}{\partial \theta_j}=S(x^{(i)})(1-S(x^{(i)}))x^{(i)}_j$" class="align-center" src="https://eli.thegreenplace.net/images/math/6bb3f809b570699a74428c137ea715d97b08b58d.png" style="height: 47px;" /> <p>Substituting back into <img alt="\frac{\partial C_{+1}}{\partial \theta_j}" class="valign-m10" src="https://eli.thegreenplace.net/images/math/cc23ed0ff22b532e2ab3fec04117c8c968318629.png" style="height: 29px;" />, we get:</p> <img alt="\begin{align*} \frac{\partial C_{+1}}{\partial \theta_j} &amp;amp;= \frac{-1}{S(x^{(i)})}S(x^{(i)})(1-S(x^{(i)}))x^{(i)}_j \\ &amp;amp;= (S(x^{(i)})-1)x^{(i)}_j \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/dc2a24be9f61f7066fbaeb48805bb59c51e445c0.png" style="height: 76px;" /> <p>Similarly, for <img alt="C_{-1}=-log(1-S(x^{(i)}))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/91460ff9e118d56fbbce2f0557bb9208a4d438a4.png" style="height: 21px;" /> we can compute:</p> <img alt="$\frac{\partial C_{-1}}{\partial \theta_j} = S(x^{(i)})x^{(i)}_j$" class="align-center" src="https://eli.thegreenplace.net/images/math/5e5fd85290c89289aacb1486d9f706bd9fca8fdc.png" style="height: 42px;" /> <p>Putting it all together, we find that the contribution of <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" /> to the gradient of <img alt="\theta_j" class="valign-m6" src="https://eli.thegreenplace.net/images/math/56adcea6f10a3cd4a439536412c7fb690f803bc9.png" style="height: 18px;" /> is:</p> <img alt="$\frac{\partial C(x^{(i)})}{\partial \theta_j}=\left\{ \begin{matrix} (S(x^{(i)})-1)x^{(i)}_j &amp;amp; \operatorname{if}\ y^{(i)}=+1 \\ S(x^{(i)})x^{(i)}_j &amp;amp; \operatorname{if}\ y^{(i)}=-1 \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/21cf7ba3c128242b99272a3e47b5ab5c09cb24bf.png" style="height: 56px;" /> <p>Using these formulae, we can train a binary logistic classifier for MNIST that gives us a probability of some input image being a 4, rather than a yes/no answer. The <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/mnist_binary_classifier.py">binary MNIST code sample</a> trains either a binary or a logistic classifier using a lot of shared infrastructure.</p> <p>The probability gives us more information than just a yes/no answer. Consider, for example the following image from the MNIST database. When I trained a binary classifier with hinge loss to recognize the image 4 for 1200 steps, it wrongly predicted it is a 4:</p> <img alt="Image of a 9 from MNIST" class="align-center" src="https://eli.thegreenplace.net/images/2016/mnist-test-9740.png" /> <p>The model clearly made a mistake here, but can we know <em>how</em> wrong it was? It would be hard to know with a binary classifier that only gives us a yes/no answer. However, when I run a logistic regression model on the same image, it tells me it is 53% confident this is a 4. Since our cutoff for yes/no is 50%, this is quite close to the threshold and thus I'd say the model didn't make a huge mistake here.</p> </div> <div class="section" id="multiclass-logistic-regression"> <h2>Multiclass logistic regression</h2> <p>The previous example is a great transition into the topic of multiclass logistic regression. Most real-life problems have more than one possible answer and it would be nice to train models to select the most suitable answer for any given input.</p> <p>Our input is still a vector <strong>x</strong>, but now instead of assigning +1 or -1 as the answer, we'll be assigning one of a fixed set of classes. If there are T classes, the answer will be a number in the closed range [0..T-1].</p> <p>The good news is that we can use the building blocks developed in this post to put together a multiclass classifier. There are many ways to do this; here I'll focus on two: one-vs-all classification and softmax.</p> </div> <div class="section" id="one-vs-all-classification"> <h2>One-vs-all classification</h2> <p>The One-vs-all (OvA), also known as one-vs-rest (OvR) approach is a natural extension of binary classification:</p> <ol class="arabic simple"> <li>For each class <img alt="t\in[0..T-1]" class="valign-m5" src="https://eli.thegreenplace.net/images/math/6eda0dcb5f9805e0e0e4c3d0af82aacdf1295efd.png" style="height: 18px;" /> we train a logistic classifier where we set <em>t</em> as the &quot;correct&quot; answer, and the other classes as the &quot;incorrect&quot; answers (+1 and -1 respectively).</li> <li>The result of each such classifier is the probability that an input sample belongs to class <em>t</em>.</li> <li>Given a new input, we run all <em>T</em> classifiers on it and the one that gives us the highest probability is chosen as the true class of the input.</li> </ol> <p>As a completely synthetic example to make this clearer, suppose that <em>T=3</em>. We take the training data and train 3 logistic regressions. In the first - <img alt="C_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/33e4cbb170d6026eb67de894c0d01e8702fb065d.png" style="height: 15px;" />, we set 0 as the right answer, 1 and 2 as the wrong answers. In the second - <img alt="C_1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c538a6221da718dd38230dcbb6e1a8fb40561f7a.png" style="height: 16px;" /> we set 1 as the right answer, 0 and 2 as the wrong answers. Finally in the third - <img alt="C_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/e65b6ebf7cbd7ef19069cc4837331af9d119cfe6.png" style="height: 15px;" /> we set 2 as the right answer, 0 and 1 the wrong answers.</p> <p>Now, given a new input vector <strong>x</strong> we run <img alt="C_0(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/5ed83eb3961cbf4855ce46814719658cdc79e5f2.png" style="height: 18px;" />, <img alt="C_1(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/47b77ff17810fb0a0d4f6b86f50d403e8a59a7a7.png" style="height: 18px;" /> and <img alt="C_2(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c063fddc4bcdd77e1131dc70ec5b578b5ec887ef.png" style="height: 18px;" />. Each of these gives us the probability of <strong>x</strong> belonging to the respective class. If we put all the classifiers in a vector, we get:</p> <img alt="$C(x)=[C_0(x), C_1(x), C_2(x)]$" class="align-center" src="https://eli.thegreenplace.net/images/math/9e4c42a11867dda976b1f7b1ac6aaa46b6625ee9.png" style="height: 19px;" /> <p>We pick the class where the probability is highest. Mathematically, we can use the <a class="reference external" href="https://en.wikipedia.org/wiki/Arg_max">argmax function</a> for this purpose. <em>argmax</em> returns the index of the maximal element in the given vector. For example, given:</p> <img alt="$C(x)=[0.45, 0.42, 0.09]$" class="align-center" src="https://eli.thegreenplace.net/images/math/983ab6a6770f41c06b3eb32f811678aab7f6fb5b.png" style="height: 19px;" /> <p>We get:</p> <img alt="$\underset{t \in [0..2]}{argmax}(C(x))=0$" class="align-center" src="https://eli.thegreenplace.net/images/math/1fa779c771c2d0abcaca9a759ab2e99608842f82.png" style="height: 34px;" /> <p>Therefore, the chosen class is 0. These class/index numbers are just labels of course. They can stand for anything depending on the problem domain: medical condition names, digits and so on.</p> <p>This approach doesn't require any additional math over what we've already covered in this post. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/mnist_multinomial_classifier.py">This multinomial MNIST classifier code sample</a> implements it. The error rate it achieves is ~11%, similar to what <a class="reference external" href="http://yann.lecun.com/exdb/publis/index.html#lecun-98">LeCun's 1998 paper</a> achieved with a simple linear classifier. Much better than 11% can be done for MNIST, even with a single-layer linear model. However, my model is very far from the state of art - there's no preprocessing, no artificially-enlarged training set, no adaptive learning rate; I didn't even spend time tuning the hyperparameters (regularization type and constants, learning rate, batch size etc.) The goal here was just to demonstrate the basics of logistic regression, not to compete for the state of the art in MNIST.</p> </div> <div class="section" id="softmax"> <h2>Softmax</h2> <p>An alternative to OvA is to use the softmax function. I covered softmax <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/">in some detail</a> previously; just briefly, softmax is a function <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1dd52a52398e38c9549b289449de49ba5fbb98b7.svg" style="height: 19px;" type="image/svg+xml">S(\mathbf{a}):\mathbb{R}^{N}\rightarrow \mathbb{R}^{N}</object> such that:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/5470218612381816a8c9a897d43201757560e646.svg" style="height: 46px;" type="image/svg+xml"> $S_j=\frac{e^{a_j}}{\sum_{k=1}^{N}e^{a_k}} \qquad \forall j \in 1..N$</object> <p>It is very useful for multiclass classification, since it lets us generate probabilities of the input belonging to one of <em>N</em> classes. Similarly to the OvA case, here we have to train 10 different parameter vectors, one for each digit. However, unlike OvA, this training doesn't happen separately but occurs at the same time. Instead of training a model to find a single parameter vector each time, we train a parameter <em>matrix</em> once.</p> <p>The model structure is as follows:</p> <img alt="Model of softmax logistic regression" class="align-center" src="https://eli.thegreenplace.net/images/2016/softmax-logistic-model.png" /> <p>I've chosen the number of classes to be 10 to reflect MNIST where we have 10 possible digits to assign to every input. In MNIST <em>N</em> is 785 (784 for each of 28x28 pixels in the image, plus one for bias). &quot;Logits&quot; is a common name to assign to the output of a fully connected layer (which is what we have with the matrix-vector multiplication in the first stage); the logits are arbitrary real numbers. The softmax function is responsibe for squeezing them into the range of probabilities (0, 1) and making sure they all add up to 1.</p> <p>This diagram shows what happens to a single input as it goes through the model. In a realistic program, there will be another dimension - the batch dimension, used to vectorize the computation over a whole batch of inputs.</p> <p>For training this model, we need a loss function. It turns out cross-entropy is a very popular loss function to use for softmax. In the <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/">softmax post</a> I also covered how to compute the gradient of cross-entropy on a softmax, so we're all set to write some code: the <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/mnist_softmax_classifier.py">full sample is here</a>. Running it on MNIST for a couple of minutes produces a 9.5% error rate - slightly better than the OvA approach, but very close. This is to be expected, since OvA and softmax compute very similar results (finding the maximal probability from a set of probabilities), just in a different way. Softmax regression is much faster, however, since we can vectorize the training for all 10 digits in the same run.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>In this post I'm following many of the conventions established in my post on <a class="reference external" href="http://eli.thegreenplace.net/2016/linear-regression/">linear regression</a>. In particular, by construction <img alt="x_0=1" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0c1d7f319728a07a57d000f2379b5215e4130147.png" style="height: 15px;" /> so that <img alt="\theta_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/ba6201ddbe2fd0bb66e0704ad8b3c6bdb36f37aa.png" style="height: 15px;" /> is the bias.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id12" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>Why? Because we have the bias as part of the model, so any constant offset can be absorbed into the learned bias.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id13" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>Note that this outcome is, once again, somewhat arbitrary. We could find another plane that intersects the x/y axis on the same line, and get a different classification. For example, if we flip the sign of all the elements of <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, we get the same intersection line. In that case, however, values &quot;to the right&quot; of the line give us <img alt="\hat{y}(x) &amp;lt; 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d8a7e77c45cecd8e4ba7c8f7d1f02944e9b55ecf.png" style="height: 18px;" />. Since the labels we attach are arbitrary, this really makes no difference. The only important thing is that we find a line that separates &quot;true&quot; from &quot;false&quot; samples and be consistent with our signs and labels throughout the process.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id14" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>Note that both the loss and the regularization are called <img alt="L_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0d2398f5890edff3f40f1686fc3b51528209bf9b.png" style="height: 15px;" />. This is a bit confusing, but both are essentially 2nd norms. It's best to ignore the name of the regularization factor and just refer to it as &quot;regularization&quot;. I thought it's important to mention initially as there are other kinds of regularization being used for machine-learning algorithms and I wanted to make it clear which one is being used here.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id15" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id5"></a></td><td>As an exercise, play with the code to increase or decrease the number of outliers (the code makes it easily controllable), and observe the effects on the misclassification rates of the different loss functions.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id16" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id6"></a></td><td>Note that using the logistic function on the model's output is strictly a generalization of the binary classifier. We can still make a binary interpretation of the result if we're so inclined, interpreting <img alt="S(z) \geq 0.5" class="valign-m4" src="https://eli.thegreenplace.net/images/math/763035b41ff594d664c57d9fcc03c85808d0ccce.png" style="height: 18px;" /> as &quot;yes&quot; and otherwise as &quot;no&quot;. In terms of the input to <img alt="S(z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/61bc9efb9d2c99669df519617ee7daee7670e156.png" style="height: 18px;" />, this means &quot;yes&quot; for <img alt="z=\hat{y}(x) \geq 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2599fe02308a2d43e5b29b2f9387ee45c5c67a1b.png" style="height: 18px;" /> which is exactly the formulation we've been using for the binary classifier.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id17" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id7"></a></td><td>In essence, cross entropy is computed between two probability distributions. Here, one of them is the &quot;real&quot; distribution observed in the <em>y</em> data. The other is what we predict given <em>X</em> data and our regression parameters <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />. The observed real probability is either 0 or 1 for any given data item, and the corresponding predicted probability is our model's output. I also discussed cross-entropy in the <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/">post about softmax</a>.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id18" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id8"></a></td><td>Many resources online condense this formula to a single line without the condition: <img alt="C(x)=-ylog(S(x))-(1-y)log(1-S(x))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0f9309397d7ef59c72cdb2d861e5532292978ca6.png" style="height: 18px;" />. I'm avoiding this formulation on purpose, because it requires the possible values of <em>y</em> to be 0 and 1, not -1 and +1. Although it's possible to play with constants a bit to reformulate the -1/+1 case in a similarly condensed fashion, I find the version with the condition more explicit and thus easier to follow, even if it requires a bit more typing.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id19" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id9"></a></td><td>See also <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus/">my post</a> about the chain rule, where this derivation is shown.</td></tr> </tbody> </table> </div> The Softmax function and its derivative2016-10-18T05:20:00-07:002016-10-18T05:20:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-10-18:/2016/the-softmax-function-and-its-derivative/<p>The softmax function takes an N-dimensional vector of arbitrary real values and produces another N-dimensional vector with real values in the range (0, 1) that add up to 1.0. It maps <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1dd52a52398e38c9549b289449de49ba5fbb98b7.svg" style="height: 19px;" type="image/svg+xml">S(\mathbf{a}):\mathbb{R}^{N}\rightarrow \mathbb{R}^{N}</object>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/cd593d87595e496072aebf5100dd87c37c889f25.svg" style="height: 86px;" type="image/svg+xml"> S(\mathbf{a}):\begin{bmatrix} a_1\\ a_2\\ \cdots …</object><p>The softmax function takes an N-dimensional vector of arbitrary real values and produces another N-dimensional vector with real values in the range (0, 1) that add up to 1.0. It maps <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1dd52a52398e38c9549b289449de49ba5fbb98b7.svg" style="height: 19px;" type="image/svg+xml">S(\mathbf{a}):\mathbb{R}^{N}\rightarrow \mathbb{R}^{N}</object>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/cd593d87595e496072aebf5100dd87c37c889f25.svg" style="height: 86px;" type="image/svg+xml"> \[S(\mathbf{a}):\begin{bmatrix} a_1\\ a_2\\ \cdots\\ a_N \end{bmatrix} \rightarrow \begin{bmatrix} S_1\\ S_2\\ \cdots\\ S_N \end{bmatrix}</object> <p>And the actual per-element formula is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/5470218612381816a8c9a897d43201757560e646.svg" style="height: 46px;" type="image/svg+xml"> $S_j=\frac{e^{a_j}}{\sum_{k=1}^{N}e^{a_k}} \qquad \forall j \in 1..N$</object> <p>It's easy to see that <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/cb8b5683be866b4c177c0c319e14085f25bec523.svg" style="height: 18px;" type="image/svg+xml">S_j</object> is always positive (because of the exponents); moreover, since the numerator appears in the denominator summed up with some other positive numbers, <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/5a34de9dd188a5a6f758bb0f7daabb58e03045ec.svg" style="height: 18px;" type="image/svg+xml">S_j&lt;1</object>. Therefore, it's in the range (0, 1).</p> <p>For example, the 3-element vector <tt class="docutils literal">[1.0, 2.0, 3.0]</tt> gets transformed into <tt class="docutils literal">[0.09, 0.24, 0.67]</tt>. The order of elements by relative size is preserved, and they add up to 1.0. Let's tweak this vector slightly into: <tt class="docutils literal">[1.0, 2.0, 5.0]</tt>. We get the output <tt class="docutils literal">[0.02, 0.05, 0.93]</tt>, which still preserves these properties. Note that as the last element is farther away from the first two, it's softmax value is dominating the overall slice of size 1.0 in the output. Intuitively, the softmax function is a &quot;soft&quot; version of the maximum function. Instead of just selecting one maximal element, softmax breaks the vector up into parts of a whole (1.0) with the maximal input element getting a proportionally larger chunk, but the other elements getting some of it as well <a class="footnote-reference" href="#id3" id="id1"></a>.</p> <div class="section" id="probabilistic-interpretation"> <h2>Probabilistic interpretation</h2> <p>The properties of softmax (all output values in the range (0, 1) and sum up to 1.0) make it suitable for a probabilistic interpretation that's very useful in machine learning. In particular, in multiclass classification tasks, we often want to assign probabilities that our input belongs to one of a set of output classes.</p> <p>If we have N output classes, we're looking for an N-vector of probabilities that sum up to 1; sounds familiar?</p> <p>We can interpret softmax as follows:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/4510f717b770547b90526c714355f4c81d1b4a50.svg" style="height: 19px;" type="image/svg+xml"> $S_j=P(y=j|a)$</object> <p>Where <em>y</em> is the output class numbered <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/310debdf2f7fe03ad7888e95000c78a0efae5500.svg" style="height: 13px;" type="image/svg+xml">1..N</object>. <em>a</em> is any N-vector. The most basic example is <a class="reference external" href="http://eli.thegreenplace.net/2016/logistic-regression/">multiclass logistic regression</a>, where an input vector <em>x</em> is multiplied by a weight matrix <em>W</em>, and the result of this dot product is fed into a softmax function to produce probabilities. This architecture is explored in detail later in the post.</p> <p>It turns out that - from a probabilistic point of view - softmax is optimal for <a class="reference external" href="https://en.wikipedia.org/wiki/Maximum_likelihood_estimation">maximum-likelihood estimation</a> of the model's parameters. This is beyond the scope of this post, though. See chapter 5 of the <a class="reference external" href="http://www.deeplearningbook.org/">&quot;Deep Learning&quot; book</a> for more details.</p> </div> <div class="section" id="some-preliminaries-from-vector-calculus"> <h2>Some preliminaries from vector calculus</h2> <p>Before diving into computing the derivative of softmax, let's start with some preliminaries from vector calculus.</p> <p>Softmax is fundamentally a vector function. It takes a vector as input and produces a vector as output; in other words, it has multiple inputs and multiple outputs. Therefore, we cannot just ask for &quot;the derivative of softmax&quot;; We should instead specify:</p> <ol class="arabic simple"> <li>Which component (output element) of softmax we're seeking to find the derivative of.</li> <li>Since softmax has multiple inputs, with respect to which input element the partial derivative is computed.</li> </ol> <p>If this sounds complicated, don't worry. This is exactly why the notation of vector calculus was developed. What we're looking for is the partial derivatives:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2eae0a040f9eb82a2cf0a596c926aca49a3cdb66.svg" style="height: 42px;" type="image/svg+xml"> $\frac{\partial S_i}{\partial a_j}$</object> <p>This is the partial derivative of the i-th output w.r.t. the j-th input. A shorter way to write it that we'll be using going forward is: <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/ca95d97dc85a733a280ccaab680d01727376e383.svg" style="height: 18px;" type="image/svg+xml">D_{j}S_i</object>.</p> <p>Since softmax is a <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/91b745aec8f7c3a5501975b040a4aef477c31412.svg" style="height: 16px;" type="image/svg+xml">\mathbb{R}^{N}\rightarrow \mathbb{R}^{N}</object> function, the most general derivative we compute for it is the Jacobian matrix:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7af5ba48ed18f62f0fa31b60ba35e8e94054931c.svg" style="height: 76px;" type="image/svg+xml"> $DS=\begin{bmatrix} D_1 S_1 &amp; \cdots &amp; D_N S_1 \\ \vdots &amp; \ddots &amp; \vdots \\ D_1 S_N &amp; \cdots &amp; D_N S_N \end{bmatrix}$</object> <p>In ML literature, the term &quot;gradient&quot; is commonly used to stand in for the derivative. Strictly speaking, gradients are only defined for scalar functions (such as loss functions in ML); for vector functions like softmax it's imprecise to talk about a &quot;gradient&quot;; the Jacobian is the fully general derivate of a vector function, but in most places I'll just be saying &quot;derivative&quot;.</p> </div> <div class="section" id="derivative-of-softmax"> <h2>Derivative of softmax</h2> <p>Let's compute <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/166e309484516e7fea86d27f36f42639ab73b471.svg" style="height: 18px;" type="image/svg+xml">D_j S_i</object> for arbitrary <em>i</em> and <em>j</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/dbee1c4ac839a1eef7202f447f754341eec98904.svg" style="height: 53px;" type="image/svg+xml"> $D_j S_i=\frac{\partial S_i}{\partial a_j}= \frac{\partial \frac{e^{a_i}}{\sum_{k=1}^{N}e^{a_k}}}{\partial a_j}$</object> <p>We'll be using the quotient rule of derivatives. For <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/25ee22368ab19a6e8608ac7417cf62e235794e54.svg" style="height: 29px;" type="image/svg+xml">f(x) = \frac{g(x)}{h(x)}</object>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/c0fd805caf8b7d8336e8c52f2759b3ce73295315.svg" style="height: 43px;" type="image/svg+xml"> $f&#x27;(x) = \frac{g&#x27;(x)h(x) - h&#x27;(x)g(x)}{[h(x)]^2}$</object> <p>In our case, we have:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/167b7392a9d51fbc4016901d48995f091f627e3a.svg" style="height: 82px;" type="image/svg+xml"> \begin{align*} g_i&amp;=e^{a_i} \\ h_i&amp;=\sum_{k=1}^{N}e^{a_k} \end{align*}</object> <p>Note that no matter which <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/c2d2e987a5cb0df2f497d2dba0da0960fb6fbcc0.svg" style="height: 14px;" type="image/svg+xml">a_j</object> we compute the derivative of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/969951984c96d748d949ee5e5322f4c2dbb75087.svg" style="height: 16px;" type="image/svg+xml">h_i</object> for, the answer will always be <object class="valign-0" data="https://eli.thegreenplace.net/images/math/a4c5fca09246e4e7c55473070976f788e032c514.svg" style="height: 12px;" type="image/svg+xml">e^{a_j}</object>. This is not the case for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d141c63d6e5b4ff91ec2936c9b320454461258a0.svg" style="height: 12px;" type="image/svg+xml">g_i</object>, howewer. The derivative of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d141c63d6e5b4ff91ec2936c9b320454461258a0.svg" style="height: 12px;" type="image/svg+xml">g_i</object> w.r.t. <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/c2d2e987a5cb0df2f497d2dba0da0960fb6fbcc0.svg" style="height: 14px;" type="image/svg+xml">a_j</object> is <object class="valign-0" data="https://eli.thegreenplace.net/images/math/a4c5fca09246e4e7c55473070976f788e032c514.svg" style="height: 12px;" type="image/svg+xml">e^{a_j}</object> only if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8e4587fc82ce6377530643c5622b41e53cdf3dd3.svg" style="height: 16px;" type="image/svg+xml">i=j</object>, because only then <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d141c63d6e5b4ff91ec2936c9b320454461258a0.svg" style="height: 12px;" type="image/svg+xml">g_i</object> has <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/c2d2e987a5cb0df2f497d2dba0da0960fb6fbcc0.svg" style="height: 14px;" type="image/svg+xml">a_j</object> anywhere in it. Otherwise, the derivative is 0.</p> <p>Going back to our <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/166e309484516e7fea86d27f36f42639ab73b471.svg" style="height: 18px;" type="image/svg+xml">D_j S_i</object>; we'll start with the <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8e4587fc82ce6377530643c5622b41e53cdf3dd3.svg" style="height: 16px;" type="image/svg+xml">i=j</object> case. Then, using the quotient rule we have:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/d7489693552878c00ad6788a0c8987416cbb0796.svg" style="height: 53px;" type="image/svg+xml"> $\frac{\partial \frac{e^{a_i}}{\sum_{k=1}^{N}e^{a_k}}}{\partial a_j}= \frac{{}e^{a_i}\Sigma-e^{a_j}e^{a_i}}{\Sigma^2}$</object> <p>For simplicity <object class="valign-0" data="https://eli.thegreenplace.net/images/math/cb5615b3fcee824f137c372e351ccca3ff3a3292.svg" style="height: 12px;" type="image/svg+xml">\Sigma</object> stands for <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/2c3662fbb97e3b5c528e8b1cdf89e108bfeed206.svg" style="height: 23px;" type="image/svg+xml">\sum_{k=1}^{N}e^{a_k}</object>. Reordering a bit:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2634d0ab6532983a88a1f55a33cf6a6719a291ee.svg" style="height: 123px;" type="image/svg+xml"> \begin{align*} \frac{\partial \frac{e^{a_i}}{\sum_{k=1}^{N}e^{a_k}}}{\partial a_j}&amp;= \frac{e^{a_i}\Sigma-e^{a_j}e^{a_i}}{\Sigma^2}\\ &amp;=\frac{e^{a_i}}{\Sigma}\frac{\Sigma - e^{a_j}}{\Sigma}\\ &amp;=S_i(1-S_j) \end{align*}</object> <p>The final formula expresses the derivative in terms of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/3e218c43050832e5df45f69fb2c8b8a01f7f5a52.svg" style="height: 15px;" type="image/svg+xml">S_i</object> itself - a common trick when functions with exponents are involved.</p> <p>Similarly, we can do the <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/09eca402f8bc6311cca3a98625e29e75cc336d31.svg" style="height: 17px;" type="image/svg+xml">i\ne j</object> case:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/d788a4ff0e07827862aaf0ded5befbf1665d90cc.svg" style="height: 123px;" type="image/svg+xml"> \begin{align*} \frac{\partial \frac{e^{a_i}}{\sum_{k=1}^{N}e^{a_k}}}{\partial a_j}&amp;= \frac{0-e^{a_j}e^{a_i}}{\Sigma^2}\\ &amp;=-\frac{e^{a_j}}{\Sigma}\frac{e^{a_i}}{\Sigma}\\ &amp;=-S_j S_i \end{align*}</object> <p>To summarize:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/f776365373202f727625c0be825d55a2fde47882.svg" style="height: 43px;" type="image/svg+xml"> $D_j S_i=\left\{\begin{matrix} S_i(1-S_j) &amp; i=j\\ -S_j S_i &amp; i\ne j \end{matrix}\right$</object> <p>I like seeing this explicit breakdown by cases, but if anyone is taking more pride in being concise and clever than programmers, it's mathematicians. This is why you'll find various &quot;condensed&quot; formulations of the same equation in the literature. One of the most common ones is using the Kronecker delta function:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/ff38cb90472289e31bd7f79c1c85c455d7962cbb.svg" style="height: 43px;" type="image/svg+xml"> $\delta_{ij}=\left\{\begin{matrix} 1 &amp; i=j\\ 0 &amp; i\ne j \end{matrix}\right$</object> <p>To write:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/6e4b626a68faabba991f9d1e83a12c74fcec0e63.svg" style="height: 19px;" type="image/svg+xml"> $D_j S_i = S_i (\delta_{ij}-S_j)$</object> <p>Which is, of course, the same thing. There are a couple of other formulations one sees in the literature:</p> <ol class="arabic simple"> <li>Using the matrix formulation of the Jacobian directly to replace <object class="valign-0" data="https://eli.thegreenplace.net/images/math/3a6a16552e246af497720ffdfe6091b42d2f8938.svg" style="height: 12px;" type="image/svg+xml">\delta</object> with <object class="valign-0" data="https://eli.thegreenplace.net/images/math/ca73ab65568cd125c2d27a22bbd9e863c10b675d.svg" style="height: 12px;" type="image/svg+xml">I</object> - the identity matrix, whose elements are expressing <object class="valign-0" data="https://eli.thegreenplace.net/images/math/3a6a16552e246af497720ffdfe6091b42d2f8938.svg" style="height: 12px;" type="image/svg+xml">\delta</object> in matrix form.</li> <li>Using &quot;1&quot; as the function name instead of the Kroneker delta, as follows: <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/a4fa3293a004c9dc1f5171ddb590ac9cb7178102.svg" style="height: 20px;" type="image/svg+xml">D_j S_i = S_i (1(i=j)-S_j)</object>. Here <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d9e260212cd116b69ffa42e9c9f824b2bcf6a217.svg" style="height: 18px;" type="image/svg+xml">1(i=j)</object> means the value 1 when <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8e4587fc82ce6377530643c5622b41e53cdf3dd3.svg" style="height: 16px;" type="image/svg+xml">i=j</object> and the value 0 otherwise.</li> </ol> <p>The condensed notation comes useful when we want to compute more complex derivatives that depend on the softmax derivative; otherwise we'd have to propagate the condition everywhere.</p> </div> <div class="section" id="computing-softmax-and-numerical-stability"> <h2>Computing softmax and numerical stability</h2> <p>A simple way of computing the softmax function on a given vector in Python is:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Compute the softmax of vector x.&quot;&quot;&quot;</span> <span class="n">exps</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">exps</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">exps</span><span class="p">)</span> </pre></div> <p>Let's try it with the sample 3-element vector we've used as an example earlier:</p> <div class="highlight"><pre><span></span>In : softmax([1, 2, 3]) Out: array([ 0.09003057, 0.24472847, 0.66524096]) </pre></div> <p>However, if we run this function with larger numbers (or large negative numbers) we have a problem:</p> <div class="highlight"><pre><span></span>In : softmax([1000, 2000, 3000]) Out: array([ nan, nan, nan]) </pre></div> <p>The numerical range of the floating-point numbers used by Numpy is limited. For <tt class="docutils literal">float64</tt>, the maximal representable number is on the order of <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/91d9772e2d01d53580c14ba9801ea3303f45cac7.svg" style="height: 16px;" type="image/svg+xml">10^{308}</object>. Exponentiation in the softmax function makes it possible to easily overshoot this number, even for fairly modest-sized inputs.</p> <p>A nice way to avoid this problem is by normalizing the inputs to be not too large or too small, by observing that we can use an arbitrary constant <em>C</em> as follows:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/21c627f153906b6de2c2723f4a20629a610945ba.svg" style="height: 46px;" type="image/svg+xml"> $S_j=\frac{e^{a_j}}{\sum_{k=1}^{N}e^{a_k}}=\frac{Ce^{a_j}}{\sum_{k=1}^{N}Ce^{a_k}}$</object> <p>And then pushing the constant into the exponent, we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/c5b631159b49e84338269e0943e00da2fb7f5d21.svg" style="height: 51px;" type="image/svg+xml"> $S_j=\frac{e^{a_j+log(C)}}{\sum_{k=1}^{N}e^{a_k+log(C)}}$</object> <p>Since <em>C</em> is just an arbitrary constant, we can instead write:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7ae51c811f1348f4762e3eee1a3cc9e8aad1890c.svg" style="height: 49px;" type="image/svg+xml"> $S_j=\frac{e^{a_j+D}}{\sum_{k=1}^{N}e^{a_k+D}}$</object> <p>Where <em>D</em> is also an arbitrary constant. This formula is equivalent to the original <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/cb8b5683be866b4c177c0c319e14085f25bec523.svg" style="height: 18px;" type="image/svg+xml">S_j</object> for any <em>D</em>, so we're free to choose a <em>D</em> that will make our computation better numerically. A good choice is the maximum between all inputs, negated:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/0433b741304b0b54a6e11be1602b63d4b6326e98.svg" style="height: 18px;" type="image/svg+xml"> $D=-max(a_1, a_2, \cdots, a_N)$</object> <p>This will shift the inputs to a range close to zero, assuming the inputs themselves are not too far from each other. Crucially, it shifts them all to be negative (except the maximal <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/c2d2e987a5cb0df2f497d2dba0da0960fb6fbcc0.svg" style="height: 14px;" type="image/svg+xml">a_j</object> which turns into a zero). Negatives with large exponents &quot;saturate&quot; to zero rather than infinity, so we have a better chance of avoiding NaNs.</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">stablesoftmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Compute the softmax of vector x in a numerically stable way.&quot;&quot;&quot;</span> <span class="n">shiftx</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">exps</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">shiftx</span><span class="p">)</span> <span class="k">return</span> <span class="n">exps</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">exps</span><span class="p">)</span> </pre></div> <p>And now:</p> <div class="highlight"><pre><span></span>In : stablesoftmax([1000, 2000, 3000]) Out: array([ 0., 0., 1.]) </pre></div> <p>Note that this is still imperfect, since mathematically softmax would never really produce a zero, but this is much better than NaNs, and since the distance between the inputs is very large it's expected to get a result extremely close to zero anyway.</p> </div> <div class="section" id="the-softmax-layer-and-its-derivative"> <h2>The softmax layer and its derivative</h2> <p>A common use of softmax appears in machine learning, in particular in logistic regression: the softmax &quot;layer&quot;, wherein we apply softmax to the output of a fully-connected layer (matrix multiplication):</p> <img alt="Generic softmax layer diagram" class="align-center" src="https://eli.thegreenplace.net/images/2016/softmax-layer-generic.png" /> <p>In this diagram, we have an input <em>x</em> with N features, and T possible output classes. The weight matrix <em>W</em> is used to transform <em>x</em> into a vector with T elements (called &quot;logits&quot; in ML folklore), and the softmax function is used to &quot;collapse&quot; the logits into a vector of probabilities denoting the probability of <em>x</em> belonging to each one of the T output classes.</p> <p>How do we compute the derivative of this &quot;softmax layer&quot; (fully-connected matrix multiplication followed by softmax)? Using the chain rule, of course! You'll find any number of derivations of this derivative online, but I want to approach it from first principles, by carefully applying the <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus/">multivariate chain rule</a> to the Jacobians of the functions involved.</p> <p>An important point before we get started: you may think that <em>x</em> is a natural variable to compute the derivative for. But it's not. In fact, in machine learning we usually want to find the best weight matrix <em>W</em>, and thus it is <em>W</em> we want to update with every step of <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent">gradient descent</a>. Therefore, we'll be computing the derivative of this layer w.r.t. <em>W</em>.</p> <p>Let's start by rewriting this diagram as a composition of vector functions. First, we have the matrix multiplication, which we denote <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a0e38e0d2b015bcbf88c39139b08982ae8b9529d.svg" style="height: 18px;" type="image/svg+xml">g(W)</object>. It maps <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/41cbe7438e5529bcab383579b09d611cd97f0444.svg" style="height: 16px;" type="image/svg+xml">\mathbb{R}^{NT}\rightarrow \mathbb{R}^{T}</object>, because the input (matrix <em>W</em>) has <em>N times T</em> elements, and the output has T elements.</p> <p>Next we have the softmax. If we denote the vector of logits as <object class="valign-0" data="https://eli.thegreenplace.net/images/math/b3931f1ce298c536432fd324b3a1ab4337120689.svg" style="height: 12px;" type="image/svg+xml">\lambda</object>, we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9af2279e6f8c350d3e301ff7ed97ff2d23d2b478.svg" style="height: 19px;" type="image/svg+xml">S(\lambda):\mathbb{R}^{T}\rightarrow \mathbb{R}^{T}</object>. Overall, we have the function composition:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/10e8a3123f66fe60ae76a3fe83b2a9b73ea3fa57.svg" style="height: 45px;" type="image/svg+xml"> \begin{align*} P(W)&amp;=S(g(W)) \\ &amp;=(S\circ g)(W) \end{align*}</object> <p>By applying the multivariate chain rule, the Jacobian of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f6dd867bfc20ac609f598f54ed834172e0985b0b.svg" style="height: 18px;" type="image/svg+xml">P(W)</object> is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/80f6a3c715eb405a68968e4c579d3a2b562cfab0.svg" style="height: 18px;" type="image/svg+xml"> $DP(W)=D(S\circ g)(W)=DS(g(W))\cdot Dg(W)$</object> <p>We've computed the Jacobian of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7f3a73c41d966d0cade30c5b1fadd35290358a15.svg" style="height: 18px;" type="image/svg+xml">S(a)</object> earlier in this post; what's remaining is the Jacobian of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a0e38e0d2b015bcbf88c39139b08982ae8b9529d.svg" style="height: 18px;" type="image/svg+xml">g(W)</object>. Since <em>g</em> is a very simple function, computing its Jacobian is easy; the only complication is dealing with the indices correctly. We have to keep track of which weight each derivative is for. Since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/50dd3f482e6e8490b6b54b110c2b8e9018c6a607.svg" style="height: 19px;" type="image/svg+xml">g(W):\mathbb{R}^{NT}\rightarrow \mathbb{R}^{T}</object>, its Jacobian has <em>T</em> rows and <em>NT</em> columns:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/0d59698eb2307932fdb5a94b7f089da40688f368.svg" style="height: 76px;" type="image/svg+xml"> $Dg=\begin{bmatrix} D_1 g_1 &amp; \cdots &amp; D_{NT} g_1 \\ \vdots &amp; \ddots &amp; \vdots \\ D_1 g_T &amp; \cdots &amp; D_{NT} g_T \end{bmatrix}$</object> <p>In a sense, the weight matrix <em>W</em> is &quot;linearized&quot; to a vector of length <em>NT</em>. If you're familiar with the <a class="reference external" href="http://eli.thegreenplace.net/2015/memory-layout-of-multi-dimensional-arrays">memory layout of multi-dimensional arrays</a>, it should be easy to understand how it's done. In our case, one simple thing we can do is linearize it in row-major order, where the first row is consecutive, followed by the second row, etc. Mathematically, <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/14147644eaa95a20bf61a81af56045475f386a83.svg" style="height: 18px;" type="image/svg+xml">W_{ij}</object> will get column number <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef7b2d987af3c0ceb75381d096c35e8c19085642.svg" style="height: 18px;" type="image/svg+xml">(i-1)N+j</object> in the Jacobian. To populate <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object>, let's recall what <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/434575851c19a9826fb6be1ca130ffa3243a2a34.svg" style="height: 12px;" type="image/svg+xml">g_1</object> is:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/64a7924d431e1a8e82f753f1f04943ddd619fedb.svg" style="height: 16px;" type="image/svg+xml"> $g_1=W_{11}x_1+W_{12}x_2+\cdots +W_{1N}x_N$</object> <p>Therefore:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2a64f2f7fdb74ca1e0e3bf86da7e9874e8855928.svg" style="height: 177px;" type="image/svg+xml"> \begin{align*} D_1g_1&amp;=x_1 \\ D_2g_1&amp;=x_2 \\ \cdots \\ D_Ng_1&amp;=x_N \\ D_{N+1}g_1&amp;=0 \\ \cdots \\ D_{NT}g_1&amp;=0 \end{align*}</object> <p>If we follow the same approach to compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/eeb76bb8cb07245435e01abcd03dec71f9c051df.svg" style="height: 12px;" type="image/svg+xml">g_2...g_T</object>, we'll get the Jacobian matrix:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/5b0d880f118ea950dd4c676a9aad2e481d83b0bf.svg" style="height: 76px;" type="image/svg+xml"> $Dg=\begin{bmatrix} x_1 &amp; x_2 &amp; \cdots &amp; x_N &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ \vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\ 0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_1 &amp; x_2 &amp; \cdots &amp; x_N \end{bmatrix}$</object> <p>Looking at it differently, if we split the index of <em>W</em> to <em>i</em> and <em>j</em>, we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/3ca9791a8734377178476d2069bbb072b7e345ac.svg" style="height: 44px;" type="image/svg+xml"> \begin{align*} D_{ij}g_t&amp;=\frac{\partial(W_{t1}x_1+W_{t2}x_2+\cdots+W_{tN}x_N)}{\partial W_{ij}} &amp;= \left\{\begin{matrix} x_j &amp; i = t\\ 0 &amp; i \ne t \end{matrix}\right. \end{align*}</object> <p>This goes into row <em>t</em>, column <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef7b2d987af3c0ceb75381d096c35e8c19085642.svg" style="height: 18px;" type="image/svg+xml">(i-1)N+j</object> in the Jacobian matrix.</p> <p>Finally, to compute the full Jacobian of the softmax layer, we just do a dot product between <object class="valign-0" data="https://eli.thegreenplace.net/images/math/2ee0d2dca289c3eb54f4cc5e98db8d63e9b0794b.svg" style="height: 12px;" type="image/svg+xml">DS</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object>. Note that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/be12618361f03651d2f459ce0fa3ac82aad3b766.svg" style="height: 19px;" type="image/svg+xml">P(W):\mathbb{R}^{NT}\rightarrow \mathbb{R}^{T}</object>, so the Jacobian dimensions work out. Since <object class="valign-0" data="https://eli.thegreenplace.net/images/math/2ee0d2dca289c3eb54f4cc5e98db8d63e9b0794b.svg" style="height: 12px;" type="image/svg+xml">DS</object> is <em>TxT</em> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object> is <em>TxNT</em>, their dot product <object class="valign-0" data="https://eli.thegreenplace.net/images/math/9f2059fa4172536236c9acfa22a911f918547e55.svg" style="height: 12px;" type="image/svg+xml">DP</object> is <em>TxNT</em>.</p> <p>In literature you'll see a much shortened derivation of the derivative of the softmax layer. That's fine, since the two functions involved are simple and well known. If we carefully compute a dot product between a row in <object class="valign-0" data="https://eli.thegreenplace.net/images/math/2ee0d2dca289c3eb54f4cc5e98db8d63e9b0794b.svg" style="height: 12px;" type="image/svg+xml">DS</object> and a column in <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/699151b941880c8adf5d363048a97c6731482ed6.svg" style="height: 54px;" type="image/svg+xml"> $D_{ij}P_t=\sum_{k=1}^{T}D_kS_t\cdot D_{ij}g_k$</object> <p><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object> is mostly zeros, so the end result is simpler. The only <em>k</em> for which <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/fca24bbbbf8cac80ccc0253802b13d2749770585.svg" style="height: 18px;" type="image/svg+xml">D_{ij}g_k</object> is nonzero is when <object class="valign-0" data="https://eli.thegreenplace.net/images/math/f4b7e42a4b8c52f40eb9458e68e81c74d70c1c61.svg" style="height: 13px;" type="image/svg+xml">i=k</object>; then it's equal to <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/73058e43db0f4edc791b10f27f913cbc5d361ab6.svg" style="height: 14px;" type="image/svg+xml">x_j</object>. Therefore:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/7f5cbb15243987230b4fa5741769938a78c9c2f2.svg" style="height: 44px;" type="image/svg+xml"> \begin{align*} D_{ij}P_t&amp;=D_iS_tx_j \\ &amp;=S_t(\delta_{ti}-S_i)x_j \end{align*}</object> <p>So it's entirely possible to compute the derivative of the softmax layer without actual Jacobian matrix multiplication; and that's good, because matrix multiplication is expensive! The reason we can avoid most computation is that the Jacobian of the fully-connected layer is <em>sparse</em>.</p> <p>That said, I still felt it's important to show how this derivative comes to life from first principles based on the composition of Jacobians for the functions involved. The advantage of this approach is that it works exactly the same for more complex compositions of functions, where the &quot;closed form&quot; of the derivative for each element is much harder to compute otherwise.</p> </div> <div class="section" id="softmax-and-cross-entropy-loss"> <h2>Softmax and cross-entropy loss</h2> <p>We've just seen how the softmax function is used as part of a machine learning network, and how to compute its derivative using the multivariate chain rule. While we're at it, it's worth to take a look at a loss function that's commonly used along with softmax for training a network: cross-entropy.</p> <p><a class="reference external" href="https://en.wikipedia.org/wiki/Cross_entropy">Cross-entropy</a> has an interesting probabilistic and information-theoretic interpretation, but here I'll just focus on the mechanics. For two discrete probability distributions <em>p</em> and <em>q</em>, the cross-entropy function is defined as:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/b26f68a12667ba254facf9815252f52ebf2238d9.svg" style="height: 38px;" type="image/svg+xml"> $xent(p,q)=-\sum_{k}p(k)log(q(k))$</object> <p>Where <em>k</em> goes over all the possible values of the random variable the distributions are defined for. Specifically, in our case there are <em>T</em> output classes, so <em>k</em> would go from 1 to <em>T</em>.</p> <p>If we start from the softmax output <em>P</em> - this is one probability distribution <a class="footnote-reference" href="#id4" id="id2"></a>. The other probability distribution is the &quot;correct&quot; classification output, usually denoted by <em>Y</em>. This is a one-hot encoded vector of size <em>T</em>, where all elements except one are 0.0, and one element is 1.0 - this element marks the correct class for the data being classified. Let's rephrase the cross-entropy loss formula for our domain:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/b02b400caa1de3f720f3c51b4891204a85a0d482.svg" style="height: 54px;" type="image/svg+xml"> $xent(Y, P)=-\sum_{k=1}^{T}Y(k)log(P(k))$</object> <p><em>k</em> goes over all the output classes. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1801d6549d7f256091d8d687062875facf870a80.svg" style="height: 18px;" type="image/svg+xml">P(k)</object> is the probability of the class as predicted by the model. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/369b88be91e9aecb20f084f95946d171096ec2ad.svg" style="height: 18px;" type="image/svg+xml">Y(k)</object> is the &quot;true&quot; probability of the class as provided by the data. Let's mark the sole index where <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/bf2a1a90dbf5ee8f3e1240a2aff2b64220f3e876.svg" style="height: 18px;" type="image/svg+xml">Y(k)=1.0</object> by <em>y</em>. Since for all <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e0e4ad3507e9dde8cc37658b436305ef9eb14ca0.svg" style="height: 17px;" type="image/svg+xml">k\ne y</object> we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9d1a77958eb2fd853cb41001e41efcfa46a099d3.svg" style="height: 18px;" type="image/svg+xml">Y(k)=0</object>, the cross-entropy formula can be simplified to:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/ca79e575abc3ff07571f9b7bd9ee477c4cac1b7a.svg" style="height: 18px;" type="image/svg+xml"> $xent(Y, P)=-log(P(y))$</object> <p>Actually, let's make it a function of just <em>P</em>, treating <em>y</em> as a constant. Moreover, since in our case <em>P</em> is a vector, we can express <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/033e08901a43a52bb55ac6d36bcb0cebb8781a4e.svg" style="height: 18px;" type="image/svg+xml">P(y)</object> as the <em>y</em>-th element of <em>P</em>, or <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/12b5ad2733328bc7191f23d13e05c4e246bb8e26.svg" style="height: 18px;" type="image/svg+xml">P_y</object>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/e659a9fdd830a347c3aae214b31013eb52c59dc7.svg" style="height: 19px;" type="image/svg+xml"> $xent(P)=-log(P_y)$</object> <p>The Jacobian of <em>xent</em> is a <em>1xT</em> matrix (a row vector), since the output is a scalar and we have <em>T</em> inputs (the vector <em>P</em> has <em>T</em> elements):</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/2e515cd8235b0385a95e5cbfff5fbcca9a78c631.svg" style="height: 22px;" type="image/svg+xml"> $Dxent=\begin{bmatrix} D_1xent &amp; D_2xent &amp; \cdots &amp; D_Txent \end{bmatrix}$</object> <p>Now recall that <em>P</em> can be expressed as a function of input weights: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ad179bfd313d392ad156b509370b8f407e7bd20a.svg" style="height: 18px;" type="image/svg+xml">P(W)=S(g(W))</object>. So we have another function composition:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/ab2f487f02c386d5f532900ffd0927c28ed23b7c.svg" style="height: 18px;" type="image/svg+xml"> $xent(W)=(xent\circ P)(W)=xent(P(W))$</object> <p>And we can, once again, use the multivariate chain rule to find the gradient of <em>xent</em> w.r.t. <em>W</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/ef938751c387283a7be6461ab0c244ac09db85be.svg" style="height: 18px;" type="image/svg+xml"> $Dxent(W)=D(xent\circ P)(W)=Dxent(P(W))\cdot DP(W)$</object> <p>Let's check that the dimensions of the Jacobian matrices work out. We already computed <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3f90f5becd4cc377e50cd6885718feb039eabcc9.svg" style="height: 18px;" type="image/svg+xml">DP(W)</object>; it's <em>TxNT</em>. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/676107d1b425649d04d82c75a37b391aa99edcf1.svg" style="height: 18px;" type="image/svg+xml">Dxent(P(W))</object> is <em>1xT</em>, so the resulting Jacobian <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86f6a5ad8eb3128d2d86c826df3d8831403e64ac.svg" style="height: 18px;" type="image/svg+xml">Dxent(W)</object> is <em>1xNT</em>, which makes sense because the whole network has one output (the cross-entropy loss - a scalar value) and <em>NT</em> inputs (the weights).</p> <p>Here again, there's a straightforward way to find a simple formula for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86f6a5ad8eb3128d2d86c826df3d8831403e64ac.svg" style="height: 18px;" type="image/svg+xml">Dxent(W)</object>, since many elements in the matrix multiplication end up cancelling out. Note that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/bb805dc98dfe8b48ded94e4f27a90e74b64371e4.svg" style="height: 18px;" type="image/svg+xml">xent(P)</object> depends only on the <em>y</em>-th element of <em>P</em>. Therefore, only <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/fb396ced0aaf5ee006e13bb7b0925ba833e01a12.svg" style="height: 18px;" type="image/svg+xml">D_{y}xent</object> is non-zero in the Jacobian:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/c2b2a7aa200023fd2988991212edc5053a85731e.svg" style="height: 22px;" type="image/svg+xml"> $Dxent=\begin{bmatrix} 0 &amp; 0 &amp; D_{y}xent &amp; \cdots &amp; 0 \end{bmatrix}$</object> <p>And <object class="valign-m10" data="https://eli.thegreenplace.net/images/math/ba6bd8869680cb3dab4a5138b909d4f4155ae6a8.svg" style="height: 26px;" type="image/svg+xml">D_{y}xent=-\frac{1}{P_y}</object>. Going back to the full Jacobian <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86f6a5ad8eb3128d2d86c826df3d8831403e64ac.svg" style="height: 18px;" type="image/svg+xml">Dxent(W)</object>, we multiply <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3845e2788792dc92a7072833fa019ce1182f4dbc.svg" style="height: 18px;" type="image/svg+xml">Dxent(P)</object> by each column of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/31bc3dde97a870d7b85f78efe4d178d38eae0fdb.svg" style="height: 18px;" type="image/svg+xml">D(P(W))</object> to get each element in the resulting row-vector. Recall that the row vector represents the whole weight matrix <em>W</em> &quot;linearized&quot; in row-major order. We'll index into it with <em>i</em> and <em>j</em> for clarity (<object class="valign-m6" data="https://eli.thegreenplace.net/images/math/d82e04a1bce5f5f685c8b6ac356997c847fa95a5.svg" style="height: 18px;" type="image/svg+xml">D_{ij}</object> points to element number <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef7b2d987af3c0ceb75381d096c35e8c19085642.svg" style="height: 18px;" type="image/svg+xml">(i-1)N+j</object> in the row vector):</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/b61c7d91efebf65b53f3dada643d86b63d06b6b5.svg" style="height: 54px;" type="image/svg+xml"> $D_{ij}xent(W)=\sum_{k=1}^{T}D_{k}xent(P)\cdot D_{ij}P_k(W)$</object> <p>Since only the <em>y</em>-th element in <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2b703a6ad534070bbe698f8d8a3a1261b5bb4549.svg" style="height: 18px;" type="image/svg+xml">D_{k}xent(P)</object> is non-zero, we get the following, also substituting the derivative of the softmax layer from earlier in the post:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/d7823846aecfb3673906d65e8da6b290b7b2f608.svg" style="height: 68px;" type="image/svg+xml"> \begin{align*} D_{ij}xent(W)&amp;=D_{y}xent(P)\cdot D_{ij}P_y(W) \\ &amp;=-\frac{1}{P_y}\cdot S_y(\delta_{yi}-S_i)x_j \end{align*}</object> <p>By our definition, <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/2ec0ba51607b94096ad077ab55cc181698494e1a.svg" style="height: 18px;" type="image/svg+xml">P_y=S_y</object>, so we get:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/e417398a544821300668c777d55ad489934d744c.svg" style="height: 96px;" type="image/svg+xml"> \begin{align*} D_{ij}xent(W)&amp;=-\frac{1}{S_y}\cdot S_y(\delta_{yi}-S_i)x_j \\ &amp;=-(\delta_{yi}-S_i)x_j \\ &amp;=(S_i-\delta_{yi})x_j \end{align*}</object> <p>Once again, even though in this case the end result is nice and clean, it didn't necessarily have to be so. The formula for <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/b0cfb602e63642cc6146ca57731821d6a9866a1e.svg" style="height: 20px;" type="image/svg+xml">D_{ij}xent(W)</object> could end up being a fairly involved sum (or sum of sums). The technique of multiplying Jacobian matrices is oblivious to all this, as the computer can do all the sums for us. All we have to do is compute the individial Jacobians, which is usually easier because they are for simpler, non-composed functions. This is the beauty and utility of the multivariate chain rule.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id3" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>To play more with sample inputs and Softmax outputs, Michael Nielsen's online book has a <a class="reference external" href="http://neuralnetworksanddeeplearning.com/chap3.html#softmax">nice interactive Javascript visualization</a> - check it out.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>Take a moment to recall that, by definition, the output of the softmax function is indeed a valid discrete probability distribution.</td></tr> </tbody> </table> </div> The Chain Rule of Calculus2016-10-10T06:24:00-07:002016-10-10T06:24:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-10-10:/2016/the-chain-rule-of-calculus/<p>The chain rule of derivatives is, in my opinion, the most important formula in differential calculus. In this post I want to explain how the chain rule works for single-variable and multivariate functions, with some interesting examples along the way.</p> <div class="section" id="preliminaries-composition-of-functions-and-differentiability"> <h2>Preliminaries: composition of functions and differentiability</h2> <p>We denote a function …</p></div><p>The chain rule of derivatives is, in my opinion, the most important formula in differential calculus. In this post I want to explain how the chain rule works for single-variable and multivariate functions, with some interesting examples along the way.</p> <div class="section" id="preliminaries-composition-of-functions-and-differentiability"> <h2>Preliminaries: composition of functions and differentiability</h2> <p>We denote a function <em>f</em> that maps from the domain <em>X</em> to the codomain <em>Y</em> as <img alt="f:X \rightarrow Y" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e2f7fcdddf5b36735350a805eeb7cae36895ab1e.png" style="height: 16px;" />. With this <em>f</em> and given <img alt="g:Y \rightarrow Z" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e664f090a7bb62573ae65c910ef7c81e5f086cf6.png" style="height: 16px;" />, we can define <img alt="g \circ f:X \rightarrow Z" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7f049e7749d289236edefaeb6399795a11afeb44.png" style="height: 16px;" /> as the composition of <em>g</em> and <em>f</em>. It's defined for <img alt="\forall x \in X" class="valign-m1" src="https://eli.thegreenplace.net/images/math/76545a2a780098fe8c8d581192fa77deccae0848.png" style="height: 14px;" /> as:</p> <img alt="$(g \circ f)(x)=g(f(x))$" class="align-center" src="https://eli.thegreenplace.net/images/math/8b9c8e67c9d2ec7fd3eefce043f380512f1230d3.png" style="height: 18px;" /> <p>In calculus we are usually concerned with the real number domain of some dimensionality. In the single-variable case, we can think of <img alt="f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4a0a19218e082a343a1b17e5333409af9d98f0f5.png" style="height: 16px;" /> and <img alt="g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/54fd1711209fb1c0781092374132c66e79e2241b.png" style="height: 12px;" /> as two regular real-valued functions: <img alt="f:\mathbb{R} \rightarrow \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/62ac71ec4fa066b12854a09cddef9ba062924d68.png" style="height: 16px;" /> and <img alt="g:\mathbb{R} \rightarrow \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/974c68b8e4454d31c7a2eb389c94bbbfd11ac9da.png" style="height: 16px;" />.</p> <p>As an example, say <img alt="f(x)=x+1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/027c36c348c172740dd168c66fbfe75d8a8da0c3.png" style="height: 18px;" /> and <img alt="g(x)=x^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9b74ba3074b06d93dacb65e40b0082897aa85b3d.png" style="height: 19px;" />. Then:</p> <img alt="$(g \circ f)(x)=g(f(x))=g(x+1)=(x+1)^2$" class="align-center" src="https://eli.thegreenplace.net/images/math/f80635cd447f9f82452529c9289d16811394ea6c.png" style="height: 21px;" /> <p>We can compose the functions the other way around as well:</p> <img alt="$(f \circ g)(x)=f(g(x))=f(x^2)=x^2+1$" class="align-center" src="https://eli.thegreenplace.net/images/math/13c07f9e990c72b1edaf651fccec5c4ad7c0f155.png" style="height: 21px;" /> <p>Obviously, we shouldn't expect composition to be commutative. It is, however, associative. <img alt="h \circ (g \circ f)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ef6897e4aad0050d8f69248de3ecd8aaa3ad51de.png" style="height: 18px;" /> and <img alt="(h \circ g) \circ f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/03ac0c8bb4a409ff1ec1badfee9693280bb2f241.png" style="height: 18px;" /> are equivalent, and both end up being <img alt="h(g(f(x)))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/bc1a23574da8c77a4fc40d5cbbad2c5e1e95da86.png" style="height: 18px;" /> for <img alt="\forall x \in X" class="valign-m1" src="https://eli.thegreenplace.net/images/math/76545a2a780098fe8c8d581192fa77deccae0848.png" style="height: 14px;" />.</p> <p>To better handle compositions in one's head it sometimes helps to denote the independent variable of the outer function (<em>g</em> in our case) by a different letter (such as <img alt="g(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e7373233d49e18a0882e0dce41d9d6aa26964d6b.png" style="height: 18px;" />). For simple cases it doesn't matter, but I'll be using this technique occasionally throughout the article. The important thing to remember here is that the name of the independent variable is completely arbitrary, and we should always be able to replace it by another name throughout the formula without any semantic change.</p> <p>The other preliminary I want to mention is <em>differentiability</em>. The function <em>f</em> is differentiable at some point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> if the following limit exists:</p> <img alt="$\lim_{h \to 0}\frac{f(x_0+h)-f(x_0)}{h}$" class="align-center" src="https://eli.thegreenplace.net/images/math/34b3ce83a20775cf99b8d204d2b845dfde5727cc.png" style="height: 39px;" /> <p>This limit is then the derivative of <em>f</em> at the point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, or <img alt="{f}&amp;#x27;(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e9b2a1134fcdc276843ee4b522039359117026ee.png" style="height: 18px;" />. Another way to express this is <img alt="\frac{d}{dx}f(x_0)" class="valign-m6" src="https://eli.thegreenplace.net/images/math/b0d6f765abf215972d5dbb982f77f1a83c233066.png" style="height: 22px;" />. Note that <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> can be any arbitrary point on the real line. I sometimes say something like &quot;<em>f</em> is differentiable at <img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" />&quot;. Here too, <img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" /> is just a real value that happens to be the value of the function <em>g</em> at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />.</p> </div> <div class="section" id="the-single-variable-chain-rule"> <h2>The single-variable chain rule</h2> <p>The chain rule for single-variable functions states: if <em>g</em> is differentiable at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> and <em>f</em> is differentiable at <img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" />, then <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" /> is differentiable at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> and its derivative is:</p> <img alt="$(f \circ g)&amp;#x27;(x_0)={f}&amp;#x27;(g(x_0)){g}&amp;#x27;(x_0)$" class="align-center" src="https://eli.thegreenplace.net/images/math/77fb8b77b35d687c20379179b0178ebdd9b2cee1.png" style="height: 20px;" /> <p>The proof of the chain rule is a bit tricky - I left it for the appendix. However, we can get a better feel for it using some intuition and a couple of examples.</p> <p>First, the intuituion. By definition:</p> <img alt="${g}&amp;#x27;(x_0)=\lim_{h \to 0}\frac{g(x_0+h)-g(x_0)}{h}$" class="align-center" src="https://eli.thegreenplace.net/images/math/cdc3e4a3bced3a7527a15cd76a688d5cc1c06aab.png" style="height: 39px;" /> <p>Multiplying both sides by <em>h</em> we get <a class="footnote-reference" href="#id6" id="id1"></a>:</p> <img alt="${g}&amp;#x27;(x_0)h=\lim_{h \to 0}g(x_0+h)-g(x_0)$" class="align-center" src="https://eli.thegreenplace.net/images/math/daf52cabed3806986d4c8c29dd60e4ce4fa9247d.png" style="height: 29px;" /> <p>Therefore we can say that when <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> changes by some very small amount, <img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" /> changes by <img alt="{g}&amp;#x27;(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/fbb4d7279a750f6d80eebeff2e2c25765b304f16.png" style="height: 18px;" /> times that small amount.</p> <p>Similarly <img alt="{f}&amp;#x27;(a_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d04139c8f65536c3042f975a6966ed49f5f15832.png" style="height: 18px;" /> is the amount of change in the value of <em>f</em> for some very small change from <img alt="a_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/4a5997da73aadd118038761e69d01e24586bf958.png" style="height: 11px;" />. However, since in our case we compose <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" />, we can say that <img alt="a_0=g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e198d0bc24284bd638c564e0b46edf975d5831d4.png" style="height: 18px;" />, evaluating <img alt="f(g(x_0))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/04a8bf9b7bd565f95f2cb3e0fe6de123b247e3be.png" style="height: 18px;" />. Suppose we shift <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> by a small amount <em>h</em>. This causes <img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" /> to shift by <img alt="{g}&amp;#x27;(x_0)h" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0e1e11ca765684cf07722c40de2bd86b208ca7c1.png" style="height: 18px;" />. So the input <img alt="a_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/4a5997da73aadd118038761e69d01e24586bf958.png" style="height: 11px;" /> of <em>f</em> shifted by <img alt="{g}&amp;#x27;(x_0)h" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0e1e11ca765684cf07722c40de2bd86b208ca7c1.png" style="height: 18px;" /> - this is still a small amount! Therefore, the total change in the value of <em>f</em> should be <img alt="{f}&amp;#x27;(g(x_0)){g}&amp;#x27;(x_0)h" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b761eb11c7502754575d0413e7ba040f4a106d0d.png" style="height: 18px;" /> <a class="footnote-reference" href="#id7" id="id2"></a>.</p> <p>Now, a couple of simple examples. Let's take the function <img alt="f(x)=(x+1)^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8db433d3f263ad489e31931ef4a3ddccbd7bece0.png" style="height: 19px;" />. The idea is to think of this function as a composition of simpler functions. In this case, one option is: <img alt="g(x)=x+1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8b2ec3a2221203b211c8a0ed975841cb508b193c.png" style="height: 18px;" /> and then <img alt="w(g(x))=g(x)^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2ed44118a1efadf34f5bf169d2ca450246519d1d.png" style="height: 19px;" />, so the original <em>f</em> is now the composition <img alt="w \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4edc28332d30c68727a56fbd473126441850c4f0.png" style="height: 12px;" />.</p> <p>The derivative of this composition is <img alt="{w}&amp;#x27;(g(x)){g}&amp;#x27;(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/05261e48f79f6e8b129bb26dee7fa8a07bcbf876.png" style="height: 18px;" />, or <img alt="2(x+1)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/39b598cf32e125c7ae18b7623043d5f8133eba78.png" style="height: 18px;" /> since <img alt="{g}&amp;#x27;(x)=1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/36cc7eeced1b708dcf6166dcaae955f733f93ded.png" style="height: 18px;" />. Note that <em>w</em> is differentiable at any point, so this derivative always exists.</p> <p>Another example will use a longer chain of composition. Let's differentiate <img alt="f(x)=sin[(x+1)^2]" class="valign-m5" src="https://eli.thegreenplace.net/images/math/3e3a23e0dd5d4ee105bcca545bddb058917e2c9c.png" style="height: 20px;" />. This is a composition of three functions:</p> <img alt="\begin{align*} g(x)&amp;amp;=x+1\\ w(x)&amp;amp;=x^2\\ v(x)&amp;amp;=sin(x) \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/6981c04536025d8e43d07bf9b067252c2028feab.png" style="height: 73px;" /> <p>Function composition is associative, so <em>f</em> can be expressed as either <img alt="v \circ (w \circ g)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1c2a8a63ec4fb6e489b0896b544e277823228906.png" style="height: 18px;" /> or <img alt="(v \circ w) \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/be47fface92aa5db8bade3049da31d065ef8244b.png" style="height: 18px;" />. Since we already know what the derivative of <img alt="w \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4edc28332d30c68727a56fbd473126441850c4f0.png" style="height: 12px;" /> is, let's use the former:</p> <img alt="\begin{align*} \frac{df(x)}{dx}=\frac{d v(w(g(x)))}{dx}&amp;amp;={v}&amp;#x27;(w(g(x))){w(g(x))}&amp;#x27;(x)\\ &amp;amp;=cos(w(g(x)))2(x+1)\\ &amp;amp;=2cos[(x+1)^2](x+1) \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/f63f9a07295583911873238c3ee6e84e8c3722ca.png" style="height: 93px;" /> </div> <div class="section" id="the-chain-rule-as-a-computational-procedure"> <h2>The chain rule as a computational procedure</h2> <p>As the last example demonstrates, the chain rule can be applied multiple times in a single derivation. This makes the chain rule a powerful tool for computing derivatives of very complex functions, which can be broken up into compositions of simpler functions. I like to draw a parallel between this process and programming; a function in a programming language can be seen as a computational procedure - we have a set of input parameters and we produce outputs. On the way, several transformations happen that can be expressed mathematically. These transformations are composed, so their derivatives can be computed naturally with the chain rule.</p> <p>This may be somewhat abstract, so let's use another example. We'll compute the derivative of the Sigmoid function - a very important function in machine learning:</p> <img alt="$S(x)=\frac{1}{1+e^{-x}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/9a39d0495ce32da5840b76adaf508a0349394c49.png" style="height: 38px;" /> <p>To make the equivalence between functions and computational procedures clearer, let's think how we'd compute <em>S</em> in Python:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span> </pre></div> <p>This doesn't look much different, but that's just because Python is a high level language with arbitrarily nested expressions. Its VM (or the CPU in general) would execute this computation step by step. Let's break it up to be clearer, assuming we can only apply a single operation at every step:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="n">f</span> <span class="o">=</span> <span class="o">-</span><span class="n">x</span> <span class="n">g</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="n">w</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">g</span> <span class="n">v</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">w</span> <span class="k">return</span> <span class="n">v</span> </pre></div> <p>I hope you're starting to see the resemblance to our chain rule examples at this point. Sacrificing some rigor in the notation for the sake of expressiveness, we can write:</p> <img alt="$S&amp;#x27;=v&amp;#x27;(w)w&amp;#x27;(g)g&amp;#x27;(f)f&amp;#x27;(x)$" class="align-center" src="https://eli.thegreenplace.net/images/math/b3029d842b915e7bf0ea1aa91372ab071dd8b80e.png" style="height: 20px;" /> <p>This is the chain rule applied to <img alt="v \circ (w \circ (g \circ f))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/caaf8ea9ee60bb84d61d422c6dee5d6cd173f0ab.png" style="height: 18px;" />. Solving this is easy because every single derivative in the chain above is trivial:</p> <img alt="\begin{align*} S&amp;#x27;&amp;amp;=v&amp;#x27;(w)w&amp;#x27;(g)g&amp;#x27;(f)(-1)\\ &amp;amp;=v&amp;#x27;(w)w&amp;#x27;(g)e^{-x}(-1)\\ &amp;amp;=v&amp;#x27;(w)(1)e^{-x}(-1)\\ &amp;amp;=\frac{-1}{(1+e^{-x})^2}e^{-x}(-1)\\ &amp;amp;=\frac{e^{-x}}{(1+e^{-x})^2} \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/b987461a2551ca622908f40f791519f3afe3b452.png" style="height: 171px;" /> <p>Now you may be thinking:</p> <ol class="arabic simple"> <li>Every function computable by a program can be broken down to trivial steps like our <tt class="docutils literal">sigmoid</tt> above.</li> <li>Using the chain rule, we can easily find the derivative of such a sequence of steps... therefore:</li> <li>We can easily find the derivative of any function computable by a program!!</li> </ol> <p>An you'll be right. This is precisely the basis for the technique known as <a class="reference external" href="https://en.wikipedia.org/wiki/Automatic_differentiation">automatic differentiation</a>, which is widely used in scienctific computing. The most notable use of automatic differentiation in recent times is the backpropagation algorithm - an essential backbone of modern machine learning. I personally find automatic differentiation fascinating, and will write a more dedicated article about it in the future.</p> </div> <div class="section" id="multivariate-chain-rule-general-formulation"> <h2>Multivariate chain rule - general formulation</h2> <p>So far this article has been looking at functions with a single input and output: <img alt="f:\mathbb{R} \to \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2e28467c90f978580e43c376716981ec5906a01d.png" style="height: 16px;" />. In the most general case of multi-variate calculus, we're dealing with functions that map from <em>n</em> dimensions to <em>m</em> dimensions: <img alt="f:\mathbb{R}^{n} \to \mathbb{R}^{m}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/13f219789047343729036279bb11630db317d98d.png" style="height: 16px;" />. Because every one of the <em>m</em> outputs of <em>f</em> can be considered a separate function dependent on <em>n</em> variables, it's very natural to deal with such math using vectors and matrices.</p> <p>First let's define some notation. We'll consider the outputs of <em>f</em> to be numbered from 1 to <em>m</em> as <img alt="f_1,f_2 \dots f_m" class="valign-m4" src="https://eli.thegreenplace.net/images/math/93b446c5209263534d09d617bbede21101d6536e.png" style="height: 16px;" />. For each such <img alt="f_i" class="valign-m4" src="https://eli.thegreenplace.net/images/math/68bd0dc647944d362ec8df628a22967b91d82c80.png" style="height: 16px;" /> we can compute its partial derivative by any of the <em>n</em> inputs as:</p> <img alt="$D_j f_i(a)=\frac{\partial f_i}{\partial a_j}(a)$" class="align-center" src="https://eli.thegreenplace.net/images/math/30881b5a92e45259714ba01c7a12fbf8f6c56109.png" style="height: 42px;" /> <p>Where <em>j</em> goes from 1 to <em>n</em> and <em>a</em> is a vector with <em>n</em> components. If <em>f</em> is differentiable at <em>a</em> <a class="footnote-reference" href="#id8" id="id3"></a> then the derivative of <em>f</em> at <em>a</em> is the <em>Jacobian matrix</em>:</p> <img alt="$Df(a)=\begin{bmatrix} D_1 f_1(a) &amp;amp; \cdots &amp;amp; D_n f_1(a) \\ \vdots &amp;amp; &amp;amp; \vdots \\ D_1 f_m(a) &amp;amp; \cdots &amp;amp; D_n f_m(a) \\ \end{bmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/ab09367d48e9ef4d8bc2314a60313dec700193af.png" style="height: 76px;" /> <p>The multivariate chain rule states: given <img alt="g:\mathbb{R}^n \to \mathbb{R}^m" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b4b7d25491897b053abf7e48688fada4a85368bd.png" style="height: 16px;" /> and <img alt="f:\mathbb{R}^m \to \mathbb{R}^p" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ac8a6cea4e02e885538fc3ef969c5733e84712f9.png" style="height: 16px;" /> and a point <img alt="a \in \mathbb{R}^n" class="valign-m1" src="https://eli.thegreenplace.net/images/math/43a85f2c59f396fe5c4e2c403a0453c463fcfb0d.png" style="height: 13px;" />, if <em>g</em> is differentiable at <em>a</em> and <em>f</em> is differentiable at <img alt="g(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e7373233d49e18a0882e0dce41d9d6aa26964d6b.png" style="height: 18px;" /> then the composition <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" /> is differentiable at <em>a</em> and its derivative is:</p> <img alt="$D(f \circ g)(a)=Df(g(a)) \cdot Dg(a)$" class="align-center" src="https://eli.thegreenplace.net/images/math/00bdefa904bd34df2dfb50cc385e6497c4e5096e.png" style="height: 18px;" /> <p>Which is the matrix multiplication of <img alt="Df(g(a))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e567730c48bb2f95c258b630b4d6e997043e09ab.png" style="height: 18px;" /> and <img alt="Dg(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2575fc98e794a733a7aa6237fe67246a41e6c8c5.png" style="height: 18px;" /> <a class="footnote-reference" href="#id9" id="id4"></a>. Intuitively, the multivariate chain rule mirrors the single-variable one (and as we'll soon see, the latter is just a special case of the former) with derivatives replaced by derivative matrices. From linear algebra, we represent linear transformations by matrices, and the composition of two linear transformations is the product of their matrices. Therefore, since derivative matrices - like derivatives in one dimension - are a linear approximation to the function, the chain rule makes sense. This is a really nice connection between linear algebra and calculus, though a full proof of the multivariate rule is very technical and outside the scope of this article.</p> </div> <div class="section" id="multivariate-chain-rule-examples"> <h2>Multivariate chain rule - examples</h2> <p>Since the chain rule deals with compositions of functions, it's natural to present examples from the world of parametric curves and surfaces. For example, suppose we define <img alt="f(x,y,z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c5d72ae6186c76bde08c693d4bfdb85e3201125d.png" style="height: 18px;" /> as a scalar function <img alt="\mathbb{R}^3 \to \mathbb{R}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/1862a20e93e78e42aafd20106ceabe142def19f1.png" style="height: 16px;" /> giving the temperature at some point in 3D. Now imagine that we're moving through this 3D space on a curve defined by a function <img alt="g:\mathbb{R} \to \mathbb{R}^3" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e97099dd54f45a2a71a33d305c517ec97565909d.png" style="height: 19px;" /> which takes time <em>t</em> and gives the coordinates <img alt="x(t),y(t),z(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4e2bdd3060e49f3494f68f99cb6d2204b2a19e1c.png" style="height: 18px;" /> at that time. We want to compute how the temperature changes as a function of time <em>t</em> - how do we do that? Recall that the temprerature is not a direct function of time, but rather is a function of location, while location <em>is</em> a function of time. Therefore, we'll want to compose <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" />. Here's a concrete example:</p> <img alt="$g(t)=\begin{pmatrix} t\\ t^2\\ t^3 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/cdaff94ebfb318ec24f472be470497e28a091c42.png" style="height: 65px;" /> <p>And:</p> <img alt="$f\begin{pmatrix} x \\ y \\ z \end{pmatrix}=x^2+xyz+5y$" class="align-center" src="https://eli.thegreenplace.net/images/math/0a2fc40b06886d3b54628680192d71a3186d9fc7.png" style="height: 65px;" /> <p>If we reformulate <em>x</em>, <em>y</em> and <em>z</em> as functions of <em>t</em>:</p> <object class="align-center" data="https://eli.thegreenplace.net/images/math/36f726e2fe10b99ab5d216310d2fe91d61f24c46.svg" style="height: 21px;" type="image/svg+xml"> $f(x(t),y(t),z(t))=x(t)^2+x(t)y(t)z(t)+5y(t)$</object> <p>Composing <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" />, we get:</p> <img alt="$(f \circ g)(t)=f(g(t))=f(t,t^2,t^3)=t^2+t^6+5t^2=6t^2+t^6$" class="align-center" src="https://eli.thegreenplace.net/images/math/63ad25f62a0e93b1f8175a627aac0a29a88a3cca.png" style="height: 21px;" /> <p>Since this is a simple function, we can find its derivative directly:</p> <img alt="$(f \circ g)&amp;#x27;(t)=12t+6t^5$" class="align-center" src="https://eli.thegreenplace.net/images/math/d1025880b042d304efe08de37eeafde5a8d9231c.png" style="height: 21px;" /> <p>Now let's repeat this exercise using the multivariate chain rule. To compute <img alt="D(f \circ g)(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c20cc5474ef67f0ec35bddccdc59b72742a864e1.png" style="height: 18px;" /> we need <img alt="Df(g(t))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ded52fd957c2b251c84052c335523b80a4e3c945.png" style="height: 18px;" /> and <img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />. Let's start with <img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />. <img alt="g(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/851fb8b00904a32dff1c79d40158c7ec9d3d5254.png" style="height: 18px;" /> maps <img alt="\mathbb{R} \to \mathbb{R}^3" class="valign-m1" src="https://eli.thegreenplace.net/images/math/0354b4368db3496b963c21b446ad726b65a0ab90.png" style="height: 16px;" />, so its Jacobian is a 3-by-1 matrix, or column vector:</p> <img alt="$Dg(t)=\begin{bmatrix} 1 \\ 2t \\ 3t^2 \end{bmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/492d3e9013352e0cd44e3c5721cd0535174fb318.png" style="height: 65px;" /> <p>To compute <img alt="Df(g(t))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ded52fd957c2b251c84052c335523b80a4e3c945.png" style="height: 18px;" /> let's first find <img alt="Df(x,y,z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/dab2e6dc478f82ef76bff84080623a27fe214dec.png" style="height: 18px;" />. Since <img alt="f(x,y,z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c5d72ae6186c76bde08c693d4bfdb85e3201125d.png" style="height: 18px;" /> maps <img alt="\mathbb{R}^3 \to \mathbb{R}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/1862a20e93e78e42aafd20106ceabe142def19f1.png" style="height: 16px;" />, its Jacobian is a 1-by-3 matrix, or row vector:</p> <img alt="$Df(x,y,z)=\begin{bmatrix} 2x+yz &amp;amp; xz+5 &amp;amp; xy \end{bmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/e8d650cac68d341d2c99c2641be3d238e516e51c.png" style="height: 22px;" /> <p>To apply the chain rule, we need <img alt="Df(g(t))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ded52fd957c2b251c84052c335523b80a4e3c945.png" style="height: 18px;" />:</p> <img alt="$Df(g(t))=\begin{bmatrix} 2t+t^5 &amp;amp; t^4+5 &amp;amp; t^3 \end{bmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/b061977c12dcc918a96473939f6dc01eb7ea7847.png" style="height: 22px;" /> <p>Finally, multiplying <img alt="Df(g(t))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ded52fd957c2b251c84052c335523b80a4e3c945.png" style="height: 18px;" /> by <img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />, we get:</p> <img alt="\begin{align*} D(f \circ g)(t)=Df(g(t)) \cdot Dg(t)&amp;amp;=\begin{bmatrix} 2t+t^5 &amp;amp; t^4+5 &amp;amp; t^3 \end{bmatrix} \cdot \begin{bmatrix} 1 \\ 2t \\ 3t^2 \end{bmatrix}\\ &amp;amp;=2t+t^5+2t^6+10t+3t^5\\ &amp;amp;=12t+6t^5 \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/9c5a5fc3e8024f6d1f2364ad5d0433bb530d4987.png" style="height: 118px;" /> <p>Another interesting way to interpret this result for the case where <img alt="f:\mathbb{R}^3 \to \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d307aff95a39ad62cc090e4d6e3bd73b1ffc2b14.png" style="height: 19px;" /> and <img alt="g:\mathbb{R} \to \mathbb{R}^3" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e97099dd54f45a2a71a33d305c517ec97565909d.png" style="height: 19px;" /> is to <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent">recall that</a> the directional derivative of <em>f</em> in the direction of some vector <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> is:</p> <img alt="$D_{\vec{v}}f=(\nabla f) \cdot \vec{v}$" class="align-center" src="https://eli.thegreenplace.net/images/math/49933775272512c4c8686d9f9692c8ea01e1c97d.png" style="height: 18px;" /> <p>In our case <img alt="(\nabla f)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/cf1f51ce22cf132c44f5cd65c1c6ada1cce0347f.png" style="height: 18px;" /> is the Jacobian of <em>f</em> (because of its dimensionality). So if we take <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> to be the vector <img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />, and evaluate the gradient at <img alt="g(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/851fb8b00904a32dff1c79d40158c7ec9d3d5254.png" style="height: 18px;" /> we get <a class="footnote-reference" href="#id10" id="id5"></a>:</p> <img alt="$D_{\vec{Dg(t)}}f(t)=(\nabla f(g(t))) \cdot Dg(t)$" class="align-center" src="https://eli.thegreenplace.net/images/math/dc8e045fe902682ada36e08fa0099f95632b7ced.png" style="height: 24px;" /> <p>This gives us some additional intuition for the temperature change question. The change in temperature as a function of time is the directional derivative of <em>f</em> in the direction of the change in location (<img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />).</p> <p>For additional examples of applying the chain rule, see <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/">my post about softmax</a>.</p> </div> <div class="section" id="tricks-with-the-multivariate-chain-rule-derivative-of-products"> <h2>Tricks with the multivariate chain rule - derivative of products</h2> <p>Earlier in the article we've seen how the chain rule helps find derivatives of complicated functions by decomposing them into simpler functions. The multivariate chain rule allows even more of that, as the following example demonstrates. Suppose <img alt="h(x)=f(x)g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6b584f7e739b604fe2d90a983216090d25643ad1.png" style="height: 18px;" />. Then, the well-known <a class="reference external" href="https://en.wikipedia.org/wiki/Product_rule">product rule</a> of derivatives states that:</p> <img alt="$h&amp;#x27;(x)=f&amp;#x27;(x)g(x)+f(x)g&amp;#x27;(x)$" class="align-center" src="https://eli.thegreenplace.net/images/math/6c77a942dbee351e8229ce7771680b6a2f55c4aa.png" style="height: 20px;" /> <p>Proving this from first principles (the definition of the derivative as a limit) isn't hard, but I want to show how it stems very easily from the multivariate chain rule.</p> <p>Let's begin by re-formulating <img alt="h(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2a1862ca703d9d9a76538d74b8f4b71df93bafab.png" style="height: 18px;" /> as a composition of two functions. The first takes a vector <img alt="\vec{s}" class="valign-0" src="https://eli.thegreenplace.net/images/math/6a16290a6fe4bd5b30bf2cf959214e8fa4924959.png" style="height: 13px;" /> in <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> and maps it to <img alt="\mathbb{R}" class="valign-0" src="https://eli.thegreenplace.net/images/math/0ed839b111fe0e3ca2b2f618b940893eaea88a57.png" style="height: 12px;" /> by computing the product of its two components:</p> <img alt="$p(\vec{s})=s_1 s_2$" class="align-center" src="https://eli.thegreenplace.net/images/math/955d480267a38ec452bcdf2774dadc7652a757fa.png" style="height: 18px;" /> <p>The second is a vector-valued function that maps a number <img alt="x \in \mathbb{R}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/ec7e4961c34351c48080f6190b6ec363af9adf25.png" style="height: 13px;" /> to <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> :</p> <img alt="$s(x)=\begin{pmatrix} f(x)\\ g(x) \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/f5c473fb1fb5ee47e59414a91dc484e182bc6210.png" style="height: 43px;" /> <p>We can compose <img alt="p \circ s" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3f1c954d3481a1a167ae311bc3c3980aaf1ee3a1.png" style="height: 12px;" />, producing a function that takes a scalar an returns a scalar: <img alt="(p \circ s) : \mathbb{R} \to \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c827ac5b7598c117c157d0377b8c30a0f9a72b81.png" style="height: 18px;" />. We get:</p> <img alt="$h(x)=(p \circ s)(x) = f(x)g(x)$" class="align-center" src="https://eli.thegreenplace.net/images/math/3cbae5f44d32653bd6bbc66e6ee8bb5e1a4dfe40.png" style="height: 18px;" /> <p>Since we're composing two multivariate functions, we can apply the multivariate chain rule here:</p> <img alt="\begin{align*} D(p \circ s) &amp;amp;= Dp(s(x)) \cdot Ds(x)\\ &amp;amp;=\begin{bmatrix} \frac{\partial p}{\partial s_1}(x) &amp;amp; \frac{\partial p}{\partial s_2}(x) \end{bmatrix}\cdot \begin{bmatrix} {s_1}&amp;#x27;(x)\\ {s_2}&amp;#x27;(x) \end{bmatrix}\\ &amp;amp;=\begin{bmatrix} s_2(x) &amp;amp; s_1(x) \end{bmatrix} \cdot \begin{bmatrix} {s_1}&amp;#x27;(x)\\ {s_2}&amp;#x27;(x) \end{bmatrix}\\ &amp;amp;={s_1}&amp;#x27;(x)s_2(x)+{s_2}&amp;#x27;(x)s_1(x) \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/ee8bd27a8257039f72c8751eb78626521f12a5fa.png" style="height: 147px;" /> <p>Since <img alt="s_1(x)=f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/dc67440057222e8222ae08269e4ba2a1e58acbb4.png" style="height: 18px;" /> and <img alt="s_2(x)=g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/89d252d983d49126f2f4a34fcf01fd6c882e4792.png" style="height: 18px;" />, this is exactly the product rule.</p> </div> <div class="section" id="connecting-the-single-variable-and-multivariate-chain-rules"> <h2>Connecting the single-variable and multivariate chain rules</h2> <p>Given function <img alt="f(x) : \mathbb{R} \to \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/bd80387ffb4e5cd8702c12837a57f1806ea1d02b.png" style="height: 18px;" />, its Jacobian matrix has a single entry:</p> <img alt="$Df(a)=\begin{bmatrix}D_{x}f(a)\end{bmatrix}= \begin{bmatrix}\frac{df}{dx}(a)\end{bmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/cc95d53415b32e6610c1a45bededb4fb584f0c64.png" style="height: 24px;" /> <p>Therefore, given two functions mapping <img alt="\mathbb{R} \to \mathbb{R}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/4aaeb3aafa05a9ad54c8d7da4e4aecad4dfac1cd.png" style="height: 13px;" />, the derivative of their composition using the multivariate chain rule is:</p> <img alt="$D(f \circ g)(a)=Df(g(a))\cdot Dg(a)=f&amp;#x27;(g(a))g&amp;#x27;(a)$" class="align-center" src="https://eli.thegreenplace.net/images/math/98e554584c9d2d967b9a6759a64126093ef704ce.png" style="height: 20px;" /> <p>Which is precisely the single-variable chain rule. This results from matrix multiplication between two 1x1 matrices, which ends up being just the product of their single entries.</p> </div> <div class="section" id="appendix-proving-the-single-variable-chain-rule"> <h2>Appendix: proving the single-variable chain rule</h2> <p>It turns out that many online resources (including Khan Academy) provide a flawed proof for the chain rule. It's flawed due to a careless division by a quantity that may be zero. This flaw can be corrected by making the proof somewhat more complicated; I won't take that road here - for details see Spivak's <em>Calculus</em>. Instead, I'll present a simpler proof inspired by the one I found at <a class="reference external" href="http://math.rice.edu/~cjd/">Casey Douglas's site</a>.</p> <p>We want to prove that:</p> <img alt="$(f \circ g)&amp;#x27;(x)={f}&amp;#x27;(g(x)){g}&amp;#x27;(x)$" class="align-center" src="https://eli.thegreenplace.net/images/math/29f4194c9af3777ae55a15dad972a145eb7797be.png" style="height: 20px;" /> <p>Note that previously we defined derivatives at some concrete point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />. Here for the sake of brevity I'll just use <img alt="x" class="valign-0" src="https://eli.thegreenplace.net/images/math/11f6ad8ec52a2984abaafd7c3b516503785c2072.png" style="height: 8px;" /> as an arbitrary point, assuming the derivative exists.</p> <p>Let's start with the definition of <img alt="g&amp;#x27;(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/5fba18e1364e151399f95daac5ed63f09feba9b7.png" style="height: 18px;" />:</p> <img alt="${g}&amp;#x27;(x)=\lim_{h \to 0}\frac{g(x+h)-g(x)}{h}$" class="align-center" src="https://eli.thegreenplace.net/images/math/c19f7ddc43c3046489d7e012c3f213403edf7e8a.png" style="height: 39px;" /> <p>We can reorder it as follows:</p> <img alt="$\lim_{h \to 0}\left [ \frac{g(x+h)-g(x)}{h} - g&amp;#x27;(x) \right ] = 0$" class="align-center" src="https://eli.thegreenplace.net/images/math/74a651394036af8aeaba69650dba26ccb4f90ae7.png" style="height: 43px;" /> <p>Let's give the part in the brackets the name <img alt="\Delta g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ba109afb4d6264ec2d39fe025fcc5a1dbc58637f.png" style="height: 18px;" />.</p> <p>Similarly, if the function <em>f</em> is differentiable at the point <img alt="a=g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/01bcc664dda02e9d98a5f37104ff028cf8fd0d62.png" style="height: 18px;" />, we have:</p> <img alt="$f&amp;#x27;(a)=\lim_{k \to 0}\frac{f(a+k)-f(a)}{k}$" class="align-center" src="https://eli.thegreenplace.net/images/math/59daea2a46cd244229625131297a773820501571.png" style="height: 39px;" /> <p>We reorder:</p> <img alt="$\lim_{k \to 0}\left [ \frac{f(a+k)-f(a)}{k} - f&amp;#x27;(a) \right ] = 0$" class="align-center" src="https://eli.thegreenplace.net/images/math/4600064fad365f360bd73063324a935a8b73266f.png" style="height: 43px;" /> <p>And call the part in the brackets <img alt="\Delta f(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/a85d05996c54eec8a9bef9a60f8e7e4f3231aa51.png" style="height: 18px;" />. The choice of the variable used to go to zero: <em>k</em> instead of <em>h</em> is arbitrary and is useful to simplify the discussion that follows.</p> <p>Let's reorder the definition of <img alt="\Delta g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ba109afb4d6264ec2d39fe025fcc5a1dbc58637f.png" style="height: 18px;" /> a bit:</p> <img alt="$g(x+h)=g(x)+[g&amp;#x27;(x)+\Delta g(x)]h$" class="align-center" src="https://eli.thegreenplace.net/images/math/59e0263f8a2ebfc0fac9a2b51f42c651b359fe31.png" style="height: 21px;" /> <p>We can apply <em>f</em> to both sides:</p> <img alt="$\begin{equation} f(g(x+h))=f(g(x)+[g&amp;#x27;(x)+\Delta g(x)]h) \tag{1} \end{equation}$" class="align-center" src="https://eli.thegreenplace.net/images/math/3b82da9d6cad509490e687b9e86093791545ea81.png" style="height: 21px;" /> <p>By reordering the definition of <img alt="\Delta f(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/a85d05996c54eec8a9bef9a60f8e7e4f3231aa51.png" style="height: 18px;" /> we get:</p> <img alt="$\begin{equation} f(a+k)=f(a)+[f&amp;#x27;(a)+\Delta f(a)]k \tag{2} \end{equation}$" class="align-center" src="https://eli.thegreenplace.net/images/math/88c5b43f3ba89da3853be9342381aa8dd60e024f.png" style="height: 21px;" /> <p>Now taking the right-hand side of (1), we can look at it as <img alt="f(a+k)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/831a544de22e2a6a8997413b576a67391ba31f53.png" style="height: 18px;" /> since <img alt="a=g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/01bcc664dda02e9d98a5f37104ff028cf8fd0d62.png" style="height: 18px;" /> and we can define <img alt="k=[g&amp;#x27;(x)+\Delta g(x)]h" class="valign-m5" src="https://eli.thegreenplace.net/images/math/ed49f313283b8f266ffd1e9b4194c36f456a950d.png" style="height: 19px;" />. We still have <em>k</em> going to zero when <em>h</em> goes to zero. Assigning these <em>a</em> and <em>k</em> into (2) we get:</p> <img alt="$f(a+k)=f(g(x))+[f&amp;#x27;(g(x))+\Delta f(g(x))][g&amp;#x27;(x)+\Delta g(x)]h$" class="align-center" src="https://eli.thegreenplace.net/images/math/275b3323c68b711b2458e4c748a887a368e32a40.png" style="height: 21px;" /> <p>So, starting from (1) again, we have:</p> <img alt="\begin{align*} f(g(x+h))&amp;amp;=f(a+k) \\ &amp;amp;=f(g(x))+[f&amp;#x27;(g(x))+\Delta f(g(x))][g&amp;#x27;(x)+\Delta g(x)]h \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/82e67cf24d9eb3dad58e7d30cd89ba1c19e367fb.png" style="height: 45px;" /> <p>Subtracting <img alt="f(g(x))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/92cb7139e348ea05a69782b2cf7221bae86a2b03.png" style="height: 18px;" /> from both sides and dividing by <em>h</em> (which is legal, since <em>h</em> is not zero, it's just very small) we get:</p> <img alt="$\frac{f(g(x+h))-f(g(x))}{h}=[f&amp;#x27;(g(x))+\Delta f(g(x))][g&amp;#x27;(x)+\Delta g(x)]$" class="align-center" src="https://eli.thegreenplace.net/images/math/bfdfef3d46b471aa5d6803c3c5a6b5e26ffe3b37.png" style="height: 39px;" /> <p>Apply a limit to both sides:</p> <img alt="$\lim_{h \to 0} \frac{f(g(x+h))-f(g(x))}{h}= \lim_{h \to 0} [f&amp;#x27;(g(x))+\Delta f(g(x))][g&amp;#x27;(x)+\Delta g(x)]$" class="align-center" src="https://eli.thegreenplace.net/images/math/0f5c316fcc2877f78b8a739898a31120471dd401.png" style="height: 39px;" /> <p>Now recall that both <img alt="\Delta f(g(x))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6752886c71e1a95dc360dd4e5ea10dd0b6f76e84.png" style="height: 18px;" /> and <img alt="\Delta g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ba109afb4d6264ec2d39fe025fcc5a1dbc58637f.png" style="height: 18px;" /> go to 0 when <em>h</em> goes to 0. Taking this into account, we get:</p> <img alt="$\lim_{h \to 0} \frac{f(g(x+h))-f(g(x))}{h}= f&amp;#x27;(g(x))g&amp;#x27;(x)$" class="align-center" src="https://eli.thegreenplace.net/images/math/3954d8d23c8fb53d4cd1732d19939d650ef830ae.png" style="height: 39px;" /> <p><em>Q.E.D.</em></p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>Here, as in the rest of the post, I'm being careless with the usage of <img alt="\lim" class="valign-0" src="https://eli.thegreenplace.net/images/math/6f5c7776306147fe3be3e4b8547a23c62eafddf4.png" style="height: 13px;" />, sometimes leaving its existence to be implicit. In general, wherever <em>h</em> appears in a formula we know there's a <img alt="\lim_{h \to 0}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0f10af054e5ddc3b9603098fec294e0247190efa.png" style="height: 17px;" /> there, whether explicitly or not.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id7" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>An alternative way to think about it is: suppose the functions <em>f</em> and <em>g</em> are linear: <img alt="f(x)=ax+b" class="valign-m4" src="https://eli.thegreenplace.net/images/math/a85393d5068f5c4bc36ff7efed535a8f1a686848.png" style="height: 18px;" /> and <img alt="g(x)=cx+d" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6d712cb582caa0e48a2b029ea4ae29a3e5e40f27.png" style="height: 18px;" />. Then the chain rule is trivially true. But now recall what the derivative is. The derivative at some point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> is the best linear approximation for the function at that point. Therefore the chain rule is true for any pair of differentiable functions - even when the functions are not linear, we approximate their rate of change in an infinitisemal area around <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> with a linear function.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>The condition for <em>f</em> being differentiable at <em>a</em> is stronger than simply saying that all partial derivatives exist at <em>a</em>, but I won't spend more time on this subtlety here.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id9" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>As an exercise, verify that the matrix dimensions of <img alt="Df" class="valign-m4" src="https://eli.thegreenplace.net/images/math/5c6bf530660cba6530e83a86f0ed49fe0821d179.png" style="height: 16px;" /> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object> make this multiplication valid.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id5"></a></td><td>It shouldn't be surprising we get here, since the definition of the directional derivative as the gradient <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent">was derived</a> using the multivariate chain rule.</td></tr> </tbody> </table> </div> Linear regression2016-08-06T05:28:00-07:002016-08-06T05:28:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-08-06:/2016/linear-regression/<p>Linear regression is one of the most basic, and yet most useful approaches for predicting a single quantitative (real-valued) variable given any number of real-valued predictors. This article presents the basics of linear regression for the &quot;simple&quot; (single-variable) case, as well as for the more general multivariate case. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/tree/master/linear-regression">Companion code …</a></p><p>Linear regression is one of the most basic, and yet most useful approaches for predicting a single quantitative (real-valued) variable given any number of real-valued predictors. This article presents the basics of linear regression for the &quot;simple&quot; (single-variable) case, as well as for the more general multivariate case. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/tree/master/linear-regression">Companion code in Python</a> implements the techniques described in the article on simulated and realistic data sets. The code is self-contained, using only Numpy as a dependency.</p> <div class="section" id="simple-linear-regression"> <h2>Simple linear regression</h2> <p>The most basic kind of regression problem has a single <em>predictor</em> (the input) and a single outcome. Given a list of input values <img alt="x_i" class="valign-m3" src="https://eli.thegreenplace.net/images/math/34e03e6559b14df9fe5a97bbd2ed10109dfebbd3.png" style="height: 11px;" /> and corresponding output values <img alt="y_i" class="valign-m4" src="https://eli.thegreenplace.net/images/math/35c2ac2f82d0ff8f9011b596ed7e54bfcc55f471.png" style="height: 12px;" />, we have to find parameters <em>m</em> and <em>b</em> such that the linear function:</p> <img alt="$\hat{y}(x) = mx + b$" class="align-center" src="https://eli.thegreenplace.net/images/math/2dabbcda3b1953b08211f7e334698366d647d697.png" style="height: 18px;" /> <p>Is &quot;as close as possible&quot; to the observed outcome <em>y</em>. More concretely, suppose we get this data <a class="footnote-reference" href="#id6" id="id1"></a>:</p> <img alt="Linear regression input data" class="align-center" src="https://eli.thegreenplace.net/images/2016/linreg-data.png" /> <p>We have to find a slope <em>m</em> and intercept <em>b</em> for a line that approximates this data as well as possible. We evaluate how well some pair of <em>m</em> and <em>b</em> approximates the data by defining a &quot;cost function&quot;. For linear regression, a good cost function to use is the <a class="reference external" href="https://en.wikipedia.org/wiki/Mean_squared_error">Mean Square Error (MSE)</a> <a class="footnote-reference" href="#id7" id="id2"></a>:</p> <img alt="$\operatorname{MSE}(m, b)=\frac{1}{n}\sum_{i=1}^n(\hat{y_i} - y_i)^2$" class="align-center" src="https://eli.thegreenplace.net/images/math/e4b7b4ce3abd90f20144e6ab468b7870cedf3b07.png" style="height: 50px;" /> <p>Expanding <img alt="\hat{y_i}=m{x_i}+b" class="valign-m4" src="https://eli.thegreenplace.net/images/math/daecd48b7bb0a06ddd4326da5b87ee14fddaeb8e.png" style="height: 17px;" />, we get:</p> <img alt="$\operatorname{MSE}(m, b)=\frac{1}{n}\sum_{i=1}^n(m{x_i} + b - y_i)^2$" class="align-center" src="https://eli.thegreenplace.net/images/math/3de1df776434b29620488aef327a9204757bc493.png" style="height: 50px;" /> <p>Let's turn this into Python code (<a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/linear-regression/simple_linear_regression.py">link to the full code sample</a>):</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">compute_cost</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Compute the MSE cost of a prediction based on m, b.</span> <span class="sd"> x: inputs vector</span> <span class="sd"> y: observed outputs vector</span> <span class="sd"> m, b: regression parameters</span> <span class="sd"> Returns: a scalar cost.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="n">yhat</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">b</span> <span class="n">diff</span> <span class="o">=</span> <span class="n">yhat</span> <span class="o">-</span> <span class="n">y</span> <span class="c1"># Vectorized computation using a dot product to compute sum of squares.</span> <span class="n">cost</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">diff</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">diff</span><span class="p">)</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="c1"># Cost is a 1x1 matrix, we need a scalar.</span> <span class="k">return</span> <span class="n">cost</span><span class="o">.</span><span class="n">flat</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> </pre></div> <p>Now we're faced with a classical optimization problem: we have some parameters (<em>m</em> and <em>b</em>) we can tweak, and some cost function <img alt="J(m, b)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d61807c64b6ab8087a11167224df4b5f818aeae3.png" style="height: 18px;" /> we want to minimize. The topic of mathematical optimization is vast, but what ends up working very well for machine learning is a fairly simple algorithm called <em>gradient descent</em>.</p> <p>Imagine plotting <img alt="J(m, b)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d61807c64b6ab8087a11167224df4b5f818aeae3.png" style="height: 18px;" /> as a 3-dimensional surface, and picking some random point on it. Our goal is to find the lowest point on the surface, but we have no idea where that is. A reasonable guess is to move a bit &quot;downwards&quot; from our current location, and then repeat.</p> <p>&quot;Downwards&quot; is exactly what &quot;gradient descent&quot; means. We make a small change to our location (defined by <em>m</em> and <em>b</em>) in the direction in which <img alt="J(m, b)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d61807c64b6ab8087a11167224df4b5f818aeae3.png" style="height: 18px;" /> decreases most - the gradient <a class="footnote-reference" href="#id8" id="id3"></a>. We then repeat this process until we reach a minimum, hopefully global. In fact, since the linear regression cost function is <em>convex</em> we will find the global minimum this way. But in the general case this is not guaranteed, and many sophisticated extensions of gradient descent exist that try to avoid local minima and maximize the chance of finding a global one.</p> <p>Back to our function, <img alt="\operatorname{MSE}(m, b)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e42329899ba53adedf0a7884b1844dba4f01bdee.png" style="height: 18px;" />. The gradient is defined as the vector:</p> <img alt="$\nabla \operatorname{MSE}=\left \langle \frac{\partial \operatorname{MSE}}{\partial m}, \frac{\partial \operatorname{MSE}}{\partial b} \right \rangle$" class="align-center" src="https://eli.thegreenplace.net/images/math/50b0404ea5a8f76da73caae5b8109dd384dbd18e.png" style="height: 43px;" /> <p>To find it, we have to compute the partial derivatives of MSE w.r.t. the learning parameters <em>m</em> and <em>b</em>:</p> <img alt="\begin{align*} \frac{\partial \operatorname{MSE}}{\partial m}&amp;amp;=\frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i)x_i\\ \frac{\partial \operatorname{MSE}}{\partial b}&amp;amp;=\frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i) \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/dbd383b0d7ee194a417b88ad117b451531758fe7.png" style="height: 108px;" /> <p>And then update <em>m</em> and <em>b</em> in each step of the learning with:</p> <img alt="\begin{align*} m &amp;amp;= m-\eta \frac{\partial \operatorname{MSE}}{\partial m} \\ b &amp;amp;= b-\eta \frac{\partial \operatorname{MSE}}{\partial b} \\ \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/b0c7ff699fc61836051968db56224e6470b56d3c.png" style="height: 81px;" /> <p>Where <img alt="\eta" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2899aeb886ad0fa72652bffd5511e452aaf084ab.png" style="height: 12px;" /> is a customizable &quot;learning rate&quot;, a hyperparameter. Here is the gradient descent loop in Python. Note that we examine the whole data set in every step; for much larger data sets, SGD (Stochastic Gradient Descent) with some reasonable mini-batch would make more sense, but for simple linear regression problems the data size is rarely very big.</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">gradient_descent</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">nsteps</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;Runs gradient descent optimization to fit a line y^ = x * m + b.</span> <span class="sd"> x, y: input data and observed outputs.</span> <span class="sd"> nsteps: how many steps to run the optimization for.</span> <span class="sd"> learning_rate: learning rate of gradient descent.</span> <span class="sd"> Yields &#39;nsteps + 1&#39; triplets of (m, b, cost) where m, b are the fit</span> <span class="sd"> parameters for the given step, and cost is their cost vs the real y.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="n">n</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># Start with m and b initialized to 0s for the first try.</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span> <span class="k">yield</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">compute_cost</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span> <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nsteps</span><span class="p">):</span> <span class="n">yhat</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">b</span> <span class="n">diff</span> <span class="o">=</span> <span class="n">yhat</span> <span class="o">-</span> <span class="n">y</span> <span class="n">dm</span> <span class="o">=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="p">(</span><span class="n">diff</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">/</span> <span class="n">n</span> <span class="n">db</span> <span class="o">=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">diff</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">/</span> <span class="n">n</span> <span class="n">m</span> <span class="o">-=</span> <span class="n">dm</span> <span class="n">b</span> <span class="o">-=</span> <span class="n">db</span> <span class="k">yield</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">compute_cost</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span> </pre></div> <p>After running this for 30 steps, the gradient converges and the parameters barely change. Here's a 3D plot of the cost as a function of the regression parameters, along with a contour plot of the same function. It's easy to see this function is convex, as expected. This makes finding the global minimum simple, since no matter where we start, the gradient will lead us directly to it.</p> <p>To help visualize this, I marked the cost for each successive training step on the contour plot - you can see how the algorithm relentlessly converges to the minimum</p> <img alt="Linear regression cost and contour" class="align-center" src="https://eli.thegreenplace.net/images/2016/linreg-cost-contour.png" /> <p>The final parameters learned by the regression are 2.2775 for <em>m</em> and 6.0028 for <em>b</em>, which is very close to the actual parameters I used to generate this fake data with.</p> <p>Here's a visualization that shows how the regression line improves progressively during learning:</p> <img alt="Regression fit visualization" class="align-center" src="https://eli.thegreenplace.net/images/2016/regressionfit.gif" /> </div> <div class="section" id="evaluating-how-good-the-fit-is"> <h2>Evaluating how good the fit is</h2> <p>In statistics, there are many ways to evaluate how good a &quot;fit&quot; some model is on the given data. One of the most popular ones is the <em>r-squared</em> test (&quot;coefficient of determination&quot;). It measures the proportion of the total variance in the output (<em>y</em>) that can be explained by the variation in <em>x</em>:</p> <img alt="$R^2 = 1 - \frac{\sum_{i=1}^n (y_i - (m{x_i} + b))^2}{n\cdot var(y)}$" class="align-center" src="https://eli.thegreenplace.net/images/math/2c989c7345d6901a0cf7c17f9b08762ef27c5148.png" style="height: 43px;" /> <p>This is trivial to translate to code:</p> <div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">compute_rsquared</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span> <span class="n">yhat</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">b</span> <span class="n">diff</span> <span class="o">=</span> <span class="n">yhat</span> <span class="o">-</span> <span class="n">y</span> <span class="n">SE_line</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">diff</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">diff</span><span class="p">)</span> <span class="n">SE_y</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="o">*</span> <span class="n">y</span><span class="o">.</span><span class="n">var</span><span class="p">()</span> <span class="k">return</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">SE_line</span> <span class="o">/</span> <span class="n">SE_y</span> </pre></div> <p>For our regression results, I get <em>r-squared</em> of 0.76, which isn't too bad. Note that the data is very jittery, so it's natural the regression cannot explain all the variance. As an interesting exercise, try to modify the code that generates the data with different standard deviations for the random noise and see the effect on <em>r-squared</em>.</p> </div> <div class="section" id="an-analytical-solution-to-simple-linear-regression"> <h2>An analytical solution to simple linear regression</h2> <p>Using the equations for the partial derivatives of MSE (shown above) it's possible to find the minimum analytically, without having to resort to a computational procedure (gradient descent). We compare the derivatives to zero:</p> <img alt="\begin{align*} \frac{\partial \operatorname{MSE}}{\partial m}&amp;amp;=\frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i)x_i = 0\\ \frac{\partial \operatorname{MSE}}{\partial b}&amp;amp;=\frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i) = 0 \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/aef02f077919896478d0456619f934dcc5809142.png" style="height: 108px;" /> <p>And solve for <em>m</em> and <em>b</em>. To make the equations easier to follow, let's introduce a bit of notation. <img alt="\bar{x}" class="valign-0" src="https://eli.thegreenplace.net/images/math/8eebe76c6f552df3f8b9480d5544fe47b1028322.png" style="height: 11px;" /> is the mean value of <em>x</em> across all samples. Similarly <img alt="\bar{y}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1e3bffc7f71c01acbc2c12e015be3086a06f824d.png" style="height: 15px;" /> is the mean value of <em>y</em>. So the sum <img alt="\sum_{i=1}^n x_i" class="valign-m6" src="https://eli.thegreenplace.net/images/math/c42eb1b96dfa184fee1bc0f3a4b713b9c38b2a1a.png" style="height: 20px;" /> is actually <img alt="n\bar{x}" class="valign-0" src="https://eli.thegreenplace.net/images/math/ea6008aefff0c7d79044287c44e890b1fba97c22.png" style="height: 11px;" />. Now let's take the second equation from above and see how to simplify it:</p> <img alt="\begin{align*} \frac{\partial \operatorname{MSE}}{\partial b} &amp;amp;= \frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i) \\ &amp;amp;= \frac{2}{n}(mn\bar{x}+nb-n\bar{y}) \\ &amp;amp;= 2m\bar{x} + 2b - 2\bar{y} = 0 \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/c97c0c9ca8a66d54974fc914fcf929085dc63879.png" style="height: 119px;" /> <p>Similarly, for the partial derivative by <em>m</em> we can reach:</p> <img alt="$\frac{\partial \operatorname{MSE}}{\partial m}= 2m\overline{x^2} + 2b\bar{x} - 2\overline{xy} = 0$" class="align-center" src="https://eli.thegreenplace.net/images/math/d9545273e11c9e179794f943e2c972bf62c38113.png" style="height: 38px;" /> <p>In these equations, all quantities except <em>m</em> and <em>b</em> are constant. Solving them for the unknowns <em>m</em> and <em>b</em>, we get <a class="footnote-reference" href="#id9" id="id4"></a>:</p> <img alt="$m = \frac{\bar{x}\bar{y} - \overline{xy}}{\bar{x}^2 - \overline{x^2}} \qquad b = \bar{y} - \bar{x}\frac{\bar{x}\bar{y} - \overline{xy}}{\bar{x}^2 - \overline{x^2}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/becd671e8c032d0568e33b986033c181ac5c133b.png" style="height: 38px;" /> <p>If we plug the data values we have for <em>x</em> and <em>y</em> in these equations, we get 2.2777 for <em>m</em> and 6.0103 for <em>b</em> - almost exactly the values we obtained with regression <a class="footnote-reference" href="#id10" id="id5"></a>.</p> <p>Remember that by comparing the partial derivatives to zero we find a <em>critical point</em>, which is not necessarily a minimum. We can use the <a class="reference external" href="https://en.wikipedia.org/wiki/Second_partial_derivative_test">second derivative test</a> to find what kind of critical point that is, by computing the Hessian of the cost:</p> <img alt="$H(m, b) = \begin{pmatrix} \operatorname{MSE}_{mm}(x, y) &amp;amp; \operatorname{MSE}_{mb}(x, y) \\ \operatorname{MSE}_{bm}(x, y) &amp;amp; \operatorname{MSE}_{bb}(x, y) \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/39c2e86ae1437d3b19bc8e77b66501486550d3bc.png" style="height: 43px;" /> <p>Plugging the numbers and running the test, we can indeed verify that the critical point is a minimum.</p> </div> <div class="section" id="multiple-linear-regression"> <h2>Multiple linear regression</h2> <p>The good thing about simple regression is that it's easy to visualize. The model is trained using just two parameters, and visualizing the cost as a function of these two parameters is possible since we get a 3D plot. Anything beyond that becomes increasingly more difficult to visualize.</p> <p>In simple linear regression, every <em>x</em> is just a number; so is every <em>y</em>. In multiple linear regression this is no longer so, and each data point <em>x</em> is a vector. The model parameters can also be represented by the vector <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />. To avoid confusion of indices and subscripts, let's agree that we use subscripts to denote components of vectors, while parenthesized superscripts are used to denote different samples. So <img alt="x_1^{(6)}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/d01999f5014c6aea058368231c0d2b958fa8a89e.png" style="height: 26px;" /> is the second component of sample 6.</p> <p>Our goal is to find the vector <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> such that the linear function:</p> <img alt="$\hat{y}(x) = \theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n$" class="align-center" src="https://eli.thegreenplace.net/images/math/ae682f9fda97c28c8e100c87aecad635c7c1d96c.png" style="height: 18px;" /> <p>Is as close as possible to the actual <em>y</em> across all samples. Since working with vectors is easier for this problem, we define <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> to always be equal to 1, so that the first term in the equation above denotes the intercept. Expressing the regression coefficients as a vector:</p> <img alt="$\begin{pmatrix} \theta_0\\ \theta_1\\ ...\\ \theta_n \end{pmatrix}\in\mathbb{R}^{n+1}$" class="align-center" src="https://eli.thegreenplace.net/images/math/b16fd3d2b3041f13cb70199837a7c02c756078c7.png" style="height: 86px;" /> <p>We can now rewrite <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> as:</p> <img alt="$\hat{y}(x) = \theta^T x$" class="align-center" src="https://eli.thegreenplace.net/images/math/8156e2dc4e654f77a8664180c168829f6b4cdb0b.png" style="height: 21px;" /> <p>Where both <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> and <em>x</em> are column vectors with <em>n+1</em> elements, as shown above. The mean square error (over <em>k</em> samples) now becomes:</p> <img alt="$\operatorname{MSE}=\frac{1}{k}\sum_{i=1}^k(\hat{y}(x^{(i)}) - y^{(i)})^2$" class="align-center" src="https://eli.thegreenplace.net/images/math/1e0a7c0c85c1827b992671b88e89ba052d37a204.png" style="height: 54px;" /> <p>Now we have to find the partial derivative of this cost by each <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />. Using the chain rule, it's easy to see that:</p> <img alt="$\frac{\partial \operatorname{MSE}}{\partial \theta_j} = \frac{2}{k}\sum_{i=1}^k(\hat{y}(x^{(i)}) - y^{(i)})x_j^{(i)}$" class="align-center" src="https://eli.thegreenplace.net/images/math/4c2fcfed81c294ef7313198debe3801f50bea92a.png" style="height: 54px;" /> <p>And use this to update the parameters in every training step. The code is actually not much different from the simple regression case; here is a <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/linear-regression/multiple_linear_regression.py">well documented, completely worked out example</a>. The code takes a realistic dataset from the <a class="reference external" href="http://archive.ics.uci.edu/ml/">UCI machine learning repository</a> with 4 predictors and a single outcome and builds a regression model. 4 predictors plus one intercept give us a 5-dimensional <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, which is utterly impossible to visualize, so we have to stick to math in order to analyze it.</p> </div> <div class="section" id="an-analytical-solution-to-multiple-linear-regression"> <h2>An analytical solution to multiple linear regression</h2> <p>Multiple linear regression also has an analytical solution. If we compute the derivative of the cost by each <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, we'll end up with <em>n+1</em> equations with the same number of variables, which we can solve analytically.</p> <p>An elegant matrix formula that computes <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> from <em>X</em> and <em>y</em> is called the Normal Equation:</p> <img alt="$\theta=(X^TX)^{-1}X^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/20baabd9d33dcd26003bc44c7d81ba39e1ad4caa.png" style="height: 21px;" /> <p>I've written about <a class="reference external" href="http://eli.thegreenplace.net/2014/derivation-of-the-normal-equation-for-linear-regression">deriving the normal equation</a> previously, so I won't spend more time on it. The accompanying code computes <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> using the normal equation and compares the result with the <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> obtained from gradient descent.</p> <p>As an excercise, you can double check that the analytical solution for simple linear regression (formulae for <em>m</em> and <em>b</em>) is just a special case of applying the normal equation in two dimensions.</p> <p>You may wonder: when should we use the analytical solution, and when is gradient descent better? In general, whenever we can use the analytical solution - we should. But it's not always feasible, computationally.</p> <p>Consider a data set with <em>k</em> samples and <em>n</em> features. Then <em>X</em> is a <em>k x n</em> matrix, and hence <img alt="X^TX" class="valign-0" src="https://eli.thegreenplace.net/images/math/5c817c84ec1f83b23494df6125edd091a7c413dd.png" style="height: 15px;" /> is a <em>n x n</em> matrix. Inverting a matrix is a <img alt="O(n^3)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/62a87bfd600dc05059675e34b881c78648f53401.png" style="height: 19px;" /> operation, so for large <em>n</em>, finding <img alt="(X^TX)^{-1}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/57f592cee6ceac659262d97e61c64f9ca405d7f1.png" style="height: 19px;" /> can take quite a bit of time. Moreover, keeping <img alt="X^TX" class="valign-0" src="https://eli.thegreenplace.net/images/math/5c817c84ec1f83b23494df6125edd091a7c413dd.png" style="height: 15px;" /> in memory can be computationally infeasible if <img alt="X" class="valign-0" src="https://eli.thegreenplace.net/images/math/c032adc1ff629c9b66f22749ad667e6beadf144b.png" style="height: 12px;" /> is huge and sparse, but <img alt="X^TX" class="valign-0" src="https://eli.thegreenplace.net/images/math/5c817c84ec1f83b23494df6125edd091a7c413dd.png" style="height: 15px;" /> is dense. In all these cases, iterative gradient descent is a more feasible approach.</p> <p>In addition, the moment we deviate from the linear regression a bit, such as adding nonlinear terms, regularization, or some other model enhancement, the analytical solutions no longer apply. Gradient descent keeps working just the same, however, as long as we know how to compute the gradient of the new cost function.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>This data was generated by using a slope of 2.25, intercept of 6 and added Gaussian noise with a standard deviation of 1.5</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id7" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>Some resources use SSE - the Squared Sum Error, which is just the MSE without the averaging. Yet others have <em>2n</em> in the denominator to make the gradient derivation cleaner. None of this really matters in practice. When finding the minimum analytically, we compare derivatives to zero so constant factors cancel out. When running gradient descent, all constant factors are subsumed into the learning rate which is arbitrary.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>For a mathematical justification for <em>why</em> the gradient leads us in the direction of most change, see <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent">this post</a>.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id9" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>An alternative way I've seen this equation written is to express <em>m</em> as:</td></tr> </tbody> </table> <img alt="\begin{align*} m &amp;amp;= \frac{\sum_{i=1}^n(x_i-\bar{x})(y_i-\bar{y})}{\sum_{i=1}^n(x_i-\bar{x})^2} \\ &amp;amp;= \frac{cov(x, y)}{var(x)} \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/53639f1f77080dbe8a6d3a8cd06e08a90de69a8e.png" style="height: 92px;" /> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id5"></a></td><td>Can you figure out why even the analytical solution is a little off from the actual parameters used to generated this data?</td></tr> </tbody> </table> </div> Understanding gradient descent2016-08-05T05:38:00-07:002016-08-05T05:38:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-08-05:/2016/understanding-gradient-descent/<p>Gradient descent is a standard tool for optimizing complex functions iteratively within a computer program. Its goal is: given some arbitrary function, find a minumum. For some small subset of functions - those that are <em>convex</em> - there's just a single minumum which also happens to be global. For most realistic functions …</p><p>Gradient descent is a standard tool for optimizing complex functions iteratively within a computer program. Its goal is: given some arbitrary function, find a minumum. For some small subset of functions - those that are <em>convex</em> - there's just a single minumum which also happens to be global. For most realistic functions, there may be many minima, so most minima are local. Making sure the optimization finds the &quot;best&quot; minumum and doesn't get stuck in sub-optimial minima is out of the scope of this article. Here we'll just be dealing with the core gradient descent algorithm for finding <em>some</em> minumum from a given starting point.</p> <p>The main premise of gradient descent is: given some current location <em>x</em> in the search space (the domain of the optimized function) we ought to update <em>x</em> for the next step in the direction opposite to the gradient of the function computed at <em>x</em>. But <em>why</em> is this the case? The aim of this article is to explain why, mathematically.</p> <p>This is also the place for a disclaimer: the examples used throughout the article are trivial, low-dimensional, convex functions. We don't really <em>need</em> an algorithmic procedure to find their global minumum - a quick computation would do, or really just eyeballing the function's plot. In reality we will be dealing with non-linear, 1000-dimensional functions where it's utterly impossible to visualize anything, or solve anything analytically. The approach works just the same there, however.</p> <div class="section" id="building-intuition-with-single-variable-functions"> <h2>Building intuition with single-variable functions</h2> <p>The gradient is formally defined for <em>multivariate</em> functions. However, to start building intuition, it's useful to begin with the two-dimensional case, a single-variable function <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" />.</p> <p>In single-variable functions, the simple derivative plays the role of a gradient. So &quot;gradient descent&quot; would really be &quot;derivative descent&quot;; let's see what that means.</p> <p>As an example, let's take the function <img alt="f(x)=(x-1)^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b898d66867ea1e832ab5cda94453ab3a69bae865.png" style="height: 19px;" />. Here's its plot, in red:</p> <img alt="Plot of parabola with tangent lines" class="align-center" src="https://eli.thegreenplace.net/images/2016/plot-parabola-with-tangents.png" /> <p>I marked a couple of points on the plot, in blue, and drew the tangents to the function at these points. Remember, our goal is to find the minimum of the function. To do that, we'll start with a guess for an <em>x</em>, and continously update it to improve our guess based on some computation. How do we know how to update <em>x</em>? The update has only two possible directions: increase <em>x</em> or decrease <em>x</em>. We have to decide which of the two directions to take.</p> <p>We do that based on the derivative of <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" />. The derivative at some point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> is defined as the limit <a class="footnote-reference" href="#id5" id="id1"></a>:</p> <img alt="$\frac{d}{dx}f(x_0)=\lim_{h \to 0}\frac{f(x_0+h)-f(x_0)}{h}$" class="align-center" src="https://eli.thegreenplace.net/images/math/bfd7f38f59e2ff0d548c19f8f780605b099ecaf7.png" style="height: 39px;" /> <p>Intuitively, this tells us what happens to <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> when we add a very small value to <em>x</em>. For example in the plot above, at <img alt="x_0=3" class="valign-m3" src="https://eli.thegreenplace.net/images/math/5fa44ff4e2c914452bf56041b4ef99ceb61592f9.png" style="height: 15px;" /> we have:</p> <img alt="\begin{align*} \frac{d}{dx}f(3)&amp;amp;=\lim_{h \to 0}\frac{f(3+h)-f(3)}{h} \\ &amp;amp;=\lim_{h \to 0}\frac{(3+h-1)^2-(3-1)^2}{h} \\ &amp;amp;=\lim_{h \to 0}\frac{h^2+4h}{h} \\ &amp;amp;=\lim_{h \to 0}h+4=4 \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/e572beffc8415b4ba4c8c9419105863e3ce2082f.png" style="height: 168px;" /> <p>This means that the <em>slope</em> of <img alt="\frac{df}{dx}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/45e7d07281bf1883224069f5b8d98a4bd6b21693.png" style="height: 23px;" /> at <img alt="x_0=3" class="valign-m3" src="https://eli.thegreenplace.net/images/math/5fa44ff4e2c914452bf56041b4ef99ceb61592f9.png" style="height: 15px;" /> is 4; for a very small positive change <em>h</em> to <em>x</em> at that point, the value of <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> will increase by <em>4h</em>. Therefore, to get closer to the minimum of <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> we should rather <em>decrease</em> <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> a bit.</p> <p>Let's take another example point, <img alt="x_0=-1" class="valign-m3" src="https://eli.thegreenplace.net/images/math/c84eef20ea61cf41b13fd1a157968eba20823c8e.png" style="height: 15px;" />. At that point, if we add a little bit to <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> will <em>decrease</em> by 4x that little bit. So that's exactly what we should do to get closer to the minimum.</p> <p>It turns out that in both cases, we should nudge <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> in the direction opposite to the derivative at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />. That's the most basic idea behind gradient descent - the derivative shows us the way to the minimum; or rather, it shows us the way to the maximum and we then go in the opposite direction. Given some initial guess <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, the next guess will be:</p> <img alt="$x_1=x_0-\eta\frac{d}{dx}f(x_0)$" class="align-center" src="https://eli.thegreenplace.net/images/math/d8666c1e2cf8740af45a228730f7c632fc00ed14.png" style="height: 37px;" /> <p>Where <img alt="\eta" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2899aeb886ad0fa72652bffd5511e452aaf084ab.png" style="height: 12px;" /> is what we call a &quot;learning rate&quot;, and is constant for each given update. It's the reason why we don't care much about the magnitude of the derivative at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, only its direction. In general, it makes sense to keep the learning rate fairly small so we only make a tiny step at at time. This makes sense mathematically, because the derivative at a point is defined as the rate of change of <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> assuming an infinitesimal change in <em>x</em>. For some large change <em>x</em> who knows where we will get. It's easy to imagine cases where we'll entirely overshoot the minimum by making too large a step <a class="footnote-reference" href="#id6" id="id2"></a>.</p> </div> <div class="section" id="multivariate-functions-and-directional-derivatives"> <h2>Multivariate functions and directional derivatives</h2> <p>With functions of multiple variables, derivatives become more interesting. We can't just say &quot;the derivative points to where the function is increasing&quot;, because... which derivative?</p> <p>Recall the formal definition of the derivative as the limit for a small step <em>h</em>. When our function has many variables, which one should have the step added? One at a time? All at once? In multivariate calculus, we use partial derivatives as building blocks. Let's use a function of two variables - <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> as an example throughout this section, and define the partial derivatives w.r.t. <em>x</em> and <em>y</em> at some point <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" />:</p> <img alt="\begin{align*} \frac{\partial }{\partial x}f(x_0,y_0)&amp;amp;=\lim_{h \to 0}\frac{f(x_0+h,y_0)-f(x_0,y_0)}{h} \\ \frac{\partial }{\partial y}f(x_0,y_0)&amp;amp;=\lim_{h \to 0}\frac{f(x_0,y_0+h)-f(x_0,y_0)}{h} \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/b58dd3cada7292828cf79f3ca8653a99fd94c1f9.png" style="height: 87px;" /> <p>When we have a single-variable function <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" />, there's really only two directions in which we can move from a given point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> - left (decrease <em>x</em>) or right (increase <em>x</em>). With two variables, the number of possible directions is <em>infinite</em>, becase we pick a direction to move on a 2D plane. Hopefully this immediately pops ups &quot;vectors&quot; in your head, since vectors are the perfect tool to deal with such problems. We can represent the change from the point <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" /> as the vector <img alt="\vec{v}=\langle a,b \rangle" class="valign-m5" src="https://eli.thegreenplace.net/images/math/4ef7c8a835491ba5ec6dc5f2b94ebff879938a21.png" style="height: 19px;" /> <a class="footnote-reference" href="#id7" id="id3"></a>. The <em>directional derivative</em> of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> along <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> at <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" /> is defined as its rate of change in the direction of the vector at that point. Mathematically, it's defined as:</p> <img alt="$\begin{equation} D_{\vec{v}}f(x_0,y_0)=\lim_{h \to 0}\frac{f(x_0+ah,y_0+bh)-f(x_0,y_0)}{h} \tag{1} \end{equation}$" class="align-center" src="https://eli.thegreenplace.net/images/math/1af5afd7427f744daa0c75b05697b32f21b2f40c.png" style="height: 39px;" /> <p>The partial derivatives w.r.t. <em>x</em> and <em>y</em> can be seen as special cases of this definition. The partial derivative <img alt="\frac{\partial f}{\partial x}" class="valign-m7" src="https://eli.thegreenplace.net/images/math/75a2ab078215106a1084cf5e262e98f32c1cc3b9.png" style="height: 25px;" /> is just the directional direvative in the direction of the <em>x</em> axis. In vector-speak, this is the directional derivative for <img alt="\vec{v}=\langle a,b \rangle=\widehat{e_x}=\langle 1,0 \rangle" class="valign-m5" src="https://eli.thegreenplace.net/images/math/36b4fd6cf884fd12b36c605cb6ec7a7c9b4ee65f.png" style="height: 19px;" />, the standard basis vector for <em>x</em>. Just plug <img alt="a=1,b=0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7feadfc4043894ed6a3de2cced949a91bea9e5b2.png" style="height: 17px;" /> into (1) to see why. Similarly, the partial derivative <img alt="\frac{\partial f}{\partial y}" class="valign-m9" src="https://eli.thegreenplace.net/images/math/5bc3d10d9714f8f7a95791fe29e497cf0ecbe3b0.png" style="height: 27px;" /> is the directional derivative in the direction of the standard basis vector <img alt="\widehat{e_y}=\langle 0,1 \rangle" class="valign-m6" src="https://eli.thegreenplace.net/images/math/3ce4793144c7bfd02d245b81f8bd44a595721196.png" style="height: 20px;" />.</p> </div> <div class="section" id="a-visual-interlude"> <h2>A visual interlude</h2> <p>Functions of two variables <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> are the last frontier for meaningful visualizations, for which we need 3D to plot the value of <img alt="f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4a0a19218e082a343a1b17e5333409af9d98f0f5.png" style="height: 16px;" /> for each given <em>x</em> and <em>y</em>. Even in 3D, visualizing gradients is significantly harder than in 2D, and yet we have to try since for anything above two variables all hopes of visualization are lost.</p> <p>Here's the function <img alt="f(x,y)=x^2+y^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d3eb0fc536d00e84cd63bb5af98b7e2bc01bde4f.png" style="height: 19px;" /> plotted in a small range around zero. I drew the standard basis vectors <img alt="\widehat{x}=\widehat{e_x}" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0ea0752aa73540ee1e464a42d5d1b2b9741d3eab.png" style="height: 17px;" /> and <img alt="\widehat{y}=\widehat{e_y}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/c0bf47cb98b1f01e6b47992929694ec9da20f8f7.png" style="height: 20px;" /> <a class="footnote-reference" href="#id8" id="id4"></a> and some combination of them <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />.</p> <img alt="3D parabola with direction vector markers" class="align-center" src="https://eli.thegreenplace.net/images/2016/plot-3d-parabola.png" /> <p>I also marked the point on <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> where the vectors are based. The goal is to help us keep in mind how the independent variables <em>x</em> and <em>y</em> change, and how that affects <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />. We change <em>x</em> and <em>y</em> by adding some small vector <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> to their current value. The result is &quot;nudging&quot; <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> in the direction of <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />. Remember our goal for this article - find <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> such that this &quot;nudge&quot; gets us closer to a minimum.</p> </div> <div class="section" id="finding-directional-derivatives-using-the-gradient"> <h2>Finding directional derivatives using the gradient</h2> <p>As we've seen, the derivative of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> in the direction of <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> is defined as:</p> <img alt="$D_{\vec{v}}f(x_0,y_0)=\lim_{h \to 0}\frac{f(x_0+ah,y_0+bh)-f(x_0,y_0)}{h}$" class="align-center" src="https://eli.thegreenplace.net/images/math/9f2c62d64f016bd77712873294a0f5e64537b1ab.png" style="height: 39px;" /> <p>Looking at the 3D plot above, this is how much the value of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> changes when we add <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> to the vector <img alt="\langle x_0,y_0 \rangle" class="valign-m5" src="https://eli.thegreenplace.net/images/math/f74aa2c6fda35535931fad69ec339eaef3693913.png" style="height: 19px;" />. But how do we do that? This limit definition doesn't look like something friendly for analytical analysis for arbitrary functions. Sure, we could plug <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" /> and <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> in there and do the computation, but it would be nice to have an easier-to-use formula. Luckily, with the help of the gradient of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> it becomes much easier.</p> <p>The gradient is a vector value we compute from a scalar function. It's defined as:</p> <img alt="$\nabla f=\left \langle \frac{\partial f}{\partial x},\frac{\partial f}{\partial y} \right \rangle$" class="align-center" src="https://eli.thegreenplace.net/images/math/03eab64984be412b6db132c2534bbecc006af47c.png" style="height: 43px;" /> <p>It turns out that given a vector <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />, the directional derivative <img alt="D_{\vec{v}}f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/03a3931c968b3b6f26e82958785539d74db94293.png" style="height: 16px;" /> can be expressed as the following dot product:</p> <img alt="$D_{\vec{v}}f=(\nabla f) \cdot \vec{v}$" class="align-center" src="https://eli.thegreenplace.net/images/math/49933775272512c4c8686d9f9692c8ea01e1c97d.png" style="height: 18px;" /> <p>If this looks like a mental leap too big to trust, please read the Appendix section at the bottom. Otherwise, feel free to verify that the two are equivalent with a couple of examples. For instance, try to find the derivative in the direction of <img alt="\vec{v}=\langle \frac{1}{\sqrt{2}},\frac{1}{\sqrt{2}} \rangle" class="valign-m11" src="https://eli.thegreenplace.net/images/math/d614069c5beaf6fb858de40fa492a7b523a683d9.png" style="height: 27px;" /> at <img alt="(x_0,y_0)=(-1.5,0.25)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/61355565f13944faf85baec62c5fc1a682b0b5d5.png" style="height: 18px;" />. You should get <img alt="\frac{-2.5}{\sqrt{2}}" class="valign-m11" src="https://eli.thegreenplace.net/images/math/0c22fc563236a48f94882876c68f6edc0c3fb4da.png" style="height: 27px;" /> using both methods.</p> </div> <div class="section" id="direction-of-maximal-change"> <h2>Direction of maximal change</h2> <p>We're almost there! Now that we have a relatively simple way of computing any directional derivative from the partial derivatives of a function, we can figure out which direction to take to get the maximal change in the value of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />.</p> <p>We can rewrite:</p> <img alt="$D_{\vec{v}}f=(\nabla f) \cdot \vec{v}$" class="align-center" src="https://eli.thegreenplace.net/images/math/49933775272512c4c8686d9f9692c8ea01e1c97d.png" style="height: 18px;" /> <p>As:</p> <img alt="$D_{\vec{v}}f=\left \| \nabla f \right \| \left \| \vec{v} \right \| cos(\theta)$" class="align-center" src="https://eli.thegreenplace.net/images/math/8227de3117c60690ced3153cdc38d9bccd960fba.png" style="height: 19px;" /> <p>Where <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> is the angle between the two vectors. Now, recall that <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> is normalized so its magnitude is 1. Therefore, we only care about the <em>direction</em> of <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> w.r.t. the gradient. When is this equation maximized? When <img alt="\theta=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/a1dffbe89f1ec5a919198de979fca459eb7fdf84.png" style="height: 12px;" />, because then <img alt="cos(\theta)=1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/66a6eb87ec7f340e2e24bd46cdf02ab050013aac.png" style="height: 18px;" />. Since a cosine can never be larger than 1, that's the best we can have.</p> <p>So <img alt="\theta=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/a1dffbe89f1ec5a919198de979fca459eb7fdf84.png" style="height: 12px;" /> gives us the largest positive change in <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />. To get <img alt="\theta=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/a1dffbe89f1ec5a919198de979fca459eb7fdf84.png" style="height: 12px;" />, <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> has to point in the same direction as the gradient. Similarly, for <img alt="\theta=180^{\circ}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/f35bd3cc416e154fddabe833458147c566028a8c.png" style="height: 13px;" /> we get <img alt="cos(\theta)=-1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/65b96d5ab442e325098894e80d655263a24b14d6.png" style="height: 18px;" /> and therefore the largest <em>negative</em> change in <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />. So if we want to decrease <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> the most, <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> has to point in the opposite direction of the gradient.</p> </div> <div class="section" id="gradient-descent-update-for-multivariate-functions"> <h2>Gradient descent update for multivariate functions</h2> <p>To sum up, given some starting point <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" />, to nudge it in the direction of the minimum of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />, we first compute the gradient of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> at <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" />. Then, we update (using vector notation):</p> <img alt="$\langle x_1,y_1 \rangle=\langle x_0,y_0 \rangle-\eta \nabla{f(x_0,y_0)}$" class="align-center" src="https://eli.thegreenplace.net/images/math/66a0a92b6ff9a4c0d2162a41484ab17115f57bd7.png" style="height: 19px;" /> <p>Generalizing to multiple dimensions, let's say we have the function <img alt="f:\mathbb{R}^n\rightarrow \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/5b4aba3ea35b9daec583b61ecb5a556ae28103e3.png" style="height: 16px;" /> taking the n-dimensional vector <img alt="\vec{x}=(x_1,x_2 \dots ,x_n)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e8ece11f27b7cf7726e6ea055cfb0718761733e0.png" style="height: 18px;" />. We define the gradient update at step <em>k</em> to be:</p> <img alt="$\vec{x}_{(k)}=\vec{x}_{(k-1)} - \eta \nabla{f(\vec{x}_{(k-1)})}$" class="align-center" src="https://eli.thegreenplace.net/images/math/265d53b7832258e30f00049a1772e9f213140628.png" style="height: 20px;" /> <p>Previously, for the single-variate case we said that the derivatve points us to the way to the minimum. Now we can say that while there are many ways to get to the minimum (eventually), the gradient points us to the <em>fastest</em> way from any given point.</p> </div> <div class="section" id="appendix-directional-derivative-definition-and-gradient"> <h2>Appendix: directional derivative definition and gradient</h2> <p>This is an optional section for those who don't like taking mathematical statements for granted. Now it's time to prove the equation shown earlier in the article, and on which its main result is based:</p> <img alt="$D_{\vec{v}}f=(\nabla f) \cdot \vec{v}$" class="align-center" src="https://eli.thegreenplace.net/images/math/49933775272512c4c8686d9f9692c8ea01e1c97d.png" style="height: 18px;" /> <p>As usual with proofs, it really helps to start by working through an example or two to build up some intuition into why the equation works. Feel free to do that if you'd like, using any function, starting point and direction vector <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />.</p> <p>Suppose we define a function <img alt="w(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0382ffc90ae7b4c24238f68a32bebd14bc53c8d7.png" style="height: 18px;" /> as follows:</p> <img alt="$w(t)=f(x,y)$" class="align-center" src="https://eli.thegreenplace.net/images/math/dc37eb3cf47966d7338e561faffeffbb291085c5.png" style="height: 18px;" /> <p>Where <img alt="x=x(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/97aeb925cf8f501cc8836794ee06fb357b9d9a83.png" style="height: 18px;" /> and <img alt="y=y(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ebacc26a97fccf1aa96e1b59f21fcb2ca66c8924.png" style="height: 18px;" /> defined as:</p> <img alt="\begin{align*} x(t)&amp;amp;=x_0+at \\ y(t)&amp;amp;=y_0+bt \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/27988a5772de0fe761873494e88f7cad887ede85.png" style="height: 45px;" /> <p>In these definitions, <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, <img alt="y_0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2bb5817d0f3bf8490a8c7b1343f84f9635e683a3.png" style="height: 12px;" />, <em>a</em> and <em>b</em> are constants, so both <img alt="x(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/62b10cd9e1396c7ea33fd211e67de2fb29019cfc.png" style="height: 18px;" /> and <img alt="y(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ed8576b7227103b62d3648e7d1bbdff4052b27ff.png" style="height: 18px;" /> are truly functions of a single variable. Using <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus">the chain rule</a>), we know that:</p> <img alt="$\frac{dw}{dt}=\frac{\partial f}{\partial x}\frac{dx}{dt}+\frac{\partial f}{\partial y}\frac{dy}{dt}$" class="align-center" src="https://eli.thegreenplace.net/images/math/d5f4f13aeba35328cd2bea9b247842acb7524724.png" style="height: 41px;" /> <p>Substituting the derivatives of <img alt="x(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/62b10cd9e1396c7ea33fd211e67de2fb29019cfc.png" style="height: 18px;" /> and <img alt="y(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ed8576b7227103b62d3648e7d1bbdff4052b27ff.png" style="height: 18px;" />, we get:</p> <img alt="$\frac{dw}{dt}=a\frac{\partial f}{\partial x}+b\frac{\partial f}{\partial y}$" class="align-center" src="https://eli.thegreenplace.net/images/math/829069469d88717c9d95e3f788ed9e0c6cbeebc6.png" style="height: 41px;" /> <p>One more step, the significance of which will become clear shortly. Specifically, the derivative of <img alt="w(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0382ffc90ae7b4c24238f68a32bebd14bc53c8d7.png" style="height: 18px;" /> at <img alt="t=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/31056375cdff6a052261f18ceb3afe466731302a.png" style="height: 12px;" /> is:</p> <img alt="$\begin{equation} \frac{d}{dt}w(0)=a\frac{\partial}{\partial x}f(x_0,y_0)+b\frac{\partial}{\partial y}f(x_0,y_0) \tag{2} \end{equation}$" class="align-center" src="https://eli.thegreenplace.net/images/math/ea579cad8f6c62a817f2253e1d596178ea673d37.png" style="height: 41px;" /> <p>Now let's see how to compute the derivative of <img alt="w(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0382ffc90ae7b4c24238f68a32bebd14bc53c8d7.png" style="height: 18px;" /> at <img alt="t=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/31056375cdff6a052261f18ceb3afe466731302a.png" style="height: 12px;" /> using the formal limit definition:</p> <img alt="\begin{align*} \frac{d}{dt}w(0)&amp;amp;=\lim_{h \to 0}\frac{w(h)-w(0)}{h} \\ &amp;amp;=\lim_{h \to 0}\frac{f(x_0+ah,b_0+bh)-f(x_0,y_0)}{h} \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/10a224da7b7ab2424b9f88edcbfe17f273f3bd8b.png" style="height: 84px;" /> <p>But the latter is precisely the definition of the directional derivative in equation (1). Therefore, we can say that:</p> <img alt="$\frac{d}{dt}w(0)=D_{\vec{v}}f(x_0,y_0)$" class="align-center" src="https://eli.thegreenplace.net/images/math/4f377110022c468e46cbdb32bfb11a072d11b330.png" style="height: 37px;" /> <p>From this and (2), we get:</p> <img alt="$\frac{d}{dt}w(0)=D_{\vec{v}}f(x_0,y_0)=a\frac{\partial}{\partial x}f(x_0,y_0)+b\frac{\partial}{\partial y}f(x_0,y_0)$" class="align-center" src="https://eli.thegreenplace.net/images/math/d259ebf3697f480823a40247ce7191f9e954a584.png" style="height: 41px;" /> <p>This derivation is not special to the point <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" /> - it works just as well for any point where <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> has partial derivatives w.r.t. <em>x</em> and <em>y</em>; therefore, for any point <img alt="(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d330d6e65470cb03e76e092ee47971f9e931f759.png" style="height: 18px;" /> where <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> is differentiable:</p> <img alt="\begin{align*} D_{\vec{v}}f(x,y)&amp;amp;=a\frac{\partial}{\partial x}f(x,y)+b\frac{\partial}{\partial y}f(x,y) \\ &amp;amp;=\left \langle \frac{\partial f}{\partial x},\frac{\partial f}{\partial y} \right \rangle \cdot \langle a,b \rangle \\ &amp;amp;=(\nabla f) \cdot \vec{v} \qedhere \end{align*}" class="align-center" src="https://eli.thegreenplace.net/images/math/7c306dfedd474d99a62894e6258cea186d8be428.png" style="height: 115px;" /> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id5" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>The notation <img alt="\frac{d}{dx}f(x_0)" class="valign-m6" src="https://eli.thegreenplace.net/images/math/b0d6f765abf215972d5dbb982f77f1a83c233066.png" style="height: 22px;" /> means: the value of the derivative of <img alt="f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4a0a19218e082a343a1b17e5333409af9d98f0f5.png" style="height: 16px;" /> w.r.t. <em>x</em>, evaluated at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />. Another way to say the same would be <img alt="f{}&amp;#x27;(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e11c4ee90d42c3261aec6ef9c71893411b11cf34.png" style="height: 18px;" />.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>That said, in some advanced variations of gradient descent we actually want to probe different areas of the function early on in the process, so a larger step makes sense (remember, realistic functions have many local minima and we want to find the best one). Further along in the optimization process, when we've settled on a general area of the function we want the learning rate to be small so we actually get to the minimum. This approach is called <em>annealing</em> and I'll leave it for some future article.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id7" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>To avoid tracking vector magnitudes, from now on in the article we'll be dealing with <em>normalized</em> direction vectors. That is, we always assume that <img alt="\left \| \vec{v} \right \|=1" class="valign-m5" src="https://eli.thegreenplace.net/images/math/d68cb9ca8e7b5fd7fe4a7c4548ed5d98b63292eb.png" style="height: 19px;" />.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>Yes, <img alt="\widehat{y}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8cf4f01720ca8008752c182a8d3443aa2b174442.png" style="height: 18px;" /> is actually going in the opposite direction so it's <img alt="-\widehat{e_y}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/160a7a02c9645a3948812151b7a0cf38eb29c562.png" style="height: 20px;" />, but that really doesn't change anything. It was easier to draw :)</td></tr> </tbody> </table> </div> Broadcasting arrays in Numpy2015-12-22T06:00:00-08:002015-12-22T06:00:00-08:00Eli Benderskytag:eli.thegreenplace.net,2015-12-22:/2015/broadcasting-arrays-in-numpy/<p><em>Broadcasting</em> is Numpy's terminology for performing mathematical operations between arrays with different shapes. This article will explain why broadcasting is useful, how to use it and touch upon some of its performance implications.</p> <div class="section" id="motivating-example"> <h2>Motivating example</h2> <p>Say we have a large data set; each datum is a list of parameters. In …</p></div><p><em>Broadcasting</em> is Numpy's terminology for performing mathematical operations between arrays with different shapes. This article will explain why broadcasting is useful, how to use it and touch upon some of its performance implications.</p> <div class="section" id="motivating-example"> <h2>Motivating example</h2> <p>Say we have a large data set; each datum is a list of parameters. In Numpy terms, we have a 2-D array, where each row is a datum and the number of rows is the size of the data set. Suppose we want to apply some sort of scaling to all these data - every parameter gets its own scaling factor; in other words, every parameter is multiplied by some factor.</p> <p>Just to have something tangible to think about, let's count calories in foods using a macro-nutrient breakdown. Roughly put, the caloric parts of food are made of fats (9 calories per gram), protein (4 calories per gram) and carbs (4 calories per gram). So if we list some foods (our data), and for each food list its macro-nutrient breakdown (parameters), we can then multiply each nutrient by its caloric value (apply scaling) to compute the caloric breakdown of each food item <a class="footnote-reference" href="#id6" id="id1"></a>:</p> <img alt="Calories macros" class="align-center" src="https://eli.thegreenplace.net/images/2015/cal-data.png" /> <p>With this transformation, we can now compute all kinds of useful information. For example, what is the total number of calories in some food. Or, given a breakdown of my dinner - how much calories did I get from protein. And so on.</p> <p>Let's see a naive way of producing this computation with Numpy:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">65</span><span class="p">]:</span> <span class="n">macros</span> <span class="o">=</span> <span class="n">array</span><span class="p">([</span> <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">2.5</span><span class="p">,</span> <span class="mf">3.5</span><span class="p">],</span> <span class="p">[</span><span class="mf">2.9</span><span class="p">,</span> <span class="mf">27.5</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">,</span> <span class="mf">23.9</span><span class="p">],</span> <span class="p">[</span><span class="mf">14.4</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mf">2.3</span><span class="p">]])</span> <span class="c1"># Create a new array filled with zeros, of the same shape as macros.</span> <span class="n">In</span> <span class="p">[</span><span class="mi">67</span><span class="p">]:</span> <span class="n">result</span> <span class="o">=</span> <span class="n">zeros_like</span><span class="p">(</span><span class="n">macros</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">69</span><span class="p">]:</span> <span class="n">cal_per_macro</span> <span class="o">=</span> <span class="n">array</span><span class="p">([</span><span class="mi">9</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span> <span class="c1"># Now multiply each row of macros by cal_per_macro. In Numpy, * is</span> <span class="c1"># element-wise multiplication between two arrays.</span> <span class="n">In</span> <span class="p">[</span><span class="mi">70</span><span class="p">]:</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">xrange</span><span class="p">(</span><span class="n">macros</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="o">....</span><span class="p">:</span> <span class="n">result</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">macros</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">cal_per_macro</span> <span class="o">....</span><span class="p">:</span> <span class="n">In</span> <span class="p">[</span><span class="mi">71</span><span class="p">]:</span> <span class="n">result</span> <span class="n">Out</span><span class="p">[</span><span class="mi">71</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mf">2.7</span><span class="p">,</span> <span class="mf">10.</span> <span class="p">,</span> <span class="mf">14.</span> <span class="p">],</span> <span class="p">[</span> <span class="mf">26.1</span><span class="p">,</span> <span class="mf">110.</span> <span class="p">,</span> <span class="mf">0.</span> <span class="p">],</span> <span class="p">[</span> <span class="mf">3.6</span><span class="p">,</span> <span class="mf">5.2</span><span class="p">,</span> <span class="mf">95.6</span><span class="p">],</span> <span class="p">[</span> <span class="mf">129.6</span><span class="p">,</span> <span class="mf">24.</span> <span class="p">,</span> <span class="mf">9.2</span><span class="p">]])</span> </pre></div> <p>This is a reasonable approach when coding in a low-level programming language: allocate the output, loop over input performing some operation, write result into output. In Numpy, however, this is fairly bad for performance because the looping is done in (slow) Python code instead of internally by Numpy in (fast) C code.</p> <p>Since element-wise operators like <tt class="docutils literal">*</tt> work on arbitrary shapes, a better way would be to delegate all the looping to Numpy, by &quot;stretching&quot; the <tt class="docutils literal">cal_per_macro</tt> array vertically and then performing element-wise multiplication with <tt class="docutils literal">macros</tt>; this moves the per-row loop from above into Numpy itself, where it can run much more efficiently:</p> <div class="highlight"><pre><span></span><span class="c1"># Use the &#39;tile&#39; function to replicate cal_per_macro over the number</span> <span class="c1"># of rows &#39;macros&#39; has (rows is the first element of the shape tuple for</span> <span class="c1"># a 2-D array).</span> <span class="n">In</span> <span class="p">[</span><span class="mi">72</span><span class="p">]:</span> <span class="n">cal_per_macro_stretch</span> <span class="o">=</span> <span class="n">tile</span><span class="p">(</span><span class="n">cal_per_macro</span><span class="p">,</span> <span class="p">(</span><span class="n">macros</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span> <span class="n">In</span> <span class="p">[</span><span class="mi">73</span><span class="p">]:</span> <span class="n">cal_per_macro_stretch</span> <span class="n">Out</span><span class="p">[</span><span class="mi">73</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span><span class="mi">9</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">9</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">9</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">9</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">]])</span> <span class="n">In</span> <span class="p">[</span><span class="mi">74</span><span class="p">]:</span> <span class="n">macros</span> <span class="o">*</span> <span class="n">cal_per_macro_stretch</span> <span class="n">Out</span><span class="p">[</span><span class="mi">74</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mf">2.7</span><span class="p">,</span> <span class="mf">10.</span> <span class="p">,</span> <span class="mf">14.</span> <span class="p">],</span> <span class="p">[</span> <span class="mf">26.1</span><span class="p">,</span> <span class="mf">110.</span> <span class="p">,</span> <span class="mf">0.</span> <span class="p">],</span> <span class="p">[</span> <span class="mf">3.6</span><span class="p">,</span> <span class="mf">5.2</span><span class="p">,</span> <span class="mf">95.6</span><span class="p">],</span> <span class="p">[</span> <span class="mf">129.6</span><span class="p">,</span> <span class="mf">24.</span> <span class="p">,</span> <span class="mf">9.2</span><span class="p">]])</span> </pre></div> <p>Nice, it's shorter too. And much, much faster! To measure the speed I created a large random data set, with 1 million rows of 10 parameters each. The loop-in-Python method takes ~2.3 seconds to churn through it. The stretching method takes 30 <em>milliseconds</em>, a ~75x speedup.</p> <p>And now, finally, comes the interesting part. You see, the operation we just did - stretching one array so that its shape matches that of another and then applying some element-wise operation between them - is actually pretty common. This often happens when we want to take a lower-dimensional array and use it to perform a computation along some axis of a higher-dimensional array. In fact, when taken to the extreme this is exactly what happens when we perform an operation between an array and a scalar - the scalar is <em>stretched</em> across the whole array so that the element-wise operation gets the same scalar value for each element it computes.</p> <p>Numpy generalizes this concept into <em>broadcasting</em> - a set of rules that permit element-wise computations between arrays of different shapes, as long as some constraints apply. We'll discuss the actual constraints later, but for the case at hand a simple example will suffice: our original <tt class="docutils literal">macros</tt> array is 4x3 (4 rows by 3 columns). <tt class="docutils literal">cal_per_macro</tt> is a 3-element array. Since its length matches the number of columns in <tt class="docutils literal">macros</tt>, it's pretty natural to apply some operation between <tt class="docutils literal">cal_per_macro</tt> and every row of <tt class="docutils literal">macros</tt> - each row of <tt class="docutils literal">macros</tt> has the exact same size as <tt class="docutils literal">cal_per_macro</tt>, so the element-wise operation makes perfect sense.</p> <p>Incidentally, this lets Numpy achieve two separate goals - usefulness as well as more consistent and general semantics. Binary operators like <tt class="docutils literal">*</tt> are element-wise, but what happens when we apply them between arrays of different shapes? Should it work or should it be rejected? If it works, how should it work? Broadcasting defines the semantics of these operations.</p> <p>Back to our example. Here's yet another way to compute the result data:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">75</span><span class="p">]:</span> <span class="n">macros</span> <span class="o">*</span> <span class="n">cal_per_macro</span> <span class="n">Out</span><span class="p">[</span><span class="mi">75</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mf">2.7</span><span class="p">,</span> <span class="mf">10.</span> <span class="p">,</span> <span class="mf">14.</span> <span class="p">],</span> <span class="p">[</span> <span class="mf">26.1</span><span class="p">,</span> <span class="mf">110.</span> <span class="p">,</span> <span class="mf">0.</span> <span class="p">],</span> <span class="p">[</span> <span class="mf">3.6</span><span class="p">,</span> <span class="mf">5.2</span><span class="p">,</span> <span class="mf">95.6</span><span class="p">],</span> <span class="p">[</span> <span class="mf">129.6</span><span class="p">,</span> <span class="mf">24.</span> <span class="p">,</span> <span class="mf">9.2</span><span class="p">]])</span> </pre></div> <p>Simple and elegant, and the fastest approach to boot <a class="footnote-reference" href="#id7" id="id2"></a>.</p> </div> <div class="section" id="defining-broadcasting"> <h2>Defining broadcasting</h2> <p>Broadcasting is often described as an operation between a &quot;smaller&quot; and a &quot;larger&quot; array. This doesn't necessarily have to be the case, as broadcasting applies also to arrays of the same size, though with different shapes. Therefore, I believe the following definition of broadcasting is the most useful one.</p> <p>Element-wise operations on arrays are only valid when the arrays' shapes are either equal or compatible. The equal shapes case is trivial - this is the stretched array from the example above. What does &quot;compatible&quot; mean, though?</p> <p>To determine if two shapes are compatible, Numpy compares their dimensions, starting with the trailing ones and working its way backwards <a class="footnote-reference" href="#id8" id="id3"></a>. If two dimensions are equal, or if one of them equals 1, the comparison continues. Otherwise, you'll see a <tt class="docutils literal">ValueError</tt> raised (saying something like &quot;operands could not be broadcast together with shapes ...&quot;).</p> <p>When one of the shapes runs out of dimensions (because it has less dimensions than the other shape), Numpy will use 1 in the comparison process until the other shape's dimensions run out as well.</p> <p>Once Numpy determines that two shapes are compatible, the shape of the result is simply the maximum of the two shapes' sizes in each dimension.</p> <p>Put a little bit more formally, here's a pseudo-algorithm:</p> <div class="highlight"><pre><span></span>Inputs: array A with m dimensions; array B with n dimensions p = max(m, n) if m &lt; p: left-pad A&#39;s shape with 1s until it also has p dimensions else if n &lt; p: left-pad B&#39;s shape with 1s until is also has p dimensions result_dims = new list with p elements for i in p-1 ... 0: A_dim_i = A.shape[i] B_dim_i = B.shape[i] if A_dim_i != 1 and B_dim_i != 1 and A_dim_i != B_dim_i: raise ValueError(&quot;could not broadcast&quot;) else: result_dims[i] = max(A_dim_i, B_dim_i) </pre></div> </div> <div class="section" id="examples"> <h2>Examples</h2> <p>The definition above is precise and complete; to get a feel for it, we'll need a few examples.</p> <p>I'm using the Numpy convention of describing shapes as tuples. <tt class="docutils literal">macros</tt> is a 4-by-3 array, meaning that it has 4 rows with 3 columns each, or 4x3. The Numpy way of describing the shape of <tt class="docutils literal">macros</tt> is <tt class="docutils literal">(4, 3)</tt>:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">80</span><span class="p">]:</span> <span class="n">macros</span><span class="o">.</span><span class="n">shape</span> <span class="n">Out</span><span class="p">[</span><span class="mi">80</span><span class="p">]:</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> </pre></div> <p>When we computed the caloric table using broadcasting, what we did was an operation between <tt class="docutils literal">macros</tt> - a <tt class="docutils literal">(4, 3)</tt> array, and <tt class="docutils literal">cal_per_macro</tt>, a <tt class="docutils literal">(3,)</tt> array <a class="footnote-reference" href="#id9" id="id4"></a>. Therefore, following the broadcasting rules outlined above, the shape <tt class="docutils literal">(3,)</tt> is left-padded with 1 to make comparison with <tt class="docutils literal">(4, 3)</tt> possible. The shapes are then deemed compatible and the result shape is <tt class="docutils literal">(4, 3)</tt>, which is exactly what we observed.</p> <p>Schematically:</p> <div class="highlight"><pre><span></span>(4, 3) (4, 3) == padding ==&gt; == result ==&gt; (4, 3) (3,) (1, 3) </pre></div> <p>Here's another example, broadcasting between a 3-D and a 1-D array:</p> <div class="highlight"><pre><span></span>(3,) (1, 1, 3) == padding ==&gt; == result ==&gt; (5, 4, 3) (5, 4, 3) (5, 4, 3) </pre></div> <p>Note, however, that only left-padding with 1s is allowed. Therefore:</p> <div class="highlight"><pre><span></span>(5,) (1, 1, 5) == padding ==&gt; ==&gt; error (5 != 3) (5, 4, 3) (5, 4, 3) </pre></div> <p>Theoretically, had the broadcasting rules been less rigid - we could say that this broadcasting is valid if we <em>right-pad</em> <tt class="docutils literal">(5,)</tt> with 1s. However, this is not how the rules are defined - therefore these shapes are incompatible.</p> <p>Broadcasting is valid between higher-dimensional arrays too:</p> <div class="highlight"><pre><span></span>(5, 4, 3) (1, 5, 4, 3) == padding ==&gt; == result ==&gt; (6, 5, 4, 3) (6, 5, 4, 3) (6, 5, 4, 3) </pre></div> <p>Also, in the beginning of the article I mentioned that broadcasting does not necessarily occur between arrays of different number of dimensions. It's perfectly valid to broadcast arrays with the same number of dimensions, as long as they are compatible:</p> <div class="highlight"><pre><span></span>(5, 4, 1) == no padding needed ==&gt; result ==&gt; (5, 4, 3) (5, 1, 3) </pre></div> <p>Finally, scalars are treated specially as 1-dimensional arrays with size 1:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">93</span><span class="p">]:</span> <span class="n">ones</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="o">+</span> <span class="mi">1</span> <span class="n">Out</span><span class="p">[</span><span class="mi">93</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">],</span> <span class="p">[</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">],</span> <span class="p">[</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">],</span> <span class="p">[</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">]])</span> <span class="c1"># Is the same as:</span> <span class="n">In</span> <span class="p">[</span><span class="mi">94</span><span class="p">]:</span> <span class="n">one</span> <span class="o">=</span> <span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">In</span> <span class="p">[</span><span class="mi">95</span><span class="p">]:</span> <span class="n">one</span> <span class="n">Out</span><span class="p">[</span><span class="mi">95</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mf">1.</span><span class="p">]])</span> <span class="n">In</span> <span class="p">[</span><span class="mi">96</span><span class="p">]:</span> <span class="n">ones</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="o">+</span> <span class="n">one</span> <span class="n">Out</span><span class="p">[</span><span class="mi">96</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">],</span> <span class="p">[</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">],</span> <span class="p">[</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">],</span> <span class="p">[</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">]])</span> </pre></div> </div> <div class="section" id="explicit-broadcasting-with-numpy-broadcast"> <h2>Explicit broadcasting with numpy.broadcast</h2> <p>In the examples above, we've seen how Numpy employs broadcasting behind the scenes to match together arrays that have compatible, but not similar, shapes. We can also ask Numpy for a more explicit exposure of broadcasting, using the <tt class="docutils literal">numpy.broadcast</tt> class:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">103</span><span class="p">]:</span> <span class="n">macros</span><span class="o">.</span><span class="n">shape</span> <span class="n">Out</span><span class="p">[</span><span class="mi">103</span><span class="p">]:</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">104</span><span class="p">]:</span> <span class="n">cal_per_macro</span><span class="o">.</span><span class="n">shape</span> <span class="n">Out</span><span class="p">[</span><span class="mi">104</span><span class="p">]:</span> <span class="p">(</span><span class="mi">3</span><span class="p">,)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">105</span><span class="p">]:</span> <span class="n">b</span> <span class="o">=</span> <span class="n">broadcast</span><span class="p">(</span><span class="n">macros</span><span class="p">,</span> <span class="n">cal_per_macro</span><span class="p">)</span> </pre></div> <p>Now <tt class="docutils literal">b</tt> is an object of type <tt class="docutils literal">numpy.broadcast</tt>, and we can query it for the result shape of broadcasting, as well as use it to iterate over pairs of elements from the input arrays in the order matched by broadcasting them:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">108</span><span class="p">]:</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span> <span class="n">Out</span><span class="p">[</span><span class="mi">108</span><span class="p">]:</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">120</span><span class="p">]:</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">b</span><span class="p">:</span> <span class="k">print</span> <span class="s1">&#39;{0}: {1} {2}&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">index</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">)</span> <span class="o">.....</span><span class="p">:</span> <span class="mi">1</span><span class="p">:</span> <span class="mf">0.3</span> <span class="mi">9</span> <span class="mi">2</span><span class="p">:</span> <span class="mf">2.5</span> <span class="mi">4</span> <span class="mi">3</span><span class="p">:</span> <span class="mf">3.5</span> <span class="mi">4</span> <span class="mi">4</span><span class="p">:</span> <span class="mf">2.9</span> <span class="mi">9</span> <span class="mi">5</span><span class="p">:</span> <span class="mf">27.5</span> <span class="mi">4</span> <span class="mi">6</span><span class="p">:</span> <span class="mf">0.0</span> <span class="mi">4</span> <span class="mi">7</span><span class="p">:</span> <span class="mf">0.4</span> <span class="mi">9</span> <span class="mi">8</span><span class="p">:</span> <span class="mf">1.3</span> <span class="mi">4</span> <span class="mi">9</span><span class="p">:</span> <span class="mf">23.9</span> <span class="mi">4</span> <span class="mi">10</span><span class="p">:</span> <span class="mf">14.4</span> <span class="mi">9</span> <span class="mi">11</span><span class="p">:</span> <span class="mf">6.0</span> <span class="mi">4</span> <span class="mi">12</span><span class="p">:</span> <span class="mf">2.3</span> <span class="mi">4</span> </pre></div> <p>This lets us see very explicitly how the &quot;stretching&quot; of <tt class="docutils literal">cal_per_macro</tt> is done to match the shape of <tt class="docutils literal">macros</tt>. So if you ever want to perform some complex computation on two arrays whose shapes aren't similar but compatible, and you want to use broadcasting, <tt class="docutils literal">numpy.broadcast</tt> can help.</p> </div> <div class="section" id="computing-outer-products-with-broadcasting"> <h2>Computing outer products with broadcasting</h2> <p>As another cool example of broadcasting rules, consider the outer product of two vectors.</p> <p>In linear algebra, it is customary to deal with column vectors by default, using a transpose for row vector. Therefore, given two vectors <img alt="x" class="valign-0" src="https://eli.thegreenplace.net/images/math/11f6ad8ec52a2984abaafd7c3b516503785c2072.png" style="height: 8px;" /> and <img alt="y" class="valign-m4" src="https://eli.thegreenplace.net/images/math/95cb0bfd2977c761298d9624e4b4d4c72a39974a.png" style="height: 12px;" />, their &quot;outer product&quot; is defined as <img alt="xy^T" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2903d8355c674ee7c2e06b7ba1940714f100243e.png" style="height: 19px;" />. Treating <img alt="x" class="valign-0" src="https://eli.thegreenplace.net/images/math/11f6ad8ec52a2984abaafd7c3b516503785c2072.png" style="height: 8px;" /> and <img alt="y" class="valign-m4" src="https://eli.thegreenplace.net/images/math/95cb0bfd2977c761298d9624e4b4d4c72a39974a.png" style="height: 12px;" /> as Nx1 matrices this matrix multiplication results in:</p> <img alt="$xy^T=\begin{bmatrix} x_1 \\ x_2 \\ ... \\ x_N \end{bmatrix}[y_1, y_2, ..., y_N]= \begin{bmatrix} x_1y_1 &amp;amp; x_1y_2 &amp;amp; \cdots &amp;amp; x_1y_N \\ x_2y_1 &amp;amp; x_2y_2 &amp;amp; \cdots &amp;amp; x_2y_N \\ \vdots\\ x_Ny_1 &amp;amp; x_Ny_2 &amp;amp; \cdots &amp;amp; x_Ny_N \\ \end{bmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/d4c1d88533494c698a0628fe472bb099b2195a13.png" style="height: 97px;" /> <p>How can we implement this in Numpy? Note that the shape of the row vector is <tt class="docutils literal">(1, N)</tt> <a class="footnote-reference" href="#id10" id="id5"></a>. The shape of the column vector is <tt class="docutils literal">(N, 1)</tt>. Therefore, if we apply an element-wise operation between them, broadcasting will kick in, find that the shapes are compatible and the result shape is <tt class="docutils literal">(N, N)</tt>. The row vector is going to be &quot;stretched&quot; over N rows and the column vector over N columns - so we'll get the outer product! Here's an interactive session that demonstrates this:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">137</span><span class="p">]:</span> <span class="n">ten</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">138</span><span class="p">]:</span> <span class="n">ten</span> <span class="n">Out</span><span class="p">[</span><span class="mi">138</span><span class="p">]:</span> <span class="n">array</span><span class="p">([</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">10</span><span class="p">])</span> <span class="n">In</span> <span class="p">[</span><span class="mi">139</span><span class="p">]:</span> <span class="n">ten</span><span class="o">.</span><span class="n">shape</span> <span class="n">Out</span><span class="p">[</span><span class="mi">139</span><span class="p">]:</span> <span class="p">(</span><span class="mi">10</span><span class="p">,)</span> <span class="c1"># Using Numpy&#39;s reshape method to convert the row vector into a</span> <span class="c1"># column vector.</span> <span class="n">In</span> <span class="p">[</span><span class="mi">140</span><span class="p">]:</span> <span class="n">ten</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">Out</span><span class="p">[</span><span class="mi">140</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span> <span class="mi">4</span><span class="p">],</span> <span class="p">[</span> <span class="mi">5</span><span class="p">],</span> <span class="p">[</span> <span class="mi">6</span><span class="p">],</span> <span class="p">[</span> <span class="mi">7</span><span class="p">],</span> <span class="p">[</span> <span class="mi">8</span><span class="p">],</span> <span class="p">[</span> <span class="mi">9</span><span class="p">],</span> <span class="p">[</span><span class="mi">10</span><span class="p">]])</span> <span class="n">In</span> <span class="p">[</span><span class="mi">141</span><span class="p">]:</span> <span class="n">ten</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">shape</span> <span class="n">Out</span><span class="p">[</span><span class="mi">141</span><span class="p">]:</span> <span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># Let&#39;s see what the &#39;broadcast&#39; class tells us about the resulting</span> <span class="c1"># shape of broadcasting ten and its column-vector version</span> <span class="n">In</span> <span class="p">[</span><span class="mi">142</span><span class="p">]:</span> <span class="n">broadcast</span><span class="p">(</span><span class="n">ten</span><span class="p">,</span> <span class="n">ten</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span><span class="o">.</span><span class="n">shape</span> <span class="n">Out</span><span class="p">[</span><span class="mi">142</span><span class="p">]:</span> <span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">143</span><span class="p">]:</span> <span class="n">ten</span> <span class="o">*</span> <span class="n">ten</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">Out</span><span class="p">[</span><span class="mi">143</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">10</span><span class="p">],</span> <span class="p">[</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">14</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">18</span><span class="p">,</span> <span class="mi">20</span><span class="p">],</span> <span class="p">[</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">15</span><span class="p">,</span> <span class="mi">18</span><span class="p">,</span> <span class="mi">21</span><span class="p">,</span> <span class="mi">24</span><span class="p">,</span> <span class="mi">27</span><span class="p">,</span> <span class="mi">30</span><span class="p">],</span> <span class="p">[</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">24</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">36</span><span class="p">,</span> <span class="mi">40</span><span class="p">],</span> <span class="p">[</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">15</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">25</span><span class="p">,</span> <span class="mi">30</span><span class="p">,</span> <span class="mi">35</span><span class="p">,</span> <span class="mi">40</span><span class="p">,</span> <span class="mi">45</span><span class="p">,</span> <span class="mi">50</span><span class="p">],</span> <span class="p">[</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">18</span><span class="p">,</span> <span class="mi">24</span><span class="p">,</span> <span class="mi">30</span><span class="p">,</span> <span class="mi">36</span><span class="p">,</span> <span class="mi">42</span><span class="p">,</span> <span class="mi">48</span><span class="p">,</span> <span class="mi">54</span><span class="p">,</span> <span class="mi">60</span><span class="p">],</span> <span class="p">[</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">14</span><span class="p">,</span> <span class="mi">21</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">35</span><span class="p">,</span> <span class="mi">42</span><span class="p">,</span> <span class="mi">49</span><span class="p">,</span> <span class="mi">56</span><span class="p">,</span> <span class="mi">63</span><span class="p">,</span> <span class="mi">70</span><span class="p">],</span> <span class="p">[</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">24</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">40</span><span class="p">,</span> <span class="mi">48</span><span class="p">,</span> <span class="mi">56</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">72</span><span class="p">,</span> <span class="mi">80</span><span class="p">],</span> <span class="p">[</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">18</span><span class="p">,</span> <span class="mi">27</span><span class="p">,</span> <span class="mi">36</span><span class="p">,</span> <span class="mi">45</span><span class="p">,</span> <span class="mi">54</span><span class="p">,</span> <span class="mi">63</span><span class="p">,</span> <span class="mi">72</span><span class="p">,</span> <span class="mi">81</span><span class="p">,</span> <span class="mi">90</span><span class="p">],</span> <span class="p">[</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">30</span><span class="p">,</span> <span class="mi">40</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">60</span><span class="p">,</span> <span class="mi">70</span><span class="p">,</span> <span class="mi">80</span><span class="p">,</span> <span class="mi">90</span><span class="p">,</span> <span class="mi">100</span><span class="p">]])</span> </pre></div> <p>The output should be familiar to anyone who's finished grade school, of course.</p> <p>Interestingly, even though Numpy has a function named <tt class="docutils literal">outer</tt> that computes the outer product between two vectors, my timings show that at least in this particular case broadcasting multiplication as shown above is more than twice as fast - so be sure to always measure.</p> </div> <div class="section" id="use-the-right-tool-for-the-job"> <h2>Use the right tool for the job</h2> <p>I'll end this article with another educational example that demonstrates a problem that can be solved in two different ways, one of which is much more efficient because it uses the right tool for the job.</p> <p>Back to the original example of counting calories in foods. Suppose I just want to know how many calories each serving of food has (total from fats, protein and carbs).</p> <p>Given the <tt class="docutils literal">macros</tt> data and a <tt class="docutils literal">cal_per_macro</tt> breakdown, we can use the broadcasting multiplication as seen before to compute a &quot;calories per macro&quot; table efficiently, for each food. All that's left is to add together the columns in each row into a sum - this will be the number of calories per serving in that food:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">160</span><span class="p">]:</span> <span class="n">macros</span> <span class="o">*</span> <span class="n">cal_per_macro</span> <span class="n">Out</span><span class="p">[</span><span class="mi">160</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mf">2.7</span><span class="p">,</span> <span class="mf">10.</span> <span class="p">,</span> <span class="mf">14.</span> <span class="p">],</span> <span class="p">[</span> <span class="mf">26.1</span><span class="p">,</span> <span class="mf">110.</span> <span class="p">,</span> <span class="mf">0.</span> <span class="p">],</span> <span class="p">[</span> <span class="mf">3.6</span><span class="p">,</span> <span class="mf">5.2</span><span class="p">,</span> <span class="mf">95.6</span><span class="p">],</span> <span class="p">[</span> <span class="mf">129.6</span><span class="p">,</span> <span class="mf">24.</span> <span class="p">,</span> <span class="mf">9.2</span><span class="p">]])</span> <span class="n">In</span> <span class="p">[</span><span class="mi">161</span><span class="p">]:</span> <span class="nb">sum</span><span class="p">(</span><span class="n">macros</span> <span class="o">*</span> <span class="n">cal_per_macro</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">Out</span><span class="p">[</span><span class="mi">161</span><span class="p">]:</span> <span class="n">array</span><span class="p">([</span> <span class="mf">26.7</span><span class="p">,</span> <span class="mf">136.1</span><span class="p">,</span> <span class="mf">104.4</span><span class="p">,</span> <span class="mf">162.8</span><span class="p">])</span> </pre></div> <p>Here I'm using the <tt class="docutils literal">axis</tt> parameter of the <tt class="docutils literal">sum</tt> function to tell Numpy to sum only over axis 1 (columns), rather than computing the sum of the whole multi-dimensional array.</p> <p>Looks easy. But is there a better way? Indeed, if you think for a moment about the operation we've just performed, a natural solution emerges. We've taken a vector (<tt class="docutils literal">cal_per_macro</tt>), element-wise multiplied it with each row of <tt class="docutils literal">macros</tt> and then added up the results. In other words, we've computed the dot-product of <tt class="docutils literal">cal_per_macro</tt> with each row of <tt class="docutils literal">macros</tt>. In linear algebra there's already an operation that will do this for the whole input table: matrix multiplication. You can work out the details on paper, but it's easy to see that multiplying the matrix <tt class="docutils literal">macros</tt> on the right by <tt class="docutils literal">cal_per_macro</tt> as a column vector, we get the same result. Let's check:</p> <div class="highlight"><pre><span></span><span class="c1"># Create a column vector out of cal_per_macro</span> <span class="n">In</span> <span class="p">[</span><span class="mi">168</span><span class="p">]:</span> <span class="n">cal_per_macro_col_vec</span> <span class="o">=</span> <span class="n">cal_per_macro</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># Use the &#39;dot&#39; function for matrix multiplication. Starting with Python 3.5,</span> <span class="c1"># we&#39;ll be able to use an operator instead: macros @ cal_per_macro_col_vec</span> <span class="n">In</span> <span class="p">[</span><span class="mi">169</span><span class="p">]:</span> <span class="n">macros</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">cal_per_macro_col_vec</span><span class="p">)</span> <span class="n">Out</span><span class="p">[</span><span class="mi">169</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span> <span class="mf">26.7</span><span class="p">],</span> <span class="p">[</span> <span class="mf">136.1</span><span class="p">],</span> <span class="p">[</span> <span class="mf">104.4</span><span class="p">],</span> <span class="p">[</span> <span class="mf">162.8</span><span class="p">]])</span> </pre></div> <p>On my machine, using <tt class="docutils literal">dot</tt> is 4-5x faster than composing <tt class="docutils literal">sum</tt> with element-wise multiplication. Even though the latter is implemented in optimized C code in the guts of Numpy, it has the disadvantage of moving too much data around - computing the intermediate matrix representing the broadcasted multiplication is not really necessary for the end product. <tt class="docutils literal">dot</tt>, on the other hand, performs the operation in one step using a highly optimized <a class="reference external" href="https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms">BLAS routine</a>.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>For the pedantic: I'm taking these numbers from <a class="reference external" href="http://www.calorieking.com">http://www.calorieking.com</a>, and I subtract the fiber from total carbs because it doesn't count for the calories.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id7" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>About 30% faster than the &quot;stretching&quot; method. This is mostly due to the creation of the <tt class="docutils literal"><span class="pre">..._stretch</span></tt> array, which takes time. Once the stretched array is there, the broadcasting method is ~5% faster - this difference being due to a better use of memory (we don't <em>really</em> have to create the whole stretched array, do we? It's just repeating the same data so why waste so much memory?)</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>For the shape <tt class="docutils literal">(4, 3, 2)</tt> the trailing dimension is 2, and working from 2 &quot;backwards&quot; produces: 2, then 3, then 4.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id9" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>Following the usual Python convention, single-element tuples also have a comma, which helps us destinguish them from other entities.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id5"></a></td><td>More precisely, <tt class="docutils literal">(1, N)</tt> is the shape of a 1-by-N matrix (matrix with one row and N columns). An actual row vector is just a 1D array with the single-dimension shape <tt class="docutils literal">(10,)</tt>. For most purposes, the two are equivalent in Numpy.</td></tr> </tbody> </table> </div> Memory layout of multi-dimensional arrays2015-09-26T06:06:00-07:002015-09-26T06:06:00-07:00Eli Benderskytag:eli.thegreenplace.net,2015-09-26:/2015/memory-layout-of-multi-dimensional-arrays/<p>When working with multi-dimensional arrays, one important decision programmers have to make fairly early on in the project is what memory layout to use for storing the data, and how to access such data in the most efficient manner. Since computer memory is inherently linear - a one-dimensional structure, mapping multi-dimensional …</p><p>When working with multi-dimensional arrays, one important decision programmers have to make fairly early on in the project is what memory layout to use for storing the data, and how to access such data in the most efficient manner. Since computer memory is inherently linear - a one-dimensional structure, mapping multi-dimensional data on it can be done in several ways. In this article I want to examine this topic in detail, talking about the various memory layouts available and their effect on the performance of the code.</p> <div class="section" id="row-major-vs-column-major"> <h2>Row-major vs. column-major</h2> <p>By far the two most common memory layouts for multi-dimensional array data are <em>row-major</em> and <em>column-major</em>.</p> <p>When working with 2D arrays (matrices), row-major vs. column-major are easy to describe. The row-major layout of a matrix puts the first row in contiguous memory, then the second row right after it, then the third, and so on. Column-major layout puts the first column in contiguous memory, then the second, etc.</p> <p>Higher dimensions are a bit more difficult to visualize, so let's start with some diagrams showing how 2D layouts work.</p> </div> <div class="section" id="d-row-major"> <h2>2D row-major</h2> <p>First, some notes on the nomenclature of this article. Computer memory will be represented as a linear array with low addresses on the left and high addresses on the right. Also, we're going to use programmer notation for matrices: rows and columns start with zero, at the top-left corner of the matrix. Row indices go over rows from top to bottom; column indices go over columns from left to right.</p> <p>As mentioned above, in row-major layout, the first row of the matrix is placed in contiguous memory, then the second, and so on:</p> <img alt="Row major 2D" class="align-center" src="https://eli.thegreenplace.net/images/2015/row-major-2D.png" /> <p>Another way to describe row-major layout is that <em>column indices change the fastest</em>. This should be obvious by looking at the linear layout at the bottom of the diagram. If you read the element index pairs from left to right, you'll notice that the column index changes all the time, and the row index only changes once per row.</p> <p>For programmers, another important observation is that given a row index <img alt="i_{row}" class="valign-m3" src="https://eli.thegreenplace.net/images/math/256e11c46808f68dec43d4a7b0e271f05d697785.png" style="height: 15px;" /> and a column index <img alt="i_{col}" class="valign-m3" src="https://eli.thegreenplace.net/images/math/e0ebbfb8bc0af1c2247c6c3f9119be855fed933d.png" style="height: 15px;" />, the offset of the element they denote in the linear representation is:</p> <img alt="$offset=i_{row}*NCOLS+i_{col}$" class="align-center" src="https://eli.thegreenplace.net/images/math/9161443cbcdff4891bbda9b82127634630ad8952.png" style="height: 16px;" /> <p>Where NCOLS is the number of columns per row in the matrix. It's easy to see this equation fits the linear layout in the diagram shown above.</p> </div> <div class="section" id="d-column-major"> <h2>2D column-major</h2> <p>Describing column-major 2D layout is just taking the description of row-major and replacing every appearance of &quot;row&quot; by &quot;column&quot; and vice versa. The first column of the matrix is placed in contiguous memory, then the second, and so on:</p> <img alt="Column major 2D" class="align-center" src="https://eli.thegreenplace.net/images/2015/column-major-2D.png" /> <p>In column-major layout, <em>row indices change the fastest</em>. The offset of an element in column-major layout can be found using this equation:</p> <img alt="$offset=i_{col}*NROWS+i_{row}$" class="align-center" src="https://eli.thegreenplace.net/images/math/ab533f15375dcdb69e7affdd1a4c835e146b7751.png" style="height: 16px;" /> <p>Where NROWS is the number of rows per column in the matrix.</p> </div> <div class="section" id="beyond-2d-indexing-and-layout-of-n-dimensional-arrays"> <h2>Beyond 2D - indexing and layout of N-dimensional arrays</h2> <p>Even though matrices are the most common multi-dimensional arrays programmers deal with, they are by no means the only ones. The notation of multi-dimensional arrays is fully generalizable to more than 2 dimensions. These entities are commonly called &quot;N-D arrays&quot; or &quot;tensors&quot;.</p> <p>When we move to 3D and beyond, it's best to leave the row/column notation of matrices behind. This is because this notation doesn't easily translate to 3 dimensions due to a <a class="reference external" href="https://eli.thegreenplace.net/2014/meshgrids-and-disambiguating-rows-and-columns-from-cartesian-coordinates/">common confusion</a> between rows, columns and the Cartesian coordinate system. In 4 dimensions and above, we lose any purely-visual intuition to describe multi-dimensional entities anyway, so it's best to stick to a consistent mathematical notation instead.</p> <p>So let's talk about some arbitrary number of dimensions <em>d</em>, numbered from 1 to <em>d</em>. For each dimension <img alt="1\leq i\leq d" class="valign-m3" src="https://eli.thegreenplace.net/images/math/05446b54bd23d571898ab5f1ad448f7ca767f19a.png" style="height: 16px;" />, <img alt="N_i" class="valign-m3" src="https://eli.thegreenplace.net/images/math/855336587fa59262965cdb9a2a6114933586800b.png" style="height: 15px;" /> is the size of the dimension. Also, the index of an element in dimension <img alt="i" class="valign-0" src="https://eli.thegreenplace.net/images/math/042dc4512fa3d391c5170cf3aa61e6a638f84342.png" style="height: 12px;" /> is <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b05dd3722f57cd7ac250228f9a1aaf3af86311d.svg" style="height: 11px;" type="image/svg+xml">n_i</object>. For example, in the latest matrix diagram above (where column-layout is shown), we have <img alt="d=2" class="valign-0" src="https://eli.thegreenplace.net/images/math/8587fbaabf40db5bd2eb87f7ec6112beb7200253.png" style="height: 13px;" />. If we choose dimension 1 to be the row and dimension 2 to be the column, then <img alt="N_1=N_2=3" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2f780e2fb7cabe0948456a71d435b0a136de60f9.png" style="height: 16px;" />, and the element in the bottom-left corner of the matrix has <img alt="n_1=2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0fb7ec5c44ec7de842ea61803aa4c9aec6412770.png" style="height: 16px;" /> and <img alt="n_2=0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/7d5ac2db0b27193cce99b1fe091ef5ce84eee9a9.png" style="height: 15px;" />.</p> <p>In row-major layout of multi-dimensional arrays, the <em>last</em> index is the fastest changing. In case of matrices the last index is columns, so this is equivalent to the previous definition.</p> <p>Given a <img alt="d" class="valign-0" src="https://eli.thegreenplace.net/images/math/3c363836cf4e16666669a25da280a1865c2d2874.png" style="height: 13px;" />-dimensional array, with the notation shown above, we compute the memory location of an element from its indices as:</p> <img alt="$offset=n_d + N_d \cdot (n_{d-1} + N_{d-1} \cdot (n_{d-2} + N_{d-2} \cdot (\cdots + N_2 n_1)\cdots))) = \sum_{i=1}^d \left( \prod_{j=i+1}^d N_j \right) n_i$" class="align-center" src="https://eli.thegreenplace.net/images/math/032b2eb5714fa4457fd349eec8f775c3e75584cd.png" style="height: 65px;" /> <p>For a matrix, <img alt="d=2" class="valign-0" src="https://eli.thegreenplace.net/images/math/8587fbaabf40db5bd2eb87f7ec6112beb7200253.png" style="height: 13px;" />, this reduces to:</p> <img alt="$offset=n_2 + N_2 \cdot n_1$" class="align-center" src="https://eli.thegreenplace.net/images/math/c80b9d7877819556e5059016a2426e9727e0f949.png" style="height: 16px;" /> <p>Which is exactly the formula we've seen above for row-major layout, just using a slightly more formal notation.</p> <p>Similarly, in column-major layout of multi-dimensional arrays, the <em>first</em> index is the fastest changing. Given a <img alt="d" class="valign-0" src="https://eli.thegreenplace.net/images/math/3c363836cf4e16666669a25da280a1865c2d2874.png" style="height: 13px;" />-dimensional array, we compute the memory location of an element from its indices as:</p> <img alt="$offset=n_1 + N_1 \cdot (n_2 + N_2 \cdot (n_3 + N_3 \cdot (\cdots + N_{d-1} n_d)\cdots))) = \sum_{i=1}^d \left( \prod_{j=1}^{i-1} N_j \right) n_i$" class="align-center" src="https://eli.thegreenplace.net/images/math/35ece5a7b18c317a71e6914ff62c7dd7840952cf.png" style="height: 65px;" /> <p>And again, for a matrix with <img alt="d=2" class="valign-0" src="https://eli.thegreenplace.net/images/math/8587fbaabf40db5bd2eb87f7ec6112beb7200253.png" style="height: 13px;" /> this reduces to the familiar:</p> <img alt="$offset=n_1+N_1\cdot n_2$" class="align-center" src="https://eli.thegreenplace.net/images/math/c084b1f35e57a402567ddc8058ea346d574cd207.png" style="height: 16px;" /> </div> <div class="section" id="example-in-3d"> <h2>Example in 3D</h2> <p>Let's see how this works out in 3D, which we can still visualize. Assuming 3 dimensions: rows, columns and depth. The following diagram shows the memory layout of a 3D array with <img alt="N_1=N_2=N_3=3" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6275e16f8c4fe91d4591dedfec44bb859159bd4c.png" style="height: 16px;" />, in <em>row-major</em>:</p> <img alt="Row major 3D" class="align-center" src="https://eli.thegreenplace.net/images/2015/row-major-3D.png" /> <p>Note how the last dimension (depth, in this case) changes the fastest and the first (row) changes the slowest. The offset for a given element is:</p> <img alt="$offset=n_3+N_3*(n_2+N_2*n_1)$" class="align-center" src="https://eli.thegreenplace.net/images/math/3952a22345f3e71ecbf5b74899d875ca2b9035f2.png" style="height: 18px;" /> <p>For example, the offset of the element with indices 2,1,1 is 22.</p> <p>As an exercise, try to figure out how this array would be laid out in <em>column-major</em> order. But beware - there's a caveat! The term <em>column-major</em> may lead you to believe that columns are the slowest-changing index, but this is wrong. The <em>last</em> index is the slowest changing in column-major, and the last index here is depth, not columns. In fact, columns would be right in the middle in terms of change speed. This is exactly why in the discussion above I suggested dropping the row/column notation when going above 2D. In higher dimensions it becomes confusing, so it's best to refer to the relative change rate of the indices, since these are unambiguous.</p> <p>In fact, one could conceive a sort of hybrid (or &quot;mixed&quot;) layout where the second dimension changes faster than the first or the third. This would be neither row-major nor column-major, but in itself it's a consistent and perfectly valid layout that may benefit some applications. More details on why we would choose one layout over another are later in the article.</p> </div> <div class="section" id="history-fortran-vs-c"> <h2>History: Fortran vs. C</h2> <p>While knowing which layout a particular data set is using is critical for good performance, there's no single answer to the question which layout &quot;is better&quot; in general. It's not much different from the big-endian vs. little-endian debate; what's important is to pick up a consistent standard and stick to it. Unfortunately, as almost always happens in the world of computing, different programming languages and environments picked different standards.</p> <p>Among the programming languages still popular today, Fortran was definitely one of the pioneers. And Fortran (which is still very important for scientific computing) uses column-major layout. I read somewhere that the reason for this is that column vectors are more commonly used and considered &quot;canonical&quot; in linear algebra computations. Personally I don't buy this, but you can make your own judgement.</p> <p>A slew of modern languages follow Fortran's lead - Matlab, R, Julia, to name a few. One of the strongest reasons for this is that they want to use LAPACK - a fast Fortran library for linear algebra, so using Fortran's layout makes sense.</p> <p>On the other hand, C and C++ use row-major layout. Following their example are a few other popular languages such as Python, Pascal and Mathematica. Since multi-dimensional arrays are a first-class type in the C language, the standard defines the layout very explicitly in section 6.5.2.1 <a class="footnote-reference" href="#id4" id="id1"></a>.</p> <p>In fact, having the first index change the slowest and the last index change the fastest makes sense if you think about how multi-dimensional arrays in C are indexed.</p> <p>Given the declaration:</p> <div class="highlight"><pre><span></span><span class="kt">int</span> <span class="n">x</span><span class="p">[</span><span class="mi">3</span><span class="p">][</span><span class="mi">5</span><span class="p">];</span> </pre></div> <p>Then <tt class="docutils literal">x</tt> is an array of 3 elements, each of which is an array of 5 integers. <tt class="docutils literal">x</tt> is the address of the second array of 5 integers contained in <tt class="docutils literal">x</tt>, and <tt class="docutils literal"><span class="pre">x</span></tt> is the fifth integer of the second 5-integer array in <tt class="docutils literal">x</tt>. These indexing rules imply row-major layout.</p> <p>None of this is to say that C could not have chosen column-major layout. It could, but then its multi-dimensional array indexing rules would have to be different as well. The result could be just as consistent as what we have now.</p> <p>Moreover, since C lets you manipulate pointers, you can decide on the layout of data in your program by computing offsets into multi-dimensional arrays on your own. In fact, this is how most C programs are written.</p> </div> <div class="section" id="memory-layout-example-numpy"> <h2>Memory layout example - numpy</h2> <p>So far we've discussed memory layout purely conceptually - using diagrams and mathematical formulae for index computations. It's worthwhile to see a &quot;real&quot; example of how multi-dimensional arrays are stored in memory. For this purpose, the Numpy library of Python is a great tool since it supports both layout kinds and is easy to play with from an interactive shell.</p> <p>The <a class="reference external" href="http://docs.scipy.org/doc/numpy/reference/generated/numpy.array.html">numpy.array constructor</a> can be used to create multi-dimensional arrays. One of the parameters it accepts is <tt class="docutils literal">order</tt>, which is either &quot;C&quot; for C-style layout (row-major) or &quot;F&quot; for Fortran-style layout (column-major). &quot;C&quot; is the default. Let's see how this looks:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">42</span><span class="p">]:</span> <span class="n">ar2d</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">11</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">13</span><span class="p">],</span> <span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">40</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;uint8&#39;</span><span class="p">,</span> <span class="n">order</span><span class="o">=</span><span class="s1">&#39;C&#39;</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">43</span><span class="p">]:</span> <span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="nb">ord</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">ar2d</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="n">Out</span><span class="p">[</span><span class="mi">43</span><span class="p">]:</span> <span class="s1">&#39;1 2 3 11 12 13 10 20 40&#39;</span> </pre></div> <p>In &quot;C&quot; order, elements of rows are contiguous, as expected. Let's try Fortran layout now:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">44</span><span class="p">]:</span> <span class="n">ar2df</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">11</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">13</span><span class="p">],</span> <span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">40</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;uint8&#39;</span><span class="p">,</span> <span class="n">order</span><span class="o">=</span><span class="s1">&#39;F&#39;</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">45</span><span class="p">]:</span> <span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="nb">ord</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">ar2df</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="n">Out</span><span class="p">[</span><span class="mi">45</span><span class="p">]:</span> <span class="s1">&#39;1 11 10 2 12 20 3 13 40&#39;</span> </pre></div> <p>For a more complex example, let's encode the following 3D array as a <tt class="docutils literal">numpy.array</tt> and see how it's laid out:</p> <img alt="Numeric 3D array" class="align-center" src="https://eli.thegreenplace.net/images/2015/numeric-3D-mat.png" /> <p>This array has two rows (first dimension), 4 columns (second dimension) and depth 2 (third dimension). As a nested Python list, this is its representation:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">47</span><span class="p">]:</span> <span class="n">lst3d</span> <span class="o">=</span> <span class="p">[[[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">11</span><span class="p">],</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">12</span><span class="p">],</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">13</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">14</span><span class="p">]],</span> <span class="p">[[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">15</span><span class="p">],</span> <span class="p">[</span><span class="mi">6</span><span class="p">,</span> <span class="mi">16</span><span class="p">],</span> <span class="p">[</span><span class="mi">7</span><span class="p">,</span> <span class="mi">17</span><span class="p">],</span> <span class="p">[</span><span class="mi">8</span><span class="p">,</span> <span class="mi">18</span><span class="p">]]]</span> </pre></div> <p>And the memory layout, in both C and Fortran orders:</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">50</span><span class="p">]:</span> <span class="n">ar3d</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">lst3d</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;uint8&#39;</span><span class="p">,</span> <span class="n">order</span><span class="o">=</span><span class="s1">&#39;C&#39;</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">51</span><span class="p">]:</span> <span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="nb">ord</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">ar3d</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="n">Out</span><span class="p">[</span><span class="mi">51</span><span class="p">]:</span> <span class="s1">&#39;1 11 2 12 3 13 4 14 5 15 6 16 7 17 8 18&#39;</span> <span class="n">In</span> <span class="p">[</span><span class="mi">52</span><span class="p">]:</span> <span class="n">ar3df</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">lst3d</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;uint8&#39;</span><span class="p">,</span> <span class="n">order</span><span class="o">=</span><span class="s1">&#39;F&#39;</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">53</span><span class="p">]:</span> <span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="nb">ord</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">ar3df</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="n">Out</span><span class="p">[</span><span class="mi">53</span><span class="p">]:</span> <span class="s1">&#39;1 5 2 6 3 7 4 8 11 15 12 16 13 17 14 18&#39;</span> </pre></div> <p>Note that in C layout (row-major), the first dimension (rows) changes the slowest while the third dimension (depth) changes the fastest. In Fortran layout (column-major) the first dimension changes the fastest while the third dimension changes the slowest.</p> </div> <div class="section" id="performance-why-it-s-worth-caring-which-layout-your-data-is-in"> <h2>Performance: why it's worth caring which layout your data is in</h2> <p>After reading the article thus far, one may wonder why any of this matters. Isn't it just another way of divergence of standards, a-la endianness? As long as we all agree on the layout, isn't this just a boring implementation detail? Why would we care about this?</p> <p>The answer is: performance. We're talking about numerical computing here (number crunching on large data sets) where performance is almost always critical. It turns out that matching the way your algorithm works with the data layout can make or break the performance of an application.</p> <p>The short takeaway is: <strong>always traverse the data in the order it was laid out</strong>. If your data sits in memory in row-major layout, iterate over each row before going to the next one, etc. The rest of the section will explain why this is so and will also present a benchmask with some measurements to get a feel of the consequences of this decision.</p> <p>There are two aspects of modern computer architecture that have a large impact on code performance and are relevant to our discussion: caching and vector units. When we iterate over each row of a row-major array, we access the array sequentially. This pattern has <a class="reference external" href="https://en.wikipedia.org/wiki/Locality_of_reference">spatial locality</a>, which makes the code perfect for cache optimization. Moreover, depending on the operations we do with the data, the CPU's vector unit can kick in since it also requires consecutive access.</p> <p>Graphically, it looks something like the following diagram. Let's say we have the array: <tt class="docutils literal">int <span class="pre">array</span></tt>, and we iterate over each row, jumping to the next one when all the columns in the current one were visited. The number within each gray cell is the memory address - it grows by 4 since this is an array if integers. The blue numbered arrow enumerates accesses in the order they are made:</p> <img alt="Row access pattern" class="align-center" src="https://eli.thegreenplace.net/images/2015/row-access-pattern.png" /> <p>Here, the optimal usage of caching and vector instructions should be obvious. Since we always access elements sequentially, this is the perfect scenario for the CPU's caches to kick in - we will <em>always hit the cache</em>. In fact, we always hit the fastest cache - L1, because the CPU correctly pre-fetches all data ahead.</p> <p>Moreover, since we always read one 32-bit word <a class="footnote-reference" href="#id5" id="id2"></a> after another, we can leverage the CPU's vector units to load the data (and perhaps process is later). The purple arrows show how this can be done with SSE vector loads that grab 128-bit chunks (four 32-bit words) at a time. In actual code, this can either be done with intrinsics or by relying on the compiler's auto-vectorizer (as we will soon see in an actual code sample).</p> <p>Contrast this with accessing this row-major data one <em>column</em> at a time, iterating over each column before moving to the next one:</p> <img alt="Column access pattern" class="align-center" src="https://eli.thegreenplace.net/images/2015/column-access-pattern.png" /> <p>We lose spatial locality here, unless the array is very narrow. If there are few columns, consecutive rows <em>may</em> be found in the cache. However, in more typical applications the arrays are large and when access #2 happens it's likely that the memory it accesses is nowhere to be found in the cache. Unsurprisingly, we also lose the vector units since the accesses are not made to consecutive memory.</p> <p>But what should you do if your algorithm <em>needs</em> to access data column-by-column rather than row-by-row? Very simple! This is precisely what column-major layout is for. With column-major data, this access pattern will hit all the same architectural sweetspots we've seen with consecutive access on row-major data.</p> <p>The diagrams above should be convincing enough, but let's do some actual measurements to see just how dramatic these effects are.</p> <p>The full code for the benchmark is <a class="reference external" href="https://github.com/eliben/code-for-blog/tree/master/2015/benchmark-row-col-major">available here</a>, so I'll just show a few selected snippets. We'll start with a basic matrix type laid out in linear memory:</p> <div class="highlight"><pre><span></span><span class="c1">// A simple Matrix of unsigned integers laid out row-major in a 1D array. M is</span> <span class="c1">// number of rows, N is number of columns.</span> <span class="k">struct</span> <span class="n">Matrix</span> <span class="p">{</span> <span class="kt">unsigned</span><span class="o">*</span> <span class="n">data</span> <span class="o">=</span> <span class="k">nullptr</span><span class="p">;</span> <span class="kt">size_t</span> <span class="n">M</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="p">};</span> </pre></div> <p>The matrix is using row-major layout: its elements are accessed using this C expression:</p> <div class="highlight"><pre><span></span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">x</span><span class="p">.</span><span class="n">N</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> </pre></div> <p>Here's a function that adds two such matrices together, using a &quot;bad&quot; access pattern - iterating over the the rows in each column before going to the next column. The access patter is very easy to spot looking at C code - the inner loop is the faster-changing index, and in this case it's rows:</p> <div class="highlight"><pre><span></span><span class="kt">void</span> <span class="nf">AddMatrixByCol</span><span class="p">(</span><span class="n">Matrix</span><span class="o">&amp;</span> <span class="n">y</span><span class="p">,</span> <span class="k">const</span> <span class="n">Matrix</span><span class="o">&amp;</span> <span class="n">x</span><span class="p">)</span> <span class="p">{</span> <span class="n">assert</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">M</span> <span class="o">==</span> <span class="n">x</span><span class="p">.</span><span class="n">M</span><span class="p">);</span> <span class="n">assert</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">N</span> <span class="o">==</span> <span class="n">x</span><span class="p">.</span><span class="n">N</span><span class="p">);</span> <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">col</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">col</span> <span class="o">&lt;</span> <span class="n">y</span><span class="p">.</span><span class="n">N</span><span class="p">;</span> <span class="o">++</span><span class="n">col</span><span class="p">)</span> <span class="p">{</span> <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">row</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">row</span> <span class="o">&lt;</span> <span class="n">y</span><span class="p">.</span><span class="n">M</span><span class="p">;</span> <span class="o">++</span><span class="n">row</span><span class="p">)</span> <span class="p">{</span> <span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">y</span><span class="p">.</span><span class="n">N</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> <span class="o">+=</span> <span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">x</span><span class="p">.</span><span class="n">N</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span> <span class="p">}</span> <span class="p">}</span> <span class="p">}</span> </pre></div> <p>And here's a version that uses a better pattern, iterating over the columns in each row before going to the next row:</p> <div class="highlight"><pre><span></span><span class="kt">void</span> <span class="nf">AddMatrixByRow</span><span class="p">(</span><span class="n">Matrix</span><span class="o">&amp;</span> <span class="n">y</span><span class="p">,</span> <span class="k">const</span> <span class="n">Matrix</span><span class="o">&amp;</span> <span class="n">x</span><span class="p">)</span> <span class="p">{</span> <span class="n">assert</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">M</span> <span class="o">==</span> <span class="n">x</span><span class="p">.</span><span class="n">M</span><span class="p">);</span> <span class="n">assert</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">N</span> <span class="o">==</span> <span class="n">x</span><span class="p">.</span><span class="n">N</span><span class="p">);</span> <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">row</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">row</span> <span class="o">&lt;</span> <span class="n">y</span><span class="p">.</span><span class="n">M</span><span class="p">;</span> <span class="o">++</span><span class="n">row</span><span class="p">)</span> <span class="p">{</span> <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">col</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">col</span> <span class="o">&lt;</span> <span class="n">y</span><span class="p">.</span><span class="n">N</span><span class="p">;</span> <span class="o">++</span><span class="n">col</span><span class="p">)</span> <span class="p">{</span> <span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">y</span><span class="p">.</span><span class="n">N</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> <span class="o">+=</span> <span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">x</span><span class="p">.</span><span class="n">N</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span> <span class="p">}</span> <span class="p">}</span> <span class="p">}</span> </pre></div> <p>How do the two access patterns compare? Based on the discussion in this article, we'd expect the by-row access pattern to be faster. But how much faster? And what role does vectorization play vs. efficient usage of cache?</p> <p>To try this, I ran the access patterns on matrices of various sizes, and added a variation of the by-row pattern where vectorization is disabled <a class="footnote-reference" href="#id6" id="id3"></a>. Here are the results; the vertical bars represent the bandwidth - how many billions of items (32-bit words) were processed (added) by the given function.</p> <img alt="Benchmark results" class="align-center" src="https://eli.thegreenplace.net/images/2015/rowcol-benchmark1.png" /> <p>Some observations:</p> <ul class="simple"> <li>For matrix sizes above 64x64, by-row access is significantly faster than by-column (6-8x, depending on size). In the case of 64x64, what I believe happens is that both matrices fit into the 32-KB L1 cache of my machine, so the by-column pattern actually manages to find the next row in cache. For larger sizes the matrices no longer fit in L1, so the by-column version has to go to L2 frequently.</li> <li>The vectorized version beats the non-vectorized one by 2-3x in all cases. On large matrices the speedup is a bit smaller; I think this is because at 256x256 and beyond the matrices no longer fit in L2 (my machine has 256KB of it) and needs slower memory access. So the CPU spends a bit more time waiting for memory on average.</li> <li>The overall speedup of the vectorized by-row access over the by-column access is enormous - up to 25x for large matrices.</li> </ul> <p>I'll have to admit that, while I expected the by-row access to be faster, I didn't expect it to be <em>this much</em> faster. Clearly, choosing the proper access pattern for the memory layout of the data is absolutely crucial for the performance of an application.</p> </div> <div class="section" id="summary"> <h2>Summary</h2> <p>This article examined the issue of multi-dimensional array layout from multiple angles. The main takeaway is: know how your data is laid out and access it accordingly. In C-based programming languages, even though the default layout for 2D-arrays is row-major, when we use pointers to dynamically allocated data, we are free to choose whatever layout we like. After all, multi-dimensional arrays are just a logical abstraction above a linear storage system.</p> <p>Due to the wonders of modern CPU architectures, choosing the &quot;right&quot; way to access multi-dimensional data may result in colossal speedups; therefore, this is something that should always be on the programmer's mind when working on large multi-dimensional data sets.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>Taken from draft n1570 of the C11 standard.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id5" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>The term &quot;word&quot; used to be clearly associated with a 16-bit entity at some point in the past (with &quot;double word&quot; meaning 32 bits and so on), but these days it's too overloaded. In various references online you'll find &quot;word&quot; to be anything from 16 to 64 bits, depending on the CPU architecture. So I'm going to deliberately side-step the confusion by explicitly mentioning the bit size of words.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>See the <a class="reference external" href="https://github.com/eliben/code-for-blog/tree/master/2015/benchmark-row-col-major">benchmark repository</a> for the full details, including function attributes and compiler flags. A special thanks goes to <a class="reference external" href="https://twitter.com/nadavrot">Nadav Rotem</a> for helping me think through an issue I was initially having due to g++ ignoring my <tt class="docutils literal"><span class="pre">no-tree-vectorize</span></tt> attribute when inlining the function into the benchmark. I turned off inlining to fix this.</td></tr> </tbody> </table> </div> Change of basis in Linear Algebra2015-07-23T05:35:00-07:002015-07-23T05:35:00-07:00Eli Benderskytag:eli.thegreenplace.net,2015-07-23:/2015/change-of-basis-in-linear-algebra/<p>Knowing how to convert a vector to a different basis has many practical applications. Gilbert Strang has a nice quote about the importance of basis changes in his book <a class="footnote-reference" href="#id6" id="id1"></a> (emphasis mine):</p> <blockquote> The standard basis vectors for <img alt="\mathbb{R}^n" class="valign-0" src="https://eli.thegreenplace.net/images/math/98165cf6e8d5d442e040d1fa47aa6845f09294c5.png" style="height: 12px;" /> and <img alt="\mathbb{R}^m" class="valign-0" src="https://eli.thegreenplace.net/images/math/91d9290b46ace1360a8a715bd7a1fa701277697b.png" style="height: 12px;" /> are the columns of <em>I</em>. That choice leads to a standard matrix …</blockquote><p>Knowing how to convert a vector to a different basis has many practical applications. Gilbert Strang has a nice quote about the importance of basis changes in his book <a class="footnote-reference" href="#id6" id="id1"></a> (emphasis mine):</p> <blockquote> The standard basis vectors for <img alt="\mathbb{R}^n" class="valign-0" src="https://eli.thegreenplace.net/images/math/98165cf6e8d5d442e040d1fa47aa6845f09294c5.png" style="height: 12px;" /> and <img alt="\mathbb{R}^m" class="valign-0" src="https://eli.thegreenplace.net/images/math/91d9290b46ace1360a8a715bd7a1fa701277697b.png" style="height: 12px;" /> are the columns of <em>I</em>. That choice leads to a standard matrix, and <img alt="T(v)=Av" class="valign-m4" src="https://eli.thegreenplace.net/images/math/bb2fe0bcb727e67170597176144917877c871201.png" style="height: 18px;" /> in the normal way. But these spaces also have other bases, so the same <em>T</em> is represented by other matrices. <strong>A main theme of linear algebra is to choose the bases that give the best matrix for T</strong>.</blockquote> <p>This should serve as a good motivation, but I'll leave the applications for future posts; in this one, I will focus on the mechanics of basis change, starting from first principles.</p> <div class="section" id="the-basis-and-vector-components"> <h2>The basis and vector components</h2> <p>A <em>basis</em> of a vector space <img alt="V" class="valign-0" src="https://eli.thegreenplace.net/images/math/c9ee5681d3c59f7541c27a38b67edf46259e187b.png" style="height: 12px;" /> is a set of vectors in <img alt="V" class="valign-0" src="https://eli.thegreenplace.net/images/math/c9ee5681d3c59f7541c27a38b67edf46259e187b.png" style="height: 12px;" /> that is linearly independent and spans <img alt="V" class="valign-0" src="https://eli.thegreenplace.net/images/math/c9ee5681d3c59f7541c27a38b67edf46259e187b.png" style="height: 12px;" />. An <em>ordered basis</em> is a list, rather than a set, meaning that the order of the vectors in an ordered basis matters. This is important with respect to the topics discussed in this post.</p> <p>Let's now define <em>components</em>. If <img alt="U = u_1,u_2,...,u_n" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7c4259c4e451f25663d0e2b0a5171ec904eacf1e.png" style="height: 16px;" /> is an ordered basis for <img alt="V" class="valign-0" src="https://eli.thegreenplace.net/images/math/c9ee5681d3c59f7541c27a38b67edf46259e187b.png" style="height: 12px;" /> and <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" /> is a vector in <img alt="V" class="valign-0" src="https://eli.thegreenplace.net/images/math/c9ee5681d3c59f7541c27a38b67edf46259e187b.png" style="height: 12px;" />, then there's a unique <a class="footnote-reference" href="#id7" id="id2"></a> list of scalars <img alt="c_1,c_2,...,c_n" class="valign-m4" src="https://eli.thegreenplace.net/images/math/aa078c610a3018c0b8a60fb3f7625854c7ee0667.png" style="height: 12px;" /> such that:</p> <img alt="$v = c_1u_1+c_2u_2+...+c_nu_n$" class="align-center" src="https://eli.thegreenplace.net/images/math/400f6b84c3ee13d328880d7b29bb7c467c868a33.png" style="height: 14px;" /> <p>These are called the <em>components</em> of <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" /> relative to the ordered basis <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" />. We'll introduce a useful piece of notation here: collect the components <img alt="c_1,c_2,...,c_n" class="valign-m4" src="https://eli.thegreenplace.net/images/math/aa078c610a3018c0b8a60fb3f7625854c7ee0667.png" style="height: 12px;" /> into a column vector and call it <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/dc587d1ab07e4744144f02d47abbf148b6c339d4.png" style="height: 18px;" />: this is the <em>component vector</em> of <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" /> relative to the basis <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" />.</p> </div> <div class="section" id="example-finding-a-component-vector"> <h2>Example: finding a component vector</h2> <p>Let's use <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> as an example. <img alt="U=(2,3), (4,5)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8bbbb50c23562c7c7dfe92d51af940199d7b366e.png" style="height: 18px;" /> is an ordered basis for <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> (since the two vectors in it are independent). Say we have <img alt="v=(2,4)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2c140a090873ddce6a3a86023428c2c72250791e.png" style="height: 18px;" />. What is <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/dc587d1ab07e4744144f02d47abbf148b6c339d4.png" style="height: 18px;" />? We'll need to solve the system of equations:</p> <img alt="$\begin{pmatrix} 2 \\ 4 \end{pmatrix}=c_1\begin{pmatrix} 2 \\ 3\end{pmatrix}+c_2\begin{pmatrix} 4 \\ 5 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/1793547a360ef6b1c94da80aebb1698bcd69e20e.png" style="height: 43px;" /> <p>In the 2-D case this is trivial - the solution is <img alt="c_1=3" class="valign-m4" src="https://eli.thegreenplace.net/images/math/a3a4efbce7552fed25d7566f2aa7bb187d035471.png" style="height: 16px;" /> and <img alt="c_2=-1" class="valign-m3" src="https://eli.thegreenplace.net/images/math/fae1728a9257249837a269fc81efb617a999f2a7.png" style="height: 15px;" />. Therefore:</p> <img alt="$[v]_{\text {\tiny U}}=\begin{pmatrix} 3 \\ -1 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/548b7f57bfc01932593b5cdaa597b77f531bd03b.png" style="height: 43px;" /> <p>In the more general case of <img alt="\mathbb{R}^n" class="valign-0" src="https://eli.thegreenplace.net/images/math/98165cf6e8d5d442e040d1fa47aa6845f09294c5.png" style="height: 12px;" />, this is akin to solving a linear system of n equations with n variables. Since the basis vectors are, by definition, linearly independent, solving the system is simply inverting a matrix <a class="footnote-reference" href="#id8" id="id3"></a>.</p> </div> <div class="section" id="change-of-basis-matrix"> <h2>Change of basis matrix</h2> <p>Now comes the key part of the post. Say we have two different ordered bases for the same vector space: <img alt="U = u_1,u_2,...,u_n" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7c4259c4e451f25663d0e2b0a5171ec904eacf1e.png" style="height: 16px;" /> and <img alt="W= w_1,w_2,...,w_n" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2eb203def2a32478b504f9250479c3f56defe9c9.png" style="height: 16px;" />. For some <img alt="v\in V" class="valign-m1" src="https://eli.thegreenplace.net/images/math/081239435d752122bef07934bbfe0662cc5228e6.png" style="height: 13px;" />, we can find <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/9c044f231102324a9c84edf98b7f5f37bcdc2e2e.png" style="height: 18px;" /> and <img alt="[v]_{\text{\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/f8ae14559a972991d3c72d1014db829284f86f6a.png" style="height: 18px;" />. How are these two related?</p> <p>Surely, given <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/dc587d1ab07e4744144f02d47abbf148b6c339d4.png" style="height: 18px;" /> we can find its coefficients in basis <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> the same way as we did in the example above <a class="footnote-reference" href="#id9" id="id4"></a>. It involves solving a linear system of <img alt="n" class="valign-0" src="https://eli.thegreenplace.net/images/math/d1854cae891ec7b29161ccaf79a24b00c274bdaa.png" style="height: 8px;" /> equations. We'll have to redo this operation for every vector <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" /> we want to convert. Is there a simpler way?</p> <p>Luckily for science, yes. The key here is to find how the basis vectors of <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> look in basis <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />. In other words, we have to find <img alt="[u_1]_{\text{\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/f53ee2bf40857eff2038d6543b07f0cbcf02a651.png" style="height: 18px;" />, <img alt="[u_2]_{\text{\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/f8f94de8a4ccd3f2d7594afd45caae1403310968.png" style="height: 18px;" /> and so on to <img alt="[u_n]_{\text{\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/3f2515b5f7e26704e96f99b9947392836689cf34.png" style="height: 18px;" />.</p> <p>Let's say we do that and find the coefficients to be <img alt="a_{ij}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/f50d06328d8d076870d59691bb4b30fcf23c8f08.png" style="height: 14px;" /> such that:</p> <img alt="$\begin{matrix} u_1=a_{11}w_1+a_{21}w_2+...+a_{n1}w_n \\ u_2=a_{12}w_1+a_{22}w_2+...+a_{n2}w_n \\ ... \\ u_n=a_{1n}w_1+a_{2n}w_2+...+a_{nn}w_n \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/4c5dd3dc9c5d0acedafd6cd20c09a4498802577a.png" style="height: 80px;" /> <p>Now, given some vector <img alt="v \in V" class="valign-m1" src="https://eli.thegreenplace.net/images/math/bba1fc2f81d8879b4f45b8874c136db6be494079.png" style="height: 13px;" />, suppose its components in basis <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> are:</p> <img alt="$[v]_{\text{\tiny U}}=\begin{pmatrix} c_1 \\ c_2 \\ ... \\ c_n \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/83437d273512b461cab0f132626d1d64df3b32ae.png" style="height: 86px;" /> <p>Let's try to figure out how it looks in basis <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />. The above equation (by definition of components) is equivalent to:</p> <img alt="$v=c_1u_1+c_2u_2+...+c_nu_n$" class="align-center" src="https://eli.thegreenplace.net/images/math/fe3fcc27f581a058afb05460629e332bc2fae909.png" style="height: 14px;" /> <p>Substituting the expansion of the <img alt="u" class="valign-0" src="https://eli.thegreenplace.net/images/math/51e69892ab49df85c6230ccc57f8e1d1606caccc.png" style="height: 8px;" />s in basis <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />, we get:</p> <img alt="$v=\begin{matrix} c_1(a_{11}w_1+a_{21}w_2+...+a_{n1}w_n)+ \\ c_2(a_{12}w_1+a_{22}w_2+...+a_{n2}w_n)+ \\ ... \\ c_n(a_{1n}w_1+a_{2n}w_2+...+a_{nn}w_n) \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/450faf4b2cc27042f6b5fd90cdf39b6588f89e67.png" style="height: 84px;" /> <p>Reordering a bit to find the multipliers of each <img alt="w" class="valign-0" src="https://eli.thegreenplace.net/images/math/aff024fe4ab0fece4091de044c58c9ae4233383a.png" style="height: 8px;" />:</p> <img alt="$v=\begin{matrix} (c_1a_{11}+c_2a_{12}+...+c_na_{1n})w_1+ \\ (c_1a_{21}+c_2a_{22}+...+c_na_{2n})w_2+ \\ ... \\ (c_1a_{n1}+c_2a_{n2}+...+c_na_{nn})w_n \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/2504c20a0377cc5defb98814b0eed08043a2bfc3.png" style="height: 84px;" /> <p>By our definition of vector components, this equation is equivalent to:</p> <img alt="$[v]_{\text{\tiny W}}=\begin{pmatrix} c_1a_{11}+c_2a_{12}+...+c_na_{1n} \\ c_1a_{21}+c_2a_{22}+...+c_na_{2n} \\ ... \\ c_1a_{n1}+c_2a_{n2}+...+c_na_{nn} \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/0c385a0913584dc008dc0f088d429cf7fc432002.png" style="height: 86px;" /> <p>Now we're in vector notation again, so we can decompose the column vector on the right hand side to:</p> <img alt="$[v]_{\text{\tiny W}}=\begin{pmatrix} a_{11} &amp;amp; a_{12} &amp;amp; ... &amp;amp; a_{1n} \\ a_{21} &amp;amp; a_{22} &amp;amp; ... &amp;amp; a_{2n} \\ ... &amp;amp; ... &amp;amp; ... \\ a_{n1} &amp;amp; a_{n2} &amp;amp; ... &amp;amp; a_{nn} \end{pmatrix}\begin{pmatrix}c_1 \\ c_2 \\ ... \\ c_n \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/b55ca14e97d3c8c1be864c00d4995f02f0406845.png" style="height: 86px;" /> <p>This is matrix times a vector. The vector on the right is <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/dc587d1ab07e4744144f02d47abbf148b6c339d4.png" style="height: 18px;" />. The matrix should look familiar too because it consists of those <img alt="a_{ij}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/f50d06328d8d076870d59691bb4b30fcf23c8f08.png" style="height: 14px;" /> coefficients we've defined above. In fact, this matrix just represents the basis vectors of <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> expressed in basis <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />. Let's call this matrix <img alt="A_{\text{\tiny U}\rightarrow \text{\tiny W}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e98bd0329cf77376132b69670177abb0c09fd70a.png" style="height: 16px;" /> - the change of basis matrix from <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> to <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />. It has <img alt="[u_1]_{\text{\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/f53ee2bf40857eff2038d6543b07f0cbcf02a651.png" style="height: 18px;" /> to <img alt="[u_n]_{\text{\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/3f2515b5f7e26704e96f99b9947392836689cf34.png" style="height: 18px;" /> laid out in its columns:</p> <img alt="$A_{\text{\tiny U}\rightarrow \text{\tiny W}}=\begin{pmatrix}[u_1]_{\text{\tiny W}},[u_2]_{\text{\tiny W}},...,[u_n]_{\text{\tiny W}}]\end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/48c37598af067782c747688c6f2f1b037539c14a.png" style="height: 22px;" /> <p>So we have:</p> <img alt="$[v]_{\text{\tiny W}}=A_{\text{\tiny U}\rightarrow \text{\tiny W}}[v]_{\text{\tiny U}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/40ce253d22d14c257a969dca84539e9d06be237d.png" style="height: 18px;" /> <p>To recap, given two bases <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> and <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />, we can spend some effort to compute the &quot;change of basis&quot; matrix <img alt="A_{\text{\tiny U}\rightarrow \text{\tiny W}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e98bd0329cf77376132b69670177abb0c09fd70a.png" style="height: 16px;" />, but then we can easily convert any vector in basis <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> to basis <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> if we simply left-multiply it by this matrix.</p> <p>A reasonable question to ask at this point is - what about converting from <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> to <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" />? Well, since the computations above are completely generic and don't special-case either base, we can just flip the roles of <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> and <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> and get another change of basis matrix, <img alt="A_{\text{\tiny W}\rightarrow \text{\tiny U}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7ead67eb4f93b24b2121304e1fa7fe62116cd30d.png" style="height: 16px;" /> - it converts vectors in base <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> to vectors in base <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> as follows:</p> <img alt="$[v]_{\text{\tiny U}}=A_{\text{\tiny W}\rightarrow \text{\tiny U}}[v]_{\text{\tiny W}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/11b3d1909fe5b2e306590dbb6c5ab4b99c911e43.png" style="height: 18px;" /> <p>And this matrix is:</p> <img alt="$A_{\text{\tiny W}\rightarrow \text{\tiny U}}=\begin{pmatrix}[w_1]_{\text{\tiny U}},[w_2]_{\text{\tiny U}},...,[w_n]_{\text{\tiny U}}]\end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/a3d5cd46635be2f2936ce51aa9c484d5c491a5be.png" style="height: 22px;" /> <p>We will soon see that the two change of basis matrices are intimately related; but first, an example.</p> </div> <div class="section" id="example-changing-bases-with-matrices"> <h2>Example: changing bases with matrices</h2> <p>Let's work through another concrete example in <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" />. We've used the basis <img alt="U=(2,3), (4,5)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8bbbb50c23562c7c7dfe92d51af940199d7b366e.png" style="height: 18px;" /> before; let's use it again, and also add the basis <img alt="W=(-1,1), (1,1)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/880182e231343144829cea71b4a367c8308bfff1.png" style="height: 18px;" />. We've already seen that for <img alt="v=(2,4)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2c140a090873ddce6a3a86023428c2c72250791e.png" style="height: 18px;" /> we have:</p> <img alt="$[v]_{\text {\tiny U}}=\begin{pmatrix} 3 \\ -1 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/548b7f57bfc01932593b5cdaa597b77f531bd03b.png" style="height: 43px;" /> <p>Similarly, we can solve a set of two equations to find <img alt="[v]_{\text {\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/81eaa18746b62d0b9942378509bc40309c799e6a.png" style="height: 18px;" />:</p> <img alt="$[v]_{\text {\tiny W}}=\begin{pmatrix} 1 \\ 3 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/cef00c5fb242059409d28cfd7a9de38cb87839a3.png" style="height: 43px;" /> <p>OK, let's see how a change of basis matrix can be used to easily compute one given the other. First, to find <img alt="A_{\text{\tiny U}\rightarrow \text{\tiny W}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b731f16fca15beb2c2a898400d28313be0adbfe9.png" style="height: 16px;" /> we'll need <img alt="[u_1]_{\text {\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/3eeb78c68e01e053bad97be52f362f84bf4ba536.png" style="height: 18px;" /> and <img alt="[u_2]_{\text {\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/a9401eed9ea04bfd2528cc46d8cc664b0fb5b35c.png" style="height: 18px;" />. We know how to do that. The result is:</p> <img alt="$[u_1]_{\text {\tiny W}}=\begin{pmatrix} 0.5 \\ 2.5 \end{pmatrix}\qquad[u_2]_{\text {\tiny W}}=\begin{pmatrix} 0.5 \\ 4.5 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/5eca6d9d809f75e136f1619b08ac6677448406d6.png" style="height: 43px;" /> <p>Now we can verify that given <img alt="[v]_{\text {\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/344b247c2488e81b7d0dcb75a6f5addd8746d0d9.png" style="height: 18px;" /> and <img alt="A_{\text{\tiny U}\rightarrow \text{\tiny W}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e599bac152aa7b275b4aa6f8b9ad2104a400cd34.png" style="height: 16px;" />, we can easily find <img alt="[v]_{\text {\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/ef43978234a008fca8fb67a5534bd6880fb1771f.png" style="height: 18px;" />:</p> <img alt="$[v]_{\text{\tiny W}}=A_{\text{\tiny U}\rightarrow \text{\tiny W}}[v]_{\text{\tiny U}}= \\ \begin{pmatrix} 0.5 &amp;amp; 0.5 \\ 2.5 &amp;amp; 4.5 \end{pmatrix} \\ \begin{pmatrix} 3 \\ -1 \end{pmatrix}=\\ \begin{pmatrix} 1 \\ 3 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/95af5086112fead5258aabba524167e746969d37.png" style="height: 43px;" /> <p>Indeed, it checks out! Let's also verify the other direction. To find <img alt="A_{\text{\tiny W}\rightarrow \text{\tiny U}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7ead67eb4f93b24b2121304e1fa7fe62116cd30d.png" style="height: 16px;" /> we'll need <img alt="[w_1]_{\text {\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/6312e3ac9abfb423dcc52556e5fe037845c20cb6.png" style="height: 18px;" /> and <img alt="[w_2]_{\text {\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/783adfcaeb03f35b08d163ff1180bd038d777d0d.png" style="height: 18px;" />:</p> <img alt="$[w_1]_{\text {\tiny U}}=\begin{pmatrix} 4.5 \\ -2.5 \end{pmatrix}\qquad[w_2]_{\text {\tiny U}}=\begin{pmatrix}- 0.5 \\ 0.5 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/028bc2e59c73e299a1cdbaf05df5ed605b737512.png" style="height: 43px;" /> <p>And now to find <img alt="[v]_{\text {\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/344b247c2488e81b7d0dcb75a6f5addd8746d0d9.png" style="height: 18px;" />:</p> <img alt="$[v]_{\text{\tiny U}}=A_{\text{\tiny W}\rightarrow \text{\tiny U}}[v]_{\text{\tiny W}}= \\ \begin{pmatrix} 4.5 &amp;amp; -0.5 \\ -2.5 &amp;amp; 0.5 \end{pmatrix} \\ \begin{pmatrix} 1 \\ 3 \end{pmatrix}=\\ \begin{pmatrix} 3 \\ -1 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/4e697d4673dcffaf65875566c470a2113defa3b0.png" style="height: 43px;" /> <p>Checks out again! If you have a keen eye, or have recently spent some time solving linar algebra problems, you'll notice something interesting about the two basis change matrices used in this example. One is an inverse of the other! Is this some sort of coincidence? No - in fact, it's always true, and we can prove it.</p> </div> <div class="section" id="the-inverse-of-a-change-of-basis-matrix"> <h2>The inverse of a change of basis matrix</h2> <p>We've derived the change of basis matrix from <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> to <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> to perform the conversion:</p> <img alt="$[v]_{\text{\tiny W}}=A_{\text{\tiny U}\rightarrow \text{\tiny W}}[v]_{\text{\tiny U}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/40ce253d22d14c257a969dca84539e9d06be237d.png" style="height: 18px;" /> <p>Left-multiplying this equation by <img alt="A_{\text{\tiny W}\rightarrow \text{\tiny U}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2bcd0ad887dc21b0aa83a4d6dd234afebd2ed56a.png" style="height: 16px;" />:</p> <img alt="$A_{\text{\tiny W}\rightarrow \text{\tiny U}}[v]_{\text{\tiny W}}=\\ A_{\text{\tiny W}\rightarrow \text{\tiny U}}A_{\text{\tiny U}\rightarrow \text{\tiny W}}[v]_{\text{\tiny U}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/487b88a4e75475a46621a4cdbffb7fc37e30c920.png" style="height: 18px;" /> <p>But the left-hand side is now, by our earlier definition, equal to <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/dc587d1ab07e4744144f02d47abbf148b6c339d4.png" style="height: 18px;" />, so we get:</p> <img alt="$[v]_{\text{\tiny U}}=\\ A_{\text{\tiny W}\rightarrow \text{\tiny U}}A_{\text{\tiny U}\rightarrow \text{\tiny W}}[v]_{\text{\tiny U}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/8d6e70905656f42e21896dda00ce2590f6218766.png" style="height: 18px;" /> <p>Since this is true for every vector <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/dc587d1ab07e4744144f02d47abbf148b6c339d4.png" style="height: 18px;" />, it must be that:</p> <img alt="$A_{\text{\tiny W}\rightarrow \text{\tiny U}}A_{\text{\tiny U}\rightarrow \text{\tiny W}}=I$" class="align-center" src="https://eli.thegreenplace.net/images/math/ff93ca3481c29a874a9ab5e903321b1d6c4e38f0.png" style="height: 15px;" /> <p>From this, we can infer that <img alt="A_{\text{\tiny W}\rightarrow \text{\tiny U}}=A_{\text{\tiny U}\rightarrow \text{\tiny W}}^{-1}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0055abbae1609478df5d1ff82e895db6c11a01d5.png" style="height: 20px;" /> and vice versa <a class="footnote-reference" href="#id10" id="id5"></a>.</p> </div> <div class="section" id="changing-to-and-from-the-standard-basis"> <h2>Changing to and from the standard basis</h2> <p>You may have noticed that in the examples above, we short-circuited a little bit of rigor by making up a vector (such as <img alt="v=(2,4)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2c140a090873ddce6a3a86023428c2c72250791e.png" style="height: 18px;" />) without explicitly specifying the basis its components are relative to. This is because we're so used to working with the &quot;standard basis&quot; we often forget it's there.</p> <p>The standard basis (let's call it <img alt="E" class="valign-0" src="https://eli.thegreenplace.net/images/math/e0184adedf913b076626646d3f52c3b49c39ad6d.png" style="height: 12px;" />) consists of unit vectors pointing in the directions of the axes of a Cartesian coordinate system. For <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> we have the basis vectors:</p> <img alt="$e_1=\begin{pmatrix} 1 \\ 0 \end{pmatrix}\qquad e_2=\begin{pmatrix} 0 \\ 1 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/cae83340f574d5d86529a8fc7e8bb578337b027b.png" style="height: 43px;" /> <p>And more generally in <img alt="\mathbb{R}^n" class="valign-0" src="https://eli.thegreenplace.net/images/math/98165cf6e8d5d442e040d1fa47aa6845f09294c5.png" style="height: 12px;" /> we have an ordered list of <img alt="n" class="valign-0" src="https://eli.thegreenplace.net/images/math/d1854cae891ec7b29161ccaf79a24b00c274bdaa.png" style="height: 8px;" /> vectors <img alt="\left\{ e_i:1\leq i \leq n \right\}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/465709e067cd5b2878edc2a90ccc3a6074bb1a25.png" style="height: 18px;" /> where <img alt="e_i" class="valign-m3" src="https://eli.thegreenplace.net/images/math/067d6602e65a6d628c3a60782ace6c359848f4bc.png" style="height: 11px;" /> has 1 in the <img alt="i" class="valign-0" src="https://eli.thegreenplace.net/images/math/042dc4512fa3d391c5170cf3aa61e6a638f84342.png" style="height: 12px;" />th position and zeros elsewhere.</p> <p>So when we say <img alt="v=(2,4)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2c140a090873ddce6a3a86023428c2c72250791e.png" style="height: 18px;" />, what we actually mean is:</p> <img alt="$\begin{matrix} v=2e_1+4e_2 \\[1em] [v]_{\text {\tiny E}}=\begin{pmatrix} 2 \\ 4 \end{pmatrix} \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/6e3fdb31e8f3be4dab7cd31785a9b649dfd472bd.png" style="height: 80px;" /> <p>The standard basis is so ingrained in our intuition of vectors that we usually neglect to mention it. This is fine, as long as we're only dealing with the standard basis. Once change of basis is required, it's worthwhile to stick to a more consistent notation to avoid confusion. Moreover, it's often useful to change a vector's basis to or from the standard one. Let's see how that works. Recall how we use the change of basis matrix:</p> <img alt="$[v]_{\text{\tiny W}}=A_{\text{\tiny U}\rightarrow \text{\tiny W}}[v]_{\text{\tiny U}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/40ce253d22d14c257a969dca84539e9d06be237d.png" style="height: 18px;" /> <p>Replacing the arbitrary basis <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> by the standard basis <img alt="E" class="valign-0" src="https://eli.thegreenplace.net/images/math/e0184adedf913b076626646d3f52c3b49c39ad6d.png" style="height: 12px;" /> in this equation, we get:</p> <img alt="$[v]_{\text{\tiny E}}=A_{\text{\tiny U}\rightarrow \text{\tiny E}}[v]_{\text{\tiny U}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/db4f797209370d41a5069009b16269967a6ba3ea.png" style="height: 18px;" /> <p>And <img alt="A_{\text{\tiny U}\rightarrow \text{\tiny E}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/640b44ab89c8033a037e90b0e661937ade5327a4.png" style="height: 16px;" /> is the matrix with <img alt="[u_1]_{\text {\tiny E}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/d74589a95920181b845aff60ea0ece107b2bd337.png" style="height: 18px;" /> to <img alt="[u_n]_{\text {\tiny E}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/2dbd88f177d5bc8bdd5ab5aac11c4b465d9a7406.png" style="height: 18px;" /> in its columns. But wait, these are just the basis vectors of <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" />! So finding the matrix <img alt="A_{\text{\tiny U}\rightarrow \text{\tiny E}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/640b44ab89c8033a037e90b0e661937ade5327a4.png" style="height: 16px;" /> for any given basis <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> is trivial - simply line up <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" />'s basis vectors as columns in their order to get a matrix. This means that any square, invertible matrix can be seen as a change of basis matrix from the basis spelled out in its columns to the standard basis. This is a natural consequence of how multiplying a matrix by a vector works by <a class="reference external" href="http://eli.thegreenplace.net/2015/visualizing-matrix-multiplication-as-a-linear-combination">linearly combining the matrix's columns</a>.</p> <p>OK, so we know how to find <img alt="[v]_{\text {\tiny E}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/8ce451ac6ebbdcc88171aa67947c14a62f81a6d8.png" style="height: 18px;" /> given <img alt="[v]_{\text {\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/83026ce3725f05dfcac554c4d85a364a840e8958.png" style="height: 18px;" />. What about the other way around? We'll need <img alt="A_{\text{\tiny E}\rightarrow \text{\tiny U}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d1ff7b0595a95dd95bac5689d042db02fdbabdf3.png" style="height: 16px;" /> for that, and we know that:</p> <img alt="$A_{\text{\tiny E}\rightarrow \text{\tiny U}}=A_{\text{\tiny U}\rightarrow \text{\tiny E}}^{-1}$" class="align-center" src="https://eli.thegreenplace.net/images/math/60679b364e239e7764aecc891a5579a3fc204ea3.png" style="height: 22px;" /> <p>Therefore:</p> <img alt="$[v]_{\text{\tiny U}}=\\ A_{\text{\tiny E}\rightarrow \text{\tiny U}}[v]_{\text{\tiny E}}=\\ A_{\text{\tiny U}\rightarrow \text{\tiny E}}^{-1}[v]_{\text{\tiny E}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/c9c3e8722c55c44dc21aec2ba823cdb0c1f8a5a0.png" style="height: 22px;" /> </div> <div class="section" id="chaining-basis-changes"> <h2>Chaining basis changes</h2> <p>What happens if we change a vector from one basis to another, and then change the resulting vector to yet another basis? I mean, for bases <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" />, <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> and <img alt="T" class="valign-0" src="https://eli.thegreenplace.net/images/math/c2c53d66948214258a26ca9ca845d7ac0c17f8e7.png" style="height: 12px;" /> and some arbitrary vector <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" />, we'll do:</p> <img alt="$A_{\text{\tiny W}\rightarrow \text{\tiny T}}A_{\text{\tiny U}\rightarrow \text{\tiny W}}[v]_{\text{\tiny U}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/69c8d309f1e7fb5d65bd72570cdadc1769d315f0.png" style="height: 18px;" /> <p>This is simply applying the change of basis by matrix multiplication equation, twice:</p> <img alt="$A_{\text{\tiny W}\rightarrow \text{\tiny T}}(A_{\text{\tiny U}\rightarrow \text{\tiny W}}[v]_{\text{\tiny U}})=\\ A_{\text{\tiny W}\rightarrow \text{\tiny T}}[v]_{\text{\tiny W}}\\ =[v]_{\text{\tiny T}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/f80dd3f39e7e736972e31ec553e8541628ab038c.png" style="height: 19px;" /> <p>What this means is that changes of basis can be chained, which isn't surprising given their linear nature. It also means that we've just found <img alt="A_{\text{\tiny U}\rightarrow \text{\tiny T}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/48f2afad29bdd5da5db7740b67f78541aa502ac6.png" style="height: 16px;" />, since we found how to transform <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/dc587d1ab07e4744144f02d47abbf148b6c339d4.png" style="height: 18px;" /> to <img alt="[v]_{\text{\tiny T}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/d3758ce460ab170a8949bab249de63bc9bb0e739.png" style="height: 18px;" /> (using an intermediary basis <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />).</p> <img alt="$A_{\text{\tiny U}\rightarrow \text{\tiny T}}=\\ A_{\text{\tiny W}\rightarrow \text{\tiny T}}A_{\text{\tiny U}\rightarrow \text{\tiny W}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/e1235761ce18868df7f936908c80f49c464550bc.png" style="height: 15px;" /> <p>Finally, let's say that the indermediary basis is not just some arbitrary <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />, but the standard basis <img alt="E" class="valign-0" src="https://eli.thegreenplace.net/images/math/e0184adedf913b076626646d3f52c3b49c39ad6d.png" style="height: 12px;" />. So we have:</p> <img alt="$A_{\text{\tiny U}\rightarrow \text{\tiny T}}=\\ A_{\text{\tiny E}\rightarrow \text{\tiny T}}A_{\text{\tiny U}\rightarrow \text{\tiny E}}=\\ A_{\text{\tiny T}\rightarrow \text{\tiny E}}^{-1}A_{\text{\tiny U}\rightarrow \text{\tiny E}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/5e2c2ecd7ad9e15cfa07d8d9f5eef1c26479c4cd.png" style="height: 22px;" /> <p>We prefer the last form, since finding <img alt="A_{\text{\tiny U}\rightarrow \text{\tiny E}}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/35e8996ab14387788bcd66d1a160d0458efdc05f.png" style="height: 16px;" /> for any basis <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> is, as we've seen above, trivial.</p> </div> <div class="section" id="example-standard-basis-and-chaining"> <h2>Example: standard basis and chaining</h2> <p>It's time to solidify the ideas of the last two sections with a concrete example. We'll use our familiar bases <img alt="U=(2,3), (4,5)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8bbbb50c23562c7c7dfe92d51af940199d7b366e.png" style="height: 18px;" /> and <img alt="W=(-1,1), (1,1)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/880182e231343144829cea71b4a367c8308bfff1.png" style="height: 18px;" /> from the previous example, along with the standard basis for <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" />. Previously, we transformed a vector <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" /> from <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> to <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> and vice-versa using the change of basis matrices between these bases. This time, let's do it by chaining via the standard basis.</p> <p>We'll pick <img alt="v=(2,4)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2c140a090873ddce6a3a86023428c2c72250791e.png" style="height: 18px;" />. Formally, the components of <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" /> relative to the standard basis are:</p> <img alt="$[v]_{\text{\tiny E}} = \begin{pmatrix} 2 \\ 4 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/69766b4a4bc3f500ac1f25d3367774958f163084.png" style="height: 43px;" /> <p>In the last example we've already computed the components of <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" /> relative to <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> and <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" />:</p> <img alt="$[v]_{\text {\tiny U}}=\begin{pmatrix} 3 \\ -1 \end{pmatrix}\qquad [v]_{\text {\tiny W}}=\begin{pmatrix} 1 \\ 3 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/c7cac93827ffa171be55db031004356516fb98fa.png" style="height: 43px;" /> <p>Previously, one was computed from the other using the &quot;direct&quot; basis change matrices from <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> to <img alt="W" class="valign-0" src="https://eli.thegreenplace.net/images/math/e2415cb7f63df0c9de23362326ad3c37a9adfc96.png" style="height: 12px;" /> and vice versa. Now we can use chaining via the standard basis to achieve the same result. For example, we know that:</p> <img alt="$[v]_{\text{\tiny W}}=\\ A_{\text{\tiny E}\rightarrow \text{\tiny W}}A_{\text{\tiny U}\rightarrow \text{\tiny E}}[v]_{\text{\tiny U}}$" class="align-center" src="https://eli.thegreenplace.net/images/math/ec831bdd78639c2bf290e705c7efb0cb4908cd16.png" style="height: 18px;" /> <p>Finding the change of basis matrices from some basis to <img alt="E" class="valign-0" src="https://eli.thegreenplace.net/images/math/e0184adedf913b076626646d3f52c3b49c39ad6d.png" style="height: 12px;" /> is just laying out the basis vectors as columns, so we immediately know that:</p> <img alt="$A_{\text{\tiny U}\rightarrow \text{\tiny E}}=\begin{pmatrix} 2 &amp;amp; 4\\ 3 &amp;amp; 5 \end{pmatrix}\qquad \qquad \\ A_{\text{\tiny W}\rightarrow \text{\tiny E}}=\begin{pmatrix} -1 &amp;amp; 1\\ 1 &amp;amp; 1 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/56debe1dbb75f938c25fe7b4645b6602bb13e637.png" style="height: 43px;" /> <p>The change of basis matrix from <img alt="E" class="valign-0" src="https://eli.thegreenplace.net/images/math/e0184adedf913b076626646d3f52c3b49c39ad6d.png" style="height: 12px;" /> to some basis is the inverse, so by inverting the above matrices we find:</p> <img alt="$A_{\text{\tiny E}\rightarrow \text{\tiny U}}=A_{\text{\tiny U}\rightarrow \text{\tiny E}}^{-1}=\begin{pmatrix} -2.5 &amp;amp; 2 \\ 1.5 &amp;amp; -1 \end{pmatrix}\qquad \qquad \\ A_{\text{\tiny E}\rightarrow \text{\tiny W}}=A_{\text{\tiny W}\rightarrow \text{\tiny E}}^{-1}=\begin{pmatrix} -0.5 &amp;amp; 0.5 \\ 0.5 &amp;amp; 0.5 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/8810ea8701a2aba7df82e88302df93c890a26e26.png" style="height: 43px;" /> <p>Now we have all we need to find <img alt="[v]_{\text{\tiny W}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/f8ae14559a972991d3c72d1014db829284f86f6a.png" style="height: 18px;" /> from <img alt="[v]_{\text{\tiny U}}" class="valign-m5" src="https://eli.thegreenplace.net/images/math/dc587d1ab07e4744144f02d47abbf148b6c339d4.png" style="height: 18px;" />:</p> <img alt="$[v]_{\text{\tiny W}}=\\ A_{\text{\tiny E}\rightarrow \text{\tiny W}}A_{\text{\tiny U}\rightarrow \text{\tiny E}}[v]_{\text{\tiny U}}=\begin{pmatrix} -0.5 &amp;amp; 0.5 \\ 0.5 &amp;amp; 0.5 \end{pmatrix}\begin{pmatrix} 2 &amp;amp; 4\\ 3 &amp;amp; 5 \end{pmatrix}\begin{pmatrix} 3 \\ -1 \end{pmatrix}=\begin{pmatrix} 1 \\ 3 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/1f5875aba56faf3fda3c9f4c72b1421529958116.png" style="height: 43px;" /> <p>The other direction can be done similarly.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td><em>Introduction to Linear Algebra</em>, 4th edition, section 7.2</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id7" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>Why is this list unique? Because given a basis <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" /> for a vector space <img alt="V" class="valign-0" src="https://eli.thegreenplace.net/images/math/c9ee5681d3c59f7541c27a38b67edf46259e187b.png" style="height: 12px;" />, every <img alt="v\in V" class="valign-m1" src="https://eli.thegreenplace.net/images/math/081239435d752122bef07934bbfe0662cc5228e6.png" style="height: 13px;" /> can be expressed <em>uniquely</em> as a linear combination of the vectors in <img alt="U" class="valign-0" src="https://eli.thegreenplace.net/images/math/b2c7c0caa10a0cca5ea7d69e54018ae0c0389dd6.png" style="height: 12px;" />. The proof for this is very simple - just assume there are two different ways to express <img alt="v" class="valign-0" src="https://eli.thegreenplace.net/images/math/7a38d8cbd20d9932ba948efaa364bb62651d5ad4.png" style="height: 8px;" /> - two alternative sets of components. Subtract one from the other and use linear independence of the basis vectors to conclude that the two ways must be the same one.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>The matrix here has the basis vectors laid out in its columns. Since the basis vectors are independent, the matrix is invertible. In our small example, the matrix equation we're looking to solve is:</td></tr> </tbody> </table> <img alt="$\begin{pmatrix} 2 &amp;amp; 4 \\ 3 &amp;amp; 5 \end{pmatrix}\begin{pmatrix} c_1 \\ c_2 \end{pmatrix}=\begin{pmatrix} 2 \\ 4 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/6d840237e5940eaadf2002f888e8537e48e90158.png" style="height: 43px;" /> <table class="docutils footnote" frame="void" id="id9" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>The example converts from the standard basis to some other basis, but converting from a non-standard basis to another requires exactly the same steps: we try to find coefficients such that a combination of some set of basis vectors adds up to some components in another basis.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id5"></a></td><td>For square matrices <img alt="A" class="valign-0" src="https://eli.thegreenplace.net/images/math/6dcd4ce23d88e2ee9568ba546c007c63d9131c1b.png" style="height: 12px;" /> and <img alt="B" class="valign-0" src="https://eli.thegreenplace.net/images/math/ae4f281df5a5d0ff3cad6371f76d5c29b6d953ec.png" style="height: 12px;" />, if <img alt="AB=I" class="valign-0" src="https://eli.thegreenplace.net/images/math/845d8defe3847392f0e4b18b07786cd7f47ddf74.png" style="height: 12px;" /> then also <img alt="BA=I" class="valign-0" src="https://eli.thegreenplace.net/images/math/b1d87d31f656d8634f1e2d862810272a919c2806.png" style="height: 12px;" />.</td></tr> </tbody> </table> </div> The Normal Equation and matrix calculus2015-05-27T06:19:00-07:002015-05-27T06:19:00-07:00Eli Benderskytag:eli.thegreenplace.net,2015-05-27:/2015/the-normal-equation-and-matrix-calculus/<p>A few months ago I wrote <a class="reference external" href="http://eli.thegreenplace.net/2014/derivation-of-the-normal-equation-for-linear-regression">a post</a> on formulating the Normal Equation for linear regression. A crucial part in the formulation is using <a class="reference external" href="http://en.wikipedia.org/wiki/Matrix_calculus">matrix calculus</a> to compute a scalar-by-vector derivative. I didn't spend much time explaining how this step works, instead remarking:</p> <blockquote> Deriving by a vector may feel uncomfortable …</blockquote><p>A few months ago I wrote <a class="reference external" href="http://eli.thegreenplace.net/2014/derivation-of-the-normal-equation-for-linear-regression">a post</a> on formulating the Normal Equation for linear regression. A crucial part in the formulation is using <a class="reference external" href="http://en.wikipedia.org/wiki/Matrix_calculus">matrix calculus</a> to compute a scalar-by-vector derivative. I didn't spend much time explaining how this step works, instead remarking:</p> <blockquote> Deriving by a vector may feel uncomfortable, but there's nothing to worry about. Recall that here we only use matrix notation to conveniently represent a system of linear formulae. So we derive by each component of the vector, and then combine the resulting derivatives into a vector again.</blockquote> <p>According to the comments received on the post, folks didn't find this convincing and asked for more details. One commenter even said that &quot;matrix calculus feels handwavy&quot;, something which I fully agree with. The reason matrix calculus feels handwavy is that it's not as commonly encountered as &quot;regular&quot; calculus, and hence its identities and intuitions are not as familiar. However, there's really not that much to it, as I want to show here.</p> <p>Let's get started with a simple example, which I'll use to demonstrate the principles. Say we have the function:</p> <img alt="$f(v)=a^Tv$" class="align-center" src="https://eli.thegreenplace.net/images/math/94f87149715376908db65a00f793836a4b2092a9.png" style="height: 21px;" /> <p>Where <strong>a</strong> and <strong>v</strong> are vectors with <em>n</em> components <a class="footnote-reference" href="#id4" id="id1"></a>. We want to compute its derivative by <strong>v</strong>. But wait, while a &quot;regular&quot; derivative by a scalar is clearly defined (using limits), what does deriving by a vector mean? It simply means that we derive by each component of the vector separately, and then combine the results into a new vector <a class="footnote-reference" href="#id5" id="id2"></a>. In other words:</p> <img alt="$\frac{\partial f}{\partial v}=\begin{pmatrix}\frac{\partial f}{\partial v_1}\\[1em] \frac{\partial f}{\partial v_2}\\ ...\\ \frac{\partial f}{\partial v_n}\\[1em] \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/13d227107c5323f47460ad077504fda60726d933.png" style="height: 131px;" /> <p>Let's see how this works out for our function <em>f</em>. It may be more convenient to rewrite it by using components rather than vector notation:</p> <img alt="$f(v)=a^Tv=a_1v_1+a_2v_2+...+a_nv_n$" class="align-center" src="https://eli.thegreenplace.net/images/math/e9e17e44bb85d825f304b09247a7f3cfbe11f64e.png" style="height: 21px;" /> <p>Computing the derivatives by each component, we'll get:</p> <img alt="$\begin{matrix}\frac{\partial f}{\partial v_1}=a_1\\[1em] \frac{\partial f}{\partial v_2}=a_2\\ ...\\ \frac{\partial f}{\partial v_n}=a_n\\[1em] \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/768563e5b7e2e8cddd00830e9b945419f598e4bb.png" style="height: 114px;" /> <p>So we have a sequence of partial derivatives, which we combine into a vector:</p> <img alt="$\frac{\partial f}{\partial v}=\begin{pmatrix}a_1\\ ...\\ a_n\\ \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/b13cc64568603d73240709c1fb49cfcc7f2a2b62.png" style="height: 65px;" /> <p>Or, in other words <img alt="\frac{\partial f}{\partial v}=a" class="valign-m7" src="https://eli.thegreenplace.net/images/math/1f3eaea99f7fab11ac1b70dc8b618635a9ed4c91.png" style="height: 25px;" />.</p> <p>This example demonstrates the algorithm for computing scalar-by-vector derivatives:</p> <ol class="arabic simple"> <li>Figure out what the dimensions of all vectors and matrices are.</li> <li>Expand the vector equations into their full form (a multiplication of two vectors is either a scalar or a matrix, depending on their orientation, etc.) Note that this will end up with a scalar.</li> <li>Compute the derivative of the scalar by each component of the variable vector separately.</li> <li>Combine the derivatives into a vector.</li> </ol> <p>Similarly to regular calculus, matrix and vector calculus rely on a set of identities to make computations more manageable. We can either go the hard way (computing the derivative of each function from basic principles using limits), or the easy way - applying the plethora of convenient identities that were developed to make this task simpler. The identity for computing the derivative of <img alt="a^Tv" class="valign-0" src="https://eli.thegreenplace.net/images/math/ea7bffcd29c6bad40e358ad7313102670fb1a9cf.png" style="height: 15px;" /> shown above plays the role of <img alt="\frac{d}{dx}ax=a" class="valign-m6" src="https://eli.thegreenplace.net/images/math/999f262480b3690892d0af5651b96160d924997e.png" style="height: 22px;" /> in regular calculus.</p> <p>Now we have the tools to understand how the vector derivatives in the <a class="reference external" href="http://eli.thegreenplace.net/2014/derivation-of-the-normal-equation-for-linear-regression">normal equation article</a> were computed. As a reminder, this is the matrix form of the cost function <em>J</em>:</p> <img alt="$J(\theta)=\theta^TX^TX\theta-2(X\theta)^Ty+y^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/2864b88546c007a79dc92271f5e01487ba608e43.png" style="height: 21px;" /> <p>And we're interested in computing <img alt="\frac{\partial J}{\partial \theta}" class="valign-m7" src="https://eli.thegreenplace.net/images/math/27ffac3eede7fce0b342abf8fc10d29f24c68263.png" style="height: 24px;" />. The equation for <em>J</em> consists of three terms added together. The last one <img alt="y^Ty" class="valign-m4" src="https://eli.thegreenplace.net/images/math/81015d6225923cec985bef47ca151ef1cb654c92.png" style="height: 19px;" /> doesn't contribute to the derivative because it doesn't depend on the variable. Let's start looking at the second (since it's simpler than the first) - and give it a name, for convenience:</p> <img alt="$P(\theta)=2(X\theta)^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/35d3ddf05898e8bc2085030aa399ce98318674f9.png" style="height: 21px;" /> <p>We'll start by recalling what all the dimensions are. <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> is a vector of n components. <img alt="y" class="valign-m4" src="https://eli.thegreenplace.net/images/math/95cb0bfd2977c761298d9624e4b4d4c72a39974a.png" style="height: 12px;" /> is a vector of m components. <img alt="X" class="valign-0" src="https://eli.thegreenplace.net/images/math/c032adc1ff629c9b66f22749ad667e6beadf144b.png" style="height: 12px;" /> is a m-by-n matrix.</p> <p>Let's see what <em>P</em> expands to <a class="footnote-reference" href="#id6" id="id3"></a>:</p> <img alt="$P(\theta)=2\left [ \begin{pmatrix} x_1_1 &amp;amp; x_1_2 &amp;amp; ... &amp;amp; x_1_n\\ x_2_1 &amp;amp; ... &amp;amp; ... &amp;amp; x_2_n\\ ...\\ x_m_1 &amp;amp; ... &amp;amp; ... &amp;amp; x_m_n\\ \end{pmatrix}\begin{pmatrix} \theta_1\\ \theta_2\\ ...\\ \theta_n\\ \end{pmatrix} \right ]^T\begin{pmatrix} y_1\\ y_2\\ ...\\ y_m\\ \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/a7873ed04e274b30852e0f8d9450b5abc200ac17.png" style="height: 91px;" /> <p>Computing the matrix-by-vector multiplication inside the parens:</p> <img alt="$P(x)=2\left [ \begin{pmatrix} x_1_1\theta_1+...+x_1_n\theta_n\\ x_2_1\theta_1+...+x_2_n\theta_n\\ ...\\ x_m_1\theta_1+...+x_m_n\theta_n \end{pmatrix} \right ]^T\begin{pmatrix} y_1\\ y_2\\ ...\\ y_m\\ \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/6b9b8e2335579f352a19ef3da609be2e8b2d9925.png" style="height: 91px;" /> <p>And finally, multiplying the two vectors together:</p> <img alt="$P(x)=2(x_1_1\theta_1+...+x_1_n\theta_n)y_1+2(x_2_1\theta_1+...+x_2_n\theta_n)y_2+...+2(x_m_1\theta_1+...+x_m_n\theta_n)y_m$" class="align-center" src="https://eli.thegreenplace.net/images/math/3271758ac98b149969516dd809fd35b90aacf056.png" style="height: 18px;" /> <p>Working with such formulae makes you appreciate why mathematicians have long ago come up with shorthand notations like &quot;sigma&quot; summation:</p> <img alt="$P(x)=2\sum_{r=1}^{m}y_r(x_r_1\theta_1+...+x_r_n\theta_n)=2\sum_{r=1}^{m}y_r\sum_{c=1}^{n}x_r_c\theta_c$" class="align-center" src="https://eli.thegreenplace.net/images/math/6c71eb575ab3fafbc7b268be33d0d17a37bb1553.png" style="height: 50px;" /> <p>OK, so we've finally completed step 2 of the algorithm - we have the scalar equation for <em>P</em>. Now it's time to compute its derivative by each <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />:</p> <img alt="$\begin{matrix} \frac{\partial P}{\partial \theta_1}=2(x_1_1y_1+...+x_m_1y_m)\\[1em] \frac{\partial P}{\partial \theta_2}=2(x_1_2y_1+...+x_m_2y_m)\\ ...\\ \frac{\partial P}{\partial \theta_n}=2(x_1_ny_1+...+x_m_ny_m) \end{matrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/889eb3c4e50b4fbdf5380c4d4e31ac4c0c09dddd.png" style="height: 111px;" /> <p>Now comes the most interesting part. If we treat <img alt="\frac{\partial P}{\partial \theta}" class="valign-m7" src="https://eli.thegreenplace.net/images/math/3c653fa292156c8914f1463fcb6869633d37487c.png" style="height: 24px;" /> as a vector of n components, we can rewrite this set of equations using a matrix-by-vector multiplication:</p> <img alt="$\frac{\partial P}{\partial \theta}=2X^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/7f75aa0f038ca73c58e95ef604ffb54468a18ae2.png" style="height: 38px;" /> <p>Take a moment to convince yourself this is true. It's just collecting the individual components of <strong>X</strong> into a matrix and the individual components of <strong>y</strong> into a vector. Since <strong>X</strong> is a m-by-n matrix and <strong>y</strong> is a m-by-1 column vector, the dimensions work out and the result is a n-by-1 column vector.</p> <p>So we've just computed the second term of the vector derivative of <em>J</em>. In the process, we've discovered a useful vector derivative identity for a matrix <strong>X</strong> and vectors <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> and <strong>y</strong>:</p> <img alt="$\frac{\partial (X\theta)^T y}{\partial \theta}=X^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/bf7325787bc464f067372a6d4ed612ea514d29b6.png" style="height: 41px;" /> <p>OK, now let's get back to the full definition of <em>J</em> and see how to compute the derivative of its first term. We'll give it the name <em>Q</em>:</p> <img alt="$Q(\theta)=\theta^TX^TX\theta$" class="align-center" src="https://eli.thegreenplace.net/images/math/0031acbab8dba6cef63f2605a15a0b7bc826766a.png" style="height: 21px;" /> <p>This derivation is somewhat more complex, since <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> appears twice in the equation. Here's the equation again with all the matrices and vectors fully laid out (note that I've already done the transposes):</p> <img alt="$Q(\theta)=(\theta_1...\theta_n)\begin{pmatrix}x_1_1 &amp;amp; x_2_1 &amp;amp; ... &amp;amp; x_m_1\\ x_1_2 &amp;amp; ... &amp;amp; ... &amp;amp; x_m_2\\ ...\\ x_1_n &amp;amp; ... &amp;amp; ... &amp;amp; x_m_n\\ \end{pmatrix}\begin{pmatrix}x_1_1 &amp;amp; x_1_2 &amp;amp; ... &amp;amp; x_1_n\\ x_2_1 &amp;amp; ... &amp;amp; ... &amp;amp; x_2_n\\ ...\\ x_m_1 &amp;amp; ... &amp;amp; ... &amp;amp; x_m_n\\ \end{pmatrix}\begin{pmatrix} \theta_1\\ \theta_2\\ ...\\ \theta_n\\ \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/b3f9b4ffe1853d6610f9814fc820d1a71825a06e.png" style="height: 87px;" /> <p>I'll just multiply the two matrices in the middle together. The result is a &quot;<strong>X</strong> squared&quot; matrix, which is n-by-n. The element in row <em>r</em> and column <em>c</em> of this square matrix is:</p> <img alt="$\sum_{i=1}^{m}x_i_rx_i_c$" class="align-center" src="https://eli.thegreenplace.net/images/math/f8628d68855e03195fb4fd01806c8655beaf7b30.png" style="height: 50px;" /> <p>Note that &quot;<strong>X</strong> squared&quot; is a symmetric matrix (this fact will be important later on). For simplicity of notation, we'll call its elements <img alt="X^{2}_{rc}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c565201908a5c75f62849e7c1634b65e0930824c.png" style="height: 19px;" />. Multiplying by the <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> vector on the right we get:</p> <img alt="$Q(\theta)=(\theta_1...\theta_n)\begin{pmatrix}X^{2}_{11}\theta_1+...+X^{2}_{1n}\theta_n\\[1em] X^{2}_{21}\theta_1+...+X^{2}_{2n}\theta_n\\ ...\\ X^{2}_{n1}\theta_1+...+X^{2}_{nn}\theta_n\end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/5821cf256f6cf6debbdac48d6e9bbe698baa0a11.png" style="height: 107px;" /> <p>And left-multiplying by <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> to get the fully unwrapped formula for <em>Q</em>:</p> <img alt="$Q(\theta)=\theta_1(X^{2}_{11}\theta_1+...+X^{2}_{1n}\theta_n)+\theta_2(X^{2}_{21}\theta_1+...+X^{2}_{2n}\theta_n)+...+\theta_n(X^{2}_{n1}\theta_1+...+X^{2}_{nn}\theta_n)$" class="align-center" src="https://eli.thegreenplace.net/images/math/0451f9fa7c61ff3a61be8c1836c15667cd916330.png" style="height: 22px;" /> <p>Once again, it's now time to compute the derivatives. Let's focus on <img alt="\frac{\partial Q}{\partial \theta_1}" class="valign-m9" src="https://eli.thegreenplace.net/images/math/5161830b1f644a3c2d1a650ccd6405e0fe5940aa.png" style="height: 27px;" />, from which we can infer the rest:</p> <img alt="$\frac{\partial Q}{\partial \theta_1}=(2\theta_1X^{2}_{11}+\theta_2X^{2}_{12}+...+\theta_nX^{2}_{1n})+\theta_2X^{2}_{21}+...+\theta_nX^{2}_{n1}$" class="align-center" src="https://eli.thegreenplace.net/images/math/f99e5e7024b4d13b0a767b98653b6ccc22fa1abd.png" style="height: 41px;" /> <p>Using the fact that <strong>X</strong> squared is symmetric, we know that <img alt="X^{2}_{12}=X^{2}_{21}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/c14595d1000ad9a8da5be7f37da801eadfdfb698.png" style="height: 21px;" /> and so on. Therefore:</p> <img alt="$\frac{\partial Q}{\partial \theta_1}=2\theta_1X^{2}_{11}+2\theta_2X^{2}_{12}+...+2\theta_nX^{2}_{1n}$" class="align-center" src="https://eli.thegreenplace.net/images/math/832b294f472a23e500616db08d9d6832770af6a3.png" style="height: 40px;" /> <p>The partial derivatives by other <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> components are similar. Collecting the sequence of partial derivatives back into a vector equation, we get:</p> <img alt="$\frac{\partial Q}{\partial \theta}=2X^2\theta=2X^TX\theta$" class="align-center" src="https://eli.thegreenplace.net/images/math/541124d49fa78dcf92a15b14643b2ebc4187eaaf.png" style="height: 38px;" /> <p>Now back to <em>J</em>. Recall that for convenience we broke <em>J</em> up into three parts: <em>P</em>, <em>Q</em> and <img alt="y^Ty" class="valign-m4" src="https://eli.thegreenplace.net/images/math/81015d6225923cec985bef47ca151ef1cb654c92.png" style="height: 19px;" />; the latter doesn't depend on <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> so it doesn't play a role in the derivative. Collecting our results from this post, we then get:</p> <img alt="$\frac{\partial J}{\partial \theta}=\frac{\partial Q}{\partial \theta}-\frac{\partial P}{\partial \theta}=2X^TX\theta-2X^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/9c3d0d108ada3bfc7290c2328c8e6171bc01d7de.png" style="height: 38px;" /> <p>Which is exactly the equation we were expecting to see.</p> <p>To conclude - if matrix calculus feels handwavy, it's because its identities are less familiar. In a sense, it's handwavy in the same way <img alt="\frac{dx^2}{dx}=2x" class="valign-m6" src="https://eli.thegreenplace.net/images/math/5fa725ae5b10a9249e9480d595770cf34accf533.png" style="height: 24px;" /> is handwavy. We remember the identity so we don't have to recalculate it every time from first principles. Once you get some experience with matrix calculus, parts of equations start looking familiar and you no longer need to engage in the long and tiresome computations demonstrated here. It's perfectly fine to just remember that the derivative of <img alt="\theta^TX\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/7616542d90e084c74423b2a9d93b7a3a6cadcd00.png" style="height: 15px;" /> with a symmetric <strong>X</strong> is <img alt="2X\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/7fa6bcc17eae56f6f3f4a6fdcadae3cb3ee2c5d7.png" style="height: 12px;" />. See the &quot;identities&quot; section of the <a class="reference external" href="http://en.wikipedia.org/wiki/Matrix_calculus">wikipedia article on matrix calculus</a> for many more examples.</p> <hr class="docutils" /> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>A few words on notation: by default, a vector <strong>v</strong> is a <em>column</em> vector. To get its row version, we transpose it. Moreover, in the vector derivative equations that follow I'm using <a class="reference external" href="http://en.wikipedia.org/wiki/Matrix_calculus#Layout_conventions">denominator layout notation</a>. This is not super-important though; as the Wikipedia article suggests, many mathematical papers and writings aren't consistent about this and it's perfectly possible to understand the derivations regardless.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id5" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>Yes, this is exactly like computing a gradient of a multivariate function.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>Take a minute to convince yourself that the dimensions of this equation work out and the result is a scalar.</td></tr> </tbody> </table> Visualizing matrix multiplication as a linear combination2015-03-19T06:03:00-07:002015-03-19T06:03:00-07:00Eli Benderskytag:eli.thegreenplace.net,2015-03-19:/2015/visualizing-matrix-multiplication-as-a-linear-combination/<p>When multiplying two matrices, there's a manual procedure we all know how to go through. Each result cell is computed separately as the dot-product of a row in the first matrix with a column in the second matrix. While it's the easiest way to compute the result manually, it may …</p><p>When multiplying two matrices, there's a manual procedure we all know how to go through. Each result cell is computed separately as the dot-product of a row in the first matrix with a column in the second matrix. While it's the easiest way to compute the result manually, it may obscure a very interesting property of the operation: <em>multiplying A by B is the linear combination of A's columns using coefficients from B</em>. Another way to look at it is that it's a <em>linear combination of the rows of B using coefficients from A</em>.</p> <p>In this quick post I want to show a colorful visualization that will make this easier to grasp.</p> <div class="section" id="right-multiplication-combination-of-columns"> <h2>Right-multiplication: combination of columns</h2> <p>Let's begin by looking at the right-multiplication of matrix <tt class="docutils literal">X</tt> by a column vector:</p> <img alt="$\begin{pmatrix} x_1 &amp;amp; y_1 &amp;amp; z_1 \\ x_2 &amp;amp; y_2 &amp;amp; z_2 \\ x_3 &amp;amp; y_3 &amp;amp; z_3 \\ \end{pmatrix}* \begin{pmatrix} a \\ b \\ c \\ \end{pmatrix}= \begin{pmatrix} ax_1+by_1+cz_1 \\ ax_2+by_2+cz_2 \\ ax_3+by_3+cz_3 \\ \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/ba570f74b122c0a20e4488a052b25fbda160c138.png" style="height: 65px;" /> <p>Representing the columns of <tt class="docutils literal">X</tt> by colorful boxes will help visualize this:</p> <img alt="Matrix by vector" class="align-center" src="https://eli.thegreenplace.net/images/2015/veccomb.png" /> <p>Sticking the white box with <tt class="docutils literal">a</tt> in it to a vector just means: multiply this vector by the scalar <tt class="docutils literal">a</tt>. The result is another column vector - a linear combination of <tt class="docutils literal">X</tt>'s columns, with <tt class="docutils literal">a</tt>, <tt class="docutils literal">b</tt>, <tt class="docutils literal">c</tt> as the coefficients.</p> <p>Right-multiplying <tt class="docutils literal">X</tt> by a matrix is more of the same. Each resulting column is a different linear combination of <tt class="docutils literal">X</tt>'s columns:</p> <img alt="$\begin{pmatrix} {\color{Red} x_1} &amp;amp; y_1 &amp;amp; z_1 \\ x_2 &amp;amp; y_2 &amp;amp; z_2 \\ x_3 &amp;amp; y_3 &amp;amp; z_3 \\ \end{pmatrix}* \begin{pmatrix} a &amp;amp; d &amp;amp; g \\ b &amp;amp; e &amp;amp; h \\ c &amp;amp; f &amp;amp; i \\ \end{pmatrix}= \begin{pmatrix} ax_1+by_1+cz_1 &amp;amp; dx_1+ey_1+fz_1 &amp;amp; gx_1+hy_1+iz_1 \\ ax_2+by_2+cz_2 &amp;amp; dx_2+ey_2+fz_2 &amp;amp; gx_2+hy_2+iz_2 \\ ax_3+by_3+cz_3 &amp;amp; dx_3+ey_3+fz_3 &amp;amp; gx_3+hy_3+iz_3 \\ \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/d6065791babbc5b967c06b57322711424097c83c.png" style="height: 65px;" /> <p>Graphically:</p> <img alt="Matrix by matrix" class="align-center" src="https://eli.thegreenplace.net/images/2015/matcomb.png" /> <p>If you look hard at the equation above and squint a bit, you can recognize this column-combination property by examining each column of the result matrix.</p> </div> <div class="section" id="left-multiplication-combination-of-rows"> <h2>Left-multiplication: combination of rows</h2> <p>Now let's examine left-multiplication. Left-multiplying a matrix <tt class="docutils literal">X</tt> by a row vector is a linear combination of <tt class="docutils literal">X</tt>'s <em>rows</em>:</p> <img alt="$\begin{pmatrix} a &amp;amp; b &amp;amp; c \end{pmatrix}* \begin{pmatrix} x_1 &amp;amp; y_1 &amp;amp; z_1 \\ x_2 &amp;amp; y_2 &amp;amp; z_2 \\ x_3 &amp;amp; y_3 &amp;amp; z_3 \\ \end{pmatrix}= \begin{pmatrix} ax_1+bx_2+cx_3 &amp;amp; ay_1+by_2+cy_3 &amp;amp; az_1+bz_2+cz_3 \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/a36b019264582a035d8df4dc13158854f3477efe.png" style="height: 65px;" /> <p>Is represented graphically thus:</p> <img alt="Vector by matrix" class="align-center" src="https://eli.thegreenplace.net/images/2015/vecrowcomb.png" /> <p>And left-multiplying by a matrix is the same thing repeated for every result row: it becomes the linear combination of the rows of <tt class="docutils literal">X</tt>, with the coefficients taken from the rows of the matrix on the left. Here's the equation form:</p> <img alt="$\begin{pmatrix} a &amp;amp; b &amp;amp; c \\ d &amp;amp; e &amp;amp; f \\ g &amp;amp; h &amp;amp; i \\ \end{pmatrix}* \begin{pmatrix} x_1 &amp;amp; y_1 &amp;amp; z_1 \\ x_2 &amp;amp; y_2 &amp;amp; z_2 \\ x_3 &amp;amp; y_3 &amp;amp; z_3 \\ \end{pmatrix}= \begin{pmatrix} ax_1+bx_2+cx_3 &amp;amp; ay_1+by_2+cy_3 &amp;amp; az_1+bz_2+cz_3 \\ dx_1+ex_2+fx_3 &amp;amp; dy_1+ey_2+fy_3 &amp;amp; dz_1+ez_2+fz_3 \\ gx_1+hx_2+ix_3 &amp;amp; gy_1+hy_2+iy_3 &amp;amp; gz_1+hz_2+iz_3 \\ \end{pmatrix}$" class="align-center" src="https://eli.thegreenplace.net/images/math/35d9e54624bf17576372da3bf144dd4659b225e1.png" style="height: 65px;" /> <p>And the graphical form:</p> <img alt="Matrix by matrix from the left" class="align-center" src="https://eli.thegreenplace.net/images/2015/matrowcomb.png" /> </div> Meshgrids and disambiguating rows and columns from Cartesian coordinates2014-12-28T07:23:00-08:002014-12-28T07:23:00-08:00Eli Benderskytag:eli.thegreenplace.net,2014-12-28:/2014/meshgrids-and-disambiguating-rows-and-columns-from-cartesian-coordinates/<p>When plotting 3D graphs, a common source of confusion in Numpy and Matplotlib (and, by extension, I'd assume in Matlab as well) is how to reconcile between matrices that are indexed with rows and columns, and Cartesian coordinates.</p> <p>Let's use the function <img alt="z = f(x,y) = 4x^2+y^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2b7a30bd2a6116249aaa03a1546780b31973ac68.png" style="height: 19px;" /> as an example. Here's its 3D plot, courtesy …</p><p>When plotting 3D graphs, a common source of confusion in Numpy and Matplotlib (and, by extension, I'd assume in Matlab as well) is how to reconcile between matrices that are indexed with rows and columns, and Cartesian coordinates.</p> <p>Let's use the function <img alt="z = f(x,y) = 4x^2+y^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2b7a30bd2a6116249aaa03a1546780b31973ac68.png" style="height: 19px;" /> as an example. Here's its 3D plot, courtesy <a class="reference external" href="https://www.google.com/search?client=ubuntu&amp;channel=fs&amp;q=plot+2x^2+%2B+y^2&amp;ie=utf-8">of Google</a>:</p> <img alt="3D plot" class="align-center" src="https://eli.thegreenplace.net/images/2014/funcplot-google.png" /> <p>Now let's use Numpy and Matplotlib to make a contour plot of this function.</p> <div class="highlight"><pre><span></span><span class="n">xx</span> <span class="o">=</span> <span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span> <span class="n">yy</span> <span class="o">=</span> <span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span> <span class="n">Z</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">xx</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">yy</span><span class="p">)))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">xx</span><span class="p">)):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">yy</span><span class="p">)):</span> <span class="n">Z</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="mi">4</span><span class="o">*</span><span class="n">xx</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="n">yy</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">**</span><span class="mi">2</span> </pre></div> <p>If the creation of <tt class="docutils literal">Z</tt> in the above looks fishy to you, you're right, and we'll get to it shortly. However, note that this is a vastly simplified demonstration - often <tt class="docutils literal">Z</tt> is created behind the scenes by a more complex computation.</p> <p>Finally, plotting the contour:</p> <div class="highlight"><pre><span></span><span class="n">contour</span><span class="p">(</span><span class="n">xx</span><span class="p">,</span> <span class="n">yy</span><span class="p">,</span> <span class="n">Z</span><span class="p">)</span> <span class="n">xlabel</span><span class="p">(</span><span class="s1">&#39;x&#39;</span><span class="p">);</span> <span class="n">ylabel</span><span class="p">(</span><span class="s1">&#39;y&#39;</span><span class="p">)</span> </pre></div> <p>We get:</p> <img alt="Contour plot" class="align-center" src="https://eli.thegreenplace.net/images/2014/contour-rowcol.png" /> <p>This plot doesn't look right. In the function we're plotting, the contour lines should be stretched in the <tt class="docutils literal">y</tt> direction, not the <tt class="docutils literal">x</tt> direction (this is obvious both from the formula for <tt class="docutils literal">z</tt> and from the 3D plot shown above). What's going on?</p> <p>This is a simple demonstration of a very common problem many people run into when plotting a matrix as a 3D scalar field (a scalar value for each <tt class="docutils literal">x, y</tt> coordinate). While we're used to thinking about <tt class="docutils literal">x</tt> as the &quot;first&quot; coordinate and <tt class="docutils literal">y</tt> as the &quot;second&quot;, in the way Numpy represents matrices this is exactly the opposite. Here's a simple matrix:</p> <div class="highlight"><pre><span></span>array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) </pre></div> <p>Imagine we'd like to plot it. Indexing into a matrix goes by <tt class="docutils literal">[row, col]</tt>, where <tt class="docutils literal">row</tt> counts from top-to-bottom, and <tt class="docutils literal">col</tt> counts from left-to-right. Now, if you just look at the matrix and visually interpose the Cartesian coordinate system on top, top-to-bottom is <tt class="docutils literal">y</tt> and left-to-right is <tt class="docutils literal">x</tt>. In other words, the indexing order is reversed.</p> <p>Here's a visualization that should make it clear:</p> <img alt="XY vs. row, col" class="align-center" src="https://eli.thegreenplace.net/images/2014/xy-rowcol.png" /> <p>There's a very simple solution to this problem - use a transpose. Plotting:</p> <div class="highlight"><pre><span></span><span class="n">contour</span><span class="p">(</span><span class="n">xx</span><span class="p">,</span> <span class="n">yy</span><span class="p">,</span> <span class="n">Z</span><span class="o">.</span><span class="n">T</span><span class="p">)</span> <span class="n">xlabel</span><span class="p">(</span><span class="s1">&#39;x&#39;</span><span class="p">);</span> <span class="n">ylabel</span><span class="p">(</span><span class="s1">&#39;y&#39;</span><span class="p">)</span> </pre></div> <p>Gives us the expected:</p> <img alt="Contour plot" class="align-center" src="https://eli.thegreenplace.net/images/2014/contour-xy.png" /> <p>A matrix transpose exchanges between rows and columns. It makes the original matrix's rows count from left-to-right and columns from top-to-bottom, matching the Cartesian coordinate system.</p> <p>Is a transpose always required, then? Not at all. As I mentioned above, the computation of <tt class="docutils literal">Z</tt> wasn't entirely correct, because matrix indices were conflated with Cartesian coordinates. In the double loop shown above it would be more correct to assign <tt class="docutils literal">Z[j, i]</tt>, and in general it's usually recommended to be explicit about <tt class="docutils literal">row, col</tt> or <tt class="docutils literal">x, y</tt> - as the <tt class="docutils literal">i, j</tt> pair is somewhat ambiguous. That said, we don't always easily control the creation of <tt class="docutils literal">Z</tt>, so the transpose is occasionally useful when the data we got is in the wrong order.</p> <div class="section" id="meshgrids"> <h2>Meshgrids</h2> <p>IMHO, by trying to be helpful, the <tt class="docutils literal">contour</tt> API helps spread the confusion. It does so by not enforcing <tt class="docutils literal">x</tt> and <tt class="docutils literal">y</tt> to be 2D data arrays, like all the 3D plotting routines do. It's better to be explicit and require a meshgrid.</p> <p>So what is a meshgrid? <tt class="docutils literal">meshgrid</tt> is a Numpy function that turns vectors such as <tt class="docutils literal">xx</tt> and <tt class="docutils literal">yy</tt> above into coordinate matrices. The idea is simple: when doing multi-dimensional operations (like 3D plotting), it's better to be very explicit about what maps to what. What we really want is three matrices: <tt class="docutils literal">X</tt>, <tt class="docutils literal">Y</tt> and <tt class="docutils literal">Z</tt>, where <tt class="docutils literal">Z[i, j]</tt> is the value of the function in question for <tt class="docutils literal">X[i, j]</tt> and <tt class="docutils literal">Y[i, j]</tt>. But more often than not, we don't have <tt class="docutils literal">X</tt> and <tt class="docutils literal">Y</tt> in this form. Rather, we just have vectors with their values running along the axes. This is what <tt class="docutils literal">meshgrid</tt> is for. Here's a simple demonstration (taken from an IPython terminal):</p> <div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">218</span><span class="p">]:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">In</span> <span class="p">[</span><span class="mi">219</span><span class="p">]:</span> <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">])</span> <span class="n">In</span> <span class="p">[</span><span class="mi">220</span><span class="p">]:</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">meshgrid</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="n">In</span> <span class="p">[</span><span class="mi">221</span><span class="p">]:</span> <span class="n">X</span> <span class="n">Out</span><span class="p">[</span><span class="mi">221</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]])</span> <span class="n">In</span> <span class="p">[</span><span class="mi">222</span><span class="p">]:</span> <span class="n">Y</span> <span class="n">Out</span><span class="p">[</span><span class="mi">222</span><span class="p">]:</span> <span class="n">array</span><span class="p">([[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">],</span> <span class="p">[</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">]])</span> </pre></div> <p>The <tt class="docutils literal">X</tt> and <tt class="docutils literal">Y</tt> matrices may appear strange at first sight, but looking more closely reveals that they're exactly the coordinate matrices we need; in tandem, they run over all the 9 pairs needed to map from the original <tt class="docutils literal">x</tt> and <tt class="docutils literal">y</tt> vectors. The values in <tt class="docutils literal">X</tt> increase from left to right; the values in <tt class="docutils literal">Y</tt> increase from top to bottom - the way it should be.</p> <p>And the best part about <tt class="docutils literal">meshgrid</tt> is that it enables vectorized computations, just the way we like them in Numpy. So the function we originally created can now be computed and plotted correctly without any loops:</p> <div class="highlight"><pre><span></span><span class="n">xx</span> <span class="o">=</span> <span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span> <span class="n">yy</span> <span class="o">=</span> <span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">meshgrid</span><span class="p">(</span><span class="n">xx</span><span class="p">,</span> <span class="n">yy</span><span class="p">)</span> <span class="n">Z</span> <span class="o">=</span> <span class="mi">4</span><span class="o">*</span><span class="n">X</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="n">Y</span><span class="o">**</span><span class="mi">2</span> <span class="n">contour</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">)</span> </pre></div> <p>Produces the correct plot.</p> <p>Finally, what about when we do get <tt class="docutils literal">Z</tt> from somewhere else and it was computed using matrix indexing, rather than Cartesian indexing. Plotting its transpose is one alternative, but there's a better one. We can create a meshgrid, using its <tt class="docutils literal">indexing</tt> keyword argument, like this:</p> <div class="highlight"><pre><span></span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">meshgrid</span><span class="p">(</span><span class="n">xx</span><span class="p">,</span> <span class="n">yy</span><span class="p">,</span> <span class="n">indexing</span><span class="o">=</span><span class="s1">&#39;ij&#39;</span><span class="p">)</span> </pre></div> <p>This tells <tt class="docutils literal">meshgrid</tt> that we're going to plot a function computed using <tt class="docutils literal">row, col</tt>, rather than <tt class="docutils literal">x, y</tt> order, and it will flip the rows and columns accordingly.</p> </div> Derivation of the Normal Equation for linear regression2014-12-22T20:50:00-08:002014-12-22T20:50:00-08:00Eli Benderskytag:eli.thegreenplace.net,2014-12-22:/2014/derivation-of-the-normal-equation-for-linear-regression/<p>I was going through the Coursera &quot;Machine Learning&quot; course, and in the section on multivariate linear regression something caught my eye. Andrew Ng presented the <a class="reference external" href="http://en.wikipedia.org/w/index.php?title=Normal_equation&amp;redirect=no">Normal Equation</a> as an analytical solution to the linear regression problem with a least-squares cost function. He mentioned that in some cases (such as for …</p><p>I was going through the Coursera &quot;Machine Learning&quot; course, and in the section on multivariate linear regression something caught my eye. Andrew Ng presented the <a class="reference external" href="http://en.wikipedia.org/w/index.php?title=Normal_equation&amp;redirect=no">Normal Equation</a> as an analytical solution to the linear regression problem with a least-squares cost function. He mentioned that in some cases (such as for small feature sets) using it is more effective than applying gradient descent; unfortunately, he left its derivation out.</p> <p>Here I want to show how the normal equation is derived.</p> <p>First, some terminology. The following symbols are compatible with the machine learning course, not with the exposition of the normal equation on Wikipedia and other sites - semantically it's all the same, just the symbols are different.</p> <p>Given the hypothesis function:</p> <img alt="$h_{\theta}(x)=\theta_0x_0+\theta_1x_1+\cdots+\theta_nx_n$" class="align-center" src="https://eli.thegreenplace.net/images/math/dd8fad9bf111e83d47252d51dd037a6c6c3136aa.png" style="height: 18px;" /> <p>We'd like to minimize the least-squares cost:</p> <img alt="$J(\theta_{0...n})=\frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)})-y^{(i)})^2$" class="align-center" src="https://eli.thegreenplace.net/images/math/c1abe0768f4deb31ed97f37d760236c94439a780.png" style="height: 50px;" /> <p>Where <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" /> is the <tt class="docutils literal">i</tt>-th sample (from a set of <tt class="docutils literal">m</tt> samples) and <img alt="y^{(i)}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d34414117d493106f731939df6bb7f1762365d3f.png" style="height: 21px;" /> is the <tt class="docutils literal">i</tt>-th expected result.</p> <p>To proceed, we'll represent the problem in matrix notation; this is natural, since we essentially have a system of linear equations here. The regression coefficients <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> we're looking for are the vector:</p> <img alt="$\begin{pmatrix} \theta_0\\ \theta_1\\ ...\\ \theta_n \end{pmatrix}\in\mathbb{R}^{n+1}$" class="align-center" src="https://eli.thegreenplace.net/images/math/b16fd3d2b3041f13cb70199837a7c02c756078c7.png" style="height: 86px;" /> <p>Each of the <tt class="docutils literal">m</tt> input samples is similarly a column vector with <tt class="docutils literal">n+1</tt> rows, <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> being 1 for convenience. So we can now rewrite the hypothesis function as:</p> <img alt="$h_{\theta}(x)=\theta^Tx$" class="align-center" src="https://eli.thegreenplace.net/images/math/be661047c89f6a48c7bc563b81207949c251de6a.png" style="height: 21px;" /> <p>When this is summed over all samples, we can dip further into matrix notation. We'll define the &quot;design matrix&quot; <tt class="docutils literal">X</tt> (uppercase X) as a matrix of <tt class="docutils literal">m</tt> rows, in which each row is the <tt class="docutils literal">i</tt>-th sample (the vector <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" />). With this, we can rewrite the least-squares cost as following, replacing the explicit sum by matrix multiplication:</p> <img alt="$J(\theta)=\frac{1}{2m}(X\theta-y)^T(X\theta-y)$" class="align-center" src="https://eli.thegreenplace.net/images/math/db5e3da78e25c18c8fc88f1291c1ac13a2645388.png" style="height: 36px;" /> <p>Now, using some matrix transpose identities, we can simplify this a bit. I'll throw the <img alt="\frac{1}{2m}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/7a2a3f6dba54b64f0e88e18c40e0f68c523713ea.png" style="height: 22px;" /> part away since we're going to compare a derivative to zero anyway:</p> <img alt="$J(\theta)=((X\theta)^T-y^T)(X\theta-y)$" class="align-center" src="https://eli.thegreenplace.net/images/math/c1368de1a0634c3fbeb92d67f368f253943d089f.png" style="height: 21px;" /> <img alt="$J(\theta)=(X\theta)^TX\theta-(X\theta)^Ty-y^T(X\theta)+y^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/e41fc822adccf1f865b02100f5671e265e7b30bc.png" style="height: 21px;" /> <p>Note that <img alt="X\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/52f2de6065bdc187b876c5696041f3c716c446f5.png" style="height: 12px;" /> is a vector, and so is <tt class="docutils literal">y</tt>. So when we multiply one by another, it doesn't matter what the order is (as long as the dimensions work out). So we can further simplify:</p> <img alt="$J(\theta)=\theta^TX^TX\theta-2(X\theta)^Ty+y^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/2864b88546c007a79dc92271f5e01487ba608e43.png" style="height: 21px;" /> <p>Recall that here <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> is our unknown. To find where the above function has a minimum, we will derive by <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> and compare to 0. Deriving by a vector may feel uncomfortable, but there's nothing to worry about. Recall that here we only use matrix notation to conveniently represent a system of linear formulae. So we derive by each component of the vector, and then combine the resulting derivatives into a vector again. The result is:</p> <img alt="$\frac{\partial J}{\partial \theta}=2X^TX\theta-2X^{T}y=0$" class="align-center" src="https://eli.thegreenplace.net/images/math/9b142c00e031c9db7f575b0542e86261732a4689.png" style="height: 38px;" /> <p>Or:</p> <img alt="$X^TX\theta=X^{T}y$" class="align-center" src="https://eli.thegreenplace.net/images/math/ab453f9f1f7bd4b1d646b9712fbe0b2fbe01740f.png" style="height: 21px;" /> <p>Now, assuming that the matrix <img alt="X^TX" class="valign-0" src="https://eli.thegreenplace.net/images/math/5c817c84ec1f83b23494df6125edd091a7c413dd.png" style="height: 15px;" /> is invertible, we can multiply both sides by <img alt="(X^TX)^{-1}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/57f592cee6ceac659262d97e61c64f9ca405d7f1.png" style="height: 19px;" /> and get:</p> <img alt="$\theta=(X^TX)^{-1}X^Ty$" class="align-center" src="https://eli.thegreenplace.net/images/math/20baabd9d33dcd26003bc44c7d81ba39e1ad4caa.png" style="height: 21px;" /> <p>Which is the normal equation.</p> <p>[<strong>Update 27-May-2015</strong>: I've written <a class="reference external" href="http://eli.thegreenplace.net/2015/the-normal-equation-and-matrix-calculus/">another post</a> that explains in more detail how these derivatives are computed.]</p> Horner's rule: efficient evaluation of polynomials2010-03-30T15:10:32-07:002010-03-30T15:10:32-07:00Eli Benderskytag:eli.thegreenplace.net,2010-03-30:/2010/03/30/horners-rule-efficient-evaluation-of-polynomials <p>Here's a general degree-n polynomial:</p> <p><img src="https://eli.thegreenplace.net/images/math/79d0e193d7bd5ba889f5992beece98ca4ce715f8.gif" /></p> <p>To evaluate such a polynomial using a computer program, several approaches can be employed.</p> <p>The simplest, naive method is to compute each term of the polynomial separately and then add them up. Here's the Python code for it:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">poly_naive</span>(A, x): p = <span style="color: #007f7f">0</span> <span style="color: #00007f; font-weight: bold">for …</span></pre></div> <p>Here's a general degree-n polynomial:</p> <p><img src="https://eli.thegreenplace.net/images/math/79d0e193d7bd5ba889f5992beece98ca4ce715f8.gif" /></p> <p>To evaluate such a polynomial using a computer program, several approaches can be employed.</p> <p>The simplest, naive method is to compute each term of the polynomial separately and then add them up. Here's the Python code for it:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">poly_naive</span>(A, x): p = <span style="color: #007f7f">0</span> <span style="color: #00007f; font-weight: bold">for</span> i, a <span style="color: #0000aa">in</span> <span style="color: #00007f">enumerate</span>(A): p += (x ** i) * a <span style="color: #00007f; font-weight: bold">return</span> p </pre></div> <p><tt class="docutils literal"><span class="pre">A</span></tt> is an array of coefficients, lowest first, <img src="https://eli.thegreenplace.net/images/math/4a5997da73aadd118038761e69d01e24586bf958.gif" /> until <img src="https://eli.thegreenplace.net/images/math/278ab95d3a54aae8eaa25c34af66d93a19b5e75f.gif" />.</p> <p>This method is quite inefficient. It requires <tt class="docutils literal"><span class="pre">n</span></tt> additions (since there are <tt class="docutils literal"><span class="pre">n+1</span></tt> terms to be added) and <img src="https://eli.thegreenplace.net/images/math/73b6f7da8c4582390c7323a29770ab2e8cb7fb64.gif" /> multiplications.</p> <div class="section" id="iterative-method"> <h3>Iterative method</h3> <p>It's obvious that there's a lot of repetitive computations being done by raising <tt class="docutils literal"><span class="pre">x</span></tt> to successive powers. We can make things much more efficient by simply keeping the previous power of <tt class="docutils literal"><span class="pre">x</span></tt> between iterations. This is the &quot;iterative method&quot;:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">poly_iter</span>(A, x): p = <span style="color: #007f7f">0</span> xn = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">for</span> a <span style="color: #0000aa">in</span> A: p += xn * a xn *= x <span style="color: #00007f; font-weight: bold">return</span> p </pre></div> <p>In this code <tt class="docutils literal"><span class="pre">xn</span></tt> is the current power of <tt class="docutils literal"><span class="pre">x</span></tt>. We don't need to raise <tt class="docutils literal"><span class="pre">x</span></tt> to a power on each iteration of the loop, a single multiplication suffices. It's easy to see that there are <tt class="docutils literal"><span class="pre">2n</span></tt> multiplications and <tt class="docutils literal"><span class="pre">n</span></tt> additions for each computation. The algorithm is now linear instead of quadratic.</p> </div> <div class="section" id="horner-s-rule"> <h3>Horner's rule</h3> <p>It can be further improved, however. Take a look at this polynomial:</p> <p><img src="https://eli.thegreenplace.net/images/math/03e98fbb410ca88f96c6124bd2fa98a88ed56d25.gif" /></p> <p>It can be rewritten as follows:</p> <p><img src="https://eli.thegreenplace.net/images/math/9a469b0cc8b4304d230b677c9f5c26129d1b73fe.gif" /></p> <p>And in general, we can always rewrite the polynomial:</p> <p><img src="https://eli.thegreenplace.net/images/math/3773c72b0bca68c7d911452088f2b9f459802b78.gif" /></p> <p>As:</p> <p><img src="https://eli.thegreenplace.net/images/math/2c6d90599184f76993d6474a226b8c03e8e7c475.gif" /></p> <p>This rearrangement is usually called &quot;Horner's rule&quot;. We can write the code to implement it as follows:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">poly_horner</span>(A, x): p = A[-<span style="color: #007f7f">1</span>] i = <span style="color: #00007f">len</span>(A) - <span style="color: #007f7f">2</span> <span style="color: #00007f; font-weight: bold">while</span> i &gt;= <span style="color: #007f7f">0</span>: p = p * x + A[i] i -= <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">return</span> p </pre></div> <p>Here we start by assigning <img src="https://eli.thegreenplace.net/images/math/278ab95d3a54aae8eaa25c34af66d93a19b5e75f.gif" /> to <cite>p</cite> and then successively multiplying by <cite>x</cite> and adding the next coefficient. This code requires <cite>n</cite> multiplications and <cite>n</cite> additions (I'm ignoring here the modification of the loop variable <tt class="docutils literal"><span class="pre">i</span></tt>, as I ignored it in all other algorithms, where it was implicit in the Python <tt class="docutils literal"><span class="pre">for</span></tt> loop).</p> <p>While asymptotically similar to the iterative method, Horner's method has better constants and thus is faster.</p> <p>Curiously, Horner's rule was discovered in the early 19th century, far before the advent of computers. It's obviously useful for manual computation of polynomials as well, for the same reason: it requires less operations.</p> <p>I've timed the 3 algorithms on a random polynomial of degree 500. The one using Horner's rule is about 5 times faster than the naive approach, and 15% faster than the iterative method.</p> </div> A group-theoretic proof of Euler's theorem2009-08-01T08:00:43-07:002009-08-01T08:00:43-07:00Eli Benderskytag:eli.thegreenplace.net,2009-08-01:/2009/08/01/a-group-theoretic-proof-of-eulers-theorem <p>A very important and useful theorem in number theory is named after Leonhard Euler:</p> <p><img src="https://eli.thegreenplace.net/images/math/b9dd84aaa5b3a778d39ea7b95f32fdeed4510389.gif" /></p> <p>Where <img src="https://eli.thegreenplace.net/images/math/20bdd8a8b971fd8582ce58915d9f42ff001daef3.gif" /> is <a class="reference external" href="http://en.wikipedia.org/wiki/Euler%27s_totient_function">Euler's totient</a> function - the count of numbers smaller than <tt class="docutils literal"><span class="pre">n</span></tt> that are coprime to it.</p> <p>Here I want to present a nice proof of this theorem, based on group theory. I begin with some …</p> <p>A very important and useful theorem in number theory is named after Leonhard Euler:</p> <p><img src="https://eli.thegreenplace.net/images/math/b9dd84aaa5b3a778d39ea7b95f32fdeed4510389.gif" /></p> <p>Where <img src="https://eli.thegreenplace.net/images/math/20bdd8a8b971fd8582ce58915d9f42ff001daef3.gif" /> is <a class="reference external" href="http://en.wikipedia.org/wiki/Euler%27s_totient_function">Euler's totient</a> function - the count of numbers smaller than <tt class="docutils literal"><span class="pre">n</span></tt> that are coprime to it.</p> <p>Here I want to present a nice proof of this theorem, based on group theory. I begin with some preliminary definitions and gradually move towards the final goal.</p> <p><strong>(I) Congruence class</strong>: Let <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">n</span> <span class="pre">&gt;</span> <span class="pre">0</span></tt> be integers. The set of all integers that have the same remainder as <tt class="docutils literal"><span class="pre">a</span></tt> when divided by <tt class="docutils literal"><span class="pre">n</span></tt> is called the congruence class of <tt class="pre">a</tt> modulo <tt class="docutils literal"><span class="pre">n</span></tt> and is denoted by <img src="https://eli.thegreenplace.net/images/math/8755e0f4dc820d67ad665a79d34fd0bbb6fd9b1b.gif" />, where:</p> <p><img src="https://eli.thegreenplace.net/images/math/41e7da26fc47068db4c05b2400799401d07f648e.gif" /></p> <p><img src="https://eli.thegreenplace.net/images/math/e8969fcfdd076fdce4b4ca0244b6a6b05964a817.gif" /> is the set of all congruence classes modulo <tt class="docutils literal"><span class="pre">n</span></tt>.</p> <p><strong>(II) Units of</strong> <img src="https://eli.thegreenplace.net/images/math/e8969fcfdd076fdce4b4ca0244b6a6b05964a817.gif" />: If for <img src="https://eli.thegreenplace.net/images/math/8755e0f4dc820d67ad665a79d34fd0bbb6fd9b1b.gif" /> we find some <img src="https://eli.thegreenplace.net/images/math/2e7ad37347086c8f38ec146b7ab4eb6ccf519672.gif" /> such that <img src="https://eli.thegreenplace.net/images/math/44b1c7f2b4cbe6e3cb87d55030fc40b1e517c4c0.gif" />, we call <img src="https://eli.thegreenplace.net/images/math/8755e0f4dc820d67ad665a79d34fd0bbb6fd9b1b.gif" /> a unit of <img src="https://eli.thegreenplace.net/images/math/e8969fcfdd076fdce4b4ca0244b6a6b05964a817.gif" />. The set of units of <img src="https://eli.thegreenplace.net/images/math/e8969fcfdd076fdce4b4ca0244b6a6b05964a817.gif" /> is denoted by <img src="https://eli.thegreenplace.net/images/math/1699294c58aa602fb840c2215844cec3979a67ee.gif" /></p> <p>For example, <img src="https://eli.thegreenplace.net/images/math/8faf8e13bb0577f8b98f532aabd1a553d7fc66b9.gif" /> is a unit of <img src="https://eli.thegreenplace.net/images/math/d24fb3c22159cb774f33d4fbebe3b875f73e333d.gif" />, because <img src="https://eli.thegreenplace.net/images/math/0fc30f90ee4511066fa696c51e09ec1dd90ecab3.gif" />.</p> <p><strong>(III)</strong> The congruence class <img src="https://eli.thegreenplace.net/images/math/8755e0f4dc820d67ad665a79d34fd0bbb6fd9b1b.gif" /> is a unit of <img src="https://eli.thegreenplace.net/images/math/e8969fcfdd076fdce4b4ca0244b6a6b05964a817.gif" /> if and only if <img src="https://eli.thegreenplace.net/images/math/b88879c72c5d589d94fcecc37ebae71004b36eb5.gif" /> (the GCD of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">n</span></tt> is 1, in other words they're co-prime).</p> <p>Proof: By definition of units, there exists some <img src="https://eli.thegreenplace.net/images/math/2e7ad37347086c8f38ec146b7ab4eb6ccf519672.gif" /> such that <img src="https://eli.thegreenplace.net/images/math/44b1c7f2b4cbe6e3cb87d55030fc40b1e517c4c0.gif" />. Therefore <img src="https://eli.thegreenplace.net/images/math/4de2086bb823b49a9ce0d5565e0d94d277fe21a4.gif" />, which implies that for some <tt class="docutils literal"><span class="pre">q</span></tt>, <img src="https://eli.thegreenplace.net/images/math/2fc6ecaa6433925d2e15d5421a752bcd4eabc3a1.gif" />. Thus <img src="https://eli.thegreenplace.net/images/math/fee61eb9d36e06b6d66cc7d225df42bb9872a2d7.gif" />. So 1 is a linear combination of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">n</span></tt>. <a class="reference external" href="http://eli.thegreenplace.net/2009/07/10/the-gcd-and-linear-combinations/">Therefore</a> <img src="https://eli.thegreenplace.net/images/math/601a1603beda34af308ef779b2550ce0d9145854.gif" />. On the other hand, if <img src="https://eli.thegreenplace.net/images/math/601a1603beda34af308ef779b2550ce0d9145854.gif" />, there exist <tt class="docutils literal"><span class="pre">q</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> such that <img src="https://eli.thegreenplace.net/images/math/2a3ae18a531b52442546e23bbf44377e5b584615.gif" />, or <img src="https://eli.thegreenplace.net/images/math/4de2086bb823b49a9ce0d5565e0d94d277fe21a4.gif" />, so <img src="https://eli.thegreenplace.net/images/math/44b1c7f2b4cbe6e3cb87d55030fc40b1e517c4c0.gif" />.</p> <p><strong>(IV)</strong> By definition, since every unit of <img src="https://eli.thegreenplace.net/images/math/e8969fcfdd076fdce4b4ca0244b6a6b05964a817.gif" /> is coprime to <tt class="docutils literal"><span class="pre">n</span></tt>, the number of units of <img src="https://eli.thegreenplace.net/images/math/e8969fcfdd076fdce4b4ca0244b6a6b05964a817.gif" /> (or, the number of elements of <img src="https://eli.thegreenplace.net/images/math/1699294c58aa602fb840c2215844cec3979a67ee.gif" />) is <img src="https://eli.thegreenplace.net/images/math/20bdd8a8b971fd8582ce58915d9f42ff001daef3.gif" />.</p> <p>Let's keep this result in mind and prepare some more theorems in order to attack the proof.</p> <p><strong>(V) Lagrange's theorem:</strong> If <em>H</em> is a subgroup of the finite group <em>G</em>, then the order of <em>H</em> is a divisor of the order of <em>G</em>.</p> <p>Proof: Let's first define <img src="https://eli.thegreenplace.net/images/math/eb2ac2e9ece06fb0136369f280a9b9e0fe90d0c7.gif" /> and <img src="https://eli.thegreenplace.net/images/math/74dc1b476dac9aa1fa28a3d82fb39370ae8aa2a9.gif" />. Also, let <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> be the equivalence relation defined in example (III) of the <a class="reference external" href="http://eli.thegreenplace.net/2009/07/17/equivalence-classes-and-group-partitions/">previous post</a>. Since it's an equivalence relation, it partitions <em>G</em> into equivalence classes. Define <img src="https://eli.thegreenplace.net/images/math/424ca85d95fe170ac1502d98d8a623adb807fa41.gif" /> as the equivalence class of <em>a</em> with <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" />, for any <img src="https://eli.thegreenplace.net/images/math/4cbe37e25ff6e34b50a2ef01190bc26af1cc355e.gif" />.</p> <p>To prove Lagrange's theorem, we're going to show that <img src="https://eli.thegreenplace.net/images/math/424ca85d95fe170ac1502d98d8a623adb807fa41.gif" /> has the same number of elements as <em>H</em>. For this purpose, let's define a function <img src="https://eli.thegreenplace.net/images/math/44ecebaf61b2207527b728e788c781d43c21e248.gif" /> by <img src="https://eli.thegreenplace.net/images/math/06f368612a045f555165ad1e02442d448d61cac2.gif" /> for all <img src="https://eli.thegreenplace.net/images/math/d2282b7258ea3a7b88850baba99bf31584143987.gif" /> and prove that it's an isomorphism. To do that, we'll have to separately prove that it's onto and one-to-one.</p> <p>But first, let's verify that the stated codomain of <img src="https://eli.thegreenplace.net/images/math/c13d3e630d6430dc77134d5df88542a73dfb1853.gif" /> is correct. If <img src="https://eli.thegreenplace.net/images/math/47710305b38e478c0091b68a632a5a7f1f9574a7.gif" /> then <img src="https://eli.thegreenplace.net/images/math/7e53ba9abb1ae5368c1676e44a3ce6a420c9d702.gif" /> because <img src="https://eli.thegreenplace.net/images/math/16938cce957255f90140fda84a82086811279f0f.gif" />, so by definition of <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> we have <img src="https://eli.thegreenplace.net/images/math/a3a37b36fabdad84d64351523f1d8a025ceb2b6a.gif" />. So indeed the codomain of <img src="https://eli.thegreenplace.net/images/math/c13d3e630d6430dc77134d5df88542a73dfb1853.gif" /> is <img src="https://eli.thegreenplace.net/images/math/424ca85d95fe170ac1502d98d8a623adb807fa41.gif" />.</p> <ol class="arabic simple"> <li>Let's pick some <em>y</em> in <em>G</em> such that <img src="https://eli.thegreenplace.net/images/math/28ccea8fbeb58e8b2aafbe57b2bd1db5e2231ea7.gif" />. By definition of our <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> it means that <img src="https://eli.thegreenplace.net/images/math/4aca76345e9c9a2476f2f490153e95bbd621d732.gif" /> for some <img src="https://eli.thegreenplace.net/images/math/47710305b38e478c0091b68a632a5a7f1f9574a7.gif" />. So <img src="https://eli.thegreenplace.net/images/math/b5a1c5b30321fa43940c63b394adc1f161f4d089.gif" /> has a solution <img src="https://eli.thegreenplace.net/images/math/2e7b84283f31f4ccb9f6b11c1007093203400eba.gif" /> (since <img src="https://eli.thegreenplace.net/images/math/66de997d8baedff86251b7c7fbaf81103eb9a8db.gif" />). Therefore <img src="https://eli.thegreenplace.net/images/math/c13d3e630d6430dc77134d5df88542a73dfb1853.gif" /> is onto.</li> <li>Suppose that <img src="https://eli.thegreenplace.net/images/math/84e5a51f34c8e39476b4e62db18d6a88b3f513c7.gif" /> with <img src="https://eli.thegreenplace.net/images/math/4b70a286b390809dc5095ab68b766b056123f843.gif" />. Then <img src="https://eli.thegreenplace.net/images/math/85212195c93331433ff9c4dafb4c066bd4eac844.gif" /> and by cancellation in groups we have <img src="https://eli.thegreenplace.net/images/math/2b8466a1849f730f97e3257cf26339443bf5af38.gif" />, which proves that <img src="https://eli.thegreenplace.net/images/math/c13d3e630d6430dc77134d5df88542a73dfb1853.gif" /> is one-to-one.</li> </ol> <p>So we've proved that <img src="https://eli.thegreenplace.net/images/math/c13d3e630d6430dc77134d5df88542a73dfb1853.gif" /> is an isomorphism, which means that <img src="https://eli.thegreenplace.net/images/math/85f3f39d0e9a695c2599b186606532ff4e1831c0.gif" /> (we can map each element of <img src="https://eli.thegreenplace.net/images/math/424ca85d95fe170ac1502d98d8a623adb807fa41.gif" /> to one and only one element of <img src="https://eli.thegreenplace.net/images/math/96ceb9b4d8ba9dc94b2358619f4de892b0cb392e.gif" />).</p> <p>We've <a class="reference external" href="http://eli.thegreenplace.net/2009/07/17/equivalence-classes-and-group-partitions/">previously shown</a> that the equivalence classes of <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> partition <em>G</em>. But now we see that the size of each equivalence class is equal to <img src="https://eli.thegreenplace.net/images/math/96ceb9b4d8ba9dc94b2358619f4de892b0cb392e.gif" />. Therefore, all the equivalence classes are of the same size, and <img src="https://eli.thegreenplace.net/images/math/b8be964828998afcc8345c4168ff9e9bee619879.gif" /> where <em>t</em> is the number of equivalence classes. This proves Lagrange's theorem.</p> <p>We're almost there. To see how all of this relates to Euler's theorem, let's first define the order of an element of a group.</p> <p><strong>(VI) Order of group element:</strong> Let <img src="https://eli.thegreenplace.net/images/math/4cbe37e25ff6e34b50a2ef01190bc26af1cc355e.gif" />. If there exists a positive integer <em>n</em> such that <img src="https://eli.thegreenplace.net/images/math/2d87ae91e6616a94fed293028372e33750f1cfc7.gif" />, then a is said to have <strong>finite order</strong> and the smallest such positive integer is called the <strong>order</strong> of <em>a</em>, denoted by <img src="https://eli.thegreenplace.net/images/math/da44ab0ab8337608c62ebeecc3ea57ebf47707a3.gif" />.</p> <p>We'll also define the subgroup <strong>generated</strong> by an element:</p> <p><strong>(VII) Cyclic subgroup:</strong> <img src="https://eli.thegreenplace.net/images/math/4c75b2ed8003dc064a1436983872dee8d997813b.gif" /> is a cyclic subgroup of <em>G</em> generated by <img src="https://eli.thegreenplace.net/images/math/4cbe37e25ff6e34b50a2ef01190bc26af1cc355e.gif" />. For a finite <em>G</em> this subgroup is also finite, and its size is: <img src="https://eli.thegreenplace.net/images/math/f107fd88a660909965da200640fae4efa5e1a8ba.gif" />.</p> <p>Armed with these definitions, we're now ready for the following corollary of Lagrange's theorem:</p> <p><strong>(VIII) Lagrange theorem corollary:</strong> Let <em>G</em> be a finite group of order <em>n</em>. Then:</p> <ol class="arabic simple"> <li>For any <img src="https://eli.thegreenplace.net/images/math/4cbe37e25ff6e34b50a2ef01190bc26af1cc355e.gif" />, <img src="https://eli.thegreenplace.net/images/math/da44ab0ab8337608c62ebeecc3ea57ebf47707a3.gif" /> divides <em>n</em></li> <li>For any <img src="https://eli.thegreenplace.net/images/math/4cbe37e25ff6e34b50a2ef01190bc26af1cc355e.gif" />, <img src="https://eli.thegreenplace.net/images/math/2d87ae91e6616a94fed293028372e33750f1cfc7.gif" /></li> </ol> <p>Proof: As we've seen, <img src="https://eli.thegreenplace.net/images/math/f107fd88a660909965da200640fae4efa5e1a8ba.gif" /> and by Lagrange's theorem <img src="https://eli.thegreenplace.net/images/math/f1d20c2a90429ea2c3ad65b999596d80f1dbbd8a.gif" /> divides <em>n</em> (since <img src="https://eli.thegreenplace.net/images/math/003dd0b9f9b592b676d030e7da22e89a06339deb.gif" /> is a subgroup of <em>G</em>). Therefore (1) is proven. For (2), note that if <em>a</em> has order <em>m</em>, then by (1) we have <img src="https://eli.thegreenplace.net/images/math/cd5cdf72717cbde659275a92956c6c74904afc18.gif" /> for some integer <em>q</em>. Thus <img src="https://eli.thegreenplace.net/images/math/346b1071cf373e6327585c473b0f7ff643c5318b.gif" />. But <em>a</em> has order <em>m</em>, so <img src="https://eli.thegreenplace.net/images/math/8b56cce28b56a3bfe81e1612d183dfc042bd402e.gif" /> and therefore <img src="https://eli.thegreenplace.net/images/math/9ccb8587d38930ca20bae3ad4c6c9d4215e3ced5.gif" />. <em>Q.E.D.</em></p> <p>We now finally have all the tools required to prove Euler's theorem.</p> <p>Proof of Euler's theorem: Let <img src="https://eli.thegreenplace.net/images/math/ae4ae2b95770589320bf2a1844bc34c8afac7f18.gif" /> the group of units modulo <em>n</em>. The order of <em>G</em> is <img src="https://eli.thegreenplace.net/images/math/20bdd8a8b971fd8582ce58915d9f42ff001daef3.gif" /> (by <strong>(IV)</strong>). Now, by <strong>(VIII)</strong> part (2), raising any congruence class to the power <img src="https://eli.thegreenplace.net/images/math/20bdd8a8b971fd8582ce58915d9f42ff001daef3.gif" /> must give the identity element. The statement <img src="https://eli.thegreenplace.net/images/math/41bfcd931fdf749bf761b9e98ae5bf9e1e050a5c.gif" /> is equivalent to <img src="https://eli.thegreenplace.net/images/math/5cf274690f4bed52d2f24bf39482492aaaf6d135.gif" /></p> <p><img src="https://eli.thegreenplace.net/images/math/7b47d4175993a732aa2287de666a82273110f26e.gif" /></p> Equivalence classes and group partitions2009-07-17T15:47:57-07:002009-07-17T15:47:57-07:00Eli Benderskytag:eli.thegreenplace.net,2009-07-17:/2009/07/17/equivalence-classes-and-group-partitions <p>In this post I want to show some interesting definitions and theorems about equivalence relations &amp; classes, and groups.</p> <p><em>Relations</em> are an important topic in algebra. Conceptually, a relation is a statement <tt class="docutils literal"><span class="pre">aRb</span></tt> about two elements of a set. If the elements are integers, then <img src="https://eli.thegreenplace.net/images/math/ccff2fee4b15e0b46f79f86ce5d1de59163bb483.gif" /> is a relation, and so is …</p> <p>In this post I want to show some interesting definitions and theorems about equivalence relations &amp; classes, and groups.</p> <p><em>Relations</em> are an important topic in algebra. Conceptually, a relation is a statement <tt class="docutils literal"><span class="pre">aRb</span></tt> about two elements of a set. If the elements are integers, then <img src="https://eli.thegreenplace.net/images/math/ccff2fee4b15e0b46f79f86ce5d1de59163bb483.gif" /> is a relation, and so is <img src="https://eli.thegreenplace.net/images/math/291666cb9894498f52e69a8e08f287ca771c204d.gif" />.</p> <p>Here's a formal set-theoretic definition:</p> <p><strong>(I) Binary relation:</strong> A <em>binary relation</em> on a set <tt class="docutils literal"><span class="pre">A</span></tt> is a collection of ordered pairs of elements of <tt class="docutils literal"><span class="pre">A</span></tt>. In other words, it is the subset of <img src="https://eli.thegreenplace.net/images/math/bc659bc638626217264a2aa7a0cca55c0cc40ddc.gif" />. More generally, a binary relation between two sets <tt class="docutils literal"><span class="pre">A</span></tt> and <tt class="docutils literal"><span class="pre">B</span></tt> is a subset of <img src="https://eli.thegreenplace.net/images/math/61589f4d75ca185c6165e5108883b41f5b630222.gif" />.</p> <p>Note it says <em>ordered pairs</em>. What this means is that the order of elements in a relation is important. Intuitively, given the relation <img src="https://eli.thegreenplace.net/images/math/14d75ce806cfa95e2165e15f3f40cbc02b2526a1.gif" /> and the set of integers, it's clear that <img src="https://eli.thegreenplace.net/images/math/291666cb9894498f52e69a8e08f287ca771c204d.gif" /> does not generally imply <img src="https://eli.thegreenplace.net/images/math/7edbd8459c300b48c7a5bfdb112e3294e51f3788.gif" />.</p> <p>An important sub-class of relations we'll be most interested with is the <em>equivalence relations</em>:</p> <p><strong>(II) Equivalence relation:</strong> A relation <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> on a set is called an <em>equivalence relation</em> if it's reflexive, symmetric and transitive:</p> <ul class="simple"> <li>Reflexive: for all <tt class="docutils literal"><span class="pre">a</span></tt> in <tt class="docutils literal"><span class="pre">S</span></tt> it holds that <img src="https://eli.thegreenplace.net/images/math/8837c782dfc7a199063f41e8a89f3d2f6968f863.gif" />.</li> <li>Symmetric: for all <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> in <tt class="docutils literal"><span class="pre">S</span></tt> it holds that if <img src="https://eli.thegreenplace.net/images/math/e30651047d0c214a1dcc4ca726674497a16b692f.gif" /> then <img src="https://eli.thegreenplace.net/images/math/9541931763753defc9183053ae757c301d2de115.gif" />.</li> <li>Transitive: for all <tt class="docutils literal"><span class="pre">a</span></tt>, <tt class="docutils literal"><span class="pre">b</span></tt> and <tt class="docutils literal"><span class="pre">c</span></tt> in <tt class="docutils literal"><span class="pre">S</span></tt> it holds that if <img src="https://eli.thegreenplace.net/images/math/e30651047d0c214a1dcc4ca726674497a16b692f.gif" /> and <img src="https://eli.thegreenplace.net/images/math/8ca3658ef957edae19b4b119308890e2b3f49796.gif" /> then <img src="https://eli.thegreenplace.net/images/math/0179f9af305ab8e1c766267cfd11318b6b4811bf.gif" />.</li> </ul> <p>Examples: equality is an equivalence relation, but greater-or-equal is not, as <img src="https://eli.thegreenplace.net/images/math/291666cb9894498f52e69a8e08f287ca771c204d.gif" /> doesn't imply <img src="https://eli.thegreenplace.net/images/math/7edbd8459c300b48c7a5bfdb112e3294e51f3788.gif" /> - the symmetric condition doesn't hold.</p> <p>Another important equivalence relation is the congruence modulo an integer, <img src="https://eli.thegreenplace.net/images/math/0c4ce290323e491e2482e46e01db353836b45d7d.gif" />.</p> <p><strong>(III) Example:</strong> Let <tt class="docutils literal"><span class="pre">G</span></tt> be a group and <tt class="docutils literal"><span class="pre">H</span></tt> a subgroup of <tt class="docutils literal"><span class="pre">G</span></tt>. For <img src="https://eli.thegreenplace.net/images/math/6678de816bbd60aedd8839895cd79707af6a97d8.gif" />, define <img src="https://eli.thegreenplace.net/images/math/e30651047d0c214a1dcc4ca726674497a16b692f.gif" /> if <img src="https://eli.thegreenplace.net/images/math/3071dc07c816d497eb3e1fc0340e7cac7016cf5c.gif" />. Then <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> is an equivalence relation on <tt class="docutils literal"><span class="pre">G</span></tt>.</p> <p>Proof: To prove that some relation is an equivalence relation, we have to prove the three properties of equivalence relations.</p> <p>Reflexive: <img src="https://eli.thegreenplace.net/images/math/a17e86615bef6d2199ab48b9b1dcb5b014565084.gif" /> is the identity element <tt class="docutils literal"><span class="pre">e</span></tt>. However, since <tt class="docutils literal"><span class="pre">H</span></tt> is a subgroup of <tt class="docutils literal"><span class="pre">G</span></tt>, it means that <img src="https://eli.thegreenplace.net/images/math/3e0c664e92b67cb72df9d035b6b604b6350fe921.gif" />. Therefore <img src="https://eli.thegreenplace.net/images/math/34527efbfb9e31e54e0852df71b1c49e95ccdfbc.gif" />, so <img src="https://eli.thegreenplace.net/images/math/8837c782dfc7a199063f41e8a89f3d2f6968f863.gif" />.</p> <p>Symmetric: Assume that <img src="https://eli.thegreenplace.net/images/math/3071dc07c816d497eb3e1fc0340e7cac7016cf5c.gif" />. Since <tt class="docutils literal"><span class="pre">H</span></tt> is a subgroup, then this element has an inverse in <tt class="docutils literal"><span class="pre">H</span></tt>: <img src="https://eli.thegreenplace.net/images/math/9e3f7534ed8dab84ae64ab9f0ec20772c69ddff0.gif" />. Using the associative law of groups several times it's possible to show that <img src="https://eli.thegreenplace.net/images/math/7317708012349dbcc4b3eb70f95119ae9c286203.gif" /> So, <img src="https://eli.thegreenplace.net/images/math/553bb94e84e7af685533d0ee89de27b427e65802.gif" />, hence <img src="https://eli.thegreenplace.net/images/math/9541931763753defc9183053ae757c301d2de115.gif" />.</p> <p>Transitive: Given <img src="https://eli.thegreenplace.net/images/math/3071dc07c816d497eb3e1fc0340e7cac7016cf5c.gif" /> and <img src="https://eli.thegreenplace.net/images/math/f11dbfe83910887c218813c7a4cb90fc7fed83f8.gif" />. <img src="https://eli.thegreenplace.net/images/math/2dc110b11d970dfda421a11dd07ecc95ded21823.gif" /> So, <img src="https://eli.thegreenplace.net/images/math/cad936cd9155324d4258bc2368ee68a2c9a6b88e.gif" />, hence <img src="https://eli.thegreenplace.net/images/math/0179f9af305ab8e1c766267cfd11318b6b4811bf.gif" />.</p> <p>Thus, we've proved that this <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> is an equivalence relation. Let's see a concrete application of the result we've just proved:</p> <p>Consider that if <img src="https://eli.thegreenplace.net/images/math/55a2a59ab8896d9110c3d3c055c8c8f3b5c88297.gif" /> with the operation <tt class="docutils literal"><span class="pre">+</span></tt>, and <tt class="docutils literal"><span class="pre">H</span></tt> is the subgroup consisting of all multiples of some <img src="https://eli.thegreenplace.net/images/math/719262cd45830248133d8a9183d18f9b43c1a7cb.gif" />, then <img src="https://eli.thegreenplace.net/images/math/3071dc07c816d497eb3e1fc0340e7cac7016cf5c.gif" /> actually means that <img src="https://eli.thegreenplace.net/images/math/54a610609dd020d4a60658c7d44b97d9dbc04dfb.gif" /> for some <img src="https://eli.thegreenplace.net/images/math/f8b63b37c9f85fea422386a6a34535a2f2a7cc07.gif" />. In other words <img src="https://eli.thegreenplace.net/images/math/0c4ce290323e491e2482e46e01db353836b45d7d.gif" />. This proves that congruence modulo <tt class="docutils literal"><span class="pre">n</span></tt> is an equivalence relation, since it's a special case of (III).</p> <p><strong>(IV) Equivalence class:</strong> If <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> is an equivalence relation on <tt class="docutils literal"><span class="pre">S</span></tt>, then <tt class="docutils literal"><span class="pre">[a]</span></tt>, the <em>equivalence class of a</em> is defined by <img src="https://eli.thegreenplace.net/images/math/4f1018740fe8826cbe9aae04c0ef6642794359e5.gif" /></p> <p>For example, let's take the integers <img src="https://eli.thegreenplace.net/images/math/b719c7ce5a7442a3bf64a8fa268fc460dcd2f3a3.gif" /> and define an equivalence relation &quot;congruent modulo 5&quot;. For instance, <img src="https://eli.thegreenplace.net/images/math/15ea75a387b7bed5bffb2ec749050bcb69efda6d.gif" />. The congruence class of 1 modulo 5 (denoted <img src="https://eli.thegreenplace.net/images/math/51c8de74b436372084218b4c20c9b56b31841b9d.gif" />) is <img src="https://eli.thegreenplace.net/images/math/21a88de3471dc39ecebb41416ef2440fa87d41f7.gif" />.</p> <p><strong>(V) Group partition:</strong> If <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> is an equivalence relation on <tt class="docutils literal"><span class="pre">S</span></tt>, then <img src="https://eli.thegreenplace.net/images/math/4e2b5a0104ce54ff58d1bfef5bf627f27b3c3144.gif" /> for all <img src="https://eli.thegreenplace.net/images/math/f27f1b6ae66beb6c65d284fef3c58b10699c72e3.gif" />, and <img src="https://eli.thegreenplace.net/images/math/002cf7d039da1d117c5f9020c2c22d6636a4c559.gif" /> implies that <img src="https://eli.thegreenplace.net/images/math/c0849a8679091eddc0ea75adcac02b4c8e4b8536.gif" />. In other words, <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> partitions <tt class="docutils literal"><span class="pre">S</span></tt> into disjoint equivalence classes.</p> <p>Proof: the first part is easy. Since always <img src="https://eli.thegreenplace.net/images/math/52809cd6fe155fec2dc9497412b1a27ab85f1f6c.gif" />, then <img src="https://eli.thegreenplace.net/images/math/c1f41fa09366f686d3c80f12c4bc1af5026d64c9.gif" />. To prove the second part, we'll show that if <img src="https://eli.thegreenplace.net/images/math/3674bbe6defde731553a80353589f1857d60505f.gif" /> then <img src="https://eli.thegreenplace.net/images/math/eb3b4462d76bf469f89dc6ded2d3acc1ab80a036.gif" />.</p> <p>Suppose that <img src="https://eli.thegreenplace.net/images/math/3674bbe6defde731553a80353589f1857d60505f.gif" />, and let <img src="https://eli.thegreenplace.net/images/math/f70628deb363032a930c8efe9c02ce456be4e741.gif" />. Therefore <img src="https://eli.thegreenplace.net/images/math/f0d3c43681c62263d12367105cd54ac13189adf3.gif" /> and <img src="https://eli.thegreenplace.net/images/math/a5c98d9612eb21bebe6860ac91e618336bd2c45a.gif" />. But <img src="https://eli.thegreenplace.net/images/math/ec2243ccce8948e68890aabd1ce859dbb83defe1.gif" /> is an equivalence relation and thus is transitive and symmetric. So <img src="https://eli.thegreenplace.net/images/math/e30651047d0c214a1dcc4ca726674497a16b692f.gif" />. But this means that <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> are in the same equivalence class: <img src="https://eli.thegreenplace.net/images/math/eb3b4462d76bf469f89dc6ded2d3acc1ab80a036.gif" />. <em>Q.E.D.</em></p> Detexify recognizes hand-written math symbols2009-07-13T05:31:16-07:002009-07-13T05:31:16-07:00Eli Benderskytag:eli.thegreenplace.net,2009-07-13:/2009/07/13/detexify-recognizes-hand-written-math-symbols Does it ever happen to you that you don't remember the Latex code for some mathematical symbol? What can you do then except wading through pages of Latex symbols trying to locate the right one? Well, no more! <a href="http://detexify.kirelabs.org/classify.html">Detexify</a> is a great new service that allows you to "draw" the … Does it ever happen to you that you don't remember the Latex code for some mathematical symbol? What can you do then except wading through pages of Latex symbols trying to locate the right one? Well, no more! <a href="http://detexify.kirelabs.org/classify.html">Detexify</a> is a great new service that allows you to "draw" the symbol you're looking for: <p> <img src="https://eli.thegreenplace.net/images/2009/07/intg_handwriting.png" title="intg_handwriting" width="301" height="366" class="alignnone size-full wp-image-1801" /> </p> ... and it will suggest the Latex code. <p> <img src="https://eli.thegreenplace.net/images/2009/07/intg_suggestions.png" title="intg_suggestions" width="287" height="310" class="alignnone size-full wp-image-1802" /> </p> Detexify is a learning OCR classifier, and can be "trained" by users to improve it's performance. Kudos to the <a href="http://kirelabs.org/">creator</a> of Detexify for a great project. It will definitely be useful... Generating multi-subsets using arithmetic2009-07-11T07:27:13-07:002009-07-11T07:27:13-07:00Eli Benderskytag:eli.thegreenplace.net,2009-07-11:/2009/07/11/generating-multi-subsets-using-arithmetic <p>In the past <a class="reference external" href="http://eli.thegreenplace.net/2005/03/29/application-of-combinations/">I've written</a> about how simple arithmetic can be employed to compute a powerset of a given set.</p> <p>Here I want to show a generalization, that uses n-nary arithmetic. But first, let's define the problem:</p> <p>Suppose you have a set of elements and you want to select multi-subsets …</p> <p>In the past <a class="reference external" href="http://eli.thegreenplace.net/2005/03/29/application-of-combinations/">I've written</a> about how simple arithmetic can be employed to compute a powerset of a given set.</p> <p>Here I want to show a generalization, that uses n-nary arithmetic. But first, let's define the problem:</p> <p>Suppose you have a set of elements and you want to select multi-subsets from it. By multi-subset in this context I mean that an element can appear more than once in it. For example, given the set {0, 1, 2, 3, 4, 5}, then {1, 1, 2} is a multi-subset. So are {5, 5, 5, 5} and {0, 1, 2, 3, 4, 5}. Suppose you want to go over <em>all</em> multi-subsets of a set. How can this be done?</p> <p>Note that generating a superset is a private case of this problem, restricting each element to appear either 0 or 1 times in the resulting subset.</p> <p>So the solution is a generalization of the <a class="reference external" href="http://eli.thegreenplace.net/2005/03/29/application-of-combinations/">binary-arithmetic solution</a> for the powerset problem.</p> <p>Intuitive motivation: consider the decimal numbers, for example 25. If we use the position of each digit (starting with the units) to convey information, this leads to an interesting observation. If we have two elements to choose from, 25 may mean 5 times the 1st element, 2 times the second element. Now, going over all numbers from 0 to 99, we are actually generating all multi-subsets of two elements where each can be picked from 0 to 9 times.</p> <p>Once this is clear, the algorithm is simple. Let's generalize to a n-ary base system, using position to point to an element and the 'digit' at this position to say how many times it appears in a given multi-subset. And the best part - the simple rules of addition with carry can now be used to efficiently generate all multi-subsets, given the amount of elements we have (<tt class="docutils literal"><span class="pre">length</span></tt>) and the maximal amount of times each can be picked (<tt class="docutils literal"><span class="pre">upto</span></tt>), the minimum being assumed 0.</p> <p>Here's the code:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">multiselects</span>(upto, length): <span style="color: #007f00"># Arithmetically, we create an array of digits</span> <span style="color: #007f00"># (each in the range 0..upto).</span> <span style="color: #007f00"># It&#39;s initialized with &#39;1&#39;</span> <span style="color: #007f00">#</span> ar = [<span style="color: #007f7f">1</span>] + [<span style="color: #007f7f">0</span>] * (length - <span style="color: #007f7f">1</span>) <span style="color: #00007f; font-weight: bold">while</span> <span style="color: #00007f">True</span>: <span style="color: #00007f; font-weight: bold">yield</span> ar <span style="color: #007f00"># The index we&#39;re currently trying to</span> <span style="color: #007f00"># advance</span> <span style="color: #007f00">#</span> idx = <span style="color: #007f7f">0</span> <span style="color: #007f00"># Advance the current index. If it reaches</span> <span style="color: #007f00"># the limit (upto), pefrorm a carry to the</span> <span style="color: #007f00"># next index (digits)</span> <span style="color: #007f00">#</span> <span style="color: #00007f; font-weight: bold">while</span> idx &lt; length: ar[idx] += <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">if</span> ar[idx] &lt;= upto: <span style="color: #00007f; font-weight: bold">break</span> <span style="color: #00007f; font-weight: bold">else</span>: ar[idx] = <span style="color: #007f7f">0</span> idx += <span style="color: #007f7f">1</span> <span style="color: #007f00"># We&#39;ve reached the last number...</span> <span style="color: #007f00">#</span> <span style="color: #00007f; font-weight: bold">if</span> idx == length: <span style="color: #00007f; font-weight: bold">break</span> </pre></div> <p>An an example run of:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">for</span> s <span style="color: #0000aa">in</span> multiselects(<span style="color: #007f7f">2</span>, <span style="color: #007f7f">3</span>): <span style="color: #00007f; font-weight: bold">print</span> s </pre></div> <p>Produces:</p> <div class="highlight"><pre>[1, 0, 0] [2, 0, 0] [0, 1, 0] [1, 1, 0] [2, 1, 0] [0, 2, 0] [1, 2, 0] [2, 2, 0] [0, 0, 1] [1, 0, 1] [2, 0, 1] [0, 1, 1] [1, 1, 1] [2, 1, 1] [0, 2, 1] [1, 2, 1] [2, 2, 1] [0, 0, 2] [1, 0, 2] [2, 0, 2] [0, 1, 2] [1, 1, 2] [2, 1, 2] [0, 2, 2] [1, 2, 2] [2, 2, 2] </pre></div> <p>Note that the solution is general, as the lists it returns are lists of indices. These can be employed with any set to generate multi-subsets.</p> <p><strong>Background and links</strong></p> <p>I came up with this function while working on Project Euler's problem 77. I ended up using a different method, but visualizing the possible partitions of primes was very useful.</p> <p>Here are some interesting mathematical links related to this problem:</p> <ul class="simple"> <li><a class="reference external" href="http://mathworld.wolfram.com/EulerTransform.html">Euler transform</a></li> <li><a class="reference external" href="http://mathworld.wolfram.com/PartitionFunctionP.html">Partition function P</a></li> <li><a class="reference external" href="http://mathworld.wolfram.com/PrimePartition.html">Prime partition</a></li> </ul> The GCD and linear combinations2009-07-10T07:49:39-07:002009-07-10T07:49:39-07:00Eli Benderskytag:eli.thegreenplace.net,2009-07-10:/2009/07/10/the-gcd-and-linear-combinations <p>A linear combination of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> is some integer of the form <img src="https://eli.thegreenplace.net/images/math/62b14b4483debf849fb89c978fc9d8de667d50ee.gif" />, where <img src="https://eli.thegreenplace.net/images/math/4d7d79c7dff51ad08353b3af8ec8de78276f7d02.gif" />.</p> <p>There's a very interesting theorem that gives a useful connection between linear combinations and the GCD of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt>, called <a class="reference external" href="http://en.wikipedia.org/wiki/B%C3%A9zout%27s_identity">Bézout's identity</a>:</p> <p><strong>Bézout's identity:</strong> <img src="https://eli.thegreenplace.net/images/math/3faea922fd91a480e6e951bdbe6568c36de8a854.gif" /> (the GCD of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt>) is the smallest positive linear …</p> <p>A linear combination of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> is some integer of the form <img src="https://eli.thegreenplace.net/images/math/62b14b4483debf849fb89c978fc9d8de667d50ee.gif" />, where <img src="https://eli.thegreenplace.net/images/math/4d7d79c7dff51ad08353b3af8ec8de78276f7d02.gif" />.</p> <p>There's a very interesting theorem that gives a useful connection between linear combinations and the GCD of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt>, called <a class="reference external" href="http://en.wikipedia.org/wiki/B%C3%A9zout%27s_identity">Bézout's identity</a>:</p> <p><strong>Bézout's identity:</strong> <img src="https://eli.thegreenplace.net/images/math/3faea922fd91a480e6e951bdbe6568c36de8a854.gif" /> (the GCD of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt>) is the smallest positive linear combination of non-zero <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt>.</p> <p>Both Bézout's identity and its corollary I show below are very useful tools in elementary number theory, being used for the proofs of many of the most fundamental theorems. Let's see why it's true.</p> <p><strong>(I) Intuition:</strong> First I'd like to explain this (surprising at first sight) theorem intuitively. By defintion, any common divisor of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> will divide <img src="https://eli.thegreenplace.net/images/math/62b14b4483debf849fb89c978fc9d8de667d50ee.gif" /> for all <img src="https://eli.thegreenplace.net/images/math/4d7d79c7dff51ad08353b3af8ec8de78276f7d02.gif" />. In particular, <img src="https://eli.thegreenplace.net/images/math/3faea922fd91a480e6e951bdbe6568c36de8a854.gif" /> also divides any <img src="https://eli.thegreenplace.net/images/math/62b14b4483debf849fb89c978fc9d8de667d50ee.gif" />.</p> <p>Now, assume we've found some small <img src="https://eli.thegreenplace.net/images/math/d7f9da6f2362bc7b8efdffe11cdf4ba00597aba0.gif" /> which isn't the GCD. But we've just said that <img src="https://eli.thegreenplace.net/images/math/3faea922fd91a480e6e951bdbe6568c36de8a854.gif" /> divides all linear combinations, so it also divides <tt class="docutils literal"><span class="pre">x</span></tt>. Therefore, <tt class="docutils literal"><span class="pre">x</span></tt> can not be smaller than the GCD. In other words, the smallest positive linear combination can only be <img src="https://eli.thegreenplace.net/images/math/3faea922fd91a480e6e951bdbe6568c36de8a854.gif" /> itself.</p> <p><strong>Corollary:</strong> An integer is a linear combination of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> IFF it is a multiple of their GCD.</p> <p>To prove Bézout's identity more formally, and along the way to see why the corollary is also true, let's first prove the following:</p> <p><strong>(III)</strong> Let <tt class="docutils literal"><span class="pre">I</span></tt> be a nonempty set of integers that is closed under addition and subtraction, and contains at least one non-zero integer. Then there exists a smallest positive element <img src="https://eli.thegreenplace.net/images/math/beff059ebb50376acab168901db80ded7949e1fa.gif" />, and <tt class="docutils literal"><span class="pre">I</span></tt> consists of all multiples of <tt class="docutils literal"><span class="pre">b</span></tt> (<img src="https://eli.thegreenplace.net/images/math/e7937e0bb65c4938597a8228eeb17a91f07986f9.gif" />).</p> <p>Proof: <tt class="docutils literal"><span class="pre">I</span></tt> contains at least one non-zero integer. Then it definitely contains at least one positive integer, because it is closed under addition and subtraction. Assume we have <img src="https://eli.thegreenplace.net/images/math/b36dc5cdf3bcc0208bf87c0f98375fd5eabf0ed8.gif" /> for some <img src="https://eli.thegreenplace.net/images/math/f9e61579d79e8c1c1fa9119bd87d3d61f00e8c19.gif" />. Therefore <img src="https://eli.thegreenplace.net/images/math/55f6c1c701b762759d0ab1491f781999ed43db6e.gif" /> and then also <img src="https://eli.thegreenplace.net/images/math/6486f2162b7ae30383261a7e5cdaa6288cfa88c4.gif" />. Thus we have positive integers in <tt class="docutils literal"><span class="pre">I</span></tt>. According to the <a class="reference external" href="http://eli.thegreenplace.net/2009/07/09/the-well-ordering-principle">well-ordering principle</a>, <tt class="docutils literal"><span class="pre">I</span></tt> has a smallest positive element which we'll call <tt class="docutils literal"><span class="pre">b</span></tt>.</p> <p>Now we'll want to show that <img src="https://eli.thegreenplace.net/images/math/e7937e0bb65c4938597a8228eeb17a91f07986f9.gif" />. As usual, to prove equalities of sets, it will be shown that they contain one another.</p> <p><img src="https://eli.thegreenplace.net/images/math/55cd1587cf04efebfb6e92f8581388e230593775.gif" /> is obvious - since <tt class="docutils literal"><span class="pre">I</span></tt> contains <tt class="docutils literal"><span class="pre">b</span></tt> and is closed under addition and subtraction, it contains all the multiples of <tt class="docutils literal"><span class="pre">b</span></tt>.</p> <p>To prove <img src="https://eli.thegreenplace.net/images/math/4a8035ec5843674bf59e5e2e2b3d59ea6117698c.gif" /> we'll demonstrate that any element <img src="https://eli.thegreenplace.net/images/math/5cf7becda6c13605eca7634f4c7f794056f756e1.gif" /> is a multiple of <tt class="docutils literal"><span class="pre">b</span></tt>. Using the division algorithm we write <img src="https://eli.thegreenplace.net/images/math/9a77ecf748d7f958c6f632f3906a7a9e61e0ad39.gif" /> for some integers <tt class="docutils literal"><span class="pre">q</span></tt> and <tt class="docutils literal"><span class="pre">0</span> <span class="pre">&lt;=</span> <span class="pre">r</span> <span class="pre">&lt;</span> <span class="pre">b</span></tt>. But this means that <img src="https://eli.thegreenplace.net/images/math/4525123efb61ba233370b72315d8557151fadfaa.gif" /> (because <tt class="docutils literal"><span class="pre">I</span></tt> contains <tt class="docutils literal"><span class="pre">bq</span></tt> and <tt class="docutils literal"><span class="pre">c</span></tt> and is closed under subtraction and addition). However, recall that <tt class="docutils literal"><span class="pre">b</span></tt> was chosen to be the smallest positive element of <tt class="docutils literal"><span class="pre">I</span></tt>, so <tt class="docutils literal"><span class="pre">r</span></tt> must be equal to 0. Therefore <tt class="docutils literal"><span class="pre">c</span></tt> is a multiple of <tt class="docutils literal"><span class="pre">b</span></tt>, and we have shown that <img src="https://eli.thegreenplace.net/images/math/4a8035ec5843674bf59e5e2e2b3d59ea6117698c.gif" />. <em>Q.E.D.</em></p> <p>Now back to Bézout's identity. We'll define:</p> <p><img src="https://eli.thegreenplace.net/images/math/9ce9860e1da4b572d6a076d2b48b4219d6980495.gif" /></p> <p>This <tt class="docutils literal"><span class="pre">I</span></tt> is obviously non-empty and is closed under addition and subtraction (by its definition as a linear combination). Note, in particular, that it also contains <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt>. By <strong>(III)</strong>, <tt class="docutils literal"><span class="pre">I</span></tt> consists of all multiples of its smallest positive element, which we'll call <tt class="docutils literal"><span class="pre">d</span></tt> here.</p> <p>To show that <img src="https://eli.thegreenplace.net/images/math/021c7f3f9c969802e3a91c61b8bc9c24400350f6.gif" /> we have to show that <tt class="docutils literal"><span class="pre">d|a</span></tt>, <tt class="docutils literal"><span class="pre">d|b</span></tt> and if <tt class="docutils literal"><span class="pre">c|a</span></tt> and <tt class="docutils literal"><span class="pre">c|b</span></tt> then <tt class="docutils literal"><span class="pre">c|d</span></tt>. First, by definition <tt class="docutils literal"><span class="pre">d</span></tt> is a divisor of any element in <tt class="docutils literal"><span class="pre">I</span></tt>, so it also divides <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt>. If <tt class="docutils literal"><span class="pre">c|a</span></tt> and <tt class="docutils literal"><span class="pre">c|b</span></tt>, say <tt class="docutils literal"><span class="pre">a=cq</span></tt> and <tt class="docutils literal"><span class="pre">b=cp</span></tt>, then:</p> <p><img src="https://eli.thegreenplace.net/images/math/29c402dd016d08d3af7b751065de7690102403dc.svg" /></p> <p>So <tt class="docutils literal"><span class="pre">c|d</span></tt>, which completes our proof that <tt class="docutils literal"><span class="pre">d=(a,b)</span></tt>. <em>Q.E.D.</em></p> <p>Regarding the corollary, it stems trivially from the definition of <tt class="docutils literal"><span class="pre">I</span></tt> and the proof above.</p> The well-ordering principle2009-07-09T09:25:53-07:002009-07-09T09:25:53-07:00Eli Benderskytag:eli.thegreenplace.net,2009-07-09:/2009/07/09/the-well-ordering-principle <p>The <a class="reference external" href="http://en.wikipedia.org/wiki/Well-ordering_principle">well-ordering principle</a> states:</p> <p><strong>The well-ordering principle:</strong> Any nonempty set of nonnegative integers has a smallest element.</p> <p><em>DUH, you don't say!</em> - seems obvious, doesn't it? This principle is, nevertheless, a very important and fundamental tool for proving other basic principles of number theory.</p> <p>Consider, for instance, the <a class="reference external" href="http://en.wikipedia.org/wiki/Division_algorithm">Division Algorithm</a>:</p> <p><strong>The …</strong></p> <p>The <a class="reference external" href="http://en.wikipedia.org/wiki/Well-ordering_principle">well-ordering principle</a> states:</p> <p><strong>The well-ordering principle:</strong> Any nonempty set of nonnegative integers has a smallest element.</p> <p><em>DUH, you don't say!</em> - seems obvious, doesn't it? This principle is, nevertheless, a very important and fundamental tool for proving other basic principles of number theory.</p> <p>Consider, for instance, the <a class="reference external" href="http://en.wikipedia.org/wiki/Division_algorithm">Division Algorithm</a>:</p> <p><strong>The Division Algorithm:</strong> If <tt class="docutils literal"><span class="pre">m</span></tt> and <tt class="docutils literal"><span class="pre">n</span></tt> are integers with <tt class="docutils literal"><span class="pre">n</span> <span class="pre">&gt;</span> <span class="pre">0</span></tt>, then there exist integers <tt class="docutils literal"><span class="pre">q</span></tt> and <tt class="docutils literal"><span class="pre">r</span></tt>, with <tt class="docutils literal"><span class="pre">0</span> <span class="pre">&lt;=</span> <span class="pre">r</span> <span class="pre">&lt;</span> <span class="pre">n</span></tt>, such that <img src="https://eli.thegreenplace.net/images/math/c2245e91a0ec659885e704a303ba2eff8e6045a0.gif" />.</p> <p>Again, this is so basic that one may doubt whether it should even be proved. But the well-ordering principle allows us, in fact, to prove the division algorithm in a rigorous manner:</p> <p>Let <img src="https://eli.thegreenplace.net/images/math/29968af66b01fdd8c5de887a44db0b4e60f8fce7.gif" />. It is obvious that <tt class="docutils literal"><span class="pre">W</span></tt> contains nonnegative integers. Let <img src="https://eli.thegreenplace.net/images/math/d398cf3839320215d8508f5492d19d4b5dbd83b8.gif" />. <em>By the well-ordering principle</em>, <tt class="docutils literal"><span class="pre">V</span></tt> has a smallest element, which we'll call <tt class="docutils literal"><span class="pre">r</span></tt>. <img src="https://eli.thegreenplace.net/images/math/7c93b5e9572f2a5efb01641ce59940f15f54e1e4.gif" />, so <img src="https://eli.thegreenplace.net/images/math/2b2c0fb96d6759348beadd2c624aa1d92b13a264.gif" /> for some <tt class="docutils literal"><span class="pre">q</span></tt> and <tt class="docutils literal"><span class="pre">r</span> <span class="pre">&gt;=</span> <span class="pre">0</span></tt> (by the definition of sets <tt class="docutils literal"><span class="pre">W</span></tt> and <tt class="docutils literal"><span class="pre">V</span></tt>, correspondingly).</p> <p>Now, what's left to prove is that <tt class="docutils literal"><span class="pre">r</span> <span class="pre">&lt;</span> <span class="pre">n</span></tt>. Let's assume the opposite, namely that <img src="https://eli.thegreenplace.net/images/math/b1a067328be8dd74e839ad82f6ebd20089567001.gif" />. Rearranging: <img src="https://eli.thegreenplace.net/images/math/6d4c9ea92f063dd1610459a98e79b2332c2a17e4.gif" />. By the definition of <tt class="docutils literal"><span class="pre">V</span></tt>, <img src="https://eli.thegreenplace.net/images/math/48f703813d8cbec496b2dd24d6a87684843c6237.gif" /> (since it has the form <img src="https://eli.thegreenplace.net/images/math/44ed40b22260264b9048f920e66afd51a982e97f.gif" /> for some integer <tt class="docutils literal"><span class="pre">t</span></tt> and is nonnegative). But recall that we called <tt class="docutils literal"><span class="pre">r</span></tt> the smallest element of <tt class="docutils literal"><span class="pre">V</span></tt>, and <img src="https://eli.thegreenplace.net/images/math/4e8431424ab97947ea763a3df6fbbf1eb74b4d53.gif" />, so we have a contradiction.</p> <p>Therefore, we see that <tt class="docutils literal"><span class="pre">r</span> <span class="pre">&lt;</span> <span class="pre">n</span></tt>. This completes the proof. <em>Q.E.D.</em></p> Project Euler problem 66 and continued fractions2009-06-19T14:49:07-07:002009-06-19T14:49:07-07:00Eli Benderskytag:eli.thegreenplace.net,2009-06-19:/2009/06/19/project-euler-problem-66-and-continued-fractions <p><a class="reference external" href="http://projecteuler.net/index.php?section=problems&amp;id=66">Problem 66</a> is one of those problems that make Project Euler lots of fun. It doesn't have a brute-force solution, and to solve it one actually has to implement a non-trivial mathematical algorithm and get exposed to several interesting techniques.</p> <p>I will not post the solution or the full code …</p> <p><a class="reference external" href="http://projecteuler.net/index.php?section=problems&amp;id=66">Problem 66</a> is one of those problems that make Project Euler lots of fun. It doesn't have a brute-force solution, and to solve it one actually has to implement a non-trivial mathematical algorithm and get exposed to several interesting techniques.</p> <p>I will not post the solution or the full code for the problem here, just a couple of hints.</p> <p>After a very short bout of Googling, you'll discover that the Diophantine equation:</p> <p><img src="https://eli.thegreenplace.net/images/math/cdc11e760e8b319f652e19c6daf547cbe9d0b0f9.gif" /></p> <p>Is quite famous and is called <a class="reference external" href="http://en.wikipedia.org/wiki/Pell%27s_equation">Pell's equation</a>. From here, further web searches and Wikipedia-reading will bring you to at least two methods for finding the <em>fundamental solution</em>, which is the pair of <tt class="docutils literal"><span class="pre">x</span></tt> and <tt class="docutils literal"><span class="pre">y</span></tt> with minimal <tt class="docutils literal"><span class="pre">x</span></tt> solving it.</p> <p>One of the methods involves computing the continued-fraction representation of the square root of <tt class="docutils literal"><span class="pre">D</span></tt>. <a class="reference external" href="http://www.mcs.surrey.ac.uk/Personal/R.Knott/Fibonacci/cfINTRO.html">This page</a> is a must read on this topic, and will help you with other Euler problems as well.</p> <p>I want to post here a code snippet that implements the continued-fraction computation described in that link. Its steps follow the <em>Algebraic algoritm</em> given there:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">CF_of_sqrt</span>(n): <span style="color: #7f007f">&quot;&quot;&quot; Compute the continued fraction representation of the</span> <span style="color: #7f007f"> square root of N.</span> <span style="color: #7f007f"> The first element in the returned array is the whole</span> <span style="color: #7f007f"> part of the fraction. The others are the denominators</span> <span style="color: #7f007f"> up to (and not including) the point where it starts</span> <span style="color: #7f007f"> repeating.</span> <span style="color: #7f007f"> Uses the algorithm explained here:</span> <span style="color: #7f007f"> http://www.mcs.surrey.ac.uk/Personal/R.Knott/Fibonacci/cfINTRO.html</span> <span style="color: #7f007f"> In the section named: &quot;Methods of finding continued</span> <span style="color: #7f007f"> fractions for square roots&quot;</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> <span style="color: #00007f; font-weight: bold">if</span> is_square(n): <span style="color: #00007f; font-weight: bold">return</span> [<span style="color: #00007f">int</span>(math.sqrt(n))] ans = [] step1_num = <span style="color: #007f7f">0</span> step1_denom = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">while</span> <span style="color: #00007f">True</span>: nextn = <span style="color: #00007f">int</span>((math.floor(math.sqrt(n)) + step1_num) / step1_denom) ans.append(<span style="color: #00007f">int</span>(nextn)) step2_num = step1_denom step2_denom = step1_num - step1_denom * nextn step3_denom = (n - step2_denom ** <span style="color: #007f7f">2</span>) / step2_num step3_num = -step2_denom <span style="color: #00007f; font-weight: bold">if</span> step3_denom == <span style="color: #007f7f">1</span>: ans.append(ans[<span style="color: #007f7f">0</span>] * <span style="color: #007f7f">2</span>) <span style="color: #00007f; font-weight: bold">break</span> step1_num, step1_denom = step3_num, step3_denom <span style="color: #00007f; font-weight: bold">return</span> ans </pre></div> <p>As I said, this still isn't enough to solve the problem, but with this code in hand, the solution isn't too far. Read some more about Pell's equation and you'll discover how to use this code to reach a solution.</p> <p>It took my program ~30 milliseconds to find an answer to the problem, by the way. It took less than a second to solve a 10-times larger problem (for D &lt;= 10000), so I believe it to be a pretty good implementation.</p> Efficient modular exponentiation algorithms2009-03-28T09:51:29-07:002009-03-28T09:51:29-07:00Eli Benderskytag:eli.thegreenplace.net,2009-03-28:/2009/03/28/efficient-modular-exponentiation-algorithms <p><a class="reference external" href="http://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms/">Earlier this week</a> I've discussed efficient algorithms for exponentiation.</p> <p>However, for real-life needs of number theoretic computations, just raising numbers to large exponents isn't very useful, because extremely huge numbers start appearing very quickly <a class="footnote-reference" href="#id8" id="id1"></a>, and these don't have much use. What's much more useful is <a class="reference external" href="http://en.wikipedia.org/wiki/Modular_exponentiation">modular exponentiation</a>, raising integers …</p> <p><a class="reference external" href="http://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms/">Earlier this week</a> I've discussed efficient algorithms for exponentiation.</p> <p>However, for real-life needs of number theoretic computations, just raising numbers to large exponents isn't very useful, because extremely huge numbers start appearing very quickly <a class="footnote-reference" href="#id8" id="id1"></a>, and these don't have much use. What's much more useful is <a class="reference external" href="http://en.wikipedia.org/wiki/Modular_exponentiation">modular exponentiation</a>, raising integers to high powers <img src="https://eli.thegreenplace.net/images/math/5ed051f99a1984c11a5b2d4ea770f3dc527912d8.gif" /> <a class="footnote-reference" href="#id9" id="id2"></a></p> <p>Luckily, we can reuse the efficient algorithms developed in the previous article, with very few modifications to perform modular exponentiation as well. This is possible because of some convenient properties of modular arithmetic.</p> <div class="section" id="modular-multiplication"> <h3>Modular multiplication</h3> <p>Given two numbers, <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt>, their product modulo <tt class="docutils literal"><span class="pre">n</span></tt> is <img src="https://eli.thegreenplace.net/images/math/60a31df99204b91b44d3bc8b6c0b462d5302182c.gif" />. Consider the number <tt class="docutils literal"><span class="pre">x</span> <span class="pre">&lt;</span> <span class="pre">n</span></tt>, such that <img src="https://eli.thegreenplace.net/images/math/a24e63eaa528ead7411690545b6f1525adf6fd11.gif" />. Such a number always exists, and we usually call it the <em>remainder</em> of dividing <tt class="docutils literal"><span class="pre">a</span></tt> by <tt class="docutils literal"><span class="pre">n</span></tt>. Similarly, there is a <tt class="docutils literal"><span class="pre">y</span> <span class="pre">&lt;</span> <span class="pre">b</span></tt>, such that <img src="https://eli.thegreenplace.net/images/math/210006d46db94551446e69e871dd5aa85a917d18.gif" />. It follows from basic rules of modular arithmetic that <img src="https://eli.thegreenplace.net/images/math/13b1a1a7a7cab72a42640bc57c99dff9f9fc78dc.gif" /> <a class="footnote-reference" href="#id10" id="id3"></a></p> <p>Therefore, if we want to know the product of <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> modulo <tt class="docutils literal"><span class="pre">n</span></tt>, we just have to keep their remainders when divided by <tt class="docutils literal"><span class="pre">n</span></tt>. Note: <tt class="docutils literal"><span class="pre">a</span></tt> and <tt class="docutils literal"><span class="pre">b</span></tt> may be arbitrarily large, but <tt class="docutils literal"><span class="pre">x</span></tt> and <tt class="docutils literal"><span class="pre">y</span></tt> are always smaller than <tt class="docutils literal"><span class="pre">n</span></tt>.</p> </div> <div class="section" id="a-naive-algorithm"> <h3>A naive algorithm</h3> <p>What is the most naive way you can think of for raising computing <img src="https://eli.thegreenplace.net/images/math/f342f6b456e722a0a7bf3fc6b194bcf73aab90e5.gif" />? Raise <tt class="docutils literal"><span class="pre">a</span></tt> to the power <tt class="docutils literal"><span class="pre">b</span></tt>, and then reduce modulo <tt class="docutils literal"><span class="pre">n</span></tt>. Right?</p> <p>Indeed, this is a very unsophisticated and slow method, because raising <tt class="docutils literal"><span class="pre">a</span></tt> to the power <tt class="docutils literal"><span class="pre">b</span></tt> can result in a really huge number that takes long to compute.</p> <p>For any useful number, this algorithm is so slow that I'm not even going to run it in the tests.</p> </div> <div class="section" id="using-the-properties-of-modular-multiplication"> <h3>Using the properties of modular multiplication</h3> <p>As we've learned above, modular multiplication allows us to just keep the intermediate result <img src="https://eli.thegreenplace.net/images/math/5ed051f99a1984c11a5b2d4ea770f3dc527912d8.gif" /> at each step. Here's the implementation of a simple repeated multiplication algorithm for computing modular exponents this way:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">modexp_mul</span>(a, b, n): r = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">for</span> i <span style="color: #0000aa">in</span> <span style="color: #00007f">xrange</span>(b): r = r * a % n <span style="color: #00007f; font-weight: bold">return</span> r </pre></div> <p>It's much better than the naive algorithm, but as we saw in the previous article it's quite slow, requiring <tt class="docutils literal"><span class="pre">b</span></tt> multiplications (and reductions modulo <tt class="docutils literal"><span class="pre">n</span></tt>).</p> <p>We can apply the same modular reduction rule to the more efficient exponentiation algorithms we've studied <a class="reference external" href="http://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms/">before</a>.</p> </div> <div class="section" id="modular-exponentiation-by-squaring"> <h3>Modular exponentiation by squaring</h3> <p>Here's the right-to-left method with modular reductions at each step.</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">modexp_rl</span>(a, b, n): r = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">while</span> <span style="color: #007f7f">1</span>: <span style="color: #00007f; font-weight: bold">if</span> b % <span style="color: #007f7f">2</span> == <span style="color: #007f7f">1</span>: r = r * a % n b /= <span style="color: #007f7f">2</span> <span style="color: #00007f; font-weight: bold">if</span> b == <span style="color: #007f7f">0</span>: <span style="color: #00007f; font-weight: bold">break</span> a = a * a % n <span style="color: #00007f; font-weight: bold">return</span> r </pre></div> <p>We use exactly the same algorithm, but reduce every multiplication <img src="https://eli.thegreenplace.net/images/math/5ed051f99a1984c11a5b2d4ea770f3dc527912d8.gif" />. So the numbers we deal with here are never very large.</p> <p>Similarly, here's the left-to-right method:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">modexp_lr</span>(a, b, n): r = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">for</span> bit <span style="color: #0000aa">in</span> reversed(_bits_of_n(b)): r = r * r % n <span style="color: #00007f; font-weight: bold">if</span> bit == <span style="color: #007f7f">1</span>: r = r * a % n <span style="color: #00007f; font-weight: bold">return</span> r </pre></div> <p>With <tt class="docutils literal"><span class="pre">_bits_of_n</span></tt> being, as before:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">_bits_of_n</span>(n): <span style="color: #7f007f">&quot;&quot;&quot; Return the list of the bits in the binary</span> <span style="color: #7f007f"> representation of n, from LSB to MSB</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> bits = [] <span style="color: #00007f; font-weight: bold">while</span> n: bits.append(n % <span style="color: #007f7f">2</span>) n /= <span style="color: #007f7f">2</span> <span style="color: #00007f; font-weight: bold">return</span> bits </pre></div> </div> <div class="section" id="relative-performance"> <h3>Relative performance</h3> <p>As I've noted in the <a class="reference external" href="http://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms/">previous article</a>, the RL method does a worse job of keeping its multiplicands low than the LR method. And indeed, for smaller <tt class="docutils literal"><span class="pre">n</span></tt>, RL is somewhat faster than LR. For larger <tt class="docutils literal"><span class="pre">n</span></tt>, RL is somewhat slower.</p> <p>What's obvious is that now the built-in <tt class="docutils literal"><span class="pre">pow</span></tt> is superior to both hand-coded methods <a class="footnote-reference" href="#id11" id="id4"></a>. My tests show it's anywhere from twice to 10 times as fast.</p> <p>Why is <tt class="docutils literal"><span class="pre">pow</span></tt> so much faster? Is it only the efficiency of C versus Python? Not really. In fact, <tt class="docutils literal"><span class="pre">pow</span></tt> uses an even more sophisticated algorithm for large exponents <a class="footnote-reference" href="#id12" id="id5"></a>. Indeed, for small exponents the runtime of <tt class="docutils literal"><span class="pre">pow</span></tt> is similar to the runtime of the implementations I presented above.</p> </div> <div class="section" id="the-k-ary-lr-method"> <h3>The k-ary LR method</h3> <p>It turns out that the LR method of repeated squaring can be generalized. Instead of breaking the exponent into bits of its base-2 representation, we can break it into larger pieces, and save some computations this way.</p> <p>I'll present the k-ary LR method that breaks the exponent into its &quot;digits&quot; in base <img src="https://eli.thegreenplace.net/images/math/1e028fb602c123d0fe4958d8a84229d6803b289e.gif" /> for some integer <tt class="docutils literal"><span class="pre">k</span></tt>. The exponent can be written as:</p> <p><img src="https://eli.thegreenplace.net/images/math/2b3373f91d2f784798a343b046defcaf0bd22786.gif" /></p> <p>Where <img src="https://eli.thegreenplace.net/images/math/8fab90b047823b97522115f88da94c5d6797de3f.gif" /> are the digits of <tt class="docutils literal"><span class="pre">b</span></tt> in base <tt class="docutils literal"><span class="pre">m</span></tt>. <img src="https://eli.thegreenplace.net/images/math/2d4469bf98c45573ce8673265c3c9bde3520e5d2.gif" /> is then:</p> <p><img src="https://eli.thegreenplace.net/images/math/0550b0df0f672136e39d9d050781d467eff82bd8.gif" /></p> <p>We compute this iteratively as follows <a class="footnote-reference" href="#id13" id="id6"></a>:</p> <p>Raise <img src="https://eli.thegreenplace.net/images/math/fa23ac9ecbf9f0cead492e9227e26757a967c284.gif" /> to the <tt class="docutils literal"><span class="pre">m</span></tt>-th power and multiply by <img src="https://eli.thegreenplace.net/images/math/4df976d630809ae3d013ecb8b764cee38121efc2.gif" />. We get <img src="https://eli.thegreenplace.net/images/math/52b89408d7c7c3d7274df85588a4ce6ee8b1a871.gif" />. Next, raise <img src="https://eli.thegreenplace.net/images/math/83b3fdda5b127e3a4f9bcb7b45d2fa7ef3659493.gif" /> to the <tt class="docutils literal"><span class="pre">m</span></tt>-th power and multiply by <img src="https://eli.thegreenplace.net/images/math/b6ed9962207f950a607b7591bd50e6375efc014b.gif" />, obtaining <img src="https://eli.thegreenplace.net/images/math/ec6ac382c692b3afe0496e8ea45a1a780528db13.gif" />. If we continue with this, we'll eventually get <img src="https://eli.thegreenplace.net/images/math/2d4469bf98c45573ce8673265c3c9bde3520e5d2.gif" />.</p> <p>This translates into the following code:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">modexp_lr_k_ary</span>(a, b, n, k=<span style="color: #007f7f">5</span>): <span style="color: #7f007f">&quot;&quot;&quot; Compute a ** b (mod n)</span> <span style="color: #7f007f"> K-ary LR method, with a customizable &#39;k&#39;.</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> base = <span style="color: #007f7f">2</span> &lt;&lt; (k - <span style="color: #007f7f">1</span>) <span style="color: #007f00"># Precompute the table of exponents</span> table = [<span style="color: #007f7f">1</span>] * base <span style="color: #00007f; font-weight: bold">for</span> i <span style="color: #0000aa">in</span> <span style="color: #00007f">xrange</span>(<span style="color: #007f7f">1</span>, base): table[i] = table[i - <span style="color: #007f7f">1</span>] * a % n <span style="color: #007f00"># Just like the binary LR method, just with a</span> <span style="color: #007f00"># different base</span> <span style="color: #007f00">#</span> r = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">for</span> digit <span style="color: #0000aa">in</span> reversed(_digits_of_n(b, base)): <span style="color: #00007f; font-weight: bold">for</span> i <span style="color: #0000aa">in</span> <span style="color: #00007f">xrange</span>(k): r = r * r % n <span style="color: #00007f; font-weight: bold">if</span> digit: r = r * table[digit] % n <span style="color: #00007f; font-weight: bold">return</span> r </pre></div> <p>Note that we save some time by pre-computing the powers of <tt class="docutils literal"><span class="pre">a</span></tt> for exponents that can be digits in base <tt class="docutils literal"><span class="pre">m</span></tt>. Also, the <tt class="docutils literal"><span class="pre">_digits_of_n</span></tt> is the following generalization of <tt class="docutils literal"><span class="pre">_bits_of_n</span></tt>:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">_digits_of_n</span>(n, b): <span style="color: #7f007f">&quot;&quot;&quot; Return the list of the digits in the base &#39;b&#39;</span> <span style="color: #7f007f"> representation of n, from LSB to MSB</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> digits = [] <span style="color: #00007f; font-weight: bold">while</span> n: digits.append(n % b) n /= b <span style="color: #00007f; font-weight: bold">return</span> digits </pre></div> </div> <div class="section" id="performance-of-the-k-ary-method"> <h3>Performance of the k-ary method</h3> <p>In my tests, the k-ary LR method with <tt class="docutils literal"><span class="pre">k</span> <span class="pre">=</span> <span class="pre">5</span></tt> is about 25% faster than the binary LR method, and is within 20% of the built-in <tt class="docutils literal"><span class="pre">pow</span></tt> function.</p> <p>Experimenting with the value of <tt class="docutils literal"><span class="pre">k</span></tt> affects these results, but 5 seems to be a good value that produces the best performance in most cases. This is probably why it's also used as the value of <tt class="docutils literal"><span class="pre">k</span></tt> in the implementation of <tt class="docutils literal"><span class="pre">pow</span></tt>.</p> </div> <div class="section" id="python-s-built-in-pow"> <h3>Python's built-in <tt class="docutils literal"><span class="pre">pow</span></tt></h3> <p>I've mentioned Python's <tt class="docutils literal"><span class="pre">pow</span></tt> function several times in this article. The Python version I'm talking about is 2.5, though I doubt this functionality has changed in 2.6 or 3.0. The <tt class="docutils literal"><span class="pre">pow</span></tt> I'm interested in is implemented in the <tt class="docutils literal"><span class="pre">long_pow</span></tt> function in <tt class="docutils literal"><span class="pre">objects/longobject.c</span></tt> in the Python source code distribution. As mentioned in <a class="footnote-reference" href="#id12" id="id7"></a>, it uses the binary LR method for small exponents, and the k-ary LR method for large exponents.</p> <p>These implementations follow closely algorithms 14.79 and 14.82 in the excellent <em>Handbook of Applied Cryptography</em>, which is freely <a class="reference external" href="http://www.cacr.math.uwaterloo.ca/hac/">available online</a>.</p> </div> <div class="section" id="summary"> <h3>Summary</h3> <p>As we've seen, exponentiation and modular exponentiation are one of those applications in which an efficient algorithm is required for feasibility. Using the trivial/naive algorithms is possible only for small cases which aren't very interesting. To process realistically large numbers (such as the ones required for cryptographic algorithms), one needs powerful methods in his toolbox.</p> <div align="center" class="align-center"><img class="align-center" src="https://eli.thegreenplace.net/images/hline.jpg" style="width: 320px; height: 5px;" /></div> <table class="docutils footnote" frame="void" id="id8" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>For instance, <img src="https://eli.thegreenplace.net/images/math/ec4a5eb840bac711d2930bed8d05d6f60f08050b.gif" /> is a 4772-digit number.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id9" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>Modular exponentiation is essential for the <a class="reference external" href="http://en.wikipedia.org/wiki/RSA">RSA algorithm</a>, for example.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id10" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td>To be a bit more rigorous, we start with <img src="https://eli.thegreenplace.net/images/math/a24e63eaa528ead7411690545b6f1525adf6fd11.gif" />. This means that <img src="https://eli.thegreenplace.net/images/math/b4554f9f77d7fc2c17e651e01166c8f1735489e3.gif" />, so also <img src="https://eli.thegreenplace.net/images/math/2486c350f0b60040b3f162ebeead1ff9ebe5f7d0.gif" />. Similarly <img src="https://eli.thegreenplace.net/images/math/4efa4bb9a60158d9497274087d139820d3c827d6.gif" />, so also <img src="https://eli.thegreenplace.net/images/math/84bf619e6f68d1d3bc9dc65522dca7e296a46dc0.gif" />. Adding these two we get <img src="https://eli.thegreenplace.net/images/math/a4d655317021fe3417f78df6ac22ee553cc833e7.gif" />), which means that <img src="https://eli.thegreenplace.net/images/math/13b1a1a7a7cab72a42640bc57c99dff9f9fc78dc.gif" />.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id11" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id4"></a></td><td>Using the 3-argument form of <tt class="docutils literal"><span class="pre">pow</span></tt>, you can perform modular exponentiation.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id12" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id5"></a></td><td><tt class="docutils literal"><span class="pre">FIVEARY_CUTOFF</span></tt> in the code of <tt class="docutils literal"><span class="pre">pow</span></tt> is set to 8. This means that for exponents with more than 8 digits, a special 5-ary algorithm is used. For smaller exponents, the regular LR binary method is used - just like the one I presented, just coded in C.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id13" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id6"></a></td><td>Note that for <tt class="docutils literal"><span class="pre">m</span> <span class="pre">=</span> <span class="pre">2</span></tt> this is the familiar binary LR method.</td></tr> </tbody> </table> </div> Efficient integer exponentiation algorithms2009-03-21T19:10:57-07:002009-03-21T19:10:57-07:00Eli Benderskytag:eli.thegreenplace.net,2009-03-21:/2009/03/21/efficient-integer-exponentiation-algorithms <p>Did you ever think about the most efficient method to perform integer exponentiation, that is, raising an integer <tt class="docutils literal"><span class="pre">a</span></tt> to an integer power <tt class="docutils literal"><span class="pre">b</span></tt>, when either <tt class="docutils literal"><span class="pre">a</span></tt> or <tt class="docutils literal"><span class="pre">b</span></tt>, or both, are rather large?</p> <div class="section" id="repeated-multiplication"> <h3>Repeated multiplication</h3> <p>The naive method is, of course, repeated multiplications. <img src="https://eli.thegreenplace.net/images/math/fde22a2136b496ef6f8dca2c4278792da0e77678.gif" /> is <tt class="docutils literal"><span class="pre">a</span></tt> multiplied by itself <tt class="docutils literal"><span class="pre">b …</span></tt></p></div> <p>Did you ever think about the most efficient method to perform integer exponentiation, that is, raising an integer <tt class="docutils literal"><span class="pre">a</span></tt> to an integer power <tt class="docutils literal"><span class="pre">b</span></tt>, when either <tt class="docutils literal"><span class="pre">a</span></tt> or <tt class="docutils literal"><span class="pre">b</span></tt>, or both, are rather large?</p> <div class="section" id="repeated-multiplication"> <h3>Repeated multiplication</h3> <p>The naive method is, of course, repeated multiplications. <img src="https://eli.thegreenplace.net/images/math/fde22a2136b496ef6f8dca2c4278792da0e77678.gif" /> is <tt class="docutils literal"><span class="pre">a</span></tt> multiplied by itself <tt class="docutils literal"><span class="pre">b</span></tt> times. Here's how it's coded in my pseudo-code of choice, Python:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">expt_mul</span>(a, b): r = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">for</span> i <span style="color: #0000aa">in</span> <span style="color: #00007f">xrange</span>(b): r *= a <span style="color: #00007f; font-weight: bold">return</span> r </pre></div> <p>Is this efficient? Not really, as we require <tt class="docutils literal"><span class="pre">b</span></tt> multiplications, and as I said earlier <tt class="docutils literal"><span class="pre">b</span></tt> can be very large (think number theory algorithms). In fact, there's a <em>much</em> more efficient method.</p> </div> <div class="section" id="exponentiation-by-squaring"> <h3>Exponentiation by squaring</h3> <p>The efficient exponentiation algorithm is based on the simple observation that for an even <tt class="docutils literal"><span class="pre">b</span></tt>, <img src="https://eli.thegreenplace.net/images/math/4d308eabc552e0744ecb53ebb55aeb7b5f6705da.gif" />. This may not look very brilliant, but now consider the following recursive definition:</p> <p><img src="https://eli.thegreenplace.net/images/math/ql_88b3da5b51bbcac021cceb33f708a130_l3.png" /></p> <p>The case of odd <tt class="docutils literal"><span class="pre">b</span></tt> is trivial, as it's obvious that <img src="https://eli.thegreenplace.net/images/math/6c6cc601fdd47eecef30907127482f149d2ed366.gif" />. So now we can compute <img src="https://eli.thegreenplace.net/images/math/fde22a2136b496ef6f8dca2c4278792da0e77678.gif" /> by doing only <tt class="docutils literal"><span class="pre">log(b)</span></tt> squarings and no more than <tt class="docutils literal"><span class="pre">log(b)</span></tt> multiplications, instead of <tt class="docutils literal"><span class="pre">b</span></tt> multiplications - and this is a vast improvement for a large <tt class="docutils literal"><span class="pre">b</span></tt>.</p> <p>This algorithm can be coded in a straightforward way:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">expt_rec</span>(a, b): <span style="color: #00007f; font-weight: bold">if</span> b == <span style="color: #007f7f">0</span>: <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">elif</span> b % <span style="color: #007f7f">2</span> == <span style="color: #007f7f">1</span>: <span style="color: #00007f; font-weight: bold">return</span> a * expt_rec(a, b - <span style="color: #007f7f">1</span>) <span style="color: #00007f; font-weight: bold">else</span>: p = expt_rec(a, b / <span style="color: #007f7f">2</span>) <span style="color: #00007f; font-weight: bold">return</span> p * p </pre></div> <p>Indeed, this algorithm is about 10 times faster than the naive one for exponents in the order of a few thousands. When the exponent is about 100K, it is more than 100 times faster, and the difference keeps growing for larger exponents.</p> </div> <div class="section" id="an-iterative-implementation"> <h3>An iterative implementation</h3> <p>It will be useful to develop an iterative implementation for the fast exponentiation algorithm. For this purpose, however, we need to dive into some mathematics.</p> <p>We can represent the exponent <tt class="docutils literal"><span class="pre">b</span></tt> as:</p> <p><img src="https://eli.thegreenplace.net/images/math/ecbde38c735d854eb05d28e2b9b7e4b034c8cb0f.gif" /></p> <p>Where <img src="https://eli.thegreenplace.net/images/math/052ed07ef4a94acc0a6e5e21d68a64e602538236.gif" /> are the bits (0 or 1) of <tt class="docutils literal"><span class="pre">b</span></tt> in base 2. <img src="https://eli.thegreenplace.net/images/math/fde22a2136b496ef6f8dca2c4278792da0e77678.gif" /> is then:</p> <p><img src="https://eli.thegreenplace.net/images/math/d9cf6b4f10b1f4b4dce1f62d3411a4bdcdfc6fdb.gif" /></p> <p>Or, in other words:</p> <p><img src="https://eli.thegreenplace.net/images/math/87d95a4f7ba3d34779c01387e7c7b52985e48e36.gif" /> for <tt class="docutils literal"><span class="pre">k</span></tt> such that <img src="https://eli.thegreenplace.net/images/math/de3451bd16070e6cbfe61f85a2f5a48798db4399.gif" /></p> <p><img src="https://eli.thegreenplace.net/images/math/b5558f10d4f57a6c991f5bf4702e2a807b11eb9d.gif" /> can be computed by repetitive squaring, and moreover, we can reuse the result from a lower <tt class="docutils literal"><span class="pre">k</span></tt> to compute a higher <tt class="docutils literal"><span class="pre">k</span></tt>. This directly translates into the following iterative algorithm:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">expt_bin_rl</span>(a, b): r = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">while</span> <span style="color: #007f7f">1</span>: <span style="color: #00007f; font-weight: bold">if</span> b % <span style="color: #007f7f">2</span> == <span style="color: #007f7f">1</span>: r *= a b /= <span style="color: #007f7f">2</span> <span style="color: #00007f; font-weight: bold">if</span> b == <span style="color: #007f7f">0</span>: <span style="color: #00007f; font-weight: bold">break</span> a *= a <span style="color: #00007f; font-weight: bold">return</span> r </pre></div> <p>To understand how the algorithm works, try to relate it to the formula from above. Using a standard &quot;divide by two and look at the LSB&quot; loop, the exponent <tt class="docutils literal"><span class="pre">b</span></tt> is broken into its binary representation. The lowest bits of <tt class="docutils literal"><span class="pre">b</span></tt> are considered first. <tt class="docutils literal"><span class="pre">a</span></tt> is continually squared to hold <img src="https://eli.thegreenplace.net/images/math/b5558f10d4f57a6c991f5bf4702e2a807b11eb9d.gif" />, and is multiplied into the result only when <img src="https://eli.thegreenplace.net/images/math/de3451bd16070e6cbfe61f85a2f5a48798db4399.gif" />.</p> <p>This algorithm is called <em>right-to-left binary exponentiation</em>, because the binary representation of the exponent is computed from right to left (from the LSB to the MSB) <a class="footnote-reference" href="#id4" id="id1"></a>.</p> <p>A related algorithm can be developed if we prefer to look at the binary representation of the exponent from left to right.</p> </div> <div class="section" id="left-to-right-binary-exponentiation"> <h3>Left-to-right binary exponentiation</h3> <p>Going over the bits of <tt class="docutils literal"><span class="pre">b</span></tt> from MSB to LSB, we get:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">expt_bin_lr</span>(a, b): r = <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">for</span> bit <span style="color: #0000aa">in</span> reversed(_bits_of_n(b)): r *= r <span style="color: #00007f; font-weight: bold">if</span> bit == <span style="color: #007f7f">1</span>: r *= a <span style="color: #00007f; font-weight: bold">return</span> r </pre></div> <p>Where <tt class="docutils literal"><span class="pre">_bits_of_n</span></tt> is a method returning the binary representation of its argument as an array of bits from LSB to MSB (which is then reversed, as you see):</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">_bits_of_n</span>(n): <span style="color: #7f007f">&quot;&quot;&quot; Return the list of the bits in the binary</span> <span style="color: #7f007f"> representation of n, from LSB to MSB</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> bits = [] <span style="color: #00007f; font-weight: bold">while</span> n: bits.append(n % <span style="color: #007f7f">2</span>) n /= <span style="color: #007f7f">2</span> <span style="color: #00007f; font-weight: bold">return</span> bits </pre></div> <p>Rationale: consider how you &quot;build&quot; a number from its binary representation when seen from MSB to LSB. You begin with 1 for the MSB (which is always 1, by definition, for numbers &gt; 0). For each new bit you see you double the result, and if the bit is 1, you add 1 <a class="footnote-reference" href="#id5" id="id2"></a>.</p> <p>For example consider the binary 1101. Begin with 1 for the leftmost 1. We have another bit, so we double. That's 2. Now, the new bit is 1, so we add 1, that's 3. We have another bit, so again double, that's 6. The new bit is 0, so nothing is added. And we have one more bit, so once again double, getting 12, and finally adding 1, getting 13. Indeed, 1101 is the binary representation of 13.</p> <p>Back to the exponentiation now. As you see in the code of <tt class="docutils literal"><span class="pre">expt_bin_lr</span></tt>, the binary representation of the exponent is read from MSB to LSB. Since this is the exponent, each &quot;doubling&quot; from the rationale above is squaring, and each &quot;adding 1&quot; is multiplying by the number itself. Hence, the algorithm works.</p> </div> <div class="section" id="performance"> <h3>Performance</h3> <p>As I've mentioned, the squaring method of exponentiation is far more efficient than the naive method of repeated multiplication. In the tests I ran, the iterative left-to-right method is about the same speed as the recursive one, while the iterative right-to-left method is somewhat slower. In fact, both the recursive and the iterative left-to-right methods are so efficient they're completely on par with Python's built-in <tt class="docutils literal"><span class="pre">pow</span></tt> method <a class="footnote-reference" href="#id6" id="id3"></a>.</p> <p>This is surprising, as I'd actually expect the right-to-left method to be faster, because it skips the reversing of bits when computing the binary representation of the exponent. I'd also expect the built-in <tt class="docutils literal"><span class="pre">pow</span></tt> to be faster.</p> <p>However, thinking harder for a moment, I think I can see why this happens. The RL (right-to-left) version has to multiply larger numbers at all stages, because LR sometimes multiplies by <code>a</code> itself, which is relatively small. Python's bignum implementation can multiply by a small number faster, and this compensates for the need to reverse the bit list. I'll come back to this issue when I'll discuss modular exponentiation. But this is a topic for another article...</p> <div align="center" class="align-center"><img class="align-center" src="https://eli.thegreenplace.net/images/hline.jpg" style="width: 320px; height: 5px;" /></div> <table class="docutils footnote" frame="void" id="id4" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id1"></a></td><td>From the looks of it (featuring the binary representation) you'd think this is a modern algorithm. Not at all! According to Knuth, it was first mentioned by the Persian mathematician Jamshīd al-Kāshī in 1427. The left-to-right method presented later in the article is even more ancient - it appeared in a Hindu book in about 200 BC.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id5" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id2"></a></td><td>This holds for any base, by the way. You can similarly build a number from its decimal digits by multiplying by 10 for each digit you see and adding the digit, at each step.</td></tr> </tbody> </table> <table class="docutils footnote" frame="void" id="id6" rules="none"> <colgroup><col class="label" /><col /></colgroup> <tbody valign="top"> <tr><td class="label"><a class="fn-backref" href="#id3"></a></td><td><tt class="docutils literal"><span class="pre">pow</span></tt> is coded in C and uses the iterative left-to-right method I described with some optimizations and complicated tricks.</td></tr> </tbody> </table> </div> Computing modular square roots in Python2009-03-07T11:59:08-08:002009-03-07T11:59:08-08:00Eli Benderskytag:eli.thegreenplace.net,2009-03-07:/2009/03/07/computing-modular-square-roots-in-python <p>Consider the congruence of the form:</p> <p> <p><img src="https://eli.thegreenplace.net/images/math/27eafd28fcb458b435c774d120c67c85b4f381c8.gif" class="align-center" /></p> </p> <p><tt class="docutils literal"><span class="pre">n</span></tt> is a <em>quadradic residue (mod p)</em>. What is <tt class="docutils literal"><span class="pre">x</span></tt>? In normal arithmetic, this is equivalent to finding the square root of a number. In modular arithmetic, <tt class="docutils literal"><span class="pre">x</span></tt> is the <em>modular square root</em> of <tt class="docutils literal"><span class="pre">n</span></tt> modulo <tt class="docutils literal"><span class="pre">p</span></tt>.</p> <p>Now, in the general case, this is …</p> <p>Consider the congruence of the form:</p> <p> <p><img src="https://eli.thegreenplace.net/images/math/27eafd28fcb458b435c774d120c67c85b4f381c8.gif" class="align-center" /></p> </p> <p><tt class="docutils literal"><span class="pre">n</span></tt> is a <em>quadradic residue (mod p)</em>. What is <tt class="docutils literal"><span class="pre">x</span></tt>? In normal arithmetic, this is equivalent to finding the square root of a number. In modular arithmetic, <tt class="docutils literal"><span class="pre">x</span></tt> is the <em>modular square root</em> of <tt class="docutils literal"><span class="pre">n</span></tt> modulo <tt class="docutils literal"><span class="pre">p</span></tt>.</p> <p>Now, in the general case, this is a very difficult problem to solve. In fact, it's equivalent to integer factorization, because no efficient algorithm is known to find the modular square root modulo a composite number, and if the modulo is composite it has to be factored first.</p> <p>But when <tt class="docutils literal"><span class="pre">p</span></tt> is prime, an efficient polynomial algorithm exists for computing <tt class="docutils literal"><span class="pre">x</span></tt>. This is the <a class="reference external" href="http://en.wikipedia.org/wiki/Shanks-Tonelli_algorithm">Tonelli-Shanks algorithm.</a></p> <p>Computing modular square roots is probably not one of those things you do daily, but I ran into it while solving a Project Euler problem. So I'm posting the Python implementation of the Tonelli-Shanks algorithm here. It is based on the explanation in the paper <em>&quot;Square roots from 1; 24, 51, 10 to Dan Shanks&quot;</em> by <a class="reference external" href="http://www.math.vt.edu/people/brown/doc.html">Ezra Brown</a>, as I found the Wikipedia algorithm hard to follow.</p> <p>The code is tested, and as far as I can tell works correctly and efficiently:</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">modular_sqrt</span>(a, p): <span style="color: #7f007f">&quot;&quot;&quot; Find a quadratic residue (mod p) of &#39;a&#39;. p</span> <span style="color: #7f007f"> must be an odd prime.</span> <span style="color: #7f007f"> Solve the congruence of the form:</span> <span style="color: #7f007f"> x^2 = a (mod p)</span> <span style="color: #7f007f"> And returns x. Note that p - x is also a root.</span> <span style="color: #7f007f"> 0 is returned is no square root exists for</span> <span style="color: #7f007f"> these a and p.</span> <span style="color: #7f007f"> The Tonelli-Shanks algorithm is used (except</span> <span style="color: #7f007f"> for some simple cases in which the solution</span> <span style="color: #7f007f"> is known from an identity). This algorithm</span> <span style="color: #7f007f"> runs in polynomial time (unless the</span> <span style="color: #7f007f"> generalized Riemann hypothesis is false).</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> <span style="color: #007f00"># Simple cases</span> <span style="color: #007f00">#</span> <span style="color: #00007f; font-weight: bold">if</span> legendre_symbol(a, p) != <span style="color: #007f7f">1</span>: <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #007f7f">0</span> <span style="color: #00007f; font-weight: bold">elif</span> a == <span style="color: #007f7f">0</span>: <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #007f7f">0</span> <span style="color: #00007f; font-weight: bold">elif</span> p == <span style="color: #007f7f">2</span>: <span style="color: #00007f; font-weight: bold">return</span> 0 <span style="color: #00007f; font-weight: bold">elif</span> p % <span style="color: #007f7f">4</span> == <span style="color: #007f7f">3</span>: <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #00007f">pow</span>(a, (p + <span style="color: #007f7f">1</span>) / <span style="color: #007f7f">4</span>, p) <span style="color: #007f00"># Partition p-1 to s * 2^e for an odd s (i.e.</span> <span style="color: #007f00"># reduce all the powers of 2 from p-1)</span> <span style="color: #007f00">#</span> s = p - <span style="color: #007f7f">1</span> e = <span style="color: #007f7f">0</span> <span style="color: #00007f; font-weight: bold">while</span> s % <span style="color: #007f7f">2</span> == <span style="color: #007f7f">0</span>: s /= <span style="color: #007f7f">2</span> e += <span style="color: #007f7f">1</span> <span style="color: #007f00"># Find some &#39;n&#39; with a legendre symbol n|p = -1.</span> <span style="color: #007f00"># Shouldn&#39;t take long.</span> <span style="color: #007f00">#</span> n = <span style="color: #007f7f">2</span> <span style="color: #00007f; font-weight: bold">while</span> legendre_symbol(n, p) != -<span style="color: #007f7f">1</span>: n += <span style="color: #007f7f">1</span> <span style="color: #007f00"># Here be dragons!</span> <span style="color: #007f00"># Read the paper &quot;Square roots from 1; 24, 51,</span> <span style="color: #007f00"># 10 to Dan Shanks&quot; by Ezra Brown for more</span> <span style="color: #007f00"># information</span> <span style="color: #007f00">#</span> <span style="color: #007f00"># x is a guess of the square root that gets better</span> <span style="color: #007f00"># with each iteration.</span> <span style="color: #007f00"># b is the &quot;fudge factor&quot; - by how much we&#39;re off</span> <span style="color: #007f00"># with the guess. The invariant x^2 = ab (mod p)</span> <span style="color: #007f00"># is maintained throughout the loop.</span> <span style="color: #007f00"># g is used for successive powers of n to update</span> <span style="color: #007f00"># both a and b</span> <span style="color: #007f00"># r is the exponent - decreases with each update</span> <span style="color: #007f00">#</span> x = <span style="color: #00007f">pow</span>(a, (s + <span style="color: #007f7f">1</span>) / <span style="color: #007f7f">2</span>, p) b = <span style="color: #00007f">pow</span>(a, s, p) g = <span style="color: #00007f">pow</span>(n, s, p) r = e <span style="color: #00007f; font-weight: bold">while</span> <span style="color: #00007f">True</span>: t = b m = <span style="color: #007f7f">0</span> <span style="color: #00007f; font-weight: bold">for</span> m <span style="color: #0000aa">in</span> <span style="color: #00007f">xrange</span>(r): <span style="color: #00007f; font-weight: bold">if</span> t == <span style="color: #007f7f">1</span>: <span style="color: #00007f; font-weight: bold">break</span> t = <span style="color: #00007f">pow</span>(t, <span style="color: #007f7f">2</span>, p) <span style="color: #00007f; font-weight: bold">if</span> m == <span style="color: #007f7f">0</span>: <span style="color: #00007f; font-weight: bold">return</span> x gs = <span style="color: #00007f">pow</span>(g, <span style="color: #007f7f">2</span> ** (r - m - <span style="color: #007f7f">1</span>), p) g = (gs * gs) % p x = (x * gs) % p b = (b * g) % p r = m <span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">legendre_symbol</span>(a, p): <span style="color: #7f007f">&quot;&quot;&quot; Compute the Legendre symbol a|p using</span> <span style="color: #7f007f"> Euler&#39;s criterion. p is a prime, a is</span> <span style="color: #7f007f"> relatively prime to p (if p divides</span> <span style="color: #7f007f"> a, then a|p = 0)</span> <span style="color: #7f007f"> Returns 1 if a has a square root modulo</span> <span style="color: #7f007f"> p, -1 otherwise.</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> ls = <span style="color: #00007f">pow</span>(a, (p - <span style="color: #007f7f">1</span>) / <span style="color: #007f7f">2</span>, p) <span style="color: #00007f; font-weight: bold">return</span> -<span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">if</span> ls == p - <span style="color: #007f7f">1</span> <span style="color: #00007f; font-weight: bold">else</span> ls </pre></div> Rabin-Miller primality test implementation2009-02-21T12:19:42-08:002009-02-21T12:19:42-08:00Eli Benderskytag:eli.thegreenplace.net,2009-02-21:/2009/02/21/rabin-miller-primality-test-implementation <p>Here's a fairly efficient Python (2.5) and well-documented implementation of the <a class="reference external" href="http://mathworld.wolfram.com/Rabin-MillerStrongPseudoprimeTest.html">Rabin-Miller primality test</a>, based on section 33.8 in CLR's <em>Introduction to Algorithms</em>. Due to Python's built-in arbitrary precision arithmetic, this works for numbers of any size.</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">from</span> <span style="color: #00007f">random</span> <span style="color: #00007f; font-weight: bold">import</span> randint <span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">_bits_of_n</span>(n): <span style="color: #7f007f">&quot;&quot;&quot; Return the list of …</span></pre></div> <p>Here's a fairly efficient Python (2.5) and well-documented implementation of the <a class="reference external" href="http://mathworld.wolfram.com/Rabin-MillerStrongPseudoprimeTest.html">Rabin-Miller primality test</a>, based on section 33.8 in CLR's <em>Introduction to Algorithms</em>. Due to Python's built-in arbitrary precision arithmetic, this works for numbers of any size.</p> <div class="highlight"><pre><span style="color: #00007f; font-weight: bold">from</span> <span style="color: #00007f">random</span> <span style="color: #00007f; font-weight: bold">import</span> randint <span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">_bits_of_n</span>(n): <span style="color: #7f007f">&quot;&quot;&quot; Return the list of the bits in the binary</span> <span style="color: #7f007f"> representation of n, from LSB to MSB</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> bits = [] <span style="color: #00007f; font-weight: bold">while</span> n: bits.append(n % <span style="color: #007f7f">2</span>) n /= <span style="color: #007f7f">2</span> <span style="color: #00007f; font-weight: bold">return</span> bits <span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">_MR_composite_witness</span>(a, n): <span style="color: #7f007f">&quot;&quot;&quot; Witness functions for the Miller-Rabin</span> <span style="color: #7f007f"> test. If &#39;a&#39; can be used to prove that</span> <span style="color: #7f007f"> &#39;n&#39; is composite, return True. If False</span> <span style="color: #7f007f"> is returned, there&#39;s high (though &lt; 1)</span> <span style="color: #7f007f"> probability that &#39;n&#39; is prime.</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> rem = <span style="color: #007f7f">1</span> <span style="color: #007f00"># Computes a^(n-1) mod n, using modular</span> <span style="color: #007f00"># exponentation by repeative squaring.</span> <span style="color: #007f00">#</span> <span style="color: #00007f; font-weight: bold">for</span> b <span style="color: #0000aa">in</span> reversed(_bits_of_n(n - <span style="color: #007f7f">1</span>)): x = rem rem = (rem * rem) % n <span style="color: #00007f; font-weight: bold">if</span> rem == <span style="color: #007f7f">1</span> <span style="color: #0000aa">and</span> x != <span style="color: #007f7f">1</span> <span style="color: #0000aa">and</span> x != n - <span style="color: #007f7f">1</span>: <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #00007f">True</span> <span style="color: #00007f; font-weight: bold">if</span> b == <span style="color: #007f7f">1</span>: rem = (rem * a) % n <span style="color: #00007f; font-weight: bold">if</span> rem != <span style="color: #007f7f">1</span>: <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #00007f">True</span> <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #00007f">False</span> <span style="color: #00007f; font-weight: bold">def</span> <span style="color: #00007f">isprime_MR</span>(n, trials=<span style="color: #007f7f">6</span>): <span style="color: #7f007f">&quot;&quot;&quot; Determine whether n is prime using the</span> <span style="color: #7f007f"> probabilistic Miller-Rabin test. Follows</span> <span style="color: #7f007f"> the procedure described in section 33.8</span> <span style="color: #7f007f"> in CLR&#39;s Introduction to Algorithms</span> <span style="color: #7f007f"> trials:</span> <span style="color: #7f007f"> The amount of trials of the test.</span> <span style="color: #7f007f"> A larger amount of trials increases</span> <span style="color: #7f007f"> the chances of a correct answer.</span> <span style="color: #7f007f"> 6 is safe enough for all practical</span> <span style="color: #7f007f"> purposes.</span> <span style="color: #7f007f"> &quot;&quot;&quot;</span> <span style="color: #00007f; font-weight: bold">if</span> n &lt; <span style="color: #007f7f">2</span>: <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #00007f">False</span> <span style="color: #00007f; font-weight: bold">for</span> ntrial <span style="color: #0000aa">in</span> <span style="color: #00007f">xrange</span>(trials): <span style="color: #00007f; font-weight: bold">if</span> _MR_composite_witness(randint(<span style="color: #007f7f">1</span>, n - <span style="color: #007f7f">1</span>), n): <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #00007f">False</span> <span style="color: #00007f; font-weight: bold">return</span> <span style="color: #00007f">True</span> </pre></div> <p>The function you should call is <tt class="docutils literal"><span class="pre">isprime_MR</span></tt>.</p> <p>Although this test is probabilistic, the chances of it erring are extremely low. According to Bruce Schneier in &quot;Applied Cryptography&quot;, the chances of error for a 256-bit number with 6 trials are less than one in <img src="https://eli.thegreenplace.net/images/math/dced696965fcd541e19ed68b16f8b99fd7bdbead.gif" /> - this is <em>very low</em>.</p> <p>Therefore, you should always use this method instead of the naive one (trying do divide by all primes up to <img src="https://eli.thegreenplace.net/images/math/e13517ff4c4fef8f8f59a599e10028d5eebef947.gif" />), because it's much faster.</p> The limit of sin(h)/h, or deriving the sine function2009-01-13T21:45:45-08:002009-01-13T21:45:45-08:00Eli Benderskytag:eli.thegreenplace.net,2009-01-13:/2009/01/13/the-limit-of-sinhh-or-deriving-the-sine-function <strong>Deriving the sine</strong> It is a basic identity of calculus that <img src="https://eli.thegreenplace.net/images/math/fc929aff5b7aa92d35efbe7e60575e937bf49539.gif" />. But how does one prove it? Well, let's use the definition of derivatives (the symbol <img src="https://eli.thegreenplace.net/images/math/27d5482eebd075de44389774fce28c69f45c8a75.gif" /> is used instead of <img src="https://eli.thegreenplace.net/images/math/6d56447973863053dfb94416852d0392187be5b6.gif" /> for readability): <p><img src="https://eli.thegreenplace.net/images/math/d7aeb7a0b3c196328d4bb72c248cc619ed4e29d8.gif" class="align-center" /></p> Using a trigonometric identity and regrouping we'll get: <p><img src="https://eli.thegreenplace.net/images/math/553fe3c71b38857bcd459c7fc9ddd1afcdd0ec59.gif" class="align-center" /></p> Now, could we only prove that <img src="https://eli.thegreenplace.net/images/math/1565cd10a93e120bd7cb1671995265af8df1a79a.gif" /> and <img src="https://eli.thegreenplace.net/images/math/27e81183a47ce0516512e5b71508d74020b1e423.gif" /> as <img src="https://eli.thegreenplace.net/images/math/2d779337053618dd98d696c55aedf6ea2d4e286b.gif" />, we'd get … <strong>Deriving the sine</strong> It is a basic identity of calculus that <img src="https://eli.thegreenplace.net/images/math/fc929aff5b7aa92d35efbe7e60575e937bf49539.gif" />. But how does one prove it? Well, let's use the definition of derivatives (the symbol <img src="https://eli.thegreenplace.net/images/math/27d5482eebd075de44389774fce28c69f45c8a75.gif" /> is used instead of <img src="https://eli.thegreenplace.net/images/math/6d56447973863053dfb94416852d0392187be5b6.gif" /> for readability): <p><img src="https://eli.thegreenplace.net/images/math/d7aeb7a0b3c196328d4bb72c248cc619ed4e29d8.gif" class="align-center" /></p> Using a trigonometric identity and regrouping we'll get: <p><img src="https://eli.thegreenplace.net/images/math/553fe3c71b38857bcd459c7fc9ddd1afcdd0ec59.gif" class="align-center" /></p> Now, could we only prove that <img src="https://eli.thegreenplace.net/images/math/1565cd10a93e120bd7cb1671995265af8df1a79a.gif" /> and <img src="https://eli.thegreenplace.net/images/math/27e81183a47ce0516512e5b71508d74020b1e423.gif" /> as <img src="https://eli.thegreenplace.net/images/math/2d779337053618dd98d696c55aedf6ea2d4e286b.gif" />, we'd get the <img src="https://eli.thegreenplace.net/images/math/562597441eed562140c81684902007f6f275c940.gif" /> we want. But how do we prove those? <strong>L'Hopital's rule?</strong> At this point some people feel inclined to use the following "proof": <p><img src="https://eli.thegreenplace.net/images/math/189b37a4280947f43728a12a47048e9897058dee.gif" class="align-center" /></p> Using L'Hopital's rule, this is equivalent to: <p><img src="https://eli.thegreenplace.net/images/math/7a4a3dd39fc7be8922bcd31848a6d1e6ae927f3f.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/200d228c3e9c87a61824ce8524a4a714ec306270.gif" class="align-center" /></p> Which indeed goes to 1 as <img src="https://eli.thegreenplace.net/images/math/2d779337053618dd98d696c55aedf6ea2d4e286b.gif" /> There's nothing wrong with using L'Hopital's rule in general, but we can't use it here, because we're creating a circular argument! We can't just assume that <img src="https://eli.thegreenplace.net/images/math/fc929aff5b7aa92d35efbe7e60575e937bf49539.gif" /> (for applying L'Hopital) when we're trying to prove it! We'll have to find another method. <strong>A geometrical proof</strong> Consider this diagram: <p><img src="https://eli.thegreenplace.net/images/2009/01/circle_geometry.png" /></p> For simplicity, this is a unit circle (i.e. the length of QO is 1). QS is perpendicular to OP, and so is RP (which is a tangent). <img src="https://eli.thegreenplace.net/images/math/27d5482eebd075de44389774fce28c69f45c8a75.gif" /> is the angle QOS. From trigonometry, QS equals <img src="https://eli.thegreenplace.net/images/math/1dc91d45f0ae10b205595508b05921083b2842d7.gif" /> (QO is 1, recall) and OS equals <img src="https://eli.thegreenplace.net/images/math/c3668957974dc873ec58d32d59ec82eacff6d12b.gif" />. The triangles QOS and ROP are similar, so the ratio between RP and OP is the same as the ratio between QS and OS, which is <img src="https://eli.thegreenplace.net/images/math/33f736886b99c216de7b572c35a6d83a39a2d1b2.gif" />. Since OP is 1, RP equals <img src="https://eli.thegreenplace.net/images/math/33f736886b99c216de7b572c35a6d83a39a2d1b2.gif" />. Now let's consider the areas of triangles ROP and QOP, and the "pie" section of the circle defined by Q, O and P. The area of the triangle QOS is: <p><img src="https://eli.thegreenplace.net/images/math/c56c8c2dff1756112c3b48d244c2ed19f668e647.gif" class="align-center" /></p> The area of ROP is similarly <img src="https://eli.thegreenplace.net/images/math/5b691de08b5a69c4ee7c85e70e84e869fab29d0d.gif" />. What is the area of the pie QOP? We'll compute it as follows: <p><img src="https://eli.thegreenplace.net/images/math/c9cdd0869aad07b026ed5f32bd6ed6f7e2d070b2.gif" class="align-center" /></p> The area of a unit circle is <img src="https://eli.thegreenplace.net/images/math/6ac47b6d7372b4087583cfd048d20f4c1571f5cf.gif" />, and its circumference is <img src="https://eli.thegreenplace.net/images/math/0833718ca4569f36e84dbdc7742eaec65e49b150.gif" />, but how do we express the length of arc PQ? <strong>Defining radians</strong> Did you know why the units for angles used in calculus are almost exclusively radians? Because the radian is defined as follows: <blockquote> One radian is the angle subtended at the center of a circle by an arc that is equal in length to the radius of the circle.</blockquote> Here's a diagram (courtesy <a href="http://en.wikipedia.org/wiki/Radian">of Wikipedia</a>): <p><img src="https://eli.thegreenplace.net/images/2009/01/Radian.png" /></p> This is a very convenient definition that allows us to make computations without messing with too much <img src="https://eli.thegreenplace.net/images/math/6ac47b6d7372b4087583cfd048d20f4c1571f5cf.gif" />s. <strong>Back to the proof</strong> So getting back to our arc PQ, it equals simply to the angle <img src="https://eli.thegreenplace.net/images/math/27d5482eebd075de44389774fce28c69f45c8a75.gif" />, when that one is defined in radians. That's because by the definition of radian above, if the angle is one radian, the arc length is 1 (since that's the radius of the unit circle). Hence if the angle is <img src="https://eli.thegreenplace.net/images/math/27d5482eebd075de44389774fce28c69f45c8a75.gif" /> radians, the arc length is <img src="https://eli.thegreenplace.net/images/math/27d5482eebd075de44389774fce28c69f45c8a75.gif" /> times 1. So we have: <p><img src="https://eli.thegreenplace.net/images/math/c0d158c96bd0c2fdc07fb01bb1bae83b801a0b4e.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/118785c5117da59954c7046a50f4f5fd0c43b6b4.gif" class="align-center" /></p> Now comes the punchline of the proof. It is obvious that the area of the triangle QOP is always smaller than the area of the pie QOP, which in turn is always smaller than the large triangle ROP. Mathematically: <p><img src="https://eli.thegreenplace.net/images/math/f85352fc830617ce91d58d24c044851d164ffbf2.gif" class="align-center" /></p> Dividing this by <img src="https://eli.thegreenplace.net/images/math/c47071837f6a1a6f645f20fedd6b96907d72b8df.gif" />: <p><img src="https://eli.thegreenplace.net/images/math/2579bf5237d2120671276572ba15c959b3c935cf.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/4dcf73f7082d86caeb2f4a969d3f4e4f4310c5ed.gif" class="align-center" /></p> Now, if we let <img src="https://eli.thegreenplace.net/images/math/2d779337053618dd98d696c55aedf6ea2d4e286b.gif" />, then <img src="https://eli.thegreenplace.net/images/math/d7962bee27672d1b7eb54f7346f597f118ed4ed0.gif" />, and it follows that <img src="https://eli.thegreenplace.net/images/math/1565cd10a93e120bd7cb1671995265af8df1a79a.gif" /> by the squeeze theorem. Recall that we also have to prove that <img src="https://eli.thegreenplace.net/images/math/27e81183a47ce0516512e5b71508d74020b1e423.gif" />, but this is a simple step from <img src="https://eli.thegreenplace.net/images/math/1565cd10a93e120bd7cb1671995265af8df1a79a.gif" /> by using the identity: <p><img src="https://eli.thegreenplace.net/images/math/69fbcbe8f29f8a0ca0d00c3d38fe1e986ae6df69.gif" class="align-center" /></p> Now, if we go back to the limit we've developed for deriving the sine: <p><img src="https://eli.thegreenplace.net/images/math/553fe3c71b38857bcd459c7fc9ddd1afcdd0ec59.gif" class="align-center" /></p> Substituting the limits <img src="https://eli.thegreenplace.net/images/math/1565cd10a93e120bd7cb1671995265af8df1a79a.gif" /> and <img src="https://eli.thegreenplace.net/images/math/27e81183a47ce0516512e5b71508d74020b1e423.gif" /> here we get: <p><img src="https://eli.thegreenplace.net/images/math/fdfe2dc25c14b6822a443bfff309274d1d3bb3f5.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/edede5a60922621b26fcaaadbdf0d29ee2b4d7ff.gif" class="align-center" /></p> Variance of the sum of independent random variables2009-01-07T22:06:58-08:002009-01-07T22:06:58-08:00Eli Benderskytag:eli.thegreenplace.net,2009-01-07:/2009/01/07/variance-of-the-sum-of-independent-variables <p> Yesterday I was trying to brush up my skills in probability and came upon this formula on the Wikipedia page <a href="http://en.wikipedia.org/wiki/Variance">about variance</a>: </p> <p><img src="https://eli.thegreenplace.net/images/math/5a8707440af01e8319f02e80f8ea33d4600a4a4b.gif" class="align-center" /></p> <p> The article calls this the <em>Bienaymé formula</em> and gives neither proof nor a link to one. Googling this formula proved equally fruitless in terms of proofs. </p> <p> So, I …</p> <p> Yesterday I was trying to brush up my skills in probability and came upon this formula on the Wikipedia page <a href="http://en.wikipedia.org/wiki/Variance">about variance</a>: </p> <p><img src="https://eli.thegreenplace.net/images/math/5a8707440af01e8319f02e80f8ea33d4600a4a4b.gif" class="align-center" /></p> <p> The article calls this the <em>Bienaymé formula</em> and gives neither proof nor a link to one. Googling this formula proved equally fruitless in terms of proofs. </p> <p> So, I set out to find why this works. It took me a few hours of digging through books and removing dust from my University-learned probability skills of 8 years ago, but finally I've made it. Here's how. </p> <p> <em>Note: the Wikipedia article states the Bienaymé formula for uncorrelated variables. Here I'll prove the case of independent variables, which is a more useful and frequently used application of the formula. I'm also proving it for discrete random variables - the continuous case is equivalent.</em> </p> <h2>Expected value and variance</h2> <p> We'll start with a few definitions. Formally, the expected value of a (discrete) random variable X is defined by: </p> <p><img src="https://eli.thegreenplace.net/images/math/6e3bd6378c646ec0f285a69df7db72194f308f5b.gif" class="align-center" /></p> Where <img src="https://eli.thegreenplace.net/images/math/810bdf91cc65f953d130a2f239cee691fa024330.gif" /> is the <a href="http://en.wikipedia.org/wiki/Probability_mass_function">PMF</a> of X, <img src="https://eli.thegreenplace.net/images/math/30eaa4945a586b21346e14bd193b9914db6c2166.gif" />. For a function <img src="https://eli.thegreenplace.net/images/math/65405422ff71ebf2db437dbd89a41355f4f19183.gif" />: <p><img src="https://eli.thegreenplace.net/images/math/28e5d4f4c0d4dc5023e96687ce05e8851a8f8329.gif" class="align-center" /></p> <p> The variance of X is defined in terms of the expected value as: </p> <p><img src="https://eli.thegreenplace.net/images/math/fe590d0bcb58c4be73d18e751200721bbc402dc0.gif" class="align-center" /></p> <p> From this we can also obtain: </p> <p><img src="https://eli.thegreenplace.net/images/math/dfcad71ca591b4179d442332299b6a5d5963e628.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/03ccc00bf5863f84fa1c081dc26c4b450ee7afcc.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/c7e98eaab83b497e5716ffa78dcd80baff9d8c59.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/a10dd400a0c745f0f369d8919994ae0652b21024.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/8a43a908e597eeb8f279a521c4d833f3ac88f7f9.gif" class="align-center" /></p> Which is more convenient to use in some calculations. <h2>Linear function of a random variable</h2> <p> From the definitions given above it can be easily shown that given a linear function of a random variable: <img src="https://eli.thegreenplace.net/images/math/5abba821a5142e83482cea2117ec22b289d4a3d6.gif" />, the expected value and variance of Y are: </p> <p><img src="https://eli.thegreenplace.net/images/math/d8e69b52ca0d6dcb06ba7b8975266518278145e3.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/c420f86130e26c754241fd1f56cbe14cd86d1358.gif" class="align-center" /></p> <p> For the expected value, we can make a stronger claim for any g(x): </p> <p><img src="https://eli.thegreenplace.net/images/math/e4cdac3ba0c020100a9b98b82f3e0ac6b74b2b78.gif" class="align-center" /></p> <h2>Multiple random variables</h2> <p> When multiple random variables are involved, things start getting a bit more complicated. I'll focus on two random variables here, but this is easily extensible to N variables. Given two random variables that participate in an experiment, their joint PMF is: </p> <p><img src="https://eli.thegreenplace.net/images/math/955c39ca88717c6449629b633a7589a910c9555f.gif" class="align-center" /></p> <p> The joint PMF determines the probability of any event that can be specified in terms of the random variables X and Y. For example if A is the set of all pairs <img src="https://eli.thegreenplace.net/images/math/f09b5d4028feab230a8f9a4499e21a0b4db3ccce.gif" /> that have a certain property, then: </p> <p><img src="https://eli.thegreenplace.net/images/math/6b1ea5b82b9f9b2609973a2c0959ea2a44c80fd8.gif" class="align-center" /></p> <p> Note that from this PMF we can infer the PMF for a single variable, like this: </p> <p><img src="https://eli.thegreenplace.net/images/math/30eaa4945a586b21346e14bd193b9914db6c2166.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/60de9cadbeef9f0d1931ba799bc7790960ca3c61.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/fec2fcd06d4ac717101f02690278a8daa3f2e068.gif" class="align-center" /></p> <p> The expected value for functions of two variables naturally extends and takes the form: </p> <p><img src="https://eli.thegreenplace.net/images/math/0d00c9b76a37f91f0a2a1efe986e545ac7c90639.gif" class="align-center" /></p> <h2>Sum of random variables</h2> <p> Let's see how the sum of random variables behaves. From the previous formula: </p> <p><img src="https://eli.thegreenplace.net/images/math/a00cd8185a75e2256e4ec8a11d9c6fe7fd776d06.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/30e7d900e6261f0d926e1e0b34a7f8bf1d6bd9e8.gif" class="align-center" /></p> <p> But recall equation (1). The above simply equals to: </p> <p><img src="https://eli.thegreenplace.net/images/math/ecece91067e8b6d6ae4250b7b59f92fc367cf433.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/3b2b53ab4299f6a63867ed4b5645018abee982e5.gif" class="align-center" /></p> <p> We'll also want to prove that <img src="https://eli.thegreenplace.net/images/math/08167b14267064495a086afab04d97c83751184c.gif" />. This is only true for independent X and Y, so we'll have to make this assumption (assuming that they're independent means that <img src="https://eli.thegreenplace.net/images/math/9a31b8e023593b070b7bfdcfca07f2707d47265a.gif" />). </p> <p><img src="https://eli.thegreenplace.net/images/math/f7fec121582e25c06c1119e689ba6828d20883b2.gif" class="align-center" /></p> <p> By independence: </p> <p><img src="https://eli.thegreenplace.net/images/math/762a9e3e2ff39253bec8efb60e0eb17f8110af99.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/ee6ebbb2633e33514285d312a999a17354e6521c.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/0b6f7ad819416477a9e8a948c7a8c9d84be25c12.gif" class="align-center" /></p> <p> A very similar proof can show that for independent X and Y: </p> <p><img src="https://eli.thegreenplace.net/images/math/951d25c3af4d29c3cb9c1e69c4d546f2e7ac7521.gif" class="align-center" /></p> <p> For any functions g and h (because if X and Y are independent, so are g(X) and h(y)). Now, at last, we're ready to tackle the variance of X + Y. We start by expanding the definition of variance: </p> <p><img src="https://eli.thegreenplace.net/images/math/10b9e3c7ad7f5621b21440e0113ec31f11b4884d.gif" class="align-center" /></p> By (2): <p><img src="https://eli.thegreenplace.net/images/math/b9fa5c9ce518390c4e2fb4d672c29bcbfb2e91b7.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/74f08dfb9aa6322473e84f1afc42e1d9e53fac1f.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/45c1416ce8b6179109d77f80ac353251e525de2c.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/65ddaa4af89d23925dae7e3cbe4d7526df04ad25.gif" class="align-center" /></p> <p> Now, note that the random variables <img src="https://eli.thegreenplace.net/images/math/c5f6fd358d4c1cbb485216af24b715050ff8a121.gif" /> and <img src="https://eli.thegreenplace.net/images/math/3da3e48f75d2e03d9795c7617016701d1cd28b2c.gif" /> are independent, so: </p> <p><img src="https://eli.thegreenplace.net/images/math/29c181dd186163c688dcfec42ce8fb08e366cbfa.gif" class="align-center" /></p> But using (2) again: <p><img src="https://eli.thegreenplace.net/images/math/918d718f3654bfe32037276651f2f9e1c2472585.gif" class="align-center" /></p> <img src="https://eli.thegreenplace.net/images/math/ec36fba3bda13db973cd60f1910c7955f10757c2.gif" /> is obviously just <img src="https://eli.thegreenplace.net/images/math/769c095daa5985533efb5176c86007611e6f4eb5.gif" />, therefore the above reduces to 0. <p> So, coming back to the long expression for the variance of sums, the last term is 0, and we have: </p> <p><img src="https://eli.thegreenplace.net/images/math/97b8d284ae899aa4f6568479f30fb6c40cf7f13d.gif" class="align-center" /></p> <p><img src="https://eli.thegreenplace.net/images/math/2172289164dba259d1e5c274e2611e15862edecb.gif" class="align-center" /></p> <p> As I've mentioned before, proving this for the sum of two variables suffices, because the proof for N variables is a simple mathematical extension, and can be intuitively understood by means of a "mental induction". Therefore: </p> <p><img src="https://eli.thegreenplace.net/images/math/5a8707440af01e8319f02e80f8ea33d4600a4a4b.gif" class="align-center" /></p> <p> For N independent variables <img src="https://eli.thegreenplace.net/images/math/97fd495350d680b99411eaf425194e5b295465a6.gif" />. <img src="https://eli.thegreenplace.net/images/math/7b47d4175993a732aa2287de666a82273110f26e.gif" /> </p> Solution to the RC circuit puzzle2008-12-26T11:21:45-08:002008-12-26T11:21:45-08:00Eli Benderskytag:eli.thegreenplace.net,2008-12-26:/2008/12/26/solution-to-the-rc-circuit-puzzle Here, as promised, is the solution to the <a href="http://eli.thegreenplace.net/2008/12/22/an-rc-circuit-puzzle">RC circuit puzzle</a> I posted earlier this week. Let's look at the circuit again: <p><img src="https://eli.thegreenplace.net/images/2008/12/cap_resistor.png" /></p> The problem with my reasoning was the direction of current in the capacitor. I've quietly assumed that: <p><img src="https://eli.thegreenplace.net/images/math/bbdc708e5df1d3ca3312149e15dbecc98b8fea5a.gif" class="align-center" /></p> But this is wrong for the circuit above. Why? Because we … Here, as promised, is the solution to the <a href="http://eli.thegreenplace.net/2008/12/22/an-rc-circuit-puzzle">RC circuit puzzle</a> I posted earlier this week. Let's look at the circuit again: <p><img src="https://eli.thegreenplace.net/images/2008/12/cap_resistor.png" /></p> The problem with my reasoning was the direction of current in the capacitor. I've quietly assumed that: <p><img src="https://eli.thegreenplace.net/images/math/bbdc708e5df1d3ca3312149e15dbecc98b8fea5a.gif" class="align-center" /></p> But this is wrong for the circuit above. Why? Because we must obey the voltage & current directions we've chosen. In passive elements, the positive current flows from the higher voltage to the lower voltage, meaning that in our circuit: <p><img src="https://eli.thegreenplace.net/images/math/31dd3853efeb2e16318323bc12736da4de1277fa.gif" class="align-center" /></p> This small minus sign makes all the difference, and now the solution will be correct. Physically, the intuition is that the current here flows from a discharging capacitor, hence it's "against" the voltage direction. Had it been a capacitor-charging circuit, there would be no confusion. An RC circuit puzzle2008-12-22T22:16:05-08:002008-12-22T22:16:05-08:00Eli Benderskytag:eli.thegreenplace.net,2008-12-22:/2008/12/22/an-rc-circuit-puzzle If you're interested in electronics, you'll find the following simple "paradox" amusing. It's the usual case of "proving that 2+2=5". The fun is finding where the mistake in the reasoning is. Consider the following circuit: <p><img src="https://eli.thegreenplace.net/images/2008/12/cap_resistor.png" /></p> Assume that the capacitor is charged to some initial voltage before the switch … If you're interested in electronics, you'll find the following simple "paradox" amusing. It's the usual case of "proving that 2+2=5". The fun is finding where the mistake in the reasoning is. Consider the following circuit: <p><img src="https://eli.thegreenplace.net/images/2008/12/cap_resistor.png" /></p> Assume that the capacitor is charged to some initial voltage before the switch is closed. At time 0, the switch is closed. What is the current in the circuit as a function of time ? Let's solve it using the familiar RC circuit methods. We know that <img src="https://eli.thegreenplace.net/images/math/5e23343bb687c00a0eb8ce9ef60e95b356568127.gif" /> because of Kirchoff's voltage law. We'll differentiate both sides by time: $\dot{V}_{c}(t) = \dot{V}_{R}(t)$ We know that for a capacitor, the relation between current and voltage is: <p><img src="https://eli.thegreenplace.net/images/math/928acbef4f8f2eeb39d3c51ca68ab3e08279393f.gif" class="align-center" /></p> Substituting it into the equation above and also recalling that <img src="https://eli.thegreenplace.net/images/math/6809dfc8324feb51c746bc469c8bd7dbbe3ea32e.gif" />, we get: <p><img src="https://eli.thegreenplace.net/images/math/73fa9296a1e93050c9ba41b7bd8d5ddeaa1d84a6.gif" class="align-center" /></p> But the current through the capacitor and resistor is the same current, so this can be rewritten simply as: <p><img src="https://eli.thegreenplace.net/images/math/eba020de23be652ba084b91a27aa173aade8a360.gif" class="align-center" /></p> This is a simple first order differential equation, the solution of which is: <p><img src="https://eli.thegreenplace.net/images/math/46c7da1f88d1806982454f784d37742fbfa0c332.gif" class="align-center" /></p> For some initial current <img src="https://eli.thegreenplace.net/images/math/7dd1d81670e79a2861ab8214c079d2f03ee310a0.gif" />. But wait a second, how can the exponent be positive, won't it grow to infinity with time ? There's obviously a mistake here, somewhere. Can you find it ? This problem gave me some headache last night, and today I've successfully stumped a few co-workers with it. I'll post a solution in a couple of days. Posting mathematical formulae in a Wordpress blog2008-12-02T22:52:26-08:002008-12-02T22:52:26-08:00Eli Benderskytag:eli.thegreenplace.net,2008-12-02:/2008/12/02/posting-mathematical-formulae-in-a-wordpress-blog <p> <strong>(Update 11.07.2009: <a href="http://eli.thegreenplace.net/2009/07/11/posting-correctly-aligned-latex-formulae-in-a-wordpress-blog/">I've switched to another plugin</a>, but the rest of this post is still relevant)</strong> </p><p> When your blog often deals with technical matters, and especially math, it is very useful to be able to post complex mathematical formulae / equations. There's only so far that you can go …</p> <p> <strong>(Update 11.07.2009: <a href="http://eli.thegreenplace.net/2009/07/11/posting-correctly-aligned-latex-formulae-in-a-wordpress-blog/">I've switched to another plugin</a>, but the rest of this post is still relevant)</strong> </p><p> When your blog often deals with technical matters, and especially math, it is very useful to be able to post complex mathematical formulae / equations. There's only so far that you can go with "ASCII-equations" like a^2 + b^2 = c^2. Being able to write <img src="https://eli.thegreenplace.net/images/math/e215b13b5947314d303dde025db05c50eabfb9e8.gif" /> is so much nicer... </p><p> Several plugins exist for this in the world of WP. In the "simple" spectrum you can find an interface to <a href="http://www.xm1math.net/phpmathpublisher/">PhpMathPublisher</a>. But when it comes to mathematical equations, you can hardly compete with Latex, and while using it is more complex, this is the best path to take if you don't want to quickly run into limitations. Recall that Latex is used by 90% of academics to publish their papers packed with mathematical equations. The Latex syntax is widely accepted and quite standard among many implementations. </p><p> Enter the <a href="http://wordpress.org/extend/plugins/latex/installation/">WP latex plugin</a>. Just install it following the instructions and it will use the web service provided by <a href="http://wordpress.com">Wordpress.com</a> to render inline Latex equations into images for you. The images are stored into a local cache which means that once your post was generated and viewed once, the image is static and safe on your server (you can, of course, cancel this feature if you want to). </p><p> The syntax of the plugin is very simple and you can use both inline equations like <img src="https://eli.thegreenplace.net/images/math/b9dd84aaa5b3a778d39ea7b95f32fdeed4510389.gif" />, or larger equations centered and on a separate line: </p><p> <p><img src="https://eli.thegreenplace.net/images/math/5062ee491e280f17b99d2597642c9da6c40ed30c.gif" class="align-center" /></p> [This is, by the way, Gauss's <a href="http://en.wikipedia.org/wiki/Divergence_theorem">Divergence Theorem</a> which, I recall, was very useful in Calculus II] </p><p> If <a href="http://wordpress.com">Wordpress.com</a> ever ceases providing the Latex rendering service you can always switch to another - there are plenty. This is the real power of the Latex standard - many renderers will understand the same syntax. </p><p> If this isn't hard core enough for you, you can always install your own Latex service. <a href="http://www.forkosh.com/mathtex.html">mathtex</a> is a CGI script you can install on your server. It will communicate with a locally installed Latex program and an image renderer to generate images from you. The problem is - it's not very simple to install Latex on a shared hosting account. It's possible though, and many people have <a href="http://mcuprogramming.com/2007/03/30/installing-latex-on-web-hosts/">done it</a>. So if you don't feel "safe" enough using a remote web service for rendering equations, you can always spend some extra effort and roll your own. The WP latex plugin makes it easy to switch services. </p> Book review: "A Certain Ambiguity: A mathematical novel" by G. Suri and H. Bal2008-11-14T19:33:59-08:002008-11-14T19:33:59-08:00Eli Benderskytag:eli.thegreenplace.net,2008-11-14:/2008/11/14/book-review-a-certain-ambiguity-a-mathematical-novel-by-g-suri-and-h-bal <p> Wrapped in a thin plot, the authors set to reconcile mathematics and religious faith in this short 280-page book. And quite surprisingly, they do a far better job that one would expect. <p> <p> Seriously, any book that tries to dig this deep philisophically is an immediate suspect for some half-baked crappy …</p></p></p> <p> Wrapped in a thin plot, the authors set to reconcile mathematics and religious faith in this short 280-page book. And quite surprisingly, they do a far better job that one would expect. <p> <p> Seriously, any book that tries to dig this deep philisophically is an immediate suspect for some half-baked crappy ending, but "A Certain Ambiguity" manages to end actually leaving the reader thoughtful. At this, the authors had done a splendid job. </p> <p> The main character is Ravi, and Indian student in Stanford who enrolls in a math class named "Thinking about infinity". Together with the class's lecturer and a small group of friends, he engages on a quasi-philosophical, guided by court records of his grandfather's discussions with a judge in the early 20s. </p> <p> The book contains a lot of interesting math, and while most of it is on a basic level, the philosophical connections are well developed and very believable. The book could easily be a work of non-fiction, as its main theme is quite real and deals with epistemological questions real philosophers have struggled with throughout the centuries. It is unlikely to change your view of life, but it will induce some interesting thinking on important topics. </p> <p> [Spoiler] I was surprised to find out that this book does a good job of explaining faith to people with rational/mathematical view of life. However, it only rationalizes the core faith - judge Taylor's "creation axiom", which really can't be disproven. But, as judge Taylor tells Vijay, his deductive method is solid, and only his axioms are at question. The faith judge Taylor rationalizes as an axiom can not, in any way, connect to the modern monotheistic religions (not to mention the polytheistic ones), because it breaks down immediately as soon as the first deductions are made from it about actual human lives. Yes, that "everything must be created by something" is an axiom that has no refutation at the moment, but any attempt to prove from it that Jesus was born to a virgin and walked on water would have to trascend deductive methods. </p> <p> All in all, this book is really recommended. It actually made me think hard about the philosophical implications of basic math axioms, and encouraged me to read more on the subject. I couldn't possibly ask more of such a small book that can be easily finished in 2-3 sittings. </p> Intersection of 1D segments2008-08-15T11:22:21-07:002008-08-15T11:22:21-07:00Eli Benderskytag:eli.thegreenplace.net,2008-08-15:/2008/08/15/intersection-of-1d-segments <p>There is a simple mathematical problem that sometimes comes up in programming<sup class="footnote"><a href="#fn1" title="I ran into it while implementing a binary application format reader, that needed to support insertion data records. Each data record has a start and an end (memory address). The problem comes up when testing whether two records collide.">1</a></sup>. The problem is:</p> <blockquote> <p>Given two one-dimensional<sup class="footnote"><a href="#fn2" title="One-dimensional here means that they only have a single coordinate, i.e. all can be laid down on a line that's parallel to one of the axes.">2</a></sup> line segments, determine whether they intersect, i.e. have points in common.</p> </blockquote> <p>Here's a graphical representation of the problem. The two segments are drawn one above the other for demonstration …</p> <p>There is a simple mathematical problem that sometimes comes up in programming<sup class="footnote"><a href="#fn1" title="I ran into it while implementing a binary application format reader, that needed to support insertion data records. Each data record has a start and an end (memory address). The problem comes up when testing whether two records collide.">1</a></sup>. The problem is:</p> <blockquote> <p>Given two one-dimensional<sup class="footnote"><a href="#fn2" title="One-dimensional here means that they only have a single coordinate, i.e. all can be laid down on a line that's parallel to one of the axes.">2</a></sup> line segments, determine whether they intersect, i.e. have points in common.</p> </blockquote> <p>Here's a graphical representation of the problem. The two segments are drawn one above the other for demonstration purposes:</p> <img src="https://eli.thegreenplace.net/images/2008/08/twosegs.PNG" /> <p>At first sight, this looks like a problem with many annoying corner cases that takes a lot of dirty code to implement. But it turns out that the solution is actually very simple and clean. The two segments intersect if and only if <em>X2 >= Y1 and Y2 >= X1</em>. That's it.</p> <p>It may be difficult to convince yourself this works by simply looking at the image above, so here is another that makes it much clearer:</p> <img src="https://eli.thegreenplace.net/images/2008/08/manysegs.png" /> <p>In this image we see all the possibilities of the positions of the second segment relatively to the first. It should take only a few seconds to verify that the algorithm returns a correct result for all 5 cases.</p> <p>Here's Python code that implements this solution:</p> <pre lang="python"> def segments_intersect(x1, x2, y1, y2): # Assumes x1 <= x2 and y1 <= y2; if this assumption is not safe, the code # can be changed to have x1 being min(x1, x2) and x2 being max(x1, x2) and # similarly for the ys. return x2 >= y1 and y2 >= x1 </pre> <center><img src="https://eli.thegreenplace.net/images/hline.jpg" width="320" height="5" /></center> <p class="footnote" id="fn1"><sup>1</sup> I ran into it while implementing a binary application format reader, that needed to support insertion data records. Each data record has a start and an end (memory address). The problem comes up when testing whether two records collide.</p> <p class="footnote" id="fn2"><sup>2</sup> One-dimensional here means that they only have a single coordinate, i.e. all can be laid down on a line that's parallel to one of the axes.</p> Pythagorean - the theorem with most proofs ?2008-01-17T20:59:33-08:002008-01-17T20:59:33-08:00Eli Benderskytag:eli.thegreenplace.net,2008-01-17:/2008/01/17/pythagorean-the-theorem-with-most-proofs <a href="http://www.cut-the-knot.org/pythagoras/index.shtml">This page</a> shows 76 different proofs of the Pythagorean theorem. However, if this isn't hard core enough, you may want to read "The Pythagorean Proposition" by E. S. Loomis, which lists 367 proofs. <a href="http://www.cut-the-knot.org/pythagoras/index.shtml">This page</a> shows 76 different proofs of the Pythagorean theorem. However, if this isn't hard core enough, you may want to read "The Pythagorean Proposition" by E. S. Loomis, which lists 367 proofs. Solution of the Two Envelopes paradox2007-04-08T18:37:19-07:002007-04-08T18:37:19-07:00Eli Benderskytag:eli.thegreenplace.net,2007-04-08:/2007/04/08/solution-of-the-two-envelopes-paradox <p> A long time ago I wrote about the <a href="http://eli.thegreenplace.net/2003/10/24/a-probability-paradox/">Two Envelopes</a> paradox. </p><p> Once you understand the solution, it's hard to see why the paradox is so controversial and so widely misunderstood. As Dominus beautifully explains <a href="http://blog.plover.com/math/envelope.html">here</a>, the solution is: </p><p> There is a fundamental mistake in the reasoning of "50% chance of …</p> <p> A long time ago I wrote about the <a href="http://eli.thegreenplace.net/2003/10/24/a-probability-paradox/">Two Envelopes</a> paradox. </p><p> Once you understand the solution, it's hard to see why the paradox is so controversial and so widely misunderstood. As Dominus beautifully explains <a href="http://blog.plover.com/math/envelope.html">here</a>, the solution is: </p><p> There is a fundamental mistake in the reasoning of "50% chance of the sum in the other envelope being larger". This statement is based on the assumption that the sums are chosen in random uniformly from -inf to +inf (think about it, otherwise how can we say that there is an exactly 50% chance of *any* number we see in one envelope being the smaller amount). However, <strong>there is no uniform random distribution from -inf to +inf</strong>. That is because in a <a href="http://en.wikipedia.org/wiki/Uniform_distribution_%28continuous%29">uniform distribution</a>, the probability density function is constant, and an integral from -inf to +inf over a constant doesn't converge. That's all there is to it. Simple. </p><p> So back to the original question. When you open one of the envelopes and find some sum of money in there, does it pay you to switch ? It doesn't - because you don't know what algorithm / distribution is used to pick the numbers. The switching argument doesn't work because it is based on a fallacious assumption of a uniform distribution. </p> Sum of digits and divisibility by 32006-07-11T17:50:10-07:002006-07-11T17:50:10-07:00Eli Benderskytag:eli.thegreenplace.net,2006-07-11:/2006/07/11/sum-of-digits-and-divisibility-by-3 It's a known math curiosity that when we take a decimal (base 10) number and add its digits together, if the sum is divisible by 3 without a remainder, then the number itself is also divisible by 3 without remainder. Example: <p> <pre><tt> 426 -> 4 + 2 + 6 = 12 12 (mod 3) = 0 …</tt></pre></p> It's a known math curiosity that when we take a decimal (base 10) number and add its digits together, if the sum is divisible by 3 without a remainder, then the number itself is also divisible by 3 without remainder. Example: <p> <pre><tt> 426 -> 4 + 2 + 6 = 12 12 (mod 3) = 0 (12 / 3 = 4) 426 (mod 3) = 0 (426 / 3 = 142) </tt></pre> <p> This fact is actually quite simple to prove. Consider the breakdown of 426 to multiples of powers of 10: <p> <pre><tt>426 = 4 * 100 + 2 * 10 + 6</tt></pre> <p> Written another way: <pre><tt> 426 = 4 * (99 + 1) + 2 * (9 + 1) + 6 = (4 * 99 + 2 * 9) + (4 + 2 + 6) </tt></pre> This clearly shows that for 426 to be divisible by 3, (4 + 2 + 6) must be divisible for 3. If this isn't immediately obvious, recall that: <ol> <li>If <code>X (mod N) = 0</code>, then <code>Y * X (mod N) = 0</code> for any X and Y <li>If <code>X (mod N) = 0</code> and <code>Y (mod N) = 0</code>, then <code>X + Y (mod N) = 0</code> for