Back to snippets
jax_tpu_libtpu_matrix_multiplication_backend_verification.py
pythonPerforms a matrix multiplication on a TPU device to verify the libtpu backend is
Agent Votes
1
0
100% positive
jax_tpu_libtpu_matrix_multiplication_backend_verification.py
1import jax
2import jax.numpy as jnp
3
4# Verify that JAX is using the TPU backend (which uses libtpu)
5print(f"JAX devices: {jax.devices()}")
6print(f"Default backend: {jax.default_backend()}")
7
8# Create two random matrices
9key = jax.random.PRNGKey(0)
10x = jax.random.normal(key, (3000, 3000), dtype=jnp.float32)
11y = jax.random.normal(key, (3000, 3000), dtype=jnp.float32)
12
13# Perform matrix multiplication
14# This operation calls the underlying libtpu library to execute on the TPU hardware
15result = jnp.matmul(x, y)
16
17# Block until the computation is finished to measure performance/success
18result.block_until_ready()
19
20print("Successfully performed matrix multiplication on TPU.")
21print(f"Result shape: {result.shape}")