gdalle / DifferentiationInterface.jl

An interface to various automatic differentiation backends in Julia.

Home Page:https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Make Enzyme dispatches compatible with closures

ChrisRackauckas opened this issue · comments

In the Enzyme setups https://github.com/gdalle/DifferentiationInterface.jl/blob/main/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L13 it looks like you're using the raw f. This omits the handling of any memory associated with caches, particularly within closures. To fix this is rather straightforward though, you can just copy SciMLSensitivity. You just do a duplicated on the f https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/derivative_wrappers.jl#L697 where the duplicated part is just an f_cache = Enzyme.make_zero(f) copy. To make this safe for repeated application, you need to add a call Enzyme.make_zero!(f_cache) so its duplicated values are always zero if you reuse it.

@wsmoses said this was probably a bad idea due to performance degradation, so I'm leaving the PR #341 closed for now. Are there other solutions?

Well the other option is incorrectness or just erroring if caches are used, I don't see how that's better?

I mean honestly this is where activity info/multi arg is critical.

if you have a closure (which is is required by DI atm), then you'll end up differentiating every var in the original fn. So if you have something like

NN = complex neural network
DI.gradient(AutoEnzyme(), x->NN() + x, 3.1)

you'll now be forced to AD the entire neural network as opposed to the one scalar. In this case leading an O(1) derivative being unboundedly worse. Without the ability to handle multiple args/activity, DI would be forced to AD through the whole NN if the closure were marked active.

Frankly, this is where I'd say it makes sense for DI to figure out how it and/or AD.jl wants to handle multiple args, use direct Enzyme autodiff calls which don't have such limitations for now, revisiting this question later.

I'm slowly getting a clearer picture of how I can pull it off. But the initial plan was for AbstractDifferentiation to handle multiple arguments, so I wanna wait for @mohamed82008's approval before I dive into it within DI.

Even if DI handles multiple arguments though, you'd still want to duplicate the function because if you don't handle any enclosed caches correctly you can get incorrect derivatives, so I don't see why this would wait. Indeed the downside is that you always have to assume that all caches can be differentiated, and this is then a good reason to allow for multiple arguments so you can Const some, but my point is that if we want DI to actually be correct then we do need to enforce the differentiation of enclosed variables carries forward their derivative values.

It at least needs to be an option, AutoEnzyme(duplicate_function = Val(true)) by default, but can be Val(false) as an optimization if someone wants to forcibly Const all enclosed values (at their own risk). If someone has no enclosed values there's no overhead, and if they are non-const then the default is correct, so it's just a performance optimization so I'd leave that as a user toggle. Adding that to ADTypes would be good for SciMLSensitivity as well as we'd do the same in implementation.

My point about support for multiple arguments and/or activity, is that they would potentially remedy the performance issue in my example.

if DI supported specifying the function as const/duplicated [aka activity] the problem is trivially remedied.

In the alternative, if multiple arguments were supported [perhaps with a Const input], you could pass the NN and/or closure data in it and again avoid the issue.

I don't disagree with that. My point though is that even if DI makes all of the inputs arguments, the default activity on a function would likely be const unless the documentation showed people how to do this. I don't think that's the right default for DI since then many common Julia functions would give wrong values. You'd basically have to say, don't pass f, the interface is Duplicated(f, make_zero(f)). My point is that shouldn't be left to the user of DI who should expect that the simple thing is correct, and if DI.gradient(f, x) is wrong because they need to DI.gradient(Duplicated(f, make_zero(f)), x) otherwise they drop derivatives on enclosed caches, I would think something has gone wrong with the interface. My suggestion is to just via AutoEnzyme make the assumption that's required, which is still optimal in the case that there are no caches, but yes is effectively a safety copy done to make caching functions work out of the box, but with an option to turn it off at their own risk.

But also, DI shouldn't wait until multi-arg activities are supported before doing any of this. Otherwise it will have issues with user-written closures until multi-arg activities, which arguably is a pretty nasty bug that requires a hotfix. It does mean that yes constants enclosed in functions will slow things down a bit because you'll differentiate more than you need to, but it also means that enclosed cache variables will correctly propagate derivatives which is more important to a high level interface.

