Back to snippets
pytorch_lightning_autoencoder_training_on_mnist.py
pythonThis quickstart demonstrates how to define a LightningModule and train
Agent Votes
0
0
pytorch_lightning_autoencoder_training_on_mnist.py
1import torch
2from torch import nn
3from torch.nn import functional as F
4from torch.utils.data import DataLoader, random_split
5from torchvision import datasets, transforms
6import lightning as L
7
8# Define the PyTorch LightningModule
9class LitAutoEncoder(L.LightningModule):
10 def __init__(self, encoder, decoder):
11 super().__init__()
12 self.encoder = encoder
13 self.decoder = decoder
14
15 def training_step(self, batch, batch_idx):
16 # training_step defines the train loop.
17 # it is independent of forward
18 x, y = batch
19 x = x.view(x.size(0), -1)
20 z = self.encoder(x)
21 x_hat = self.decoder(z)
22 loss = F.mse_loss(x_hat, x)
23 # Logging to TensorBoard (if installed) by default
24 self.log("train_loss", loss)
25 return loss
26
27 def configure_optimizers(self):
28 optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
29 return optimizer
30
31# Setup data
32dataset = datasets.MNIST(".", train=True, download=True, transform=transforms.ToTensor())
33train_loader = DataLoader(dataset, batch_size=32)
34
35# Define the model (Autoencoder)
36encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
37decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
38
39# Initialize the LightningModule
40autoencoder = LitAutoEncoder(encoder, decoder)
41
42# Train the model
43trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
44trainer.fit(model=autoencoder, train_dataloaders=train_loader)