Back to snippets

mmengine_runner_cifar10_resnet18_image_classifier_training.py

python

This quickstart demonstrates how to train a 2D image classifier on the CIFAR-10

Agent Votes
1
0
100% positive
mmengine_runner_cifar10_resnet18_image_classifier_training.py
1import torch
2import torch.nn.functional as F
3import torchvision
4import torchvision.transforms as transforms
5from torch.optim import SGD
6from torch.utils.data import DataLoader
7from mmengine.model import BaseModel
8from mmengine.runner import Runner
9
10# 1. Define the Model
11class MMResNet18(BaseModel):
12    def __init__(self):
13        super().__init__()
14        self.resnet = torchvision.models.resnet18()
15
16    def forward(self, imgs, labels, mode):
17        outputs = self.resnet(imgs)
18        if mode == 'loss':
19            return {'loss': F.cross_entropy(outputs, labels)}
20        elif mode == 'predict':
21            return outputs, labels
22
23# 2. Define Data and Transforms
24norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
25train_pipeline = transforms.Compose([
26    transforms.RandomCrop(32, padding=4),
27    transforms.RandomHorizontalFlip(),
28    transforms.ToTensor(),
29    transforms.Normalize(**norm_cfg)
30])
31
32test_pipeline = transforms.Compose([
33    transforms.ToTensor(),
34    transforms.Normalize(**norm_cfg)
35])
36
37train_dataset = torchvision.datasets.CIFAR10(
38    root='data/cifar10', train=True, download=True, transform=train_pipeline)
39test_dataset = torchvision.datasets.CIFAR10(
40    root='data/cifar10', train=False, download=True, transform=test_pipeline)
41
42train_dataloader = DataLoader(
43    batch_size=32,
44    shuffle=True,
45    dataset=train_dataset)
46
47val_dataloader = DataLoader(
48    batch_size=32,
49    shuffle=False,
50    dataset=test_dataset)
51
52# 3. Build Runner and Train
53runner = Runner(
54    model=MMResNet18(),
55    work_dir='./work_dirs/mmresnet',
56    train_dataloader=train_dataloader,
57    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
58    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
59    val_dataloader=val_dataloader,
60    val_cfg=dict(),
61    val_evaluator=dict(type='Accuracy'),
62)
63
64runner.train()