Question: Scaling Factor of Weights Primary vs Not
jomayeri opened this issue · comments
In this section of code that appears in the Linear and other modules
if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
Why is the meta_scale_inv
reset in the case where the weights are stored in fp8 vs the case where they are casted?
This is a kludge to get Float8Tensor
working with the original FP8 implementation. Originally we stored FP8 data in UINT8 buffers and kept scale_inv
within FP8TensorMeta
s, so we needed to make sure that the two were always updated together or else the FP8 data would not be interpreted correctly. Float8Tensor
keeps its own copy of scale_inv
, so there's more flexibility in when you can update the scaling factors in FP8TensorMeta
(e.g. we update the FP8 data in the optimizer step and update the FP8 scaling factors before the forward pass). However, some TE kernels take in FP8TensorMeta
as an arg, so reset_fp8_meta_scale_inv
is needed to make sure that the scale_inv
in FP8TensorMeta
matches Float8Tensor
.