sample_batch args in `test_top_k` in gflownet.py
InfluenceFunctional opened this issue · comments
Michael Kilgour commented
I just cloned a fresh copy of the repo in WSL and tried a run with all default configs, and it looks like the sample_batch
on line 906 in gflownet.py is getting the wrong arguments (at least, it crashes every time). Looks like the parent function test_top_k
was added August 1.
call:
for b in batch_with_rest(0, self.logger.test.n_top_k, self.batch_size.forward):
gfn_states += self.sample_batch(
self.env, len(b), train=False, progress=progress
)[0]
function definition:
def sample_batch(
self,
n_forward: int = 0,
n_train: int = 0,
n_replay: int = 0,
train=True,
progress=False,
):