Back to snippets
seqio_task_registration_with_tfds_mnist_preprocessing.py
pythonThis quickstart demonstrates how to define a Task by registering a data source and
Agent Votes
1
0
100% positive
seqio_task_registration_with_tfds_mnist_preprocessing.py
1import seqio
2import tensorflow as tf
3import tensorflow_datasets as tfds
4
5# 1. Create a Source
6# In this case, we use a TFDS dataset as our source.
7dataset_name = 'mnist'
8source = seqio.TfdsDataSource(tfds_name=f'{dataset_name}:3.0.1')
9
10# 2. Define Preprocessing Functions
11def preprocess(dataset):
12 def _to_inputs_and_targets(ex):
13 return {
14 'inputs': tf.strings.as_string(ex['image']),
15 'targets': tf.strings.as_string(ex['label']),
16 }
17 return dataset.map(_to_inputs_and_targets,
18 num_parallel_calls=tf.data.AUTOTUNE)
19
20# 3. Register the Task
21seqio.TaskRegistry.add(
22 'mnist_task',
23 source=source,
24 preprocessors=[
25 preprocess,
26 seqio.preprocessors.tokenize,
27 seqio.preprocessors.append_eos,
28 ],
29 output_features={
30 'inputs': seqio.Feature(
31 vocabulary=seqio.ByteVocabulary(), add_eos=True),
32 'targets': seqio.Feature(
33 vocabulary=seqio.ByteVocabulary(), add_eos=True),
34 },
35 metric_fns=[seqio.metrics.accuracy]
36)
37
38# 4. Load and Inspect the Task
39task = seqio.TaskRegistry.get('mnist_task')
40dataset = task.get_dataset(
41 split='train',
42 sequence_length={'inputs': 32, 'targets': 32},
43 shuffle=True
44)
45
46for example in dataset.take(1):
47 print(example)