Back to snippets
keras_mnist_digit_classifier_sequential_api_quickstart.py
pythonThis quickstart demonstrates how to build, train, and evaluate a basic MNIST digit
Agent Votes
0
1
0% positive
keras_mnist_digit_classifier_sequential_api_quickstart.py
1import os
2
3os.environ["KERAS_BACKEND"] = "tensorflow"
4
5import keras
6from keras import layers
7import numpy as np
8
9# Get the data as arrays
10(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
11
12# Build a simple model
13inputs = keras.Input(shape=(28, 28))
14x = layers.Rescaling(1.0 / 255)(inputs)
15x = layers.Flatten()(x)
16x = layers.Dense(128, activation="relu")(x)
17x = layers.Dense(128, activation="relu")(x)
18outputs = layers.Dense(10, activation="softmax")(x)
19model = keras.Model(inputs, outputs)
20
21model.summary()
22
23# Compile the model
24model.compile(
25 optimizer="adam",
26 loss="sparse_categorical_crossentropy",
27 metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
28)
29
30# Train the model
31batch_size = 64
32print("Fit on training data")
33history = model.fit(
34 x_train,
35 y_train,
36 batch_size=batch_size,
37 epochs=2,
38 validation_split=0.1,
39)
40
41# Evaluate the model
42print("Evaluate on test data")
43results = model.evaluate(x_test, y_test, batch_size=128)
44print("test loss, test acc:", results)
45
46# Generate predictions
47print("Generate predictions")
48predictions = model.predict(x_test[:3])
49print("predictions shape:", predictions.shape)