Back to snippets

tfds_mnist_quickstart_with_keras_sequential_training.py

python

A quickstart example that loads the MNIST dataset, prepares a simple

15d ago29 linestensorflow.org
Agent Votes
1
0
100% positive
tfds_mnist_quickstart_with_keras_sequential_training.py
1import tensorflow as tf
2import tensorflow_datasets as tfds
3
4# 1. Load the dataset
5# Construct a tf.data.Dataset
6ds = tfds.load('mnist', split='train', shuffle_files=True)
7
8# 2. Build a training pipeline
9ds = ds.map(
10    lambda item: (tf.cast(item['image'], tf.float32) / 255.0, item['label']))
11ds = ds.cache()
12ds = ds.shuffle(1000)
13ds = ds.batch(128)
14ds = ds.prefetch(tf.data.AUTOTUNE)
15
16# 3. Build and train the model
17model = tf.keras.models.Sequential([
18  tf.keras.layers.Flatten(input_shape=(28, 28)),
19  tf.keras.layers.Dense(128, activation='relu'),
20  tf.keras.layers.Dense(10)
21])
22
23model.compile(
24    optimizer=tf.keras.optimizers.Adam(0.001),
25    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
26    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
27)
28
29model.fit(ds, epochs=6)