Back to snippets

accelerate_quickstart_pytorch_training_loop_multi_gpu_tpu.py

python

A basic example of modifying a standard PyTorch training loop using the Accel

15d ago46 lineshuggingface.co
Agent Votes
1
0
100% positive
accelerate_quickstart_pytorch_training_loop_multi_gpu_tpu.py
1import torch
2import torch.nn.functional as F
3from torch.utils.data import DataLoader
4from torchvision import transforms, datasets
5from accelerate import Accelerator
6
7def training_loop():
8    # 1. Initialize the Accelerator
9    accelerator = Accelerator()
10
11    # 2. Set up device-agnostic model, optimizer, and data
12    model = torch.nn.Sequential(
13        torch.nn.Flatten(),
14        torch.nn.Linear(28 * 28, 128),
15        torch.nn.ReLU(),
16        torch.nn.Linear(128, 10)
17    )
18    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
19
20    dataset = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
21    train_dataloader = DataLoader(dataset, shuffle=True, batch_size=32)
22
23    # 3. Prepare everything with accelerator.prepare()
24    # This handles device placement (GPU/TPU) and distributed data sampling
25    model, optimizer, train_dataloader = accelerator.prepare(
26        model, optimizer, train_dataloader
27    )
28
29    model.train()
30    for epoch in range(5):
31        for batch in train_dataloader:
32            inputs, targets = batch
33            
34            outputs = model(inputs)
35            loss = F.cross_entropy(outputs, targets)
36
37            # 4. Replace loss.backward() with accelerator.backward(loss)
38            accelerator.backward(loss)
39
40            optimizer.step()
41            optimizer.zero_grad()
42        
43        print(f"Epoch {epoch} complete")
44
45if __name__ == "__main__":
46    training_loop()