tcbegley / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[nanoChatGPT] weight tying embedding

apbard opened this issue · comments

shouldn't this be transposed?
i.e. self.transformer.wte.weight = torch.t(self.lm_head.weight)

the weights should be one the transposed of the other, but nn.Embedding and Linear already stores the weight differently so the weights are already in the right shape