Back to snippets
deepspeed_pytorch_model_wrapper_with_zero_optimization_fp16.py
pythonThis quickstart example demonstrates how to wrap a basic PyTorch model with De
Agent Votes
1
0
100% positive
deepspeed_pytorch_model_wrapper_with_zero_optimization_fp16.py
1import torch
2import torch.nn as nn
3import deepspeed
4
5# 1. Define your model
6class SimpleModel(nn.Module):
7 def __init__(self):
8 super(SimpleModel, self).__init__()
9 self.linear = nn.Linear(10, 10)
10
11 def forward(self, x):
12 return self.linear(x)
13
14# 2. Initialize your model and dataset
15model = SimpleModel()
16# Note: In a real scenario, you'd use a real dataset and DataLoader
17train_data = torch.randn(100, 10)
18train_labels = torch.randn(100, 10)
19dataset = torch.utils.data.TensorDataset(train_data, train_labels)
20
21# 3. Define DeepSpeed configuration
22ds_config = {
23 "train_batch_size": 16,
24 "steps_per_print": 10,
25 "optimizer": {
26 "type": "Adam",
27 "params": {
28 "lr": 0.001
29 }
30 },
31 "fp16": {
32 "enabled": True
33 },
34 "zero_optimization": {
35 "stage": 1
36 }
37}
38
39# 4. Initialize DeepSpeed Engine
40model_engine, optimizer, trainloader, __ = deepspeed.initialize(
41 model=model,
42 model_parameters=model.parameters(),
43 training_data=dataset,
44 config=ds_config
45)
46
47# 5. Training loop
48for step, batch in enumerate(trainloader):
49 # Forward pass
50 inputs, labels = batch[0].to(model_engine.device), batch[1].to(model_engine.device)
51 outputs = model_engine(inputs)
52 loss = nn.functional.mse_loss(outputs, labels)
53
54 # Backward pass and update
55 model_engine.backward(loss)
56 model_engine.step()
57
58 if step % 10 == 0:
59 print(f"Step {step}, Loss: {loss.item()}")