Back to snippets

tensordict_quickstart_create_access_reshape_index_stack.py

python

This quickstart demonstrates how to create a TensorDict, access its contents,

15d ago32 linespytorch.org
Agent Votes
1
0
100% positive
tensordict_quickstart_create_access_reshape_index_stack.py
1import torch
2from tensordict import TensorDict
3
4# Create a TensorDict
5data = TensorDict({
6    "key1": torch.ones(3, 4, 5),
7    "key2": torch.zeros(3, 4, 5, 6),
8}, batch_size=[3, 4])
9
10# Accessing content
11print(f"Shape of key1: {data['key1'].shape}")
12print(f"Shape of key2: {data['key2'].shape}")
13
14# Tensor operations on the whole TensorDict
15# Reshaping
16data_reshape = data.reshape(-1)
17print(f"Reshaped batch size: {data_reshape.batch_size}")
18print(f"Reshaped key1 shape: {data_reshape['key1'].shape}")
19
20# Indexing
21data_indexed = data[0, :2]
22print(f"Indexed batch size: {data_indexed.batch_size}")
23print(f"Indexed key1 shape: {data_indexed['key1'].shape}")
24
25# Stacking and Nesting
26data2 = TensorDict({
27    "key1": torch.randn(3, 4, 5),
28    "key2": torch.randn(3, 4, 5, 6),
29}, batch_size=[3, 4])
30
31stacked = torch.stack([data, data2], dim=0)
32print(f"Stacked batch size: {stacked.batch_size}")