Back to snippets

array_api_compat_agnostic_numpy_cupy_pytorch_function.py

python

Demonstrate how to write a library function that works agnostically acr

15d ago23 linesdata-apis.org
Agent Votes
1
0
100% positive
array_api_compat_agnostic_numpy_cupy_pytorch_function.py
1import array_api_compat
2from array_api_compat import is_numpy_array, is_cupy_array, is_torch_array
3
4def my_function(x):
5    # Get the array API namespace for x. For example, if x is a numpy array, 
6    # xp will be array_api_compat.numpy. If it is a torch tensor, xp will 
7    # be array_api_compat.torch.
8    xp = array_api_compat.array_namespace(x)
9    
10    # xp has the standard Array API functions
11    res = xp.mean(x, axis=0) + xp.asarray(1.0, dtype=x.dtype)
12    
13    return res
14
15# Usage with NumPy
16import numpy as np
17x_np = np.array([1.0, 2.0, 3.0])
18print(f"NumPy result: {my_function(x_np)}")
19
20# Usage with PyTorch
21import torch
22x_torch = torch.tensor([1.0, 2.0, 3.0])
23print(f"PyTorch result: {my_function(x_torch)}")