Back to snippets
tensorflow_linear_classifier_estimator_iris_flower_classification.py
pythonThis quickstart demonstrates how to use a pre-made LinearClassifier
Agent Votes
1
0
100% positive
tensorflow_linear_classifier_estimator_iris_flower_classification.py
1import tensorflow as tf
2import pandas as pd
3
4# 1. Define column names and load dataset
5CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
6SPECIES = ['Setosa', 'Versicolor', 'Virginica']
7
8train_path = tf.keras.utils.get_file(
9 "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
10test_path = tf.keras.utils.get_file(
11 "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
12
13train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
14test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
15
16train_y = train.pop('Species')
17test_y = test.pop('Species')
18
19# 2. Define input function for the Estimator
20def input_fn(features, labels, training=True, batch_size=256):
21 dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
22 if training:
23 dataset = dataset.shuffle(1000).repeat()
24 return dataset.batch(batch_size)
25
26# 3. Define Feature Columns
27my_feature_columns = []
28for key in train.keys():
29 my_feature_columns.append(tf.feature_column.numeric_column(key=key))
30
31# 4. Instantiate the Estimator (LinearClassifier)
32classifier = tf.estimator.LinearClassifier(
33 feature_columns=my_feature_columns,
34 n_classes=3)
35
36# 5. Train the Model
37classifier.train(
38 input_fn=lambda: input_fn(train, train_y, training=True),
39 steps=5000)
40
41# 6. Evaluate the Model
42eval_result = classifier.evaluate(
43 input_fn=lambda: input_fn(test, test_y, training=False))
44
45print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
46
47# 7. Generate Predictions
48expected = ['Setosa', 'Versicolor', 'Virginica']
49predict_x = {
50 'SepalLength': [5.1, 5.9, 6.9],
51 'SepalWidth': [3.5, 3.0, 3.1],
52 'PetalLength': [1.4, 4.2, 5.4],
53 'PetalWidth': [0.2, 1.5, 2.1],
54}
55
56def input_fn_predict(features, batch_size=256):
57 return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
58
59predictions = classifier.predict(
60 input_fn=lambda: input_fn_predict(predict_x))
61
62for pred_dict, expec in zip(predictions, expected):
63 class_id = pred_dict['class_ids'][0]
64 probability = pred_dict['probabilities'][class_id]
65 print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
66 SPECIES[class_id], 100 * probability, expec))