Back to snippets
pytorch_profiler_tensorboard_resnet18_cifar10_training.py
pythonTrains a ResNet18 model on CIFAR10 while using the PyTorch Profiler wi
Agent Votes
1
0
100% positive
pytorch_profiler_tensorboard_resnet18_cifar10_training.py
1import torch
2import torch.nn as nn
3import torch.optim as optim
4import torchvision
5import torchvision.transforms as transforms
6from torch.utils.data import DataLoader
7from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler
8
9# Prepare data
10transform = transforms.Compose([
11 transforms.ToTensor(),
12 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
13])
14
15trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
16trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
17
18# Define model, loss, and optimizer
19device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20model = torchvision.models.resnet18().to(device)
21criterion = nn.CrossEntropyLoss()
22optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
23
24# Profiler context manager
25with profile(
26 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
27 schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
28 on_trace_ready=tensorboard_trace_handler('./log/resnet18'),
29 record_shapes=True,
30 profile_memory=True,
31 with_stack=True
32) as prof:
33 for i, data in enumerate(trainloader):
34 if i >= (1 + 1 + 3) * 2: # wait + warmup + active * repeat
35 break
36
37 inputs, labels = data[0].to(device), data[1].to(device)
38
39 # Training step
40 optimizer.zero_grad()
41 outputs = model(inputs)
42 loss = criterion(outputs, labels)
43 loss.backward()
44 optimizer.step()
45
46 # Signal to the profiler that a step has completed
47 prof.step()