Testing NNLib / Lux / Flux
gdalle opened this issue · comments
Lower hanging fruit: NNLib.jl, because there are less weird structs, mostly arrays
Cross-referencing:
Slow
- Replace several calls to
grad_test
with a vector of scenarios, like insoftmax.jl
and thenscatter.jl
Fast
- Replace
grad_test
and watch the world crumble: https://github.com/FluxML/NNlib.jl/blob/master/test/test_utils.jl
If we want to be adventurous, you can change https://github.com/LuxDL/LuxTestUtils.jl and all downstream CPU tests in Lux will be triggered (and we just need to copy one of the buildkite files from LuxLib to trigger the CUDA + AMDGPU tests)
Don't tempt me Avik
On a serious note though, I had to write it to mostly deal with arrays or at least convert structures to arrays https://github.com/LuxDL/LuxTestUtils.jl/blob/143a51f0d2fb4cbc75ea583c706ff5194be103d2/src/LuxTestUtils.jl#L387-L398, so that could be helpful to writing your test suite. (But this is also terribly inefficient and only tests correctness and definitely don't combine @test_gradients
with @jet
)
Are the tests of LuxTestUtils already interesting to run locally, or should we wait for the Downstream CI every time?
no the tests there do nothing practically, it is all via the downstream CI
but the Lux test suite doesn't take long -- 10 mins on a nicer machine (like the buildkite ones) but github actions ones take longer ~30 mins
If you want to test locally, set RETESTITEMS_NWORKERS
and it will be much faster
So the workflow is to:
- fork LuxTestUtils.jl and Lux.jl
- put my own gradient callers in LuxTestUtils.jl
- dev LuxTestUtils.jl into the test environment of Lux.jl
- test Lux.jl
right?
If you want to test locally yes.
Any suggestions on dealing with multiple arguments? Is wrapping them in a ComponentVector
always gonna work, or are there non-array structs in the mix?
DifferentiationInterface only accepts a single input
I'm thinking https://docs.julialang.org/en/v1/base/base/#Base.splat on a ComponentVector
Based on how the tests are written, for multiple arguments, I assume any non-array is non-differentiable (this is a testing package so I can assume that) so these get filtered out in https://github.com/LuxDL/LuxTestUtils.jl/blob/143a51f0d2fb4cbc75ea583c706ff5194be103d2/src/LuxTestUtils.jl#L357-L383. After that there are 2 possibilities -- 1) backend supports multi args so in that case it just forwards it 2) all other cases use a componentarray and create a closure which unflattens the componentarray to provide the correct args.
I'll see what I can do once our own testing interface stabilizes. Step one would be to replace your gradient calls, but we can actually aim to replace your entire testing macro
Our function https://gdalle.github.io/DifferentiationInterface.jl/dev/api/#DifferentiationInterfaceTest.test_differentiation does something very similar
I'll see what I can do once our own testing interface stabilizes. Step one would be to replace your gradient calls, but we can actually aim to replace your entire testing macro
correct. I had planned to replace the API with something like skip = [AutoTracker(), ...]
and broken = [AutoReverseDiff()...]
. But eventually we might use DI