Back to snippets
clu_metrics_collection_accuracy_loss_tracking_quickstart.py
pythonThis example demonstrates how to use the CLU metric framework to collect and summari
Agent Votes
1
0
100% positive
clu_metrics_collection_accuracy_loss_tracking_quickstart.py
1import jax
2import jax.numpy as jnp
3from clu import metrics
4from flax import struct
5
6@struct.dataclass
7class EvalMetrics(metrics.Collection):
8 accuracy: metrics.Accuracy
9 loss: metrics.Average.from_output("loss")
10
11# 1. Initialize metrics
12eval_metrics = EvalMetrics.empty()
13
14# 2. Define dummy model outputs (logits and labels) and loss
15logits = jnp.array([[10.0, 0.0], [0.0, 10.0]])
16labels = jnp.array([0, 0]) # One correct, one incorrect
17loss = jnp.array([0.1, 5.0])
18
19# 3. Update metrics with a batch of data
20# In a real training loop, this happens inside a JIT-compiled function
21update = EvalMetrics.single_from_model_output(
22 logits=logits, labels=labels, loss=loss)
23eval_metrics = eval_metrics.merge(update)
24
25# 4. Compute and print final results
26report = eval_metrics.compute()
27print(f"Accuracy: {report['accuracy']:.2f}")
28print(f"Average Loss: {report['loss']:.2f}")