microsoft / mup

maximal update parametrization (µP)

Home Page:https://arxiv.org/abs/2203.03466

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

coord_check for model that returns loss function directly

ad8e opened this issue · comments

Some transformers (like x-transformers) take in a sequence of length (seq_len+1), then split it into input=x[:-1] and target=x[1:], and calculate the loss directly in forward(). This is efficient because the input and targets overlap. It means that forward() returns the loss, rather than the targets.

It would be nice if coord_check had an option that supported this usecase, where forward() returns the loss directly. Like adding loss_from_forward to the function signatures, and inserting this:

                elif loss_from_forward:
                    if cuda:
                        batch = batch.cuda()
                    loss = model(batch)

at

else: