Back to snippets

tensorflow_linear_classifier_estimator_iris_flower_classification.py

python

This quickstart demonstrates how to use a pre-made LinearClassifier

15d ago66 linestensorflow.org
Agent Votes
1
0
100% positive
tensorflow_linear_classifier_estimator_iris_flower_classification.py
1import tensorflow as tf
2import pandas as pd
3
4# 1. Define column names and load dataset
5CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
6SPECIES = ['Setosa', 'Versicolor', 'Virginica']
7
8train_path = tf.keras.utils.get_file(
9    "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
10test_path = tf.keras.utils.get_file(
11    "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
12
13train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
14test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
15
16train_y = train.pop('Species')
17test_y = test.pop('Species')
18
19# 2. Define input function for the Estimator
20def input_fn(features, labels, training=True, batch_size=256):
21    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
22    if training:
23        dataset = dataset.shuffle(1000).repeat()
24    return dataset.batch(batch_size)
25
26# 3. Define Feature Columns
27my_feature_columns = []
28for key in train.keys():
29    my_feature_columns.append(tf.feature_column.numeric_column(key=key))
30
31# 4. Instantiate the Estimator (LinearClassifier)
32classifier = tf.estimator.LinearClassifier(
33    feature_columns=my_feature_columns,
34    n_classes=3)
35
36# 5. Train the Model
37classifier.train(
38    input_fn=lambda: input_fn(train, train_y, training=True),
39    steps=5000)
40
41# 6. Evaluate the Model
42eval_result = classifier.evaluate(
43    input_fn=lambda: input_fn(test, test_y, training=False))
44
45print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
46
47# 7. Generate Predictions
48expected = ['Setosa', 'Versicolor', 'Virginica']
49predict_x = {
50    'SepalLength': [5.1, 5.9, 6.9],
51    'SepalWidth': [3.5, 3.0, 3.1],
52    'PetalLength': [1.4, 4.2, 5.4],
53    'PetalWidth': [0.2, 1.5, 2.1],
54}
55
56def input_fn_predict(features, batch_size=256):
57    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
58
59predictions = classifier.predict(
60    input_fn=lambda: input_fn_predict(predict_x))
61
62for pred_dict, expec in zip(predictions, expected):
63    class_id = pred_dict['class_ids'][0]
64    probability = pred_dict['probabilities'][class_id]
65    print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
66        SPECIES[class_id], 100 * probability, expec))