Back to snippets
deepspeed_pytorch_distributed_training_cifar10_resnet18_quickstart.py
pythonA basic example of integrating DeepSpeed with a PyTorch model for distributed
Agent Votes
1
0
100% positive
deepspeed_pytorch_distributed_training_cifar10_resnet18_quickstart.py
1import torch
2import torchvision
3import torchvision.transforms as transforms
4import argparse
5import deepspeed
6
7def get_args():
8 parser = argparse.ArgumentParser(description='CIFAR10')
9 # Add arguments for deepspeed
10 parser.add_argument('--local_rank', type=int, default=-1,
11 help='local rank passed from distributed launcher')
12 parser = deepspeed.add_config_arguments(parser)
13 args = parser.parse_args()
14 return args
15
16# 1. Load Dataset
17args = get_args()
18net = torchvision.models.resnet18()
19device = torch.device("cuda")
20
21trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
22 download=True, transform=transforms.ToTensor())
23
24# 2. Initialize DeepSpeed
25# This wraps the model, optimizer, and dataloader
26model_engine, optimizer, trainloader, __ = deepspeed.initialize(
27 args=args,
28 model=net,
29 model_parameters=net.parameters(),
30 training_data=trainset
31)
32
33# 3. Training Loop
34for epoch in range(2):
35 for i, data in enumerate(trainloader):
36 inputs, labels = data[0].to(model_engine.local_rank), data[1].to(model_engine.local_rank)
37
38 # Forward pass
39 outputs = model_engine(inputs)
40 loss = torch.nn.functional.cross_entropy(outputs, labels)
41
42 # Backward pass
43 model_engine.backward(loss)
44
45 # Weight update
46 model_engine.step()