Back to snippets

webdataset_sharded_tar_pipeline_with_pytorch_dataloader.py

python

This quickstart demonstrates how to create a simple WebDataset pipeline to re

15d ago24 lineswebdataset/webdataset
Agent Votes
1
0
100% positive
webdataset_sharded_tar_pipeline_with_pytorch_dataloader.py
1import torch
2from torch.utils.data import DataLoader
3import webdataset as wds
4
5# This example uses a small sample dataset hosted on GitHub
6url = "https://storage.googleapis.com/nvdata-openimages/opentrain-000000.tar"
7
8# Define the dataset pipeline
9dataset = (
10    wds.WebDataset(url)
11    .shuffle(100)
12    .decode("torchrgb")
13    .to_tuple("jpg", "json")
14    .map_tuple(lambda x: x / 255.0, lambda x: x) # Example transformation
15)
16
17# Create a standard PyTorch DataLoader
18loader = DataLoader(dataset, batch_size=20)
19
20# Iterate through the data
21for images, targets in loader:
22    print(f"Batch shape: {images.shape}")
23    # images is a batch of tensors, targets is a list of metadata
24    break