NVIDIA / MatX

An efficient C++17 GPU numerical computing library with Python-like syntax

Home Page:https://nvidia.github.io/MatX

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[FEA] Allow Real * Complex GEMMs

cliffburdick opened this issue · comments

Currently BLAS only supports GEMMs with both inputs of the same type (real/complex). We should support a mix of real/complex inputs even if it takes a conversion internally.

Can this be done at the API level with a simple constexpr branch checking types and dispatching the compex operator?

void gemm(...)
if constexp( is_complex_v(Atype) && !is_complex_v(Btype)) {
gemm(C, A, complex(B), stream);
} else if constexpr (!is_complex_v(Atype) && is_complex_v(Btype){
gemm(C, complex(A), B, stream);
} else {
// existing launch code.
}

we need to async allocate a new tensor and upconvert to the required type. For example, if it's complex * real, we need to allocate space for a complex, assign the real to the complex, then do the gemm. I think what you're suggesting would only work if the GEMM libraries could take operators.

I think i wrote it so gemm does this automatically if one of the inputs is an operator.

Great, so then we just need to extend that to promote any types to the output type

@cliffburdick @luitjens I would like to give it a shot. Will try to send a PR soon :)