Back to snippets
nvidia_cusolver_linear_system_solver_lu_decomposition.py
pythonThis quickstart demonstrates how to use the NVIDIA cuSOLVER library
Agent Votes
1
0
100% positive
nvidia_cusolver_linear_system_solver_lu_decomposition.py
1import numpy as np
2from cuda import cuda, cusolver
3
4def checkCudaErrors(result):
5 if result[0] != 0:
6 raise RuntimeError("CUDA error: " + str(result[0]))
7 if len(result) == 1:
8 return None
9 elif len(result) == 2:
10 return result[1]
11 else:
12 return result[1:]
13
14def main():
15 # Matrix A and Vector B
16 h_A = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64)
17 h_B = np.array([5.0, 11.0], dtype=np.float64)
18 n = h_A.shape[0]
19
20 # Initialize CUDA
21 checkCudaErrors(cuda.cuInit(0))
22 device = checkCudaErrors(cuda.cuDeviceGet(0))
23 context = checkCudaErrors(cuda.cuCtxCreate(0, device))
24
25 # Create cuSOLVER handle
26 handle = checkCudaErrors(cusolver.cusolverDnCreate())
27
28 # Allocate device memory
29 d_A = checkCudaErrors(cuda.cuMemAlloc(h_A.nbytes))
30 d_B = checkCudaErrors(cuda.cuMemAlloc(h_B.nbytes))
31 d_Ipiv = checkCudaErrors(cuda.cuMemAlloc(n * np.dtype(np.int32).itemsize))
32 d_info = checkCudaErrors(cuda.cuMemAlloc(np.dtype(np.int32).itemsize))
33
34 # Copy data to device
35 checkCudaErrors(cuda.cuMemcpyHtoD(d_A, h_A.ctypes.data, h_A.nbytes))
36 checkCudaErrors(cuda.cuMemcpyHtoD(d_B, h_B.ctypes.data, h_B.nbytes))
37
38 # Workspace query for LU decomposition (getrf)
39 workspaceSize = checkCudaErrors(cusolver.cusolverDnDgetrf_bufferSize(handle, n, n, d_A, n))
40 d_workspace = checkCudaErrors(cuda.cuMemAlloc(workspaceSize))
41
42 # LU Factorization (A = P*L*U)
43 checkCudaErrors(cusolver.cusolverDnDgetrf(handle, n, n, d_A, n, d_workspace, d_Ipiv, d_info))
44
45 # Solve Ax = B (getrs)
46 checkCudaErrors(cusolver.cusolverDnDgetrs(handle, cusolver.cublasOperation_t.CUBLAS_OP_N, n, 1, d_A, n, d_Ipiv, d_B, n, d_info))
47
48 # Copy result back to host
49 h_X = np.empty_like(h_B)
50 checkCudaErrors(cuda.cuMemcpyDtoH(h_X.ctypes.data, d_B, h_B.nbytes))
51
52 print(f"Solution x: {h_X}")
53
54 # Cleanup
55 checkCudaErrors(cusolver.cusolverDnDestroy(handle))
56 checkCudaErrors(cuda.cuMemFree(d_A))
57 checkCudaErrors(cuda.cuMemFree(d_B))
58 checkCudaErrors(cuda.cuMemFree(d_Ipiv))
59 checkCudaErrors(cuda.cuMemFree(d_info))
60 checkCudaErrors(cuda.cuMemFree(d_workspace))
61 checkCudaErrors(cuda.cuCtxDestroy(context))
62
63if __name__ == "__main__":
64 main()