kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is the treatment of embedding bias in to_hf_weights.py correct?

xiaoda99 opened this issue · comments

Hello,

mesh-transformer-jax uses a linear layer with bias for embedding while hf model has no wte.embedding.bias. The code below shows how to handle this problem:
https://github.com/kingoflolz/mesh-transformer-jax/blob/master/to_hf_weights.py#L386-L397

I think this treatment is incorrect. IMO, there's no way of absorbing a linear layer's bias into its weights.
if we set w' = w + b, then
y = x * w + b
y' = x * w' = x * (w + b) = x * w + x * b
The only case y == y' is when b == 0, which is generally not true.

  • Da Xiao
commented

You are right that in the general case, y != y'. However, in this code, x is the result of a one-hot encoding at:

input_onehot = jax.nn.one_hot(x - shard_start_index, self.in_dim_per_shard)

This means that x is always a matrix with only 0's and 1's such that there is at most one 1 in each row. Therefore, every row of x * w is either a row from w or a row filled with 0's.

Furthermore, as long as all of the token IDs are nonnegative integers less than the vocabulary size, (the non-parallelized version of) x won't have any rows with all 0's so every row of x * w is a row from w. Hence, (x * w) + b = x * (w + b).

got it. thx!