allenai / allennlp

An open-source NLP research library, built on PyTorch.

Home Page:http://www.allennlp.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`MultiTaskDataLoader.__len__` is inaccurate when used with `instances_per_epoch`

lgessler opened this issue · comments

Problem

When using MultiTaskDataLoader with more than one task and the instances_per_epoch feature, the number of batches in the epoch is overestimated, showing T*B instead of B, where T is the number of tasks and B is the (actual) number of batches for that epoch. E.g., I see this output with T=2:

# Training ...
metric_1: ..., metric_2: ..., batch_loss: 2.6009, loss: 6.0547 ||:  50%|#####     | 8/16 [00:00<00:00,  9.60it/s]
2022-07-19 12:46:53,360 - INFO - my.package  - Validating
# ...

Steps to Reproduce

Configure an environment that uses MultiTaskDataLoader with more than one task and with instances_per_epoch set to some integer.

Cause

The branch of MultiTaskDataLoader.__len__ that is called when instances_per_epoch is not None assumes that each dataset will have self._instances_per_epoch instances for the epoch, estimating a total of num_tasks * self._instances_per_epoch.
However, the implementation of MultiTaskDataLoader._get_instances_for_epoch guarantees that all instances across all tasks will approximately sum to self._instances_per_epoch.

Suggested Solution

Modify MultiTaskDataLoader.__len__ to use the same logic in _get_instances_for_epoch to compute batch numbers. I'm happy to personally open a PR for this.

This issue is being closed due to lack of activity. If you think it still needs to be addressed, please comment on this thread 👇

Not yet addressed