google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Computational time for Brownian Interval

qsh-zh opened this issue · comments

I observed something strange about computation time for brownain interval

sde.cnt = 0
%timeit -n 2 -r 7 torchsde.sdeint(sde, y0, th.tensor([0.0, 1.0]), dt=0.01, method="euler")
print(sde.cnt)
# 1.87 s ± 60.6 ms per loop (mean ± std. dev. of 7 runs, 2 loops each)
# 1428

sde.cnt = 0
%timeit -n 2 -r 7 torchsde.sdeint(sde, y0, th.tensor([0.0, 5.0]), dt=0.05, method="euler")
print(sde.cnt)
# 57.3 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 2 loops each)
# 1414

sde.cnt = 0
%timeit -n 2 -r 7 torchsde.sdeint(sde, y0, th.tensor([0.0, 10.0]), dt=0.1, method="euler")
print(sde.cnt)
# 57.2 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 2 loops each)
# 1414

where the sde is very similar to the one defined in the Quick example in README. In the above three examples, I change the different ts and dt. I think they should have roughly the same computation time. But it turns out the time used by the line are very different. According to the paper, the worse case should roughly be O(log T/dt) if I understand correctly. Why the first case is so slow?

When I change from th.tensor([0.0, 1.0]), dt=0.01 to th.tensor([0.0, 5.0]), dt=0.05.
image
image

So this is a bit weird! In particular because one tends to imagine a BrownianInterval being scale-invariant. I've been able to figure out why this happens, but I don't have a fix in mind yet.

The reason has to do with the binary tree heuristic built in to the BrownianInterval. Once >100 queries have been made, the BrownianInterval averages the step sizes it has observed over those queries, and uses those as an estimate of the average step size for the rest of the SDE solve. This is used to build up a binary tree, as per Section E.2, "Backward pass" of the paper. (Which I refer to as it sounds like you've read it.)

The dt=0.01 case makes 101 steps, which triggers the calculation of this heuristic. Evaluating that heuristic (building up the binary tree) is what takes up so much time. The other cases make only 100 steps, and the solve actually completes before the heuristic even triggers.

Why different number of steps between these apparently-scale-invariant cases? Floating point inaccuracies:

>>> import torch
>>> x = torch.tensor(0.01)
>>> sum = 0
>>> for _ in range(100):
...     sum = sum + x
...
>>> sum
tensor(1.0000)
>>> sum < 1
tensor(True)

The real bug here is simply that the heuristic takes so much time to compute. I'll need to have a deeper look, later, to figure out what might be done to resolve this.