tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.

Home Page:https://burn.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Feature Request: Partial Derivative

loganbnielsen opened this issue · comments

Feature description

I would like to be able to take partial derivative of the neural network.

In PyTorch like this: https://stackoverflow.com/a/66709533/10969548

In TensorFlow like this: https://stackoverflow.com/a/65968334/10969548

Feature motivation

This feature is useful whenever your model explicitly uses partials in its objective function. (e.g. differential equation solvers)

Linking an existing ticket which was closed because additional information was missing:

#121

Is the missing information about what a mixed partial derivative is? Maybe we can work with a pretty simple example:

f(x,y) = x^2 + 3y + xy

Then the cross partial would be 1. (you take the partial with respect to x or y and then the partial w.r.t to the other)

Using @nathanielsimard code from #121 the cross partial would be the same as:

fn run<B: Backend>() {
    let a = ADTensor::<B, 2>::random(...);
    let b = ADTensor::<B, 2>::random(...);
    let y = some_function(&a, &b);
    
    let grads = y.backward();
    
    let grad_a = grads.wrt(&a); // d some_function / da
    let grad_b = grads.wrt(&b); // d some_function / db

   // extension of provided code
   grad_ab = grad_a.wrt(&b); 
   grad_ba = grad_b.wrt(&a);

  // grad_ab == grad_ba -- Young's Theorem: https://en.wikipedia.org/wiki/Symmetry_of_second_derivatives
}

(I don't know if the new lines I added are legal code, I'm haven't done much with Burn yet. Presently doing the burn book MNIST classification example.)

I'm not sure about the details for how this is implemented efficiently in Tensorflow or Pytorch. Is this something I should do some research into? Or how can I be helpful?

@nathanielsimard could you provide the docs to wrt that you referenced in #121? For some reason having a hard time finding this method. There may have been some API changes since the post since since ADTensor doesn't appear to be a type anymore either.

@loganbnielsen They are now:

let mut grads = loss.backward(); // Compute the gradients.
let x_d = x.grad(&grads); // Can retrieve multiple times.
let x_d = x.grad_remove(&mut grads); // Can retrieve only one time.