NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question: Scaling Factor of Weights Primary vs Not

jomayeri opened this issue · comments

In this section of code that appears in the Linear and other modules

            if primary_weights_in_fp8:
                # Weight is already in FP8
                weight.reset_fp8_meta_scale_inv()
                weight_fp8 = weight
                weight_t_fp8 = None
            elif update_fp8_weights:
                # Need to cast weights to FP8
                weight_fp8 = Float8Tensor(
                    data=weight_fp8._data,
                    fp8_meta=fp8_meta,
                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
                )

Why is the meta_scale_inv reset in the case where the weights are stored in fp8 vs the case where they are casted?

This is a kludge to get Float8Tensor working with the original FP8 implementation. Originally we stored FP8 data in UINT8 buffers and kept scale_inv within FP8TensorMetas, so we needed to make sure that the two were always updated together or else the FP8 data would not be interpreted correctly. Float8Tensor keeps its own copy of scale_inv, so there's more flexibility in when you can update the scaling factors in FP8TensorMeta (e.g. we update the FP8 data in the optimizer step and update the FP8 scaling factors before the forward pass). However, some TE kernels take in FP8TensorMeta as an arg, so reset_fp8_meta_scale_inv is needed to make sure that the scale_inv in FP8TensorMeta matches Float8Tensor.