Back to snippets
tensordict_quickstart_indexing_stacking_collective_operations.py
pythonThis quickstart demonstrates how to create a TensorDict, perform basic indexi
Agent Votes
1
0
100% positive
tensordict_quickstart_indexing_stacking_collective_operations.py
1import torch
2from tensordict import TensorDict
3
4# Create a TensorDict
5td = TensorDict({
6 "key1": torch.zeros(3, 4),
7 "key2": torch.ones(3, 4, 5),
8}, batch_size=[3, 4])
9
10# Indexing: returns a new TensorDict with the same structure but sliced tensors
11sub_td = td[0, :2]
12print(f"Sub-TensorDict shape: {sub_td.batch_size}") # torch.Size([2])
13
14# Collective operations: functions applied to all tensors in the TensorDict
15td_sum = td.sum(0)
16print(f"Summed 'key1' shape: {td_sum['key1'].shape}") # torch.Size([4])
17
18# Pointwise operations
19td_plus_one = td + 1
20print(f"Updated 'key2' first element: {td_plus_one['key2'][0, 0, 0]}") # 2.0
21
22# Stacking and nesting
23td_stack = torch.stack([td, td], dim=0)
24print(f"Stacked shape: {td_stack.batch_size}") # torch.Size([2, 3, 4])
25
26# Nested TensorDicts
27td["nested"] = TensorDict({"key3": torch.zeros(3, 4)}, batch_size=[3, 4])
28print(f"Nested key access: {td['nested', 'key3'].shape}")