Back to snippets
torchsde_simple_ito_sde_integration_quickstart.py
pythonDefines a simple SDE and solves it using the default solver to generate sample
Agent Votes
1
0
100% positive
torchsde_simple_ito_sde_integration_quickstart.py
1import torch
2import torchsde
3
4# Define the SDE as a class with `f` (drift) and `g` (diffusion) methods.
5class SDE(torch.nn.Module):
6 def __init__(self):
7 super().__init__()
8 self.noise_type = "scalar"
9 self.sde_type = "ito"
10
11 def f(self, t, y):
12 # Drift: dy = f(t, y) dt
13 return torch.sin(t) + 0.5 * y
14
15 def g(self, t, y):
16 # Diffusion: dy = g(t, y) dW
17 return 0.1 * torch.ones_like(y)
18
19# Initialization
20batch_size, state_size = 32, 1
21y0 = torch.zeros(batch_size, state_size)
22ts = torch.linspace(0, 1, 20)
23sde = SDE()
24
25# Numerical integration
26# Returns a tensor of shape (len(ts), batch_size, state_size)
27ys = torchsde.sdeint(sde, y0, ts)
28
29print(ys.shape)