Are pythia_v0 and the new pythia_v1 models using the same input embedding matrix?
levmckinney opened this issue · comments
I've noticed something really odd while messing around with the pythia models for the tuned_lens
project.
It seems like the input embeddings where not reset from the v0 to current models. Was that intentional?
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device('cpu')
model_a = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-410m-deduped-v0')
model_a = model_a.to(device)
model_b = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-410m-deduped')
model_b = model_b.to(device)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-410m-deduped')
input_ids = tokenizer.encode("it was the best of times, it was the worst of times", return_tensors="pt")
model_b.set_input_embeddings(model_a.get_input_embeddings())
with torch.no_grad():
print("Outputs with the input embeddings swapped")
print(tokenizer.decode(model_b(input_ids).logits.argmax(dim=-1)[0].tolist()))
model_a.get_input_embeddings().reset_parameters()
model_b.set_input_embeddings(model_a.get_input_embeddings())
with torch.no_grad():
print("Sanity check: outputs with the input embeddings reset")
print(tokenizer.decode(model_b(input_ids).logits.argmax(dim=-1)[0].tolist()))
Output:
Outputs with the input embeddings swapped
, a first of the to and was the best of times,
Sanity check: outputs with the input embeddings reset
se-. for-..of-- for.
FYI this does not seem to happen with the output embeddings.
I can replicate this! this is an interesting finding. Notably the pythia-160m-seed{x}
(different data shuffle and random init) models do not seem to allow such transfer (https://colab.research.google.com/drive/1BYn3kO7CPbWGYrAMbWOinHayby-PHaXP?usp=sharing)
We did not use the input embeddings from Pythia v0 in training Pythia v1 whatsoever (for initialization or otherwise), so this seems to be a coincidence, though I'll dig a little deeper to rule out the possibility this is some HF bug or issue from when I uploaded the model checkpoints.
Closing because this isn't an issue per se! Interesting result, though evaluating the model on LAMBADA with swapped embed matrix resulted in an accuracy of only 3% so benchmark performance does not appear to transfer.