google / heir

A compiler for homomorphic encryption

Home Page:https://heir.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Write canonicalization patterns for NTT/INTT

j2kun opened this issue · comments

In particular, immediate ntt/intt pairs can be elided. Depends on #182

I've created a simple canonicalization pass: NTT(INTT(x)) => x and INTT(NTT(x)) => x.

Now, I started thinking that it would be good to have the following optimization pass as well:
INTT(Op(NTT(x1), NTT(x2))) => Op(x1, x2) where Op = polynomial.add/sub/mul_scalar (and the version with NTT and INTT switched.)

Should this be included in the canonicalization pass, or should I create a new pass for it? 😃

I think this second pattern will be a bit more complicated, since Op will be a tensor op in one domain and a polynomial op in the other, and it is context dependent which is cheaper (since it depends on what other neighboring ops are doing, if they are in the coefficient or NTT domain)

I think the right approach to figure out what to do next is:

  1. Add a pass that actually introduces ntt/intt ops by replacing polynomial.mul with ntt + tensor mul + intt.
  2. Analyze some natural programs involving polynomial multiplications to look for optimization opportunities. Ideally we can do this by lowering some high level programs from BGV to polynomial, but that will involve a lot of extra work to actually implement bgv operations.

That or, comb through the HEaaN.mlir paper looking for optimizations related to this: https://dl.acm.org/doi/pdf/10.1145/3591228

I could also imagine a sort of global optimization that has a cost model for ntt/intt vs the ops in each domain, and tries to optimize the total cost jointly. If that's of interest we could try to spec out the optimization model.

I think this second pattern will be a bit more complicated, since Op will be a tensor op in one domain and a polynomial op in the other, and it is context dependent which is cheaper (since it depends on what other neighboring ops are doing, if they are in the coefficient or NTT domain)

Ah, you are right! This requires considering much more complex factors than I thought.

I'll tackle the task 1(making a (i)ntt introducing pass) and then skim through the HEaaN.mlir optimizations when time allows. 👍

I filed a separate issue for HEaaN.mlir, and we can limit this issue to the basic canonicalization patterns: #635

I think loop fusion would be a good place to start. CC @inbelic since he is working on the initial ntt lowering.

Sweet. Feel free to ask me any questions on how the ntt op lowering is implemented.

There is also a (messy) initial implementation of turning poly.mul into ntt ops here:(inbelic@2e0f6d5). You can patch that to do some analysis with it.

Awesome! Thank you everyone for letting me know and I'll move on to the new issue now :)

Syfer-MLIR paper also talks about the optimization of forward and inverse NTT using distributive property, using CSE pass, and using all operator/operands in expression in transform domain to eliminate unneeded forward and inverse NTTs.
It gets complicated to get these constraints into code but well worth it in Heir

Safer-MLIR paper ...

Could you link to this paper? I don't recall it.

Safer-MLIR paper ...

Could you link to this paper? I don't recall it.

Maybe SyFER-MLIR? They do discuss a few things around NTT, iirc, but mostly things you get "for free" with MLIR which we already do (cse, folding away ntt(intt(x)) and intt(ntt(x)), etc). They do also do things such as moving (i)NTTs through additions (i.e. ntt(a) + ntt(b) becomes ntt(a+b)), which is probably a "safe" rewrite (most reasonable backends shouldn't show a performance difference between add in coeff form and eval form).

I'm sorry about the typo/auto-correct; it is indeed Syfer-MLIR.
@AlexanderViand-Intel - the paper moves out n-invocations of FFT and IFFT; this is important and not come for free from MLIR AFAIK

I'm sorry about the typo/auto-correct; it is indeed Syfer-MLIR. @AlexanderViand-Intel - the paper moves out n-invocations of FFT and IFFT; this is important and not come for free from MLIR AFAIK

So the only FHE-specific optimization I see from their paper is canceling ntt(intt(x)) and intt(ntt(x)), as Alex mentioned, which we already have from #631

We should add Alex's additional suggested pattern (ntt(a) + ntt(b) = ntt(a+b)). I suggested in an earlier comment that this optimization depends on what's going on around it, but for add/sub I think it should be always more optimal since add is equivalently expensive in each domain.

Opened a PR for this upstream: llvm/llvm-project#93132

Unless there are any more specific canonicalization patterns we think we can add for ntt/intt, I think we can close this issue.

I will close this issue, but note that until I get #675 working and merged (upstreaming polynomial) the new pattern won't apply.