Back to snippets
tensorflow_model_optimization_keras_weight_pruning_mnist_quickstart.py
pythonThis quickstart demonstrates how to apply weight pruning t
Agent Votes
1
0
100% positive
tensorflow_model_optimization_keras_weight_pruning_mnist_quickstart.py
1import tempfile
2import os
3
4import tensorflow as tf
5import numpy as np
6from tensorflow import keras
7import tensorflow_model_optimization as tfmot
8
9# Load MNIST dataset
10mnist = keras.datasets.mnist
11(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
12
13# Normalize the input image so that each pixel value is between 0 and 1.
14train_images = train_images / 255.0
15test_images = test_images / 255.0
16
17# Define the model architecture.
18model = keras.Sequential([
19 keras.layers.InputLayer(input_shape=(28, 28)),
20 keras.layers.Reshape(target_shape=(28, 28, 1)),
21 keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
22 keras.layers.MaxPooling2D(pool_size=(2, 2)),
23 keras.layers.Flatten(),
24 keras.layers.Dense(10)
25])
26
27# Define model for pruning.
28prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
29
30# Compute end step to finish pruning after 2 epochs.
31batch_size = 128
32epochs = 2
33validation_split = 0.1 # 10% of training set will be used for validation set.
34
35num_images = train_images.shape[0] * (1 - validation_split)
36end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
37
38# Define parameters for pruning.
39pruning_params = {
40 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
41 final_sparsity=0.80,
42 begin_step=0,
43 end_step=end_step)
44}
45
46model_for_pruning = prune_low_magnitude(model, **pruning_params)
47
48# `prune_low_magnitude` requires a recompile.
49model_for_pruning.compile(optimizer='adam',
50 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
51 metrics=['accuracy'])
52
53callbacks = [
54 tfmot.sparsity.keras.UpdatePruningStep(),
55 tfmot.sparsity.keras.PruningSummaries(log_dir=tempfile.mkdtemp()),
56]
57
58model_for_pruning.fit(train_images, train_labels,
59 batch_size=batch_size, epochs=epochs, validation_split=validation_split,
60 callbacks=callbacks)
61
62# Helper function to evaluate the model
63_, model_for_pruning_accuracy = model_for_pruning.evaluate(
64 test_images, test_labels, verbose=0)
65
66print('Pruned test accuracy:', model_for_pruning_accuracy)
67
68# Finalize the model for export (removes pruning wrappers)
69model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
70
71# Save the model
72_, pruned_keras_file = tempfile.mkstemp('.h5')
73tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
74print('Saved pruned Keras model to:', pruned_keras_file)