gdalle / DifferentiationInterface.jl

An interface to various automatic differentiation backends in Julia.

Home Page:https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add Reactant

wsmoses opened this issue · comments

It's not quite an AD engine in its own right, but given how it traces things and can make for much better performance/compatibility, and can be used with an AD engine, it seems like a natural fit.

Why not but I don't understand what the package does or how it would fit within DI, so you'll have to hold my hand on this one ^^

Yeah so basically, Reactant exports a compile utility which takes a julia function and arguments, and compiles it to a nice/fast version of it. It does require that all data one cares about to be passed in using ConcreteRArrays rather than Ararys [and structs of those/etc]. We do have a tracer utility which will take an arg and auotmatically changes arrays to ConcreteRArrays for you though.

This can obviously compile code with an autodiff call on the inside, such as Enzyme. See here: https://github.com/EnzymeAD/Reactant.jl/blob/292dc03593ceb1a7a1f022fd7d3289bd69b000b5/test/basic.jl#L81. Enzyme usage should end up with quite good performance since the fancy optimizations that Reactant applies will be able to interoperate with AD.

We also now have a brief tldr in the readme: https://github.com/EnzymeAD/Reactant.jl/tree/main

So how would you see this in the context of DI? As a variant of the Enzyme backend which Reactant-compiles the function during preparation?

And do what with it? Compile and convert everything before and after use?

I would welcome a PR with an example!

I'm not quite sure how to make an example PR, but here's an example from our tests: https://github.com/EnzymeAD/Reactant.jl/blob/59d9304948cc7e28acdd4351db5e069d62a4f4ec/test/basic.jl#L92

function sumcos(x)
    return sum(cos.(x))
end

function resgrad_ip(x)
    dx = Enzyme.make_zero(x)
    res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx))
    return (res, dx)
end

 c = Reactant.ConcreteRArray(ones(3, 2))
    f = Reactant.compile(resgrad_ip, (c,))
    orig, r = f(c)

    @test orig[2]  sum(cos.(ones(3, 2)))
    @test r  -sin.(ones(3, 2))

Essentially you have to convert all inputs to ConcreteRarrays, then compile your desired function (which in this case will call an autodiff tool).

Happy to hop on a call or help however else!

This would be really great to get in and would give me a lot more use cases for DI.

Also fun fact the Reactant backend wouldn't suffer from any potential closure issues (since it would trace closures out anyways)

I'm open to testing it but I can't do much until it is registered

oh it's been registered for a while now!

My bad, I only checked the GitHub repo and the releases aren't tagged there

I wonder how to add this to DI with as little code duplication as possible. I see several options:

  • A backend wrapper ReactantBackend(AutoSomething)
  • A function wrapper ReactantFunction(f)
  • An option in the preparation prepare_operator(f, backend, x; reactant=true)

The first one seems easier to just put in a Reactant.jl extension