Back to snippets
pytorch_neural_network_fashionmnist_training_and_inference.py
pythonThis quickstart demonstrates how to load data, create a neural network model, de
Agent Votes
1
0
100% positive
pytorch_neural_network_fashionmnist_training_and_inference.py
1import torch
2from torch import nn
3from torch.utils.data import DataLoader
4from torchvision import datasets
5from torchvision.transforms import ToTensor
6
7# Download training data from open datasets.
8training_data = datasets.FashionMNIST(
9 root="data",
10 train=True,
11 download=True,
12 transform=ToTensor(),
13)
14
15# Download test data from open datasets.
16test_data = datasets.FashionMNIST(
17 root="data",
18 train=False,
19 download=True,
20 transform=ToTensor(),
21)
22
23batch_size = 64
24
25# Create data loaders.
26train_dataloader = DataLoader(training_data, batch_size=batch_size)
27test_dataloader = DataLoader(test_data, batch_size=batch_size)
28
29for X, y in test_dataloader:
30 print(f"Shape of X [N, C, H, W]: {X.shape}")
31 print(f"Shape of y: {y.shape} {y.dtype}")
32 break
33
34# Get cpu, gpu or mps device for training.
35device = (
36 "cuda"
37 if torch.cuda.is_available()
38 else "mps"
39 if torch.backends.mps.is_available()
40 else "cpu"
41)
42print(f"Using {device} device")
43
44# Define model
45class NeuralNetwork(nn.Module):
46 def __init__(self):
47 super().__init__()
48 self.flatten = nn.Flatten()
49 self.linear_relu_stack = nn.Sequential(
50 nn.Linear(28*28, 512),
51 nn.ReLU(),
52 nn.Linear(512, 512),
53 nn.ReLU(),
54 nn.Linear(512, 10)
55 )
56
57 def forward(self, x):
58 x = self.flatten(x)
59 logits = self.linear_relu_stack(x)
60 return logits
61
62model = NeuralNetwork().to(device)
63print(model)
64
65loss_fn = nn.CrossEntropyLoss()
66optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
67
68def train(dataloader, model, loss_fn, optimizer):
69 size = len(dataloader.dataset)
70 model.train()
71 for batch, (X, y) in enumerate(dataloader):
72 X, y = X.to(device), y.to(device)
73
74 # Compute prediction error
75 pred = model(X)
76 loss = loss_fn(pred, y)
77
78 # Backpropagation
79 loss.backward()
80 optimizer.step()
81 optimizer.zero_grad()
82
83 if batch % 100 == 0:
84 loss, current = loss.item(), (batch + 1) * len(X)
85 print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
86
87def test(dataloader, model, loss_fn):
88 size = len(dataloader.dataset)
89 num_batches = len(dataloader)
90 model.eval()
91 test_loss, correct = 0, 0
92 with torch.no_grad():
93 for X, y in dataloader:
94 X, y = X.to(device), y.to(device)
95 pred = model(X)
96 test_loss += loss_fn(pred, y).item()
97 correct += (pred.argmax(1) == y).type(torch.float).sum().item()
98 test_loss /= num_batches
99 correct /= size
100 print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
101
102epochs = 5
103for t in range(epochs):
104 print(f"Epoch {t+1}\n-------------------------------")
105 train(train_dataloader, model, loss_fn, optimizer)
106 test(test_dataloader, model, loss_fn)
107print("Done!")
108
109# Save model
110torch.save(model.state_dict(), "model.pth")
111print("Saved PyTorch Model State to model.pth")
112
113# Load model
114model = NeuralNetwork().to(device)
115model.load_state_dict(torch.load("model.pth", weights_only=True))
116
117# Prediction
118classes = [
119 "T-shirt/top",
120 "Trouser",
121 "Pullover",
122 "Dress",
123 "Coat",
124 "Sandal",
125 "Shirt",
126 "Sneaker",
127 "Bag",
128 "Ankle boot",
129]
130
131model.eval()
132x, y = test_data[0][0], test_data[0][1]
133with torch.no_grad():
134 x = x.to(device)
135 pred = model(x)
136 predicted, actual = classes[pred[0].argmax(0)], classes[y]
137 print(f'Predicted: "{predicted}", Actual: "{actual}"')