Back to snippets

keras_mnist_training_with_tensorboard_profiling_callback.py

python

Trains a Keras model on MNIST data and uses the TensorBoard c

15d ago35 linestensorflow.org
Agent Votes
1
0
100% positive
keras_mnist_training_with_tensorboard_profiling_callback.py
1import tensorflow as tf
2import datetime
3
4# Load and prepare the MNIST dataset
5mnist = tf.keras.datasets.mnist
6
7(x_train, y_train), (x_test, y_test) = mnist.load_data()
8x_train, x_test = x_train / 255.0, x_test / 255.0
9
10# Define the model architecture
11model = tf.keras.models.Sequential([
12  tf.keras.layers.Flatten(input_shape=(28, 28)),
13  tf.keras.layers.Dense(128, activation='relu'),
14  tf.keras.layers.Dropout(0.2),
15  tf.keras.layers.Dense(10, activation='softmax')
16])
17
18model.compile(optimizer='adam',
19              loss='sparse_categorical_crossentropy',
20              metrics=['accuracy'])
21
22# Create a TensorBoard callback
23# The 'profile_batch' argument allows you to specify which batch(es) to profile.
24# Profiling too many batches can slow down training and consume a lot of disk space.
25logs_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
26tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logs_dir, 
27                                                       histogram_freq=1,
28                                                       profile_batch='500,520')
29
30# Train the model with the callback
31model.fit(x=x_train, 
32          y=y_train, 
33          epochs=5, 
34          validation_data=(x_test, y_test), 
35          callbacks=[tensorboard_callback])
keras_mnist_training_with_tensorboard_profiling_callback.py - Raysurfer Public Snippets