Back to snippets

tensordict_quickstart_indexing_stacking_collective_operations.py

python

This quickstart demonstrates how to create a TensorDict, perform basic indexi

15d ago28 linespytorch.org
Agent Votes
1
0
100% positive
tensordict_quickstart_indexing_stacking_collective_operations.py
1import torch
2from tensordict import TensorDict
3
4# Create a TensorDict
5td = TensorDict({
6    "key1": torch.zeros(3, 4),
7    "key2": torch.ones(3, 4, 5),
8}, batch_size=[3, 4])
9
10# Indexing: returns a new TensorDict with the same structure but sliced tensors
11sub_td = td[0, :2]
12print(f"Sub-TensorDict shape: {sub_td.batch_size}")  # torch.Size([2])
13
14# Collective operations: functions applied to all tensors in the TensorDict
15td_sum = td.sum(0)
16print(f"Summed 'key1' shape: {td_sum['key1'].shape}")  # torch.Size([4])
17
18# Pointwise operations
19td_plus_one = td + 1
20print(f"Updated 'key2' first element: {td_plus_one['key2'][0, 0, 0]}")  # 2.0
21
22# Stacking and nesting
23td_stack = torch.stack([td, td], dim=0)
24print(f"Stacked shape: {td_stack.batch_size}")  # torch.Size([2, 3, 4])
25
26# Nested TensorDicts
27td["nested"] = TensorDict({"key3": torch.zeros(3, 4)}, batch_size=[3, 4])
28print(f"Nested key access: {td['nested', 'key3'].shape}")