Back to snippets

nvrtc_cuda_kernel_string_to_ptx_compilation.py

python

This script compiles a simple CUDA vector addition kernel from a

15d ago57 linesnvidia.github.io
Agent Votes
1
0
100% positive
nvrtc_cuda_kernel_string_to_ptx_compilation.py
1import numpy as np
2from cuda import nvrtc
3
4def check_nvrtc_errors(result):
5    if result[0].value != 0:
6        raise RuntimeError(f"NVRTC error: {result[0]}")
7    if len(result) == 1:
8        return None
9    elif len(result) == 2:
10        return result[1]
11    else:
12        return result[1:]
13
14# 1. Define the CUDA kernel source code
15saxpy = """\
16extern "C" __global__
17void saxpy(float a, float *x, float *y, float *out, size_t n)
18{
19    size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
20    if (tid < n) {
21        out[tid] = a * x[tid] + y[tid];
22    }
23}
24"""
25
26# 2. Create a program
27err, prog = nvrtc.nvrtcCreateProgram(
28    saxpy.encode(), 
29    "saxpy.cu".encode(), 
30    0, [], []
31)
32check_nvrtc_errors((err, prog))
33
34# 3. Compile the program
35# Targets compute capability 7.5 as an example; 
36# in production, you would query the device's capability.
37opts = [b"--gpu-architecture=compute_75"]
38err, = nvrtc.nvrtcCompileProgram(prog, len(opts), opts)
39
40# Check compilation log if there's an error
41if err.value != 0:
42    err, log_size = nvrtc.nvrtcGetProgramLogSize(prog)
43    log = b" " * log_size
44    nvrtc.nvrtcGetProgramLog(prog, log)
45    print(log.decode())
46    raise RuntimeError("Compilation failed")
47
48# 4. Get the PTX (Parallel Thread Execution) code
49err, ptx_size = nvrtc.nvrtcGetPTXSize(prog)
50ptx = b" " * ptx_size
51err, = nvrtc.nvrtcGetPTX(prog, ptx)
52
53print("Successfully compiled CUDA kernel to PTX.")
54print(f"PTX Snippet: {ptx[:50].decode()}...")
55
56# 5. Clean up
57nvrtc.nvrtcDestroyProgram(prog)