SHI-Labs / Neighborhood-Attention-Transformer

Neighborhood Attention Transformer, arxiv 2022 / CVPR 2023. Dilated Neighborhood Attention Transformer, arxiv 2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is it necessary to write dedicated fp16 kernel ?

rayleizhu opened this issue · comments

Thanks for your great work. You provide a very good template to start from for building attention extensions.

I notice that you use dedicated fp16 kernels, in which instructions like __hfma2 are used, e.g.:

__global__ void natten1dqkrpb_cuda_forward_kernel_fp16(

So, if I directly reuse the kernel currently used for fp32, and dispatch it with AT_DISPATCH_FLOATING_TYPES_AND_HALF as specified in https://discuss.pytorch.org/t/how-can-i-write-the-cuda-code-to-support-fp16-calculation/107181, will the speed slow down heavily?

Thank you for your interest.
So I would recommend you take a look at our older commits to see the previous version of the CUDA extension.
Much simpler, but also naive. Usually the more specialized you make your kernels, the faster they can run.
That's why we have the new "tiled" kernels, they use shared memory based on assumptions made about the problem (i.e. dim and kernel size.)

As for FP16, there's two separate things: FP16 support and FP16 utilization.
Dispatching using AT_DISPATCH_FLOATING_TYPES_AND_HALF allows you to support FP16, it does nothing to utilize it.
Meaning it will call your kernel with data types float, double, and half (technically more than that but this is basically the three types).
Supporting FP16 is important because if you're using AMP, it's better for your module to support FP16 to save up on memory, reduce extra casting, and so on. (However not always feasible, because certain operations need the full precision, i.e. softmax).

FP16 utilization is a different story, and that's where half2 comes in. Now I may get the details wrong here (not a CUDA expert), so someone else might correct what I say, but the idea is that multiplying two half scalars doesn't increase your throughput by much, because it's just not that much more efficient, cycle-wise.
It's a good way to save on memory, but not really faster.
What you do with half2, which I wish PyTorch would have supported because we've had to override some things to work it in, is you simply change the addressing and package two half scalars into a half2 which is the same size as a float.
And with those special instructions for multiply and additions, you're basically doing operations on two things at once -- more parallelization.
As a result you get a major speedup by using up HALF the threadblocks you would need with full precision, because you're computing two cells essentially in each kernel instead of one. And that's FP16 utilization.
With these special kernels you get a speedup of up to 80-90%, whereas if you don't do these it'll barely be 20-30%.

If you take a look at PyTorch's source, a lot of their native kernels (that we could find) didn't support and therefore didn't utilize half2 for some reason (this was months ago, not sure if it's still the case), which is the reason for a lot of rewriting -- and writing our own special dispatcher to ensure it doesn't call weird half types from ATen to mess up compilation.

I hope this clarifies things.

Thank you for your interest. So I would recommend you take a look at our older commits to see the previous version of the CUDA extension. Much simpler, but also naive. Usually the more specialized you make your kernels, the faster they can run. That's why we have the new "tiled" kernels, they use shared memory based on assumptions made about the problem (i.e. dim and kernel size.)

As for FP16, there's two separate things: FP16 support and FP16 utilization. Dispatching using AT_DISPATCH_FLOATING_TYPES_AND_HALF allows you to support FP16, it does nothing to utilize it. Meaning it will call your kernel with data types float, double, and half (technically more than that but this is basically the three types). Supporting FP16 is important because if you're using AMP, it's better for your module to support FP16 to save up on memory, reduce extra casting, and so on. (However not always feasible, because certain operations need the full precision, i.e. softmax).

FP16 utilization is a different story, and that's where half2 comes in. Now I may get the details wrong here (not a CUDA expert), so someone else might correct what I say, but the idea is that multiplying two half scalars doesn't increase your throughput by much, because it's just not that much more efficient, cycle-wise. It's a good way to save on memory, but not really faster. What you do with half2, which I wish PyTorch would have supported because we've had to override some things to work it in, is you simply change the addressing and package two half scalars into a half2 which is the same size as a float. And with those special instructions for multiply and additions, you're basically doing operations on two things at once -- more parallelization. As a result you get a major speedup by using up HALF the threadblocks you would need with full precision, because you're computing two cells essentially in each kernel instead of one. And that's FP16 utilization. With these special kernels you get a speedup of up to 80-90%, whereas if you don't do these it'll barely be 20-30%.

