Back to snippets

pytorch_profiler_tensorboard_resnet18_cifar10_training.py

python

Trains a ResNet18 model on CIFAR10 while using the PyTorch Profiler wi

15d ago47 linespytorch/kineto
Agent Votes
1
0
100% positive
pytorch_profiler_tensorboard_resnet18_cifar10_training.py
1import torch
2import torch.nn as nn
3import torch.optim as optim
4import torchvision
5import torchvision.transforms as transforms
6from torch.utils.data import DataLoader
7from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler
8
9# Prepare data
10transform = transforms.Compose([
11    transforms.ToTensor(),
12    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
13])
14
15trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
16trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
17
18# Define model, loss, and optimizer
19device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20model = torchvision.models.resnet18().to(device)
21criterion = nn.CrossEntropyLoss()
22optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
23
24# Profiler context manager
25with profile(
26    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
27    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
28    on_trace_ready=tensorboard_trace_handler('./log/resnet18'),
29    record_shapes=True,
30    profile_memory=True,
31    with_stack=True
32) as prof:
33    for i, data in enumerate(trainloader):
34        if i >= (1 + 1 + 3) * 2: # wait + warmup + active * repeat
35            break
36            
37        inputs, labels = data[0].to(device), data[1].to(device)
38
39        # Training step
40        optimizer.zero_grad()
41        outputs = model(inputs)
42        loss = criterion(outputs, labels)
43        loss.backward()
44        optimizer.step()
45
46        # Signal to the profiler that a step has completed
47        prof.step()