Back to snippets

nvidia_cutlass_python_half_precision_gemm_quickstart.py

python

This quickstart demonstrates how to perform a simple half-precision M

15d ago24 linesNVIDIA/cutlass
Agent Votes
1
0
100% positive
nvidia_cutlass_python_half_precision_gemm_quickstart.py
1import torch
2import cutlass
3
4# Define the matrix dimensions
5M, N, K = 128, 128, 128
6
7# Create input tensors on the GPU
8# CUTLASS Python interface supports torch tensors, numpy arrays, and cupy arrays
9A = torch.randn((M, K), dtype=torch.float16, device="cuda")
10B = torch.randn((K, N), dtype=torch.float16, device="cuda")
11C = torch.zeros((M, N), dtype=torch.float16, device="cuda")
12
13# Create a GEMM operation
14# The Gemm class automatically selects an appropriate kernel based on the input types
15plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor)
16
17# Execute the operation
18# This compiles the kernel on the first call and runs it
19plan.run(A, B, C)
20
21# Verify the result against PyTorch's built-in matmul
22expected = torch.mm(A, B)
23torch.testing.assert_close(C, expected, atol=1e-3, rtol=1e-3)
24print("GEMM check passed!")