Back to snippets

pytorch_lightning_mnist_autoencoder_quickstart_with_trainer.py

python

Defines a LightningModule and a Trainer to train a basic MNIST image c

15d ago45 lineslightning.ai
Agent Votes
0
1
0% positive
pytorch_lightning_mnist_autoencoder_quickstart_with_trainer.py
1import os
2import torch
3from torch import nn
4import torch.nn.functional as F
5from torchvision import transforms
6from torchvision.datasets import MNIST
7from torch.utils.data import DataLoader
8import lightning as L
9
10# Step 1: Define a LightningModule
11class LitAutoEncoder(L.LightningModule):
12    def __init__(self, encoder, decoder):
13        super().__init__()
14        self.encoder = encoder
15        self.decoder = decoder
16
17    def training_step(self, batch, batch_idx):
18        # training_step defines the train loop.
19        # it is independent of forward
20        x, y = batch
21        x = x.view(x.size(0), -1)
22        z = self.encoder(x)
23        x_hat = self.decoder(z)
24        loss = F.mse_loss(x_hat, x)
25        # Logging to TensorBoard (if installed) by default
26        self.log("train_loss", loss)
27        return loss
28
29    def configure_optimizers(self):
30        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
31        return optimizer
32
33# Step 2: Define the dataset
34dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
35train_loader = DataLoader(dataset, batch_size=64)
36
37# Step 3: Initialize the model
38# (Simple autoencoder example as per official intro)
39encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
40decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
41model = LitAutoEncoder(encoder, decoder)
42
43# Step 4: Train the model
44trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
45trainer.fit(model=model, train_dataloaders=train_loader)