coord_check for model that returns loss function directly
ad8e opened this issue · comments
Kevin Yin commented
Some transformers (like x-transformers) take in a sequence of length (seq_len+1), then split it into input=x[:-1] and target=x[1:], and calculate the loss directly in forward(). This is efficient because the input and targets overlap. It means that forward()
returns the loss, rather than the targets.
It would be nice if coord_check had an option that supported this usecase, where forward() returns the loss directly. Like adding loss_from_forward
to the function signatures, and inserting this:
elif loss_from_forward:
if cuda:
batch = batch.cuda()
loss = model(batch)
at
Line 317 in 1981497