coreylowman / dfdx

Deep learning in Rust, with shape checked tensors and neural networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`TapeGlobal` and thoughts on the tape variants

emchristiansen opened this issue · comments

TapeGlobal

Here's another tape-tracking API to consider, as implemented in client code:

type FloatInner = f32;
type DfdxDevice = Cpu;

thread_local! {
static TAPE_GLOBAL: once_cell::sync::Lazy<
  Mutex<Option<OwnedTape<FloatInner, DfdxDevice>>>,
> = once_cell::sync::Lazy::new(|| Default::default());
}

#[derive(Debug, Clone, Default)]
pub struct TapeGlobal;

impl TapeGlobal
{
  pub fn init()
  {
    TAPE_GLOBAL.with(|tape| {
      *tape
        .lock()
        .unwrap() = Some(OwnedTape::default());
    });
  }

  pub fn reset()
  {
    TAPE_GLOBAL.with(|tape| {
      *tape
        .lock()
        .unwrap() = None;
    });
  }

  pub fn get() -> OwnedTape<FloatInner, DfdxDevice>
  {
    TAPE_GLOBAL.with(|tape| {
      let mut locked_tape = tape
        .lock()
        .unwrap();
      let out = locked_tape
        .take()
        .expect("Tape must be initialized before calling get");
      out
    })
  }

  pub fn set(value: OwnedTape<FloatInner, DfdxDevice>)
  {
    TAPE_GLOBAL.with(|tape| {
      let mut locked_tape = tape
        .lock()
        .unwrap();
      assert!(
        locked_tape.is_none(),
        "Tape must be None before calling set"
      );
      *locked_tape = Some(value);
    });
  }
}

impl Merge<NoneTape> for TapeGlobal
{
  fn merge(
    self,
    _: NoneTape,
  ) -> Self
  {
    self
  }
}

impl Merge<TapeGlobal> for TapeGlobal
{
  fn merge(
    self,
    _other: Self,
  ) -> Self
  {
    self
  }
}

impl Tape<FloatInner, DfdxDevice> for TapeGlobal
{
  const OWNS_TAPE: bool = true;

  fn add_backward_op<F>(
    &mut self,
    operation: F,
  ) where
    F: 'static
      + FnOnce(
        &mut Gradients<FloatInner, DfdxDevice>,
      ) -> Result<(), <DfdxDevice as HasErr>::Err>,
  {
    let mut tape = TapeGlobal::get();
    tape.add_backward_op(operation);
    TapeGlobal::set(tape);
  }
}

pub struct TapeGlobalTensor(
  pub Tensor<Rank0, FloatInner, DfdxDevice, TapeGlobal>,
);

impl HasErr for TapeGlobalTensor
{
  type Err = <Tensor<Rank0, FloatInner, DfdxDevice, TapeGlobal> as HasErr>::Err;
}

impl Backward<FloatInner, DfdxDevice> for TapeGlobalTensor
{
  fn try_backward(self)
    -> Result<Gradients<FloatInner, DfdxDevice>, Self::Err>
  {
    let (t, _) = self
      .0
      .split_tape();
    let tape = TapeGlobal::get();
    t.put_tape(tape)
      .try_backward()
  }
}

TapeGlobal does what it says: It maintains a global, thread-local, tape, thus avoiding the partitioned-gradients problem*.
Here's how you use it:

TapeGlobal::init();
let x = dev
  .tensor(1.0)
  .put_tape(TapeGlobal);
let y = x.clone() * x;
dbg!(TapeGlobalTensor(y).backward());

// If calling in a loop, you'd call TapeGlobal::reset() then
// TapeGlobal::init() each time to clear the gradients.

Let me know if you'd like a PR with something like this.

Thoughts on the tape variants

As you know, I have very non-standard model inputs and outputs, and I've been playing with different tape tracking APIs in the hope of finding one that's easy to use and not error-prone, specifically: OwnedTape<_, _>, Arc<Mutex<OwnedTape<_, _>>>, Arc<Mutex<Arc<Mutex<OwnedTape<_, _>>>>>, and GlobalTape.

Here are my thoughts:

  1. OwnedTape<_, _>: Simple and avoids Arc, but suffers from the partitioned-gradients problem, so the programmer must either 1) mentally track which tape has which gradients, or 2) do a final accumulate across all the model outputs to merge all the tapes into one.
  2. Arc<Mutex<OwnedTensor<_, _>>>: Uses Arc (sorta bad) and only partially solves the partitioned-gradients problem.
  3. Arc<Mutex<Arc<Mutex<OwnedTensor<_, _>>>>>: Uses nested Arc<Mutex<_>> (bad code smell) and still does not fully solve the partitioned-gradients problem. E.g., if you have dependency paths that originate from model parameters and never interact with the main tape (this can occur in a conditional computation setting). This behavior isn't terrible, because these gradients will be zero, but you will get an unwrap error if you expect them to be defined.
  4. TapeGlobal: This does fully solve the partitioned-gradients problem, at the expense of maintaining an ugly global variable and losing type parameterization.

I'm personally still not sure which API I prefer between 1, 3, and 4 (I don't love any of them).
But I think 2 is essentially useless and rather dangerous, as you can accidentally omit gradients from your tape, and the shared state makes the tape accumulation difficult to reason about.

*Partitioned-gradients problem: When the gradients for you network are partitioned across several tapes and you as a programmer have to worry about which tape has which gradients.

I don't see a way to do global/shared tapes without using some form of Arc/Mutex.

FWIW the optimizers.update returns an error that indicates whether some tensors that were supposed to have gradients did not, so that was intended to capture this error.

Arc<Mutex<_>>

FYI you can define the TapeGlobal state like this:

thread_local! {
static TAPE_GLOBAL:
  RefCell<Option<OwnedTape<FloatInner, DfdxDevice>>
> = RefCell::new(None);
}

IIUC, thread_local! ensures there are no threading concurrency issues, and the lack of async ensures there are no cooperative futures concurrency issues.
So, you don't need anything like an Arc<Mutex<_>>; a RefCell<_> suffices.

FWIW I don't love TapeGlobal but I prefer it to the other approaches.

Missing gradients

I think relying on that optimizer.update behavior becomes problematic when you have conditional computation (as I do, of course!).
In this case you expect only a fraction of the parameters to have gradients for any given pass.

Are these variants the only way to reuse a taped tensor more than once during a forward pass? E.g. when x is used twice in

TapeGlobal::init();
let x = dev
  .tensor(1.0)
  .put_tape(TapeGlobal);
let y = x.clone() * x;
dbg!(TapeGlobalTensor(y).backward());

What's the difference between this and retaping x (instead of cloning)? I'm trying to continue work on #437, but am suspicious that some of the retaping is incorrect (particularly lines 105 and 115 in rl-ppo-continuous.rs).