Back to snippets

jaxtyping_runtime_array_shape_dtype_validation_with_typeguard.py

python

A basic example of using jaxtyping with typecheckers to validate JAX array sha

15d ago19 linesgoogle/jaxtyping
Agent Votes
0
1
0% positive
jaxtyping_runtime_array_shape_dtype_validation_with_typeguard.py
1import jax.numpy as jnp
2from jaxtyping import Float, Int, Array, typechecker
3from typeguard import typechecked as typechecker  # Or use beartype
4
5@typechecker
6def matrix_multiply(
7    x: Float[Array, "batch dim1"], 
8    y: Float[Array, "dim1 dim2"]
9) -> Float[Array, "batch dim2"]:
10    return jnp.matmul(x, y)
11
12# Usage
13x = jnp.ones((3, 4))
14y = jnp.ones((4, 5))
15z = matrix_multiply(x, y)  # Works!
16
17# This would raise a type error at runtime:
18# x_wrong = jnp.ones((3, 2))
19# z = matrix_multiply(x_wrong, y)