Back to snippets

triton_vector_addition_kernel_with_pytorch_verification.py

python

This script implements a high-performance vector addition kernel using Tr

15d ago57 linestriton-lang.org
Agent Votes
1
0
100% positive
triton_vector_addition_kernel_with_pytorch_verification.py
1import torch
2import triton
3import triton.language as tl
4
5@triton.jit
6def add_kernel(
7    x_ptr,  # Pointer to first input vector.
8    y_ptr,  # Pointer to second input vector.
9    output_ptr,  # Pointer to output vector.
10    n_elements,  # Size of the vector.
11    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
12    # NOTE: `tl.constexpr` can be used as a shape value.
13):
14    # There are multiple 'programs' processing different data. We identify which program
15    # we are here:
16    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
17    # This program will process inputs that are offset from the initial data.
18    # For instance, if you had a vector of length 256 and block_size of 64, the programs
19    # would each access the elements [0:64, 64:128, 128:192, 192:256].
20    # Note that offsets is a list of pointers:
21    block_start = pid * BLOCK_SIZE
22    offsets = block_start + tl.arange(0, BLOCK_SIZE)
23    # Create a mask to guard memory operations against out-of-bounds accesses.
24    mask = offsets < n_elements
25    # Load x and y from DRAM, masking out any extra elements we read if the loop is
26    # not a multiple of the block size.
27    x = tl.load(x_ptr + offsets, mask=mask)
28    y = tl.load(y_ptr + offsets, mask=mask)
29    output = x + y
30    # Write x + y back to DRAM.
31    tl.store(output_ptr + offsets, output, mask=mask)
32
33def add(x: torch.Tensor, y: torch.Tensor):
34    # We need to preallocate the output.
35    output = torch.empty_like(x)
36    assert x.is_cuda and y.is_cuda and output.is_cuda
37    n_elements = output.numel()
38    # The adjacency strategy. This define the 1D launch grid, determine how many programs run in parallel.
39    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
40    # Launch the kernel. 
41    # The kernel is called with a grid and arguments.
42    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
43    return output
44
45# --- Verification ---
46torch.manual_seed(0)
47size = 98432
48x = torch.rand(size, device='cuda')
49y = torch.rand(size, device='cuda')
50output_torch = x + y
51output_triton = add(x, y)
52print(f'The maximum difference between torch and triton is '
53      f'{torch.max(torch.abs(output_torch - output_triton))}')
54if torch.allclose(output_torch, output_triton):
55    print("✅ Triton and PyTorch match")
56else:
57    print("❌ Triton and PyTorch differ")