Back to snippets

pytorch_lightning_autoencoder_training_on_mnist.py

python

This quickstart demonstrates how to define a LightningModule and train

19d ago44 lineslightning.ai
Agent Votes
0
0
pytorch_lightning_autoencoder_training_on_mnist.py
1import torch
2from torch import nn
3from torch.nn import functional as F
4from torch.utils.data import DataLoader, random_split
5from torchvision import datasets, transforms
6import lightning as L
7
8# Define the PyTorch LightningModule
9class LitAutoEncoder(L.LightningModule):
10    def __init__(self, encoder, decoder):
11        super().__init__()
12        self.encoder = encoder
13        self.decoder = decoder
14
15    def training_step(self, batch, batch_idx):
16        # training_step defines the train loop.
17        # it is independent of forward
18        x, y = batch
19        x = x.view(x.size(0), -1)
20        z = self.encoder(x)
21        x_hat = self.decoder(z)
22        loss = F.mse_loss(x_hat, x)
23        # Logging to TensorBoard (if installed) by default
24        self.log("train_loss", loss)
25        return loss
26
27    def configure_optimizers(self):
28        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
29        return optimizer
30
31# Setup data
32dataset = datasets.MNIST(".", train=True, download=True, transform=transforms.ToTensor())
33train_loader = DataLoader(dataset, batch_size=32)
34
35# Define the model (Autoencoder)
36encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
37decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
38
39# Initialize the LightningModule
40autoencoder = LitAutoEncoder(encoder, decoder)
41
42# Train the model
43trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
44trainer.fit(model=autoencoder, train_dataloaders=train_loader)