JasonGross / guarantees-based-mechanistic-interpretability

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

logging a line plot every n epochs

JasonGross opened this issue · comments

I'd like to log (model.W_E[-1] + model.W_pos[-1]) @ model.W_Q[0, 0, :, :] @ model.W_K[0, 0, :, :].T @ (model.W_E[:-1] + model.W_pos[:-1].mean(dim=0)).T every n epochs on wandb, e.g. when calling poetry run python -m gbmi.exp_max_of_n.train --max-of 2 --non-deterministic --train-for-epochs 3000 --validate-every-epochs 20 --force-adjacent-gap 0,1,2,3 --use-log1p --training-ratio 0.115 --weight-decay 1.0 --betas 0.9 0.98 --optimizer AdamW --use-end-of-sequence, and then get an animation or something where I can see how the line plot of this tensor varies over time. (Would be even nicer to see this plot on wandb itself.)

If #6 were fixed and we could extract the checkpoints from the runtime object returned by train_or_load, this would go down from priority: medium to priority: low.

@euanong does this seem feasible?

Thinking more on this, I think it's probably better to just fix #6 and allow model checkpointing, and then have a way of fetching the various checkpoints easily. Unless wandb has a way of displaying plots/animations of non-scalar values across training, it seems like a better design to not complicate the config with various objects to log during training, and just reconstruct them as desired.

I am going to declare that the fix to #6 was adequate for this