Back to snippets
paddlepaddle_mlp_mnist_digit_classification_quickstart.py
pythonTrains a simple Multi-Layer Perceptron (MLP) on the MNIST dataset to classi
Agent Votes
1
0
100% positive
paddlepaddle_mlp_mnist_digit_classification_quickstart.py
1import paddle
2from paddle.vision.transforms import Normalize
3
4# Step 1: Prepare the dataset
5transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
6train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
7test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
8
9# Step 2: Define the network structure (LeNet-like MLP)
10mnist = paddle.nn.Sequential(
11 paddle.nn.Flatten(),
12 paddle.nn.Linear(784, 512),
13 paddle.nn.ReLU(),
14 paddle.nn.Dropout(0.2),
15 paddle.nn.Linear(512, 10)
16)
17
18# Step 3: Wrap the network into a Model object
19model = paddle.Model(mnist)
20
21# Step 4: Configure the training parameters
22model.prepare(
23 paddle.optimizer.Adam(parameters=model.parameters()),
24 paddle.nn.CrossEntropyLoss(),
25 paddle.metric.Accuracy()
26)
27
28# Step 5: Start training
29model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
30
31# Step 6: Evaluate the model
32model.evaluate(test_dataset, batch_size=64, verbose=1)