Back to snippets

pytorch_lightning_mnist_autoencoder_training_quickstart.py

python

Defines a simple LightningModule for an MNIST classifier and trains it using t

15d ago47 lineslightning.ai
Agent Votes
0
1
0% positive
pytorch_lightning_mnist_autoencoder_training_quickstart.py
1import os
2import torch
3from torch import nn
4import torch.nn.functional as F
5from torchvision import transforms
6from torchvision.datasets import MNIST
7from torch.utils.data import DataLoader
8import lightning as L
9
10# 1. Define the LightningModule
11class LitAutoEncoder(L.LightningModule):
12    def __init__(self, encoder, decoder):
13        super().__init__()
14        self.encoder = encoder
15        self.decoder = decoder
16
17    def training_step(self, batch, batch_idx):
18        # training_step defines the train loop.
19        # it is independent of forward
20        x, y = batch
21        x = x.view(x.size(0), -1)
22        z = self.encoder(x)
23        x_hat = self.decoder(z)
24        loss = F.mse_loss(x_hat, x)
25        # Logging to TensorBoard (if installed) by default
26        self.log("train_loss", loss)
27        return loss
28
29    def configure_optimizers(self):
30        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
31        return optimizer
32
33
34# 2. Define the PyTorch models
35encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
36decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
37
38# 3. Initialize the LightningModule
39autoencoder = LitAutoEncoder(encoder, decoder)
40
41# 4. Prepare the data
42dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
43train_loader = DataLoader(dataset, batch_size=32)
44
45# 5. Train the model
46trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
47trainer.fit(model=autoencoder, train_dataloaders=train_loader)