EleutherAI / pythia

The hub for EleutherAI's work on interpretability and learning dynamics

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Are pythia_v0 and the new pythia_v1 models using the same input embedding matrix?

levmckinney opened this issue · comments

I've noticed something really odd while messing around with the pythia models for the tuned_lens project.

It seems like the input embeddings where not reset from the v0 to current models. Was that intentional?

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


device = torch.device('cpu')
model_a = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-410m-deduped-v0')
model_a = model_a.to(device)
model_b = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-410m-deduped')
model_b = model_b.to(device)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-410m-deduped')

input_ids = tokenizer.encode("it was the best of times, it was the worst of times", return_tensors="pt")
model_b.set_input_embeddings(model_a.get_input_embeddings())
with torch.no_grad():
    print("Outputs with the input embeddings swapped")
    print(tokenizer.decode(model_b(input_ids).logits.argmax(dim=-1)[0].tolist()))

model_a.get_input_embeddings().reset_parameters()
model_b.set_input_embeddings(model_a.get_input_embeddings())
with torch.no_grad():
    print("Sanity check: outputs with the input embeddings reset")
    print(tokenizer.decode(model_b(input_ids).logits.argmax(dim=-1)[0].tolist()))

Output:

Outputs with the input embeddings swapped
, a first of the to and was the best of times,
Sanity check: outputs with the input embeddings reset

se-. for-..of-- for.

FYI this does not seem to happen with the output embeddings.

Can replicate
Captura de pantalla 2023-04-07 a la(s) 1 29 42 p m

I can replicate this! this is an interesting finding. Notably the pythia-160m-seed{x} (different data shuffle and random init) models do not seem to allow such transfer (https://colab.research.google.com/drive/1BYn3kO7CPbWGYrAMbWOinHayby-PHaXP?usp=sharing)

We did not use the input embeddings from Pythia v0 in training Pythia v1 whatsoever (for initialization or otherwise), so this seems to be a coincidence, though I'll dig a little deeper to rule out the possibility this is some HF bug or issue from when I uploaded the model checkpoints.

Closing because this isn't an issue per se! Interesting result, though evaluating the model on LAMBADA with swapped embed matrix resulted in an accuracy of only 3% so benchmark performance does not appear to transfer.