SHI-Labs / Neighborhood-Attention-Transformer

Neighborhood Attention Transformer, arxiv 2022 / CVPR 2023. Dilated Neighborhood Attention Transformer, arxiv 2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Does forward always use fp16?

rayleizhu opened this issue · comments

I notice that you add @custom_fwd(cast_inputs=torch.float16) above the forward() method, does this mean NATTEN operator always runs fp16 inference?

@custom_fwd(cast_inputs=torch.float16)

I think it was resolved, but just in case anyone else finds this, those decorators are there to tell PyTorch's Automatic Mixed Precision (AMP) that our custom modules support half precision. If they're removed, NATTEN modules won't run in half precision.
So to answer the question: no, it means that it will always run in FP16 when AMP is enabled.

I think it was resolved, but just in case anyone else finds this, those decorators are there to tell PyTorch's Automatic Mixed Precision (AMP) that our custom modules support half precision. If they're removed, NATTEN modules won't run in half precision. So to answer the question: no, it means that it will always run in FP16 when AMP is enabled.

Hi, I have one question. If we set fix the data type as fp16 here, the value for fp16 may overflow or underflow. Is it correct?

Yes that is correct, but to my knowledge you should refrain from explicitly typecasting and only let AMP decide when to typecast or not. But how typecast values are clipped is probably not up to higher level pieces of software, and more dependent on your hardware and version of CUDA API.

Yes that is correct, but to my knowledge you should refrain from explicitly typecasting and only let AMP decide when to typecast or not. But how typecast values are clipped is probably not up to higher level pieces of software, and more dependent on your hardware and version of CUDA API.

The issue is that you should fix the data type here, either float or fp16, otherwise you will encounter a running error. Have you encountered this problem before? BTW, I am training a huge transformer and encountered NaN losses halfway. Since I am using this NATTEN package, I am not sure if the problem is caused by this package. I am still figuring out the potential bug.

The issue is that you should fix the data type here, either float or fp16, otherwise you will encounter a running error. Have you encountered this problem before?
No sorry, I haven't really run into this before.

As is, if I enable mixed precision training (wrap my main loop with the amp wrapper) it will cast to fp16, otherwise it will remain fp32.

BTW, I am training a huge transformer and encountered NaN losses halfway. Since I am using this NATTEN package, I am not sure if the problem is caused by this package. I am still figuring out the potential bug.

We strongly encourage not using mixed precision in larger models.
It is not uncommon to see this happen, it's mostly a precision issue and not limited to NATTEN.
You can optionally disable mixed precision for NA/DiNA by either changing that one line, or again forcing blocks of your code that include NA to run in full precision, regardless of AMP being enabled or not:

with torch.autocast(dtype=torch.float32):
    # Run these operations with full precision all the time.
    x = na2d(x)

References:

https://pytorch.org/docs/stable/amp.html