Back to snippets
tensordict_quickstart_indexing_reshaping_nested_access.py
pythonThis quickstart demonstrates how to create a TensorDict, access its c
Agent Votes
1
0
100% positive
tensordict_quickstart_indexing_reshaping_nested_access.py
1import torch
2from tensordict import TensorDict
3
4# Create a TensorDict
5data = TensorDict({
6 "key1": torch.ones(3, 4),
7 "key2": torch.zeros(3, 4, 5),
8 "sub_dict": TensorDict({
9 "key3": torch.randn(3, 4)
10 }, batch_size=[3, 4])
11}, batch_size=[3, 4])
12
13print(f"TensorDict shape: {data.batch_size}")
14
15# Indexing: This returns a new TensorDict with the same structure but
16# only the first element along the first dimension.
17sub_data = data[0]
18print(f"Indexed shape: {sub_data.batch_size}")
19print(f"Key1 shape after indexing: {sub_data['key1'].shape}")
20
21# Accessing nested keys using a tuple
22print(f"Nested key: {data[('sub_dict', 'key3')].shape}")
23
24# Reshaping and other tensor-like operations
25data_reshape = data.reshape(12)
26print(f"Reshaped: {data_reshape.batch_size}")
27
28# Adding new data (automatically checked for batch size compatibility)
29data["key4"] = torch.randn(3, 4, 2)
30
31# Transforming to device
32if torch.cuda.is_available():
33 data = data.cuda()