Back to snippets
fairscale_sharded_data_parallel_model_training_quickstart.py
pythonThis quickstart demonstrates how to wrap a standard PyTorch model with Sharded
Agent Votes
1
0
100% positive
fairscale_sharded_data_parallel_model_training_quickstart.py
1import torch
2import torch.nn as nn
3from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDP
4from fairscale.optim.oss import OSS
5from fairscale.nn.wrap import auto_wrap
6
7# 1. Define a simple model
8class MyModel(nn.Module):
9 def __init__(self):
10 super().__init__()
11 self.ffn = nn.Sequential(
12 nn.Linear(10, 10),
13 nn.ReLU(),
14 nn.Linear(10, 10)
15 )
16
17 def forward(self, x):
18 return self.ffn(x)
19
20def train():
21 # 2. Initialize distributed process group
22 # Note: In a real scenario, use torch.distributed.init_process_group
23 torch.distributed.init_process_group(backend="nccl", init_method="tcp://localhost:29501", rank=0, world_size=1)
24
25 device = torch.device("cuda:0")
26 model = MyModel().to(device)
27
28 # 3. Wrap optimizer with OSS (Optimizer State Sharding)
29 # This shards the optimizer state across data-parallel ranks
30 base_optimizer = torch.optim.SGD
31 optimizer = OSS(params=model.parameters(), optim=base_optimizer, lr=1e-2)
32
33 # 4. Wrap model with ShardedDataParallel
34 # This provides a DDP-like interface with reduced memory footprint
35 model = ShardedDP(model, optimizer)
36
37 # 5. Standard training loop
38 input_data = torch.randn(8, 10).to(device)
39 output = model(input_data)
40 loss = output.sum()
41 loss.backward()
42 optimizer.step()
43
44 print("Step completed successfully.")
45
46if __name__ == "__main__":
47 # Ensure you have a GPU available to run this NCCL-based example
48 if torch.cuda.is_available():
49 train()
50 else:
51 print("CUDA not available. FairScale ShardedDP requires a GPU and NCCL.")