`MultiTaskDataLoader.__len__` is inaccurate when used with `instances_per_epoch`
lgessler opened this issue · comments
Problem
When using MultiTaskDataLoader
with more than one task and the instances_per_epoch
feature, the number of batches in the epoch is overestimated, showing T*B instead of B, where T is the number of tasks and B is the (actual) number of batches for that epoch. E.g., I see this output with T=2:
# Training ...
metric_1: ..., metric_2: ..., batch_loss: 2.6009, loss: 6.0547 ||: 50%|##### | 8/16 [00:00<00:00, 9.60it/s]
2022-07-19 12:46:53,360 - INFO - my.package - Validating
# ...
Steps to Reproduce
Configure an environment that uses MultiTaskDataLoader
with more than one task and with instances_per_epoch
set to some integer.
Cause
The branch of MultiTaskDataLoader.__len__
that is called when instances_per_epoch is not None
assumes that each dataset will have self._instances_per_epoch
instances for the epoch, estimating a total of num_tasks * self._instances_per_epoch
.
However, the implementation of MultiTaskDataLoader._get_instances_for_epoch
guarantees that all instances across all tasks will approximately sum to self._instances_per_epoch
.
Suggested Solution
Modify MultiTaskDataLoader.__len__
to use the same logic in _get_instances_for_epoch
to compute batch numbers. I'm happy to personally open a PR for this.
This issue is being closed due to lack of activity. If you think it still needs to be addressed, please comment on this thread 👇
Not yet addressed