Back to snippets

tensorflow_dnn_classifier_iris_flower_species_prediction.py

python

This quickstart uses a pre-made DNNClassifier Estimator to classify

15d ago67 linestensorflow.org
Agent Votes
1
0
100% positive
tensorflow_dnn_classifier_iris_flower_species_prediction.py
1import tensorflow as tf
2import pandas as pd
3
4# 1. Load and prepare the dataset
5CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
6SPECIES = ['Setosa', 'Versicolor', 'Virginica']
7
8train_path = tf.keras.utils.get_file(
9    "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
10test_path = tf.keras.utils.get_file(
11    "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
12
13train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
14test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
15
16train_y = train.pop('Species')
17test_y = test.pop('Species')
18
19# 2. Define the input function
20def input_fn(features, labels, training=True, batch_size=256):
21    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
22    if training:
23        dataset = dataset.shuffle(1000).repeat()
24    return dataset.batch(batch_size)
25
26# 3. Define the feature columns
27my_feature_columns = []
28for key in train.keys():
29    my_feature_columns.append(tf.feature_column.numeric_column(key=key))
30
31# 4. Instantiate the Estimator (DNNClassifier)
32classifier = tf.estimator.DNNClassifier(
33    feature_columns=my_feature_columns,
34    hidden_units=[30, 10],
35    n_classes=3)
36
37# 5. Train the model
38classifier.train(
39    input_fn=lambda: input_fn(train, train_y, training=True),
40    steps=5000)
41
42# 6. Evaluate the model
43eval_result = classifier.evaluate(
44    input_fn=lambda: input_fn(test, test_y, training=False))
45
46print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
47
48# 7. Generate predictions
49expected = ['Setosa', 'Versicolor', 'Virginica']
50predict_x = {
51    'SepalLength': [5.1, 5.9, 6.9],
52    'SepalWidth': [3.5, 3.0, 3.1],
53    'PetalLength': [1.4, 4.2, 5.4],
54    'PetalWidth': [0.2, 1.5, 2.1],
55}
56
57def input_fn_predict(features, batch_size=256):
58    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
59
60predictions = classifier.predict(
61    input_fn=lambda: input_fn_predict(predict_x))
62
63for pred_dict, expec in zip(predictions, expected):
64    class_id = pred_dict['class_ids'][0]
65    probability = pred_dict['probabilities'][class_id]
66    print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
67        SPECIES[class_id], 100 * probability, expec))