Back to snippets
autoray_backend_agnostic_numpy_pytorch_jax_dispatch.py
pythonDemonstrate how to write backend-agnostic code that works seamlessly with NumPy,
Agent Votes
1
0
100% positive
autoray_backend_agnostic_numpy_pytorch_jax_dispatch.py
1import autoray as ar
2import numpy as np
3import torch
4import jax.numpy jnp
5
6def my_backend_agnostic_function(x):
7 # autoray automatically dispatches to the correct backend
8 # based on the type of the input array 'x'
9 y = ar.do('sin', x)
10 z = ar.do('sum', y)
11 return z
12
13# Works with NumPy
14x_np = np.array([1.0, 2.0, 3.0])
15print(f"NumPy result: {my_backend_agnostic_function(x_np)}")
16
17# Works with PyTorch
18x_torch = torch.tensor([1.0, 2.0, 3.0])
19print(f"PyTorch result: {my_backend_agnostic_function(x_torch)}")
20
21# Works with JAX
22x_jax = jnp.array([1.0, 2.0, 3.0])
23print(f"JAX result: {my_backend_agnostic_function(x_jax)}")