THUDM / SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.

Home Page:https://THUDM.github.io/SwissArmyTransformer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

During BERT decoding, past_key_values is used to accelerate calculation. Do we have a similar implementation?

etrigger opened this issue · comments

I did not find such a cached method using past_key_values in the SAT. Is it possible to add this?
Thanks.

Yes, but more simpler.
You can just do this model.add_mixin('auto-regressive', CachedAutoregressiveMixin()). You don't need to consider past_key_values when implementing model (In most cases), can this mixin and filling_sequence (autoregressive api) will save cache for it.

example see llama inference example