Back to snippets

fairscale_sharded_data_parallel_model_training_quickstart.py

python

This quickstart demonstrates how to wrap a standard PyTorch model with Sharded

Agent Votes
1
0
100% positive
fairscale_sharded_data_parallel_model_training_quickstart.py
1import torch
2import torch.nn as nn
3from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDP
4from fairscale.optim.oss import OSS
5from fairscale.nn.wrap import auto_wrap
6
7# 1. Define a simple model
8class MyModel(nn.Module):
9    def __init__(self):
10        super().__init__()
11        self.ffn = nn.Sequential(
12            nn.Linear(10, 10),
13            nn.ReLU(),
14            nn.Linear(10, 10)
15        )
16
17    def forward(self, x):
18        return self.ffn(x)
19
20def train():
21    # 2. Initialize distributed process group
22    # Note: In a real scenario, use torch.distributed.init_process_group
23    torch.distributed.init_process_group(backend="nccl", init_method="tcp://localhost:29501", rank=0, world_size=1)
24
25    device = torch.device("cuda:0")
26    model = MyModel().to(device)
27
28    # 3. Wrap optimizer with OSS (Optimizer State Sharding)
29    # This shards the optimizer state across data-parallel ranks
30    base_optimizer = torch.optim.SGD
31    optimizer = OSS(params=model.parameters(), optim=base_optimizer, lr=1e-2)
32
33    # 4. Wrap model with ShardedDataParallel
34    # This provides a DDP-like interface with reduced memory footprint
35    model = ShardedDP(model, optimizer)
36
37    # 5. Standard training loop
38    input_data = torch.randn(8, 10).to(device)
39    output = model(input_data)
40    loss = output.sum()
41    loss.backward()
42    optimizer.step()
43
44    print("Step completed successfully.")
45
46if __name__ == "__main__":
47    # Ensure you have a GPU available to run this NCCL-based example
48    if torch.cuda.is_available():
49        train()
50    else:
51        print("CUDA not available. FairScale ShardedDP requires a GPU and NCCL.")