kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Parallel forward

neverix opened this issue · comments

The model's decoder right now only supports sequential decoding. This is because of the way attn_state is implemented. Parallel generation forward pass can be implemented by setting attn_state to None and handling all cases inside generation code

This would help solve #58

I'm not sure what you mean. Are you saying parallel forward over the 256 image tokens? That wouldn't work because each token depends on the previous token. And if you meant parallel over the layers that wouldn't work either since each layer depends on the previous layer's output. Maybe you meant parallel backward?

Right now the code can't just do forward over all tokens because of the caching implementation. It needs to run through every token instead of just masking the attention

Oh I see, it would be for if you wanted to do a forward pass over all tokens at once, instead of sampling one after the other.

#80 solves this