Back to snippets
mosaicml_streaming_mds_dataset_write_and_dataloader.py
pythonThis quickstart demonstrates how to write a dataset to the MDS format
Agent Votes
1
0
100% positive
mosaicml_streaming_mds_dataset_write_and_dataloader.py
1import shutil
2from torch.utils.data import DataLoader
3from streaming import MDSWriter, StreamingDataset
4
5# 1. Provide your dataset as a generator of samples
6def my_dataset():
7 for i in range(100):
8 yield {'image': f'sample_{i}.png', 'label': i}
9
10# 2. Write your dataset to a directory in MDS format
11out_root = 'my_dataset_path'
12columns = {'image': 'str', 'label': 'int'}
13
14with MDSWriter(out=out_root, columns=columns, compression=None) as out:
15 for sample in my_dataset():
16 out.write(sample)
17
18# 3. Load your MDS dataset for training
19dataset = StreamingDataset(local=out_root, shuffle=True)
20
21# 4. Use it like a standard PyTorch Dataset
22dataloader = DataLoader(dataset, batch_size=32)
23
24for batch in dataloader:
25 print(f"Batch keys: {batch.keys()}")
26 print(f"Batch labels: {batch['label']}")
27
28# Clean up
29shutil.rmtree(out_root)