NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Output scale not being used with `te_gemm` in FP8

snarayan21 opened this issue · comments

Hey, I'm using the te_gemm function defined in the PyTorch extensions here, and I'm trying to apply a scaling factor to the output. My gemm inputs are in fp8e4m3 and the output is in bf16.

For the D_scale argument, I am passing in tensors like torch.tensor([4.0], device='cuda') but changing the value of the scaling factor has no impact on the output. Am I doing something wrong here? Or is the scaling factor only applied when the output is of a certain dtype?

Hello @snarayan21. Yes, the purpose of this argument is to be the scaling factor for the output when the operator is producing FP8. The cuBLAS API does not otherwise have a parameter for scaling of the output. There is a potentially close thing in alpha: cublas performs D = alpha * A * B + beta * C. Currently in te_gemm we do not give a way to specify alpha (and beta is set to 1 by the accumulate option which performs D = A * B + D).
What is the usecase you are interested in that would benefit from that scaling?

@ptrendx Oh I see. I think that since the model I have doesn't use bias terms, it would be nice to just specify alpha...but would I get the same result by just modifying the A_scale_inverse in the te_gemm function? As in, I could multiply the existing A_scale_inverse by alpha to get what I want? The model has a few scaling factors that I would like to fuse into the FP8 gemms, if possible.

Yes, if you don't have bias that would be possible - you just need to be careful to not overwrite the scale inverse given as an input and instead create a new tensor there to pass to the gemm.