
Share
Explore the inner workings of JAX through "Autodidax," a beginner-friendly version of JAX’s core system, and unravel the mysteries behind transformations as interpreters in this detailed tutorial.
If you’ve ever been curious about how JAX works under the hood but found its implementation daunting, this tutorial is for you. By diving into "Autodidax," a simplified version of JAX’s core system, you’ll gain insights into the fundamental ideas and jargon used in JAX. This article covers Part 1, focusing on transformations as interpreters, including standard evaluation, jvp, and vmap.
In JAX, functions are transformed by interpreting them differently. Instead of just applying primitive operations to numerical inputs to produce outputs, we can override these primitives to let different values flow through the program. For example, instead of computing the result directly, we might compute its derivative (JVP) or vectorize it (vmap).
Consider a simple function:
def f(x):
y = sin(x) * 2.
z = -y + x
return z
Here, sin, multiplication (*), addition (+), and negation (-) are primitive operations. JAX allows us to transform this function by interpreting these primitives in different ways.
To implement transformations, we need a way to intercept and override the application of these primitives. We start by defining them as Primitive objects:
from typing import NamedTuple
class Primitive(NamedTuple):
name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive('neg')
sin_p = Primitive('sin')
cos_p = Primitive('cos')
reduce_sum_p = Primitive('reduce_sum')
greater_p = Primitive('greater')
less_p = Primitive('less')
transpose_p = Primitive('transpose')
broadcast_p = Primitive('broadcast')
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
The bind1 function is a placeholder for the actual binding logic that will apply the primitive with the given inputs. This setup allows us to intercept and transform the application of these primitives.

In standard evaluation, we simply apply the primitive operations as they are defined:
def bind1(prim, *args):
if prim == add_p:
return args[0] + args[1]
elif prim == mul_p:
return args[0] * args[1]
elif prim == neg_p:
return -args[0]
# Add more primitives as needed
This is the default behavior, but we can override it for transformations.
The JVP transformation computes the derivative of a function. Instead of applying the primitive directly, we apply its JVP rule:
def bind1(prim, *args):
if prim == add_p:
return args[0] + args[1], (1, 1)
elif prim == mul_p:
return args[0] * args[1], (args[1], args[0])
# Add more primitives as needed
Here, the second element of each tuple represents the derivative with respect to each input. This allows us to compute derivatives while evaluating the function.
The vmap transformation vectorizes a function to handle batched inputs. For example, if we have an array of inputs, we can apply the function to each element in the array:
def bind1(prim, *args):
if prim == add_p:
return jnp.vectorize(add)(args[0], args[1])
elif prim == mul_p:
return jnp.vectorize(mul)(args[0], args[1])
# Add more primitives as needed
This allows us to handle batched inputs efficiently.
JAX can compose multiple transformations, leading to stacks of interpreters. For example, we might want to compute the JVP and then vectorize the result:
def transformed
Tags
Original Sources
About the author
Kai built ML infrastructure at a Bay Area startup before developing an obsession with transformer architectures and inference optimisation that eventually pulled him out of product work entirely. A stint at a compute research lab sharpened his instinct for what actually matters in a model release versus what is marketing. He writes from the inside — from the perspective of someone who has debugged the systems he is describing at three in the morning. He is allergic to hype and instinctively drawn to the unglamorous plumbing questions that everyone else skips over.
More from The Engineer →This Week's Edition
30 April 2024
88 articles
Related Articles
Related Articles
More Stories