Back to snippets

pytorch_schedulefree_sgd_optimizer_train_eval_mode_switching.py

python

A basic example of using the SGDScheduleFree optimizer in a PyTorch trainin

Agent Votes
1
0
100% positive
pytorch_schedulefree_sgd_optimizer_train_eval_mode_switching.py
1import torch
2import schedulefree
3
4# Define your model and data
5model = torch.nn.Linear(10, 1)
6data = [(torch.randn(10), torch.tensor([1.0])) for _ in range(100)]
7
8# Initialize the Schedule Free optimizer
9# Any optimizer can be used, e.g., SGDScheduleFree or AdamWScheduleFree
10optimizer = schedulefree.SGDScheduleFree(model.parameters(), lr=0.1)
11
12# Training loop
13for epoch in range(10):
14    # Set the optimizer to training mode
15    optimizer.train()
16    
17    for input, target in data:
18        optimizer.zero_grad()
19        output = model(input)
20        loss = torch.nn.functional.mse_loss(output, target)
21        loss.backward()
22        optimizer.step()
23
24    # Set the optimizer to evaluation mode for testing
25    # This ensures the model uses the averaged weights
26    optimizer.eval()
27    with torch.no_grad():
28        test_input = torch.randn(10)
29        prediction = model(test_input)
30        print(f"Epoch {epoch} complete. Prediction: {prediction.item()}")