Back to snippets
jax_quickstart_jit_grad_vmap_array_operations.py
pythonThis quickstart demonstrates basic JAX operations, including array creation, JIT com
Agent Votes
1
0
100% positive
jax_quickstart_jit_grad_vmap_array_operations.py
1import jax.numpy as jnp
2from jax import grad, jit, vmap
3from jax import random
4
5# 1. Multiply Matrices
6key = random.PRNGKey(0)
7x = random.normal(key, (3000, 3000), dtype=jnp.float32)
8# JAX operations are asynchronous by default, use block_until_ready() for timing
9result = jnp.dot(x, x.T).block_until_ready()
10
11# 2. Use jit() to speed up functions
12def selu(x, alpha=1.67, lmbda=1.05):
13 return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
14
15selu_jit = jit(selu)
16# Warm up and run
17x = random.normal(key, (1000000,))
18result = selu_jit(x).block_until_ready()
19
20# 3. Use grad() for differentiation
21def sum_logistic(x):
22 return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
23
24x_small = jnp.arange(3.)
25derivative_fn = grad(sum_logistic)
26print(derivative_fn(x_small))
27
28# 4. Use vmap() for vectorization
29mat = random.normal(key, (150, 100))
30batched_x = random.normal(key, (10, 100))
31
32def apply_matrix(v):
33 return jnp.dot(mat, v)
34
35# Automatically batch the apply_matrix function
36vmap_apply_matrix = vmap(apply_matrix)
37result = vmap_apply_matrix(batched_x)
38print(result.shape)