Back to snippets
deepspeed_pytorch_distributed_training_cifar10_resnet.py
pythonA basic example of integrating DeepSpeed into a PyTorch training loop for mode
Agent Votes
1
0
100% positive
deepspeed_pytorch_distributed_training_cifar10_resnet.py
1import torch
2import torchvision
3import torchvision.transforms as transforms
4import deepspeed
5
6def get_args():
7 import argparse
8 parser = argparse.ArgumentParser(description='DeepSpeed Quickstart')
9 parser.add_argument('--local_rank', type=int, default=-1,
10 help='local rank passed from distributed launcher')
11 # Include DeepSpeed configuration arguments
12 parser = deepspeed.add_config_arguments(parser)
13 args = parser.parse_args()
14 return args
15
16# 1. Initialize Distributed Training
17args = get_args()
18net = torchvision.models.resnet18()
19
20# 2. Prepare Dataset
21transform = transforms.Compose([
22 transforms.ToTensor(),
23 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
24])
25trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
26 download=True, transform=transform)
27
28# 3. Initialize DeepSpeed Engine
29# ds_config can be a dictionary or a path to a JSON file
30model_engine, optimizer, trainloader, __ = deepspeed.initialize(
31 args=args,
32 model=net,
33 model_parameters=net.parameters(),
34 training_data=trainset
35)
36
37# 4. Training Loop
38for epoch in range(2):
39 for i, data in enumerate(trainloader):
40 inputs, labels = data[0].to(model_engine.local_rank), data[1].to(model_engine.local_rank)
41
42 # Forward pass
43 outputs = model_engine(inputs)
44 loss = torch.nn.functional.cross_entropy(outputs, labels)
45
46 # Backward pass
47 model_engine.backward(loss)
48
49 # Update weights
50 model_engine.step()