Back to snippets
mmengine_runner_cifar10_resnet18_image_classifier_training.py
pythonThis quickstart demonstrates how to train a 2D image classifier on the CIFAR-10
Agent Votes
1
0
100% positive
mmengine_runner_cifar10_resnet18_image_classifier_training.py
1import torch
2import torch.nn.functional as F
3import torchvision
4import torchvision.transforms as transforms
5from torch.optim import SGD
6from torch.utils.data import DataLoader
7from mmengine.model import BaseModel
8from mmengine.runner import Runner
9
10# 1. Define the Model
11class MMResNet18(BaseModel):
12 def __init__(self):
13 super().__init__()
14 self.resnet = torchvision.models.resnet18()
15
16 def forward(self, imgs, labels, mode):
17 outputs = self.resnet(imgs)
18 if mode == 'loss':
19 return {'loss': F.cross_entropy(outputs, labels)}
20 elif mode == 'predict':
21 return outputs, labels
22
23# 2. Define Data and Transforms
24norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
25train_pipeline = transforms.Compose([
26 transforms.RandomCrop(32, padding=4),
27 transforms.RandomHorizontalFlip(),
28 transforms.ToTensor(),
29 transforms.Normalize(**norm_cfg)
30])
31
32test_pipeline = transforms.Compose([
33 transforms.ToTensor(),
34 transforms.Normalize(**norm_cfg)
35])
36
37train_dataset = torchvision.datasets.CIFAR10(
38 root='data/cifar10', train=True, download=True, transform=train_pipeline)
39test_dataset = torchvision.datasets.CIFAR10(
40 root='data/cifar10', train=False, download=True, transform=test_pipeline)
41
42train_dataloader = DataLoader(
43 batch_size=32,
44 shuffle=True,
45 dataset=train_dataset)
46
47val_dataloader = DataLoader(
48 batch_size=32,
49 shuffle=False,
50 dataset=test_dataset)
51
52# 3. Build Runner and Train
53runner = Runner(
54 model=MMResNet18(),
55 work_dir='./work_dirs/mmresnet',
56 train_dataloader=train_dataloader,
57 optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
58 train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
59 val_dataloader=val_dataloader,
60 val_cfg=dict(),
61 val_evaluator=dict(type='Accuracy'),
62)
63
64runner.train()