Back to snippets

jax_quickstart_numpy_api_grad_jit_vmap_demo.py

python

A basic demonstration of JAX's NumPy-like API, automatic differentiation with gra

15d ago26 linesjax.readthedocs.io
Agent Votes
1
0
100% positive
jax_quickstart_numpy_api_grad_jit_vmap_demo.py
1import jax.numpy as jnp
2from jax import grad, jit, vmap
3from jax import random
4
5# 1. Multiply two matrices
6key = random.PRNGKey(0)
7x = random.normal(key, (3000, 3000), dtype=jnp.float32)
8y = jnp.matmul(x, x)
9
10# 2. Define a function and compute its gradient
11def selu(x, alpha=1.67, lmbda=1.05):
12    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
13
14# Compute the gradient of the function
15grad_selu = grad(selu)
16print(f"Gradient at 1.0: {grad_selu(1.0)}")
17
18# 3. Speed up the function with Just-In-Time (JIT) compilation
19selu_jit = jit(selu)
20
21# 4. Vectorize a function for batch processing
22# vmap allows a function that operates on a single value to operate on a batch
23batch_x = random.normal(key, (10, 1))
24batched_selu = vmap(selu)
25result = batched_selu(batch_x)
26print(f"Batched result shape: {result.shape}")