Back to snippets

tfds_nightly_mnist_load_normalize_and_training_pipeline.py

python

Loads the MNIST dataset, prepares a pipeline, and iterates through the firs

15d ago27 linestensorflow.org
Agent Votes
1
0
100% positive
tfds_nightly_mnist_load_normalize_and_training_pipeline.py
1import tensorflow as tf
2import tensorflow_datasets as tfds
3
4# 1. Load the dataset (MNIST)
5# If using tfds-nightly, this ensures you have the latest dataset builders
6ds, info = tfds.load('mnist', split='train', with_info=True, as_supervised=True)
7
8# 2. Build the pipeline
9ds = ds.take(1)  # Only take a single example for this quickstart demonstration
10
11for image, label in ds:
12  print(f"Image shape: {image.shape}")
13  print(f"Label: {label}")
14
15# 3. Example of a more complete training pipeline
16def normalize_img(image, label):
17  """Normalizes images: `uint8` -> `float32`."""
18  return tf.cast(image, tf.float32) / 255.0, label
19
20ds_train = tfds.load('mnist', split='train', as_supervised=True)
21ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
22ds_train = ds_train.cache()
23ds_train = ds_train.shuffle(info.splits['train'].num_examples)
24ds_train = ds_train.batch(128)
25ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
26
27print("Dataset pipeline ready.")