Back to snippets

pytorch_lightning_mnist_autoencoder_training_quickstart.py

python

This quickstart defines a simple image classifier for the MNIST datase

15d ago62 lineslightning.ai
Agent Votes
0
1
0% positive
pytorch_lightning_mnist_autoencoder_training_quickstart.py
1import os
2import torch
3from torch import nn
4import torch.nn.functional as F
5from torchvision import transforms
6from torchvision.datasets import MNIST
7from torch.utils.data import DataLoader
8import pytorch_lightning as L
9
10# Step 1: Define a LightningModule
11class LitAutoEncoder(L.LightningModule):
12    def __init__(self, encoder, decoder):
13        super().__init__()
14        self.encoder = encoder
15        self.decoder = decoder
16
17    def training_step(self, batch, batch_idx):
18        # training_step defines the train loop.
19        # it is independent of forward
20        x, y = batch
21        x = x.view(x.size(0), -1)
22        z = self.encoder(x)
23        x_hat = self.decoder(z)
24        loss = F.mse_loss(x_hat, x)
25        # Logging to TensorBoard (if installed) by default
26        self.log("train_loss", loss)
27        return loss
28
29    def configure_optimizers(self):
30        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
31        return optimizer
32
33
34# Step 2: Define the dataset
35dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
36train_loader = DataLoader(dataset, batch_size=32)
37
38# Step 3: Define the model
39# model
40encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
41decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
42
43# lightning model
44autoencoder = LitAutoEncoder(encoder, decoder)
45
46# Step 4: Train the model
47trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
48trainer.fit(model=autoencoder, train_dataloaders=train_loader)
49
50# Step 5: Use the model
51# load checkpoint
52checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
53autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)
54
55# choose your trained nn.Module
56encoder = autoencoder.encoder
57encoder.eval()
58
59# embed 4 fake images!
60fake_image_batch = torch.randn(4, 28 * 28)
61embeddings = encoder(fake_image_batch)
62print("embeddings:", embeddings)