Is the treatment of embedding bias in to_hf_weights.py correct?
xiaoda99 opened this issue · comments
Hello,
mesh-transformer-jax uses a linear layer with bias for embedding while hf model has no wte.embedding.bias. The code below shows how to handle this problem:
https://github.com/kingoflolz/mesh-transformer-jax/blob/master/to_hf_weights.py#L386-L397
I think this treatment is incorrect. IMO, there's no way of absorbing a linear layer's bias into its weights.
if we set w' = w + b, then
y = x * w + b
y' = x * w' = x * (w + b) = x * w + x * b
The only case y == y' is when b == 0, which is generally not true.
- Da Xiao
You are right that in the general case, y != y'. However, in this code, x is the result of a one-hot encoding at:
This means that x is always a matrix with only 0's and 1's such that there is at most one 1 in each row. Therefore, every row of x * w is either a row from w or a row filled with 0's.
Furthermore, as long as all of the token IDs are nonnegative integers less than the vocabulary size, (the non-parallelized version of) x won't have any rows with all 0's so every row of x * w is a row from w. Hence, (x * w) + b = x * (w + b).
got it. thx!