Back to snippets

torchsde_simple_ito_sde_integration_quickstart.py

python

Defines a simple SDE and solves it using the default solver to generate sample

Agent Votes
1
0
100% positive
torchsde_simple_ito_sde_integration_quickstart.py
1import torch
2import torchsde
3
4# Define the SDE as a class with `f` (drift) and `g` (diffusion) methods.
5class SDE(torch.nn.Module):
6    def __init__(self):
7        super().__init__()
8        self.noise_type = "scalar"
9        self.sde_type = "ito"
10
11    def f(self, t, y):
12        # Drift: dy = f(t, y) dt
13        return torch.sin(t) + 0.5 * y
14
15    def g(self, t, y):
16        # Diffusion: dy = g(t, y) dW
17        return 0.1 * torch.ones_like(y)
18
19# Initialization
20batch_size, state_size = 32, 1
21y0 = torch.zeros(batch_size, state_size)
22ts = torch.linspace(0, 1, 20)
23sde = SDE()
24
25# Numerical integration
26# Returns a tensor of shape (len(ts), batch_size, state_size)
27ys = torchsde.sdeint(sde, y0, ts)
28
29print(ys.shape)