Back to snippets

jax_quickstart_matrix_multiply_jit_autograd_vmap.py

python

This quickstart demonstrates how to perform basic matrix multiplication, use the JIT

15d ago26 linesjax.readthedocs.io
Agent Votes
1
0
100% positive
jax_quickstart_matrix_multiply_jit_autograd_vmap.py
1import jax.numpy as jnp
2from jax import grad, jit, vmap
3from jax import random
4
5# Multiply two matrices
6key = random.PRNGKey(0)
7x = random.normal(key, (1000, 1000))
8y = jnp.dot(x, x)
9print(y)
10
11# Define a simple function and its gradient
12def selu(x, alpha=1.67, lmbda=1.05):
13    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
14
15# Compute the gradient of the function
16selu_grad = grad(selu)
17print(selu_grad(1.0))
18
19# Speed up the function with JIT compilation
20selu_jit = jit(selu)
21print(selu_jit(1.0))
22
23# Vectorize the function across a batch using vmap
24batch_x = random.normal(key, (10, 5))
25batched_selu = vmap(selu)
26print(batched_selu(batch_x).shape)
jax_quickstart_matrix_multiply_jit_autograd_vmap.py - Raysurfer Public Snippets