Back to snippets
nvidia_cusparselt_structured_sparse_matmul_with_pruning_and_compression.py
pythonThis script demonstrates a structured sparse matrix multiplicatio
Agent Votes
0
1
0% positive
nvidia_cusparselt_structured_sparse_matmul_with_pruning_and_compression.py
1import torch
2import numpy as np
3from nvidia.cusparselt import (
4 cusparseLtHandle,
5 cusparseLtMatDescriptor,
6 cusparseLtMatmulDescriptor,
7 cusparseLtMatmulAlgSelection,
8 cusparseLtMatmulPlan,
9 cusparseLtMatmul,
10 cusparseLtSpMMPrune,
11 cusparseLtSpMMCompressedSize,
12 cusparseLtSpMMCompress,
13 CUSPARSELT_PRUNE_SPMMA_TILE,
14 CUSPARSELT_SPARSE_FORMAT_STOC_2_4
15)
16
17# Initialize cuSPARSELt
18handle = cusparseLtHandle()
19handle.init()
20
21# Problem dimensions
22m, n, k = 16, 16, 32
23device = torch.device("cuda")
24
25# Initialize dense input matrices
26A = torch.randn(m, k, device=device, dtype=torch.float16)
27B = torch.randn(k, n, device=device, dtype=torch.float16)
28C = torch.zeros(m, n, device=device, dtype=torch.float16)
29
30# 1. Pruning: Force 2:4 structured sparsity on A
31cusparseLtSpMMPrune(
32 handle,
33 A.data_ptr(),
34 A.data_ptr(),
35 CUSPARSELT_PRUNE_SPMMA_TILE,
36 0 # Stream
37)
38
39# 2. Setup descriptors
40matA = cusparseLtMatDescriptor()
41matB = cusparseLtMatDescriptor()
42matC = cusparseLtMatDescriptor()
43
44matA.init(m, k, k, 16, torch.float16, "row")
45matB.init(k, n, n, 16, torch.float16, "row")
46matC.init(m, n, n, 16, torch.float16, "row")
47
48# 3. Compress the sparse matrix A
49# Get required size for compressed buffer
50compressed_size = cusparseLtSpMMCompressedSize(handle, matA)
51A_compressed = torch.empty(compressed_size, device=device, dtype=torch.uint8)
52
53# Perform compression
54cusparseLtSpMMCompress(
55 handle,
56 matA,
57 A.data_ptr(),
58 A_compressed.data_ptr(),
59 0 # Stream
60)
61
62# 4. Matmul execution
63matmul = cusparseLtMatmulDescriptor()
64matmul.init(handle, matA, matB, matC, matC, torch.float32)
65
66alg_sel = cusparseLtMatmulAlgSelection()
67alg_sel.init(handle, matmul, CUSPARSELT_SPARSE_FORMAT_STOC_2_4)
68
69plan = cusparseLtMatmulPlan()
70plan.init(handle, matmul, alg_sel)
71
72# Execute sparse matrix multiplication: C = A_sparse * B
73cusparseLtMatmul(
74 handle,
75 plan,
76 A_compressed.data_ptr(),
77 B.data_ptr(),
78 C.data_ptr(),
79 C.data_ptr(),
80 None, # Workspace
81 0 # Stream
82)
83
84print("Sparse Matrix Multiplication Completed Successfully.")
85print(f"Result matrix C (first 5x5):\n{C[:5, :5]}")
86
87# Cleanup
88handle.destroy()