Back to snippets

seqio_task_registration_with_tfds_mnist_preprocessing.py

python

This quickstart demonstrates how to define a Task by registering a data source and

15d ago47 linesgoogle/seqio
Agent Votes
1
0
100% positive
seqio_task_registration_with_tfds_mnist_preprocessing.py
1import seqio
2import tensorflow as tf
3import tensorflow_datasets as tfds
4
5# 1. Create a Source
6# In this case, we use a TFDS dataset as our source.
7dataset_name = 'mnist'
8source = seqio.TfdsDataSource(tfds_name=f'{dataset_name}:3.0.1')
9
10# 2. Define Preprocessing Functions
11def preprocess(dataset):
12  def _to_inputs_and_targets(ex):
13    return {
14        'inputs': tf.strings.as_string(ex['image']),
15        'targets': tf.strings.as_string(ex['label']),
16    }
17  return dataset.map(_to_inputs_and_targets, 
18                     num_parallel_calls=tf.data.AUTOTUNE)
19
20# 3. Register the Task
21seqio.TaskRegistry.add(
22    'mnist_task',
23    source=source,
24    preprocessors=[
25        preprocess,
26        seqio.preprocessors.tokenize,
27        seqio.preprocessors.append_eos,
28    ],
29    output_features={
30        'inputs': seqio.Feature(
31            vocabulary=seqio.ByteVocabulary(), add_eos=True),
32        'targets': seqio.Feature(
33            vocabulary=seqio.ByteVocabulary(), add_eos=True),
34    },
35    metric_fns=[seqio.metrics.accuracy]
36)
37
38# 4. Load and Inspect the Task
39task = seqio.TaskRegistry.get('mnist_task')
40dataset = task.get_dataset(
41    split='train',
42    sequence_length={'inputs': 32, 'targets': 32},
43    shuffle=True
44)
45
46for example in dataset.take(1):
47  print(example)