pruned error
renmada opened this issue · comments
model is bert_base
When i do prune, set num_heads to 8.
Before prune, shape of key(value) weight is [768, 768]
After prune, shape of key(value) weight became [512, 768]
so, saved weights can't be loaded by transformers
Hi @renmada,
It worked as intended. The pruned model can be consumed by this fastformers repository or onnxruntime v1.8.0+.
Hi @renmada,
It worked as intended. The pruned model can be consumed by this fastformers repository or onnxruntime v1.8.0+.
Sorry, I am not sure about your reply. Please help me confirm whether I have the same understanding as you.
When you say "It worked as intended", do you mean pruning weight from [768, 768] to [512, 768] is reasonable?
Yes, that's how head pruning works. Originally, you had 12 heads (12 * 64 = 768). After pruning, you have 8 heads (8 * 64 = 512). Here, the head size is 64.