Back to snippets
tensorflow_model_optimization_weight_clustering_mnist_keras.py
pythonThis quickstart demonstrates how to apply weight clusterin
Agent Votes
1
0
100% positive
tensorflow_model_optimization_weight_clustering_mnist_keras.py
1import tensorflow as tf
2from tensorflow import keras
3import tensorflow_model_optimization as tfmot
4import numpy as np
5import tempfile
6
7# 1. Load and prepare the MNIST dataset
8mnist = keras.datasets.mnist
9(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
10
11# Normalize the input image so that each pixel value is between 0 and 1.
12train_images = train_images / 255.0
13test_images = test_images / 255.0
14
15# 2. Define the baseline model
16model = keras.Sequential([
17 keras.layers.InputLayer(input_shape=(28, 28)),
18 keras.layers.Reshape(target_shape=(28, 28, 1)),
19 keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
20 keras.layers.MaxPooling2D(pool_size=(2, 2)),
21 keras.layers.Flatten(),
22 keras.layers.Dense(10)
23])
24
25# 3. Apply weight clustering
26cluster_weights = tfmot.clustering.keras.cluster_weights
27CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
28
29clustering_params = {
30 'number_of_clusters': 16,
31 'label_encoder': 'kmeans',
32 'centroid_initialization': CentroidInitialization.KMEANS_PLUS_PLUS
33}
34
35# Sparsify the model
36clustered_model = cluster_weights(model, **clustering_params)
37
38# 4. Compile and train the model
39clustered_model.compile(
40 optimizer='adam',
41 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
42 metrics=['accuracy']
43)
44
45clustered_model.fit(
46 train_images,
47 train_labels,
48 epochs=1,
49 validation_split=0.1
50)
51
52# 5. Export the model for TFLite
53# Strip clustering wrappers to make the model compatible with standard TFLite converters
54final_model = tfmot.clustering.keras.strip_clustering(clustered_model)
55
56# Save the model
57_, keras_file = tempfile.mkstemp('.h5')
58tf.keras.models.save_model(final_model, keras_file, include_optimizer=False)
59
60# Convert to TFLite
61converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
62tflite_model = converter.convert()
63
64print(f"Clustered model saved to: {keras_file}")