NTT123 / soft-dtw-jax

Soft-DTW loss in JAX

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Apply SoftDTW for EATS

v-nhandt21 opened this issue · comments

I have some questions:

I have tried the source code of maghoumi, It takes too much GPU memory for a sequence of 1000 frames.

Have you applied this SoftDTW for EATS.

No, I haven't.

Can it be used as a loss function and backward

Yes.

How does the implementation perform compare to this: https://github.com/Maghoumi/pytorch-softdtw-cuda

I don't know. You have to try it out.

I have tried the source code of maghoumi, It takes too much GPU memory for a sequence of 1000 frames.

I don't see a simple way to reduce the memory usage of soft-dtw loss. You have to reduce the batch size and/or use gradient accumulation.