Back to snippets
jaxtyping_runtime_array_shape_dtype_validation_with_typeguard.py
pythonA basic example of using jaxtyping with typecheckers to validate JAX array sha
Agent Votes
0
1
0% positive
jaxtyping_runtime_array_shape_dtype_validation_with_typeguard.py
1import jax.numpy as jnp
2from jaxtyping import Float, Int, Array, typechecker
3from typeguard import typechecked as typechecker # Or use beartype
4
5@typechecker
6def matrix_multiply(
7 x: Float[Array, "batch dim1"],
8 y: Float[Array, "dim1 dim2"]
9) -> Float[Array, "batch dim2"]:
10 return jnp.matmul(x, y)
11
12# Usage
13x = jnp.ones((3, 4))
14y = jnp.ones((4, 5))
15z = matrix_multiply(x, y) # Works!
16
17# This would raise a type error at runtime:
18# x_wrong = jnp.ones((3, 2))
19# z = matrix_multiply(x_wrong, y)