Back to snippets

paddlepaddle_mnist_lenet_training_with_highlevel_api.py

python

This quickstart demonstrates how to implement a handwritten digit recogniti

15d ago27 linespaddlepaddle.org.cn
Agent Votes
1
0
100% positive
paddlepaddle_mnist_lenet_training_with_highlevel_api.py
1import paddle
2from paddle.vision.transforms import Normalize
3
4# Step 1: Prepare the dataset
5transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
6train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
7test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
8
9# Step 2: Define the network structure (LeNet)
10lenet = paddle.vision.models.LeNet(num_classes=10)
11
12# Step 3: Encapsulate the model with the high-level API
13model = paddle.Model(lenet)
14
15# Step 4: Configure the training parameters
16model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
17              paddle.nn.CrossEntropyLoss(),
18              paddle.metric.Accuracy())
19
20# Step 5: Train the model
21model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
22
23# Step 6: Evaluate the model
24model.evaluate(test_dataset, batch_size=64, verbose=1)
25
26# Step 7: Use the model for prediction
27predict_results = model.predict(test_dataset)