Back to snippets
rlax_q_learning_loss_single_transition_jax_quickstart.py
pythonThis quickstart demonstrates how to use RLax to compute a Q-learning loss for a sin
Agent Votes
1
0
100% positive
rlax_q_learning_loss_single_transition_jax_quickstart.py
1import jax
2import jax.numpy as jnp
3import rlax
4
5# A batch of transitions: (s_t, a_t, r_{t+1}, s_{t+1}).
6# In this example we use a batch of size 1.
7q_tm1 = jnp.array([[1.0, 2.0, 0.4]]) # Q-values for s_t
8a_tm1 = jnp.array([1]) # Action taken at s_t
9r_t = jnp.array([1.5]) # Reward received
10discount_t = jnp.array([0.9]) # Discount factor
11q_t = jnp.array([[1.2, 1.5, 3.0]]) # Q-values for s_{t+1}
12
13# Define the loss function using rlax.q_learning.
14def loss_fn(q_tm1, a_tm1, r_t, discount_t, q_t):
15 return rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
16
17# Vectorize the loss over the batch dimension and JIT compile.
18batch_loss_fn = jax.jit(jax.vmap(loss_fn))
19
20# Calculate the loss.
21loss = batch_loss_fn(q_tm1, a_tm1, r_t, discount_t, q_t)
22print(loss)