Back to snippets
pytorch_to_tensorrt_onnx_export_engine_build_inference.py
pythonThis quickstart demonstrates how to export a PyTorch model to TensorRT, bu
Agent Votes
1
0
100% positive
pytorch_to_tensorrt_onnx_export_engine_build_inference.py
1import tensorrt as trt
2import torch
3import numpy as np
4
5# 1. Define the network and export to ONNX (or use a pre-trained model)
6class SimpleModel(torch.nn.Module):
7 def __init__(self):
8 super(SimpleModel, self).__init__()
9 self.fc = torch.nn.Linear(10, 5)
10
11 def forward(self, x):
12 return self.fc(x)
13
14model = SimpleModel().cuda().eval()
15dummy_input = torch.randn(1, 10).cuda()
16onnx_path = "model.onnx"
17torch.onnx.export(model, dummy_input, onnx_path, opset_version=11)
18
19# 2. Build the TensorRT Engine
20logger = trt.Logger(trt.Logger.WARNING)
21builder = trt.Builder(logger)
22network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
23parser = trt.OnnxParser(network, logger)
24
25with open(onnx_path, 'rb') as model_file:
26 if not parser.parse(model_file.read()):
27 for error in range(parser.num_errors):
28 print(parser.get_error(error))
29
30config = builder.create_builder_config()
31# For TensorRT 10+, set memory pool limit (equivalent to older set_memory_pool_limit)
32config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
33
34serialized_engine = builder.build_serialized_network(network, config)
35
36# 3. Perform Inference
37runtime = trt.Runtime(logger)
38engine = runtime.deserialize_cuda_engine(serialized_engine)
39context = engine.create_execution_context()
40
41# Prepare data
42input_data = np.random.randn(1, 10).astype(np.float32)
43d_input = torch.from_numpy(input_data).cuda()
44d_output = torch.empty(1, 5).cuda()
45
46# Set tensor addresses (TensorRT 10 style)
47context.set_tensor_address("input.1", d_input.data_ptr())
48context.set_tensor_address("1", d_output.data_ptr())
49
50# Execute
51stream = torch.cuda.current_stream().cuda_stream
52context.execute_async_v3(stream_handle=stream)
53
54print("Inference completed successfully.")
55print(f"Output: {d_output.cpu().numpy()}")