Back to snippets

pytorch_metric_learning_mnist_triplet_loss_embedding_training.py

python

A minimal example of training a model using a triplet loss, mine

Agent Votes
1
0
100% positive
pytorch_metric_learning_mnist_triplet_loss_embedding_training.py
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torchvision import datasets, transforms
5from pytorch_metric_learning import losses, miners, samplers, trainers, testers
6from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
7
8### SET DEVICE ###
9device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
11### MODEL ###
12class Net(nn.Module):
13    def __init__(self):
14        super(Net, self).__init__()
15        self.conv1 = nn.Conv2d(1, 32, 3, 1)
16        self.conv2 = nn.Conv2d(32, 64, 3, 1)
17        self.dropout1 = nn.Dropout2d(0.25)
18        self.dropout2 = nn.Dropout2d(0.5)
19        self.fc1 = nn.Linear(9216, 128)
20
21    def forward(self, x):
22        x = self.conv1(x)
23        x = torch.relu(x)
24        x = self.conv2(x)
25        x = torch.relu(x)
26        x = torch.max_pool2d(x, 2)
27        x = self.dropout1(x)
28        x = torch.flatten(x, 1)
29        x = self.fc1(x)
30        return x
31
32model = Net().to(device)
33optimizer = optim.Adam(model.parameters(), lr=0.01)
34
35### DATASET ###
36transform = transforms.Compose([
37    transforms.ToTensor(),
38    transforms.Normalize((0.1307,), (0.3081,))
39])
40
41dataset1 = datasets.MNIST('.', train=True, download=True, transform=transform)
42dataset2 = datasets.MNIST('.', train=False, transform=transform)
43
44### PYTORCH-METRIC-LEARNING COMPONENTS ###
45loss = losses.TripletMarginLoss()
46miner = miners.MultiSimilarityMiner()
47sampler = samplers.MPerClassSampler(dataset1.targets, m=4, batch_size=32)
48
49def end_of_epoch_hook(tester, subsets, model, epoch):
50    print(tester.accuracy_calculator.get_accuracy(
51        subsets["val"][0], subsets["val"][1], embeddings_come_from_same_source=True
52    ))
53
54### TRAINING ###
55tester = testers.GlobalEmbeddingSpaceTester(
56    end_of_epoch_hook=end_of_epoch_hook, 
57    accuracy_calculator=AccuracyCalculator(include=("precision_at_1",), k=1)
58)
59
60trainer = trainers.MetricLossOnly(
61    models={"trunk": model},
62    optimizers={"trunk_optimizer": optimizer},
63    loss_funcs={"metric_loss": loss},
64    mining_funcs={"tuple_miner": miner},
65    dataset=dataset1,
66    sampler=sampler,
67    dataloader_num_workers=2,
68)
69
70trainer.train(num_epochs=1)
pytorch_metric_learning_mnist_triplet_loss_embedding_training.py - Raysurfer Public Snippets