mosaicml / examples

Fast and flexible reference benchmarks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Matmul error when using output_all_encoded_layers = True, and pooler

MostHumble opened this issue · comments

Hi,

First off thanks for this great contribution!

There seems to be an issue with the handling of then encoder_outputs in the pooler level when passing output_all_encoded_layers = True.

encoder_outputs = self.encoder(
embedding_output,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
subset_mask=subset_mask)
if masked_tokens_mask is None:
sequence_output = encoder_outputs[-1]
pooled_output = self.pooler(
sequence_output) if self.pooler is not None else None
else:
# TD [2022-03-01]: the indexing here is very tricky.
attention_mask_bool = attention_mask.bool()

because when doing that, I'm getting:

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/PatientTrajectoryForecasting/utils/bert_layers_mosa.py:567, in BertPooler.forward(self, hidden_states, pool)
    561 def forward(self,
    562             hidden_states: torch.Tensor,
    563             pool: Optional[bool] = True) -> torch.Tensor:
    564     # We "pool" the model by simply taking the hidden state corresponding
    565     # to the first token.
    566     first_token_tensor = hidden_states[:, 0] if pool else hidden_states
--> 567     pooled_output = self.dense(first_token_tensor)
    568     pooled_output = self.activation(pooled_output)
    569     return pooled_output

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x54784 and 768x768)

I believe the issue is due to the padding function not being applied to the hidden layens before appending to the list in the bert encoder level:

all_encoder_layers = []
if subset_mask is None:
for layer_module in self.layer:
hidden_states = layer_module(hidden_states,
cu_seqlens,
seqlen,
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
# Pad inputs and mask. It will insert back zero-padded tokens.
# Assume ntokens is total number of tokens (padded and non-padded)
# and ntokens_unpad is total number of non-padded tokens.
# Then padding performs the following de-compression:
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
hidden_states = bert_padding_module.pad_input(
hidden_states, indices, batch, seqlen)
else:

(Edit: yep this works, but not haven't checked for deps)

all_encoder_layers.append(bert_padding_module.pad_input(
                hidden_states, indices, batch, seqlen))

The same thing should probably be done when the subset_mask is not None...

Thanks again for your contribution to the comunity!