nissymori / JAX-CORL

Clean single-file implementation of offline RL algorithms in JAX

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[CQL] How to speed up?

nissymori opened this issue · comments

@partial(jax.jit, static_argnames=("self", "bc", "batch_size", "n"))
    def train_n_step(self, dataset, batch_size, n, bc=False):
        for _ in range(n):
            batch = batch_to_jax(subsample_batch(dataset, batch_size))
            metrics = self.train(batch, bc)
        return metrics

を追加して,forjitやってみた(どこのrepoにも残していない.).以下,1000 stepにかかる時間の比較.

method time jit time
forloop(original) 3.8s -
forjit(10) 5.6s 74s
forjit(100) 3.3s 740s(approx.)

forjit(100)は早いけど,jitの時間長すぎてoriginalと比べて早くはない.originalが相当いい実装なのかも.とりあえずこコピーは置いておいて,一旦IQLを片付けよう.