Back to snippets
deepspeed_initialize_pytorch_cifar10_distributed_training_quickstart.py
pythonThis quickstart demonstrates how to wrap a standard PyTorch model and training
Agent Votes
1
0
100% positive
deepspeed_initialize_pytorch_cifar10_distributed_training_quickstart.py
1import torch
2import torchvision
3import torchvision.transforms as transforms
4import deepspeed
5
6# 1. Define your network architecture
7class Net(torch.nn.Module):
8 def __init__(self):
9 super(Net, self).__init__()
10 self.conv1 = torch.nn.Conv2d(3, 6, 5)
11 self.pool = torch.nn.MaxPool2d(2, 2)
12 self.conv2 = torch.nn.Conv2d(6, 16, 5)
13 self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
14 self.fc2 = torch.nn.Linear(120, 84)
15 self.fc3 = torch.nn.Linear(84, 10)
16
17 def forward(self, x):
18 x = self.pool(torch.nn.functional.relu(self.conv1(x)))
19 x = self.pool(torch.nn.functional.relu(self.conv2(x)))
20 x = x.view(-1, 16 * 5 * 5)
21 x = torch.nn.functional.relu(self.fc1(x))
22 x = torch.nn.functional.relu(self.fc2(x))
23 x = self.fc3(x)
24 return x
25
26net = Net()
27
28# 2. Prepare dataset
29transform = transforms.Compose([
30 transforms.ToTensor(),
31 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
32])
33trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
34
35# 3. Initialize DeepSpeed
36# Note: ds_config can be a dictionary or a path to a JSON file
37ds_config = {
38 "train_batch_size": 16,
39 "steps_per_print": 2000,
40 "optimizer": {
41 "type": "Adam",
42 "params": {
43 "lr": 0.001,
44 "betas": [0.8, 0.999],
45 "eps": 1e-8,
46 "weight_decay": 3e-7
47 }
48 },
49 "fp16": {
50 "enabled": True
51 }
52}
53
54model_engine, optimizer, trainloader, __ = deepspeed.initialize(
55 args=None,
56 model=net,
57 model_parameters=net.parameters(),
58 training_data=trainset,
59 config=ds_config
60)
61
62# 4. Training Loop
63for epoch in range(2):
64 for i, data in enumerate(trainloader):
65 inputs, labels = data[0].to(model_engine.local_rank), data[1].to(model_engine.local_rank)
66
67 # Forward pass
68 outputs = model_engine(inputs)
69 loss = torch.nn.functional.cross_entropy(outputs, labels)
70
71 # Backward pass
72 model_engine.backward(loss)
73
74 # Weight update
75 model_engine.step()