I didn't test this exactly, but I would think an MWE would be as simple as:

a = [1.0]
function f(x)
  a[1] = 1.0
  a[1] += x
  a[1]^2
end

would give an incorrect derivative with DI without this, which to me is a red flag that needs to be fixed. And then we can argue when the multi-arg form comes whether the user needs to enable the fix or whether the fix comes enabled by default, but I don't think we should wait to make this work.

And to be clear, I don't think Enzyme's interface should do this, but Enzyme is a much lower level utility targeting a different level of user.

I tend to agree with Chris on this one. Until I add activities or multiple arguments, better safe and slow than fast and wrong.

I see what you're saying, but I still feel like this is an edge case that is more likely to cause problems for users than fixes.

In particular, the only sort of case where this is needed is where you read and write to a captured buffer. With the exception of preallocation tools code, this is immensely rare in Julia functions you want to AD [largely due to non-mutation, especially non-mutation of closures].

However, by marking the entire closure as duplicated, you now need enzyme to successfully differentiate all closure operations, including those where this read and write to capured buffer doesn't apply. If there's a function currently unhandled by Enzyme you'll error with the duplicated fn, whereas marking it const would succeed.

To be clear, I see the arguments for both sides of this, but I'm wondering what is the better trade off to make.

Honestly, given that I'm doubtful of much code outside of preallocationtools that would have this apply, I wonder if it make sense to just add a preallocationtools mode to DI [which may be separately useful in its own right]

In particular, the only sort of case where this is needed is where you read and write to a captured buffer. With the exception of preallocation tools code, this is immensely rare in Julia functions you want to AD [largely due to non-mutation, especially non-mutation of closures].

That's not really the case though. It's not rare. It's actually very common and explicitly mentioned in the documentation of many packages and tutorials that one should write non-allocating code. Here is one of many examples of that:

https://docs.sciml.ai/DiffEqDocs/stable/tutorials/faster_ode_example/#Example-Accelerating-Linear-Algebra-PDE-Semi-Discretization

Such functions are made to be fully mutating and non-allocating, and also fully type-stable, and so perfectly within the realm of Enzyme. And these functions will not error but give the wrong answer if the closure is not duplicated, which is not the nicest behavior.

I think you're thinking specifically about Flux using functors where it's effectively allocating type-unstable functional code carrying around parameters in its objects which may not need to be differentiated. Flux is the weird one, not everything else. I actually can't think of another library that is engineered similarly to Flux, while most scientific models, PDE solvers, etc. are engineered similarly to the example I have up there where pre-allocated buffers are either passed around or enclosed and then used for getting a allocation-free runtime. And in any case, I'd argue it should be the Flux example to opt-out of duplicating the closure as a performance improvement, not the scientific models, PDE solvers, etc. opting into duplicating the function in order to ensure they get the right gradient value on repeated applications with caches.

Ayu = zeros(N, N)
uAx = zeros(N, N)
Du = zeros(N, N)
Ayv = zeros(N, N)
vAx = zeros(N, N)
Dv = zeros(N, N)
function gm3!(dr, r, p, t)
    a, α, ubar, β, D1, D2 = p
    u = @view r[:, :, 1]
    v = @view r[:, :, 2]
    du = @view dr[:, :, 1]
    dv = @view dr[:, :, 2]
    mul!(Ayu, Ay, u)
    mul!(uAx, u, Ax)
    mul!(Ayv, Ay, v)
    mul!(vAx, v, Ax)
    @. Du = D1 * (Ayu + uAx)
    @. Dv = D2 * (Ayv + vAx)
    @. du = Du + a * u * u ./ v + ubar - α * u
    @. dv = Dv + a * u * u - β * v
end
prob = ODEProblem(gm3!, r0, (0.0, 0.1), p)
@btime solve(prob, Tsit5());