If you take a look at PyTorch's source, a lot of their native kernels (that we could find) didn't support and therefore didn't utilize half2 for some reason (this was months ago, not sure if it's still the case), which is the reason for a lot of rewriting -- and writing our own special dispatcher to ensure it doesn't call weird half types from ATen to mess up compilation.

I hope this clarifies things.

Very clear.

So in version 0.11,

  1. AT_DISPATCH_FLOATING_TYPES_AND2 is for FP16 support but not FP16 utilization, right?
  2. can you give an impression on speed comparisons: v0.11-FP32 v.s. v0.11-FP16, and v0.11-FP16 v.s. v0.12-FP16?

AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, d_rpb.scalar_type(), "nattenqkrpb_backward_cuda", ([&] {

Exactly. 0.11 supported FP16, 0.12 utilizes it.
AT_DISPATCH_FLOATING_TYPES_AND2 is basically compiling the same kernel with different data types for your basic support: kHalf, kBFloat16, along with float and `double.
It's in ATen, one of PyTorch's CUDA backends.
In 0.12, we borrow the same dispatcher and write one specifically for FP16:

#define AT_DISPATCH_HALF_TYPES(SCALARTYPE1, TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE( \
NAME, \
SCALARTYPE1, \
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
__VA_ARGS__) \
default: \
break; \
} \
}()

This is both to limit it only to FP16, and to remove kBFloat16, which is harder to cast to half and half2.

As for speed, I'm afraid I don't have detailed measurements across different test cases, but basically just supporting FP16 leads to kernels seeing very little improvement in throughput (15% to 25% given a certain dim/head combo and different resolutions).
With utilizing it, improvements range from 70% to 90% improvement per kernel.
But keep in mind per kernel improvement doesn't impact model throughput the same way. For instance throughput on a Tiny variant model may jump from 3700 imgs/sec to 4800 imgs/sec when training, which is barely 30% total, but without the so called utilization it would be more like 3700 to 3900.

Also keep in mind there are even more advanced FP16 utilizations in NVIDIA libraries. For instance, SGEMM works out differently with FP16 and yields at least a 3 fold improvement.

Thanks for so quick and detailed reply. It helps me a lot.

One more question: I notice that, in v0.11, acc_type is used. I comprehend this is to reduce the risk of numerical overflow. But in v0.12 FP16 kernels, no such acc_type is used, why? Do half/half2 operators avoid numerical overflow automatically?

So that's basically the accumulator, as each kernel typically does more than a single operation to compute its corresponding output cell. And yes, you are right, for FP16 scalar type, acc scalar would end up being float again.

using accscalar_t = at::acc_type<scalar_t, true>;

The problem with keeping it that way is that it again hurts FP16 utilization, because outputs are computed in half precision, then cast back to full to be accumulated, and this is repeated a number of times.
Additionally there's the problem of casting half2 into half now.

However, this does not guarantee preventing numerical overflow, but in our experience it hasn't caused any issues so far.

Hi, I wanna confirm one thing: the throughput comparisons (e.g. this figure in README) are conducted under FP32 precision inference, right?

  1. As far as I know, torch.bmm() supports tensor core operations (nvcuda::wmma).
  2. However, your current implementation of NA does not utilize tensor cores.
  3. Moreover, I think it is difficult for NA to do that support, as key/value tokens are per-query determined, hence can only be divided to 1*KERNEL_SIZE_SQUARE*CHAN_STRIDE GEMMs instead of 16*16*CHAN_STRIDE GEMMs which are supported by tensor cores.

Am I correct?

  1. I would say that torch.bmm does not necessarily call a specific CUDA kernel (at least not in our experience), that is subject to other factors such as environment, architecture and input.
  2. No, our implementation does not utilize tensor cores.
  3. It is hard to say -- first off because there's two forward kernels, and they are different in terms of operations, and secondly because it's all a matter of implementation, not method. We faced similar questions when developing the tiled version of NA, it all comes down to implementation. There's always going to be more efficient architecture-specific kernels that improve throughput, and similarly better algorithms. We do no have plans on whether or not to explore them at this time, but we do hope to eventually see contributions from the community.

Closing this due to inactivity. If you still have questions feel free to open it back up.