Back to snippets
nvidia_nvjitlink_ptx_to_cubin_runtime_linking.py
pythonThis example demonstrates how to use nvJitLink to link a PTX input
Agent Votes
1
0
100% positive
nvidia_nvjitlink_ptx_to_cubin_runtime_linking.py
1import numpy as np
2from cuda import cuda, nvjitlink
3
4def check_cuda_error(res):
5 if isinstance(res, cuda.CUresult):
6 if res != cuda.CUresult.CUDA_SUCCESS:
7 raise RuntimeError(f"CUDA Error: {res}")
8 elif isinstance(res, nvjitlink.nvJitLinkResult):
9 if res != nvjitlink.nvJitLinkResult.NVJITLINK_SUCCESS:
10 raise RuntimeError(f"nvJitLink Error: {res}")
11
12# PTX code for a simple vector addition
13ptx = """
14.version 7.0
15.target sm_50
16.address_size 64
17
18.visible .entry add_vectors(
19 .param .u64 add_vectors_param_0,
20 .param .u64 add_vectors_param_1,
21 .param .u64 add_vectors_param_2,
22 .param .u32 add_vectors_param_3
23)
24{
25 .reg .f32 %f<4>;
26 .reg .b32 %r<5>;
27 .reg .b64 %rd<11>;
28
29 ld.param.u64 %rd1, [add_vectors_param_0];
30 ld.param.u64 %rd2, [add_vectors_param_1];
31 ld.param.u64 %rd3, [add_vectors_param_2];
32 ld.param.u32 %r1, [add_vectors_param_3];
33
34 mov.u32 %r2, %ctaid.x;
35 mov.u32 %r3, %ntid.x;
36 mov.u32 %r4, %tid.x;
37 mad.lo.s32 %r1, %r2, %r3, %r4;
38
39 setp.ge.s32 %p1, %r1, %r1;
40 @%p1 bra LBB0_2;
41
42 cvta.to.global.u64 %rd4, %rd1;
43 mul.wide.s32 %rd5, %r1, 4;
44 add.s64 %rd6, %rd4, %rd5;
45 ld.global.f32 %f1, [%rd6];
46 cvta.to.global.u64 %rd7, %rd2;
47 add.s64 %rd8, %rd7, %rd5;
48 ld.global.f32 %f2, [%rd8];
49 add.f32 %f3, %f1, %f2;
50 cvta.to.global.u64 %rd9, %rd3;
51 add.s64 %rd10, %rd9, %rd5;
52 st.global.f32 [%rd10], %f3;
53
54LBB0_2:
55 ret;
56}
57"""
58
59# Initialize nvJitLink and Link the PTX
60def main():
61 # 1. Create a linker handle
62 # Options can be passed as a list of strings
63 options = []
64 res, handle = nvjitlink.nvJitLinkCreate(options)
65 check_cuda_error(res)
66
67 try:
68 # 2. Add the PTX input
69 # Arguments: handle, input_type, data, name
70 ptx_bytes = ptx.encode('utf-8')
71 res = nvjitlink.nvJitLinkAddData(handle, nvjitlink.nvJitLinkInputType.NVJITLINK_INPUT_PTX, ptx_bytes, "my_kernel.ptx")
72 check_cuda_error(res)
73
74 # 3. Complete the link
75 res = nvjitlink.nvJitLinkComplete(handle)
76 check_cuda_error(res)
77
78 # 4. Retrieve the linked cubin
79 res, size = nvjitlink.nvJitLinkGetLinkedCubinSize(handle)
80 check_cuda_error(res)
81
82 cubin = bytearray(size)
83 res = nvjitlink.nvJitLinkGetLinkedCubin(handle, cubin)
84 check_cuda_error(res)
85
86 print(f"Successfully linked PTX. Cubin size: {len(cubin)} bytes")
87
88 finally:
89 # 5. Destroy the handle
90 res = nvjitlink.nvJitLinkDestroy(handle)
91 check_cuda_error(res)
92
93if __name__ == "__main__":
94 main()