EnzymeAD / Enzyme

High-performance automatic differentiation of LLVM and MLIR.

Home Page:https://enzyme.mit.edu

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support for llvm.smin and llvm.smax

rmoyard opened this issue · comments

I have some code that generates llvm including llvm.smin and llvm.smax and get the following error when running the Enzyme passes.

cannot handle (reverse) unknown intrinsic
llvm.smin
  %36 = tail call i64 @llvm.smin.i64(i64 %35, i64 3) #6

Are those operations supported by Enzyme? If not what is the best way to add them?

Oh hi, someone else from Toronto!
Enzyme doesn't really support computing gradients of integers, but I assume you try to get them for some float based variable and then it's possible that our Analysis isn't strong enough here to understand what's really going on. You can try enzyme-loose-ta, see here: https://enzyme.mit.edu/getting_started/UsingEnzyme/#loose-type-analysis
You might also need to add support in this style:

if (looseTypeAnalysis) {

edit: Here is the code handling smax:

if (ID == Intrinsic::umax || ID == Intrinsic::smax)

You should be able to add umin and smin here in a small PR. Since you also have llvm-ir, can you please also add a minimal testcase to "test/Enzyme/ReverseMode"?

Hi @ZuseZ4 nice to see another Torontonian here 👍 Thanks a lot for your answer!

This is correct I have some float based variable, I have this function (python) where we slice and sum. We (Catalyst) take the gradient and this where we use Enzyme.

import catalyst
import jax

def f(x):
    return jax.numpy.sum(x[::2])

x = jax.numpy.array([0.1, 0.2, 0.3, 0.4])
print(catalyst.qjit(catalyst.grad(f)(x))

it is lowered to MLIR and then to LLVM (we add _enzyme_autodiff), we apply O2 and then we apply the Enzyme passes. We end up with the LLVM (before Enzyme) for the original function. More details here https://gist.github.com/rmoyard/8402c4991eed5bced0c9a23cfaf8ecf3

define internal void @f.cloned(ptr nocapture readnone %0, ptr nocapture readonly %1, i64 %2, i64 %3, i64 %4, ptr nocapture readnone %5, ptr nocapture writeonly %6, i64 %7) {
.critedge:
  ...
  %46 = tail call i32 @llvm.smax.i32(i32 %45, i32 0)
  %47 = zext nneg i32 %46 to i64
  %48 = tail call i64 @llvm.smin.i64(i64 %47, i64 3)
  %49 = getelementptr double, ptr %1, i64 %48
  %50 = load double, ptr %49, align 8
  ...
  ret void
}

I don't think the args of smin and smax are depending on the arguments of the function and should probably not be active variables in the calculation of the gradient. Should it not be ignore in the adjoint generator? What do you think? Also maybe we do not generate the correct args flags for the function (noneed, constant, ...).

On the other hand I will give a try to the loose type analysis, but I would prefer to have a common solution for all functions.

Update:
I have added the looseTypeAnalysis and the couple of lines to add smin in the adjoint generator but I encounter another issue with another operation.
Enzyme: Cannot deduce adding type (cast) of %47 = zext nneg i32 %46 to i64
from LLVM doc it corresponds to:

The ‘zext’ instruction zero extends its operand to type ty2.

That seems more like type analysis cannot deduce this as inactive nicely. Do you have an llvm reproducer. This should be fairly simple to fix/add to type analysis.

Oh sorry missed the IR dump.

Yeah You also probably want to add a similar looseTypeAnalysis to a zext cast. Probably just extra cases here for int right after the float ones:

Do you mind opening a PR for your previous fix, as well as for this one?

The true solution is telling Type Analysis how to go through smin/max, which presumably it cannot deduce through.

You can see the results of type analysis here: https://fwd.gymni.ch/Mljbnu adding the -enzyme-print-type flag will show how the types propagate. Adding these intrinsics properly would be done here:

case Intrinsic::ctpop:

Sorry, specifically, TypeAnalysis is clearly missing smax/smin, which is stopping type propagation here.

@wsmoses Thanks, I am trying to add smax and smin to the type analysis!