`TapeGlobal` and thoughts on the tape variants
emchristiansen opened this issue · comments
TapeGlobal
Here's another tape-tracking API to consider, as implemented in client code:
type FloatInner = f32;
type DfdxDevice = Cpu;
thread_local! {
static TAPE_GLOBAL: once_cell::sync::Lazy<
Mutex<Option<OwnedTape<FloatInner, DfdxDevice>>>,
> = once_cell::sync::Lazy::new(|| Default::default());
}
#[derive(Debug, Clone, Default)]
pub struct TapeGlobal;
impl TapeGlobal
{
pub fn init()
{
TAPE_GLOBAL.with(|tape| {
*tape
.lock()
.unwrap() = Some(OwnedTape::default());
});
}
pub fn reset()
{
TAPE_GLOBAL.with(|tape| {
*tape
.lock()
.unwrap() = None;
});
}
pub fn get() -> OwnedTape<FloatInner, DfdxDevice>
{
TAPE_GLOBAL.with(|tape| {
let mut locked_tape = tape
.lock()
.unwrap();
let out = locked_tape
.take()
.expect("Tape must be initialized before calling get");
out
})
}
pub fn set(value: OwnedTape<FloatInner, DfdxDevice>)
{
TAPE_GLOBAL.with(|tape| {
let mut locked_tape = tape
.lock()
.unwrap();
assert!(
locked_tape.is_none(),
"Tape must be None before calling set"
);
*locked_tape = Some(value);
});
}
}
impl Merge<NoneTape> for TapeGlobal
{
fn merge(
self,
_: NoneTape,
) -> Self
{
self
}
}
impl Merge<TapeGlobal> for TapeGlobal
{
fn merge(
self,
_other: Self,
) -> Self
{
self
}
}
impl Tape<FloatInner, DfdxDevice> for TapeGlobal
{
const OWNS_TAPE: bool = true;
fn add_backward_op<F>(
&mut self,
operation: F,
) where
F: 'static
+ FnOnce(
&mut Gradients<FloatInner, DfdxDevice>,
) -> Result<(), <DfdxDevice as HasErr>::Err>,
{
let mut tape = TapeGlobal::get();
tape.add_backward_op(operation);
TapeGlobal::set(tape);
}
}
pub struct TapeGlobalTensor(
pub Tensor<Rank0, FloatInner, DfdxDevice, TapeGlobal>,
);
impl HasErr for TapeGlobalTensor
{
type Err = <Tensor<Rank0, FloatInner, DfdxDevice, TapeGlobal> as HasErr>::Err;
}
impl Backward<FloatInner, DfdxDevice> for TapeGlobalTensor
{
fn try_backward(self)
-> Result<Gradients<FloatInner, DfdxDevice>, Self::Err>
{
let (t, _) = self
.0
.split_tape();
let tape = TapeGlobal::get();
t.put_tape(tape)
.try_backward()
}
}
TapeGlobal
does what it says: It maintains a global, thread-local, tape, thus avoiding the partitioned-gradients problem*.
Here's how you use it:
TapeGlobal::init();
let x = dev
.tensor(1.0)
.put_tape(TapeGlobal);
let y = x.clone() * x;
dbg!(TapeGlobalTensor(y).backward());
// If calling in a loop, you'd call TapeGlobal::reset() then
// TapeGlobal::init() each time to clear the gradients.
Let me know if you'd like a PR with something like this.
Thoughts on the tape variants
As you know, I have very non-standard model inputs and outputs, and I've been playing with different tape tracking APIs in the hope of finding one that's easy to use and not error-prone, specifically: OwnedTape<_, _>
, Arc<Mutex<OwnedTape<_, _>>>
, Arc<Mutex<Arc<Mutex<OwnedTape<_, _>>>>>
, and GlobalTape
.
Here are my thoughts:
OwnedTape<_, _>
: Simple and avoidsArc
, but suffers from the partitioned-gradients problem, so the programmer must either 1) mentally track which tape has which gradients, or 2) do a final accumulate across all the model outputs to merge all the tapes into one.Arc<Mutex<OwnedTensor<_, _>>>
: UsesArc
(sorta bad) and only partially solves the partitioned-gradients problem.Arc<Mutex<Arc<Mutex<OwnedTensor<_, _>>>>>
: Uses nestedArc<Mutex<_>>
(bad code smell) and still does not fully solve the partitioned-gradients problem. E.g., if you have dependency paths that originate from model parameters and never interact with the main tape (this can occur in a conditional computation setting). This behavior isn't terrible, because these gradients will be zero, but you will get anunwrap
error if you expect them to be defined.TapeGlobal
: This does fully solve the partitioned-gradients problem, at the expense of maintaining an ugly global variable and losing type parameterization.
I'm personally still not sure which API I prefer between 1, 3, and 4 (I don't love any of them).
But I think 2 is essentially useless and rather dangerous, as you can accidentally omit gradients from your tape, and the shared state makes the tape accumulation difficult to reason about.
*Partitioned-gradients problem: When the gradients for you network are partitioned across several tapes and you as a programmer have to worry about which tape has which gradients.
I don't see a way to do global/shared tapes without using some form of Arc/Mutex.
FWIW the optimizers.update returns an error that indicates whether some tensors that were supposed to have gradients did not, so that was intended to capture this error.
Arc<Mutex<_>>
FYI you can define the TapeGlobal
state like this:
thread_local! {
static TAPE_GLOBAL:
RefCell<Option<OwnedTape<FloatInner, DfdxDevice>>
> = RefCell::new(None);
}
IIUC, thread_local!
ensures there are no threading concurrency issues, and the lack of async ensures there are no cooperative futures concurrency issues.
So, you don't need anything like an Arc<Mutex<_>>
; a RefCell<_>
suffices.
FWIW I don't love TapeGlobal
but I prefer it to the other approaches.
Missing gradients
I think relying on that optimizer.update
behavior becomes problematic when you have conditional computation (as I do, of course!).
In this case you expect only a fraction of the parameters to have gradients for any given pass.
Are these variants the only way to reuse a taped tensor more than once during a forward pass? E.g. when x is used twice in
TapeGlobal::init();
let x = dev
.tensor(1.0)
.put_tape(TapeGlobal);
let y = x.clone() * x;
dbg!(TapeGlobalTensor(y).backward());
What's the difference between this and retaping x (instead of cloning)? I'm trying to continue work on #437, but am suspicious that some of the retaping is incorrect (particularly lines 105 and 115 in rl-ppo-continuous.rs).