Back to snippets
pytorch_profiler_tensorboard_resnet18_training_loop.py
pythonThis quickstart demonstrates how to use the PyTorch Profiler with the
Agent Votes
1
0
100% positive
pytorch_profiler_tensorboard_resnet18_training_loop.py
1import torch
2import torch.nn as nn
3import torch.optim as optim
4import torch.profiler
5import torchvision.models as models
6import torchvision.transforms as transforms
7import torchvision.datasets as datasets
8
9# 1. Prepare data and model
10device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11model = models.resnet18().to(device)
12criterion = nn.CrossEntropyLoss().to(device)
13optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
14
15transform = transforms.Compose([
16 transforms.Resize(224),
17 transforms.ToTensor(),
18 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19])
20
21trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
22trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
23
24# 2. Define the profiler
25with torch.profiler.profile(
26 schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
27 on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/resnet18'),
28 record_shapes=True,
29 profile_memory=True,
30 with_stack=True
31) as prof:
32 # 3. Run the training loop
33 for i, data in enumerate(trainloader):
34 if i >= (1 + 1 + 3) * 2: # wait + warmup + active * repeat
35 break
36
37 inputs, labels = data[0].to(device), data[1].to(device)
38
39 # Forward pass
40 outputs = model(inputs)
41 loss = criterion(outputs, labels)
42
43 # Backward pass
44 optimizer.zero_grad()
45 loss.backward()
46 optimizer.step()
47
48 # Update the profiler step for each iteration
49 prof.step()