
Share
Jaxtyping 1.2 upgrades its game with detailed error messages and advanced type annotations, helping developers swiftly identify and resolve tensor shape issues during runtime.
In a significant update to the popular jaxtyping library, version 1.2 introduces improved runtime type-checking with more detailed error messages. This enhancement is crucial for developers working with tensors in machine learning frameworks like PyTorch, TensorFlow, NumPy, and JAX. Let's dive into what has changed and why it matters.
The latest release of jaxtyping brings a more robust runtime type-checking mechanism, which is particularly useful for catching shape errors early in development. Here are the key changes:
jaxtyping now supports PyTorch, TensorFlow, NumPy, and JAX, with no dependency on JAX itself.If you've ever written code using tensors in frameworks like PyTorch or JAX, you've likely encountered shape errors. These errors often occur silently due to broadcasting, leading to frustrating downstream issues such as models failing to train without clear reasons. For example:
from torch import Tensor
def f(x: Tensor, y: Tensor):
# x and y must be one-dimensional
# x and y must have the same size
...
In this snippet, the requirements for x and y are only documented in comments. If you or another developer violate these constraints, the code might still run but produce incorrect results.
jaxtyping allows you to specify tensor shapes and data types directly in function signatures. This not only makes your code more readable but also ensures that tensors meet the required conditions at runtime. Here’s how it works:
from jaxtyping import Float
from torch import Tensor

def f(x: Float[Tensor, "channels"], y: Float[Tensor, "channels"]): ...
- `Float[Tensor, "channels"]` indicates that `x` and `y` are one-dimensional tensors with a floating-point data type.
- The `"channels"` annotation ensures that both tensors have the same size along this dimension.
2. **Declaring Shape Consistency Across Multiple Arguments**:
By using the same string for shape annotations, you can enforce consistency across multiple arguments. For example:
```python
def g(x: Float[Tensor, "height width"],
y: Float[Tensor, "height width"]):
...
x and y must have the same shape (height, width).beartype library for type-checking and jaxtyped from jaxtyping to integrate it seamlessly:
from beartype import beartype
from jaxtyping import Float, jaxtyped
from torch import Tensor
@jaxtyped(typechecker=beartype)
def f(x: Float[Tensor, "channels"],
y: Float[Tensor, "channels"]):
...
Let's see how jaxtyping catches shape errors with a practical example:
from torch import zeros
from jaxtyping import Float, jaxtyped
from beartype import beartype
@jaxtyped(typechecker=beartype)
def f(x: Float[Tensor, "channels"],
y: Float[Tensor, "channels"]):
...
# Attempting to call the function with tensors of different sizes
f(zeros(3), zeros(4))
This will raise a TypeCheckError with a clear message:
jaxtyping.TypeCheckError: Type-check failed for argument 'y' in function f.
Expected shape (channels,) but got shape (4,).
Tags
Original Sources
About the author
Kai built ML infrastructure at a Bay Area startup before developing an obsession with transformer architectures and inference optimisation that eventually pulled him out of product work entirely. A stint at a compute research lab sharpened his instinct for what actually matters in a model release versus what is marketing. He writes from the inside — from the perspective of someone who has debugged the systems he is describing at three in the morning. He is allergic to hype and instinctively drawn to the unglamorous plumbing questions that everyone else skips over.
More from The Engineer →This Week's Edition
19 December 2023
88 articles
Related Articles
Related Articles
More Stories