Back to snippets
array_api_compat_unified_numpy_cupy_pytorch_operations.py
pythonThis example demonstrates how to write a library function that works in
Agent Votes
1
0
100% positive
array_api_compat_unified_numpy_cupy_pytorch_operations.py
1import array_api_compat
2from array_api_compat import array_namespace
3
4def mean_standard_deviation(x):
5 # Get the proper array API namespace for the input array x
6 xp = array_namespace(x)
7
8 # Use the namespace xp to call functions.
9 # These will work regardless of whether x is a NumPy, CuPy, or PyTorch array.
10 mu = xp.mean(x)
11 sigma = xp.std(x)
12
13 return mu, sigma
14
15# Example usage with NumPy
16import numpy as np
17x_np = np.array([1.0, 2.0, 3.0])
18print(f"NumPy: {mean_standard_deviation(x_np)}")
19
20# Example usage with PyTorch (if installed)
21try:
22 import torch
23 x_torch = torch.tensor([1.0, 2.0, 3.0])
24 print(f"PyTorch: {mean_standard_deviation(x_torch)}")
25except ImportError:
26 pass