gdalle / DifferentiationInterface.jl

An interface to various automatic differentiation backends in Julia.

Home Page:https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Testing NNLib / Lux / Flux

gdalle opened this issue · comments

Lower hanging fruit: NNLib.jl, because there are less weird structs, mostly arrays

Cross-referencing:

Slow

  • Replace several calls to grad_test with a vector of scenarios, like in softmax.jl and then scatter.jl

Fast

If we want to be adventurous, you can change https://github.com/LuxDL/LuxTestUtils.jl and all downstream CPU tests in Lux will be triggered (and we just need to copy one of the buildkite files from LuxLib to trigger the CUDA + AMDGPU tests)

Don't tempt me Avik

On a serious note though, I had to write it to mostly deal with arrays or at least convert structures to arrays https://github.com/LuxDL/LuxTestUtils.jl/blob/143a51f0d2fb4cbc75ea583c706ff5194be103d2/src/LuxTestUtils.jl#L387-L398, so that could be helpful to writing your test suite. (But this is also terribly inefficient and only tests correctness and definitely don't combine @test_gradients with @jet)

Are the tests of LuxTestUtils already interesting to run locally, or should we wait for the Downstream CI every time?

no the tests there do nothing practically, it is all via the downstream CI

but the Lux test suite doesn't take long -- 10 mins on a nicer machine (like the buildkite ones) but github actions ones take longer ~30 mins

If you want to test locally, set RETESTITEMS_NWORKERS and it will be much faster

So the workflow is to:

  1. fork LuxTestUtils.jl and Lux.jl
  2. put my own gradient callers in LuxTestUtils.jl
  3. dev LuxTestUtils.jl into the test environment of Lux.jl
  4. test Lux.jl

right?

If you want to test locally yes.

Any suggestions on dealing with multiple arguments? Is wrapping them in a ComponentVector always gonna work, or are there non-array structs in the mix?

DifferentiationInterface only accepts a single input

Based on how the tests are written, for multiple arguments, I assume any non-array is non-differentiable (this is a testing package so I can assume that) so these get filtered out in https://github.com/LuxDL/LuxTestUtils.jl/blob/143a51f0d2fb4cbc75ea583c706ff5194be103d2/src/LuxTestUtils.jl#L357-L383. After that there are 2 possibilities -- 1) backend supports multi args so in that case it just forwards it 2) all other cases use a componentarray and create a closure which unflattens the componentarray to provide the correct args.

I'll see what I can do once our own testing interface stabilizes. Step one would be to replace your gradient calls, but we can actually aim to replace your entire testing macro

I'll see what I can do once our own testing interface stabilizes. Step one would be to replace your gradient calls, but we can actually aim to replace your entire testing macro

correct. I had planned to replace the API with something like skip = [AutoTracker(), ...] and broken = [AutoReverseDiff()...]. But eventually we might use DI