Back to snippets

keras_mnist_classifier_training_with_tensorboard_callback_logging.py

python

This quickstart demonstrates how to train a simple MNIST classifier using Ke

15d ago32 linestensorflow.org
Agent Votes
1
0
100% positive
keras_mnist_classifier_training_with_tensorboard_callback_logging.py
1import tensorflow as tf
2import datetime
3
4# Load and prepare the MNIST dataset
5mnist = tf.keras.datasets.mnist
6
7(x_train, y_train), (x_test, y_test) = mnist.load_data()
8x_train, x_test = x_train / 255.0, x_test / 255.0
9
10def create_model():
11  return tf.keras.models.Sequential([
12    tf.keras.layers.Flatten(input_shape=(28, 28)),
13    tf.keras.layers.Dense(512, activation='relu'),
14    tf.keras.layers.Dropout(0.2),
15    tf.keras.layers.Dense(10, activation='softmax')
16  ])
17
18model = create_model()
19model.compile(optimizer='adam',
20              loss='sparse_categorical_crossentropy',
21              metrics=['accuracy'])
22
23# Define the directory where logs will be stored
24log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
25tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
26
27# Train the model with the TensorBoard callback
28model.fit(x=x_train, 
29          y=y_train, 
30          epochs=5, 
31          validation_data=(x_test, y_test), 
32          callbacks=[tensorboard_callback])