pytorch / torchtitan

A native PyTorch Library for large model training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Code change that changes the model semantics

kwen2501 opened this issue · comments

Args:
tokens (torch.Tensor): Input token indices.
Returns:
torch.Tensor: Output logits after applying the Transformer model.
"""
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
for layer in self.layers.values():
h = layer(h, self.freqs_cis)

  • The if-else in line
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

may seem a bit invasive a code change to the original model, in that it changes the model's semantics.

tokens are "Input token indices" as per the signature description. It should go through an embedding (self.tok_embeddings) to expand into feature values -- a space expansion, then get processed by the Transformer layers in those feature spaces. Passing tokens directly to transformer layers is an unclear semantic.

  • The change may also become "type unsafe" once the forward function's signature is typed, i.e. the tokens are generally int64 so the Transformer module's signature may be like:
def forward(self, tokens: torch.LongTensor):

whereas the h values are usually bfloat16 or float32.

Dtype annotated signatures are not uncommon, see for example the GPT2 model in Transformer. (In fact, it seems most models use detailed types.)
https://github.com/huggingface/transformers/blob/481a95781404e48b1c80940be17e8279dec82fe8/src/transformers/models/gpt2/modeling_gpt2.py#L975-L990

I agree about the point that the input signature semantics are questionable. Can we improve that part?

  • What if we defined the input semantics as union(longtensor, tensor)? Honestly, probably we would just define it as tensor and leave it at that
  • I think we could document that 'inputs' refers to 'tokens' if the model is not used in PP, otherwise 'tokens' for the first pp stage and layer inputs for other pp stages'

Otoh, i don't see this as a bad/invasive change. I think we can better document it, but it has the net effect of making the code easier to work with for PP and no other issues.

If we were to define a model from scratch, then yes, we can define it in whatever way that fits us.

But if we were given a model and to serve such model, then it is questionable whether we should recommend this way to the audience.

It would also fail to work once the layer signature differs from the Transformer signature in the number of arguments they take.

I think of this as an opportunity to show off how it can be relatively simple to use the Manual PP frontend with a certain way of writing the model code.

Also, I think it is in line with the goals of torchtitan not to be too opinionated, but to show off the native PT-D tech. In this case, let's show off how both frontends can be applied to the model easily, and with a very consistent UX between them. (there are rough edges still, but once we fix them it'd be cool to show that the same model can be pretty easily fed into either frontend.