chengchingwen / Transformers.jl

Julia Implementation of Transformer models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Adding support for checkpointing

pevnak opened this issue · comments

I am copying that here from my post on slack, such that it does not get lost.

I think it might be worth to add a rudimentary support for checkpointing as

struct Checkpointed{S} <: Transformers.Layers.AbstractTransformerBlock
	f::S
end

Base.show(io::IO, c::Checkpointed) = print(io, c.f)

(m::Checkpointed)(args...) = Zygote.checkpointed(m.f, args...)

and then wrappend blocks to Checkpointed as

decoder = Transformers.Layers.Chain(Transformer(map(Checkpointed, decoder.layers[1].blocks)), decoder.layers[2]) 

and while it is probably not the nicest representation, it seems to work.
The running times are approximately 50% longer, which I think is correct since the the forward pass is need to do twice.

I do not know, if this is something that is wanted. If yes, I might try to add this as a more proper solution and improve it. Ideally, one would like to have an option to download the model from HF and add checkpointing. I think that HF has this option.

Sounds good to have!

HF handle it in the forward method of hf-models (equiv. Layers.Transformer). I'm not sure Checkedpointed as AbstractTransformerBlock is the best place to add the checkpoint functionality. Some alternative ideas I currently have in mind:

  1. Generalized Checkedpointed{S} <: LayerStruct and overload Checkpointed{<:Transformer} to add checkpoint per blocks.
  2. Modify Layers.applyblocks to allow hooks and use Zygote.checkpointed as the hook function.
  3. Similar to 2. but provide a HookedTransformerBlock <: AbstractTransformerBlock.

The wrapping function can be implemented with postwalk like the Layers.set_dropout.

I will look at your suggestions. Checkpointed as a AbstractTransformerBlock was quick and dirty trick. I like the postwalk trick.

One thing you'll want to think about is stateful layers like Dropout and BatchNorm which would not behave the same in subsequent calls. For the former I think some mechanism to snapshot RNG state would be required, and for the latter maybe an explicit overload?

It seems the problem is that we cannot know if a Dropout or BatchNorm is executed under checkpointed environment?

@ToucheSir I have not thought about this. Is there still switch to toggle train and test mode? That would effectively solve the problem.

That doesn't sounds the same. One could always completely turn off all dropouts, but normally we would want the checkpoint computed with the same dropout state as the first forward call so that the gradient with or without checkpoint are the same.

If the pullback is only called once, I believe BatchNorm and co should actually not require any special handling. Otherwise, the approach would be to traverse the model looking for these layers, saving their current train/test status, doing the checkpointing and then restoring the saved status.

As Peter notes, Dropout is trickier because you still need the RNG state around to create a mask. The most straightforward solution using struct Checkpointed would be to recurse through the model looking for Dropout layers and snapshotting their RNG state beforehand. Then that state can be restored whenever the checkpointing runs. I haven't quite thought about how this interacts with RNGs shared between layers (as is the default), but that should be solvable.

Medium-long term, we may want to consider a mechanism like https://github.com/vchuravy/ScopedValues.jl for exposing whether checkpointing is currently active in Flux itself. Then layers can query that info and change their behaviour accordingly without a wrapper.

@ToucheSir I wonder if we could subtype the AContext in Zygote for a CheckpointedContext and overload the pullback behavior for dropout or so?

Maybe, but generally we'd like to avoid coupling Flux to Zygote wherever possible (e.g. no custom pullbacks).

I would say that only need to couple NNlib to Zygote since dropout is moved out from Flux.

Yeah, NNlib has no dep (hard or weak) on Zygote right now and it'd be better to keep it that way. Porting Zygote.checkpoint to use the ChainRules API shouldn't be an issue, just need to decide if it lives in Flux or NNlib.

These are good points @ToucheSir. I will come to this in two weeks timeframe, I am a bit busy now with academic stuff.