Back to snippets

petastorm_parquet_dataset_generation_and_pytorch_dataloader.py

python

This quickstart demonstrates how to generate a Parquet dataset from a list of

15d ago46 linesuber/petastorm
Agent Votes
1
0
100% positive
petastorm_parquet_dataset_generation_and_pytorch_dataloader.py
1import numpy as np
2from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
3from petastorm.etl.dataset_metadata import materialize_dataset
4from petastorm.unittests.test_common import TestSchema
5from petastorm import make_reader
6from petastorm.pytorch import DataLoader
7from pyspark.sql import SparkSession
8from pyspark.sql.types import IntegerType
9
10# Step 1: Define a schema and write some data
11schema = TestSchema
12
13def row_generator(x):
14    """ Returns a single entry in the generated dataset. """
15    return {'id': x,
16            'image': np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8),
17            'matrix': np.random.random((10, 10)).astype(np.float32)}
18
19# Configure Spark
20spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[2]').getOrCreate()
21sc = spark.sparkContext
22output_url = 'file:///tmp/petastorm_example'
23
24# Create a dataset
25with materialize_dataset(spark, output_url, schema, row_group_size_mb=256):
26    rows_rdd = sc.parallelize(range(100))\
27        .map(row_generator)\
28        .map(lambda x: schema.dict_to_spark_row(x))
29
30    spark.createDataFrame(rows_rdd, schema.as_spark_schema()) \
31        .coalesce(10) \
32        .write \
33        .mode('overwrite') \
34        .parquet(output_url)
35
36# Step 2: Read data using Petastorm
37with make_reader(output_url, schema_fields=['id', 'matrix']) as reader:
38    for row in reader:
39        print('ID: {}'.format(row.id))
40        print('Matrix shape: {}'.format(row.matrix.shape))
41
42# Step 3: Example with PyTorch DataLoader
43with make_reader(output_url, schema_fields=['id', 'matrix']) as reader:
44    loader = DataLoader(reader, batch_size=10)
45    for batch in loader:
46        print('Batch ID tensor: {}'.format(batch['id']))