pytorch / FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[QST] Int4 Decoding: `ThreadID`, `ElementID`

jeromeku opened this issue · comments

Thanks for the great blogpost describing the step-by-step process optimizing the GQA decoding kernel.

Having some trouble understanding Figures 15 and 16, which describe the tensor-core thread, element, and swizzled layouts.

Specifically, what do the Element-IDs refer to: what are the logical (8, 16) coordinates being mapped to (0, 1, 2, 3) in the middle layout of Figure 15? I understand the Thread-ID layout (top layout in Figure 15) as a mapping of logical elements to thread ownership (Thread 0 holds elements (0,0) and (0,1), Thread 1 holds (0,2) and (0,3), etc.

Also, can you expand on the point about how the new layout with swizzling achieves contiguity?

Hello @jeromeku, thank you for your questions!

Specifically, what do the Element-IDs refer to: what are the logical (8, 16) coordinates being mapped to (0, 1, 2, 3) in the middle layout of Figure 15? I understand the Thread-ID layout (top layout in Figure 15) as a mapping of logical elements to thread ownership (Thread 0 holds elements (0,0) and (0,1), Thread 1 holds (0,2) and (0,3), etc.

You can think of a tensor core fragment as a logical storage unit. Physically, data in the tensor fragment is stored in thread local buffers (which are distributed among all threads in a warp). For the P fragment (Figure 15). There are 8 * 16 = 128 elements in the fragment. Each thread stores 128 / 32 = 4 elements in a buffer. Let's call this buffer frag_p.x. So, in this case, frag_p.x's length is 4. You can access it using a load operator like frag_p.x[i] where 0 <= i < 4. Each element-id in the figure is i for each thread.

So, let's take the first cell (row 0, column 0) in the figure as an example. The first cell is stored in frag_p.x[0] of Thread 0. While the last cell (row 7, column 15) is stored in frag_p.x[3] of Thread 31.

Also, can you expand on the point about how the new layout with swizzling achieves contiguity?

Actually, we made the data layout contiguous to to achieve smaller bank conflicts in Optimization 9. We basically transposed the shared memory layout (see the figure below). In the old layout, the stride of dim 0 of each fragment is 8192. But in the new layout, the stride of dim 0 of each fragment is 8 which is the size of dim 1. With this new layout, we are able to increase the concurrent shared memory bank accesses from 4 to 16.

frag_c_overview

I hope this helps.