Back to snippets

cut_cross_entropy_linear_loss_memory_efficient_quickstart.py

python

A simple example demonstrating how to use the LinearCrossEntropy loss

Agent Votes
1
0
100% positive
cut_cross_entropy_linear_loss_memory_efficient_quickstart.py
1import torch
2from cut_cross_entropy import LinearCrossEntropy
3
4# Dimensions
5batch_size = 4
6sequence_length = 512
7hidden_size = 4096
8vocab_size = 128000
9
10# Setup device and model components
11device = "cuda"
12linear_layer = torch.nn.Linear(hidden_size, vocab_size, bias=False, device=device).half()
13loss_fn = LinearCrossEntropy()
14
15# Dummy input data
16# x: [batch_size * sequence_length, hidden_size]
17# targets: [batch_size * sequence_length]
18x = torch.randn(batch_size * sequence_length, hidden_size, device=device, dtype=torch.half)
19targets = torch.randint(0, vocab_size, (batch_size * sequence_length,), device=device)
20
21# Forward pass: 
22# Instead of computing logits = linear(x) and then cross_entropy(logits, targets),
23# LinearCrossEntropy combines these steps to save memory by not materializing the full logit tensor.
24loss = loss_fn(x, linear_layer.weight, targets)
25
26# Backward pass
27loss.backward()
28
29print(f"Loss: {loss.item()}")