TiledTensor / TiledCUDA

TiledCUDA is a highly efficient kernel template library designed to elevate CUDA C’s level of abstraction for processing tiles.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to define the layout of register tile?

haruhi55 opened this issue · comments

Currently, we have a quick implementation for the register tile.

enum class RegLayout {
Default = 0,
// Refer to this slide for details on how data is distributed in each
// thread's local register after the TCU's WMMA operation:
// https://developer.download.nvidia.com/video/gputechconf/gtc/2020/
// presentations/s21745-developing-cuda-kernels-to-push-tensor-cores-to-the-absolute-limit-on-nvidia-a100.pdf
// WMMA shape is "m16n16k16
WMMA_m16n16k16 = 1,

This approach is implementation-driven. The register tile, which involves tensor core instructions, has a sub-structure that requires a more structured way to communicate its layout to other macro kernels, such as global-to-register copy, register-level GEMM, reduction, etc.

First, consider the sophisticated layout required by the tensor core, while ensuring sufficient expressive power with reasonable implementation complexity and avoiding the introduction of unnecessary concepts

A reasonable simplification is that the register tile is a two-dimensional depth-2 nested array, declared as follows:

using RTile = RegTile<RegTile<Element, tl::RowMajor<2, 4>>, tl::RowMajor<height, width>>;
  1. At the bottom level: an atomic 2D array (with the 1D array being a degenerate case of the 2D array) that has a specific shape to utilize the hardware's capabilities.
  2. At the higher level: how many times the atomic array is repeatedly executed along two dimensions.

This two-level nesting contributes to the final register tile for a single thread.

@KuangjuX I wrote a quick summary of our discussion.

BaseTile

When executing warp-cooperative instructions to process a block of data with a 16x16 shape, irrespective of the element type, the BaseTile is defined by the number of elements processed by a single thread according to this shape, which is 16 x 16 / 32 = 8. A BaseTile is represented as a list, and its internal structure is not further elaborated.

RegTile

Treating the BaseTile as the element and the user only needs to specify one layout to determining how the BaseTile is laid out in memory, either in RowMajor or ColMajor format.

Let's examine the example from the figure above and explore the potential usage of a register tile. The figure illustrates the execution of four ldmatrix instructions to load a 32 x 32 matrix, where:

  1. The elements are half-precision floating-point numbers transferred from shared memory to a thread's local register.
  2. The shared memory is organized in a row-major layout, with the contiguous dimension in shared memory indicated by the orange arrow. The red box represents a BaseTile.
  3. When multiple BaseTiles are required to compose an entire 32 x 32 data tile, each dimension requires loading 2 BaseTiles. The criterion is that inner loops should process the dimension that is contiguous in memory. The layout for RegTile specifies the order in which these executions of BaseTiles are stored in the destination memory hierarchy.

Examples:
If the register tile is declared as:

using Tile = RegTile<RowMajor<2, 2>>
0 1
0 0,1,256,257,8,9,264,265,16,17,272 16,17,272,273,24,25,280,281
1 512,513,768,769,520,521,776,777 528,528,784,785,536,537,792,793
using Tile = RegTile<ColMajor<2, 2>>
0 1
0 0,1,256,257,8,9,264,265,16,17,272 512,513,768,769,520,521,776,777
1 16,17,272,273,24,25,280,281 528,528,784,785,536,537,792,793

When the source layout differs from the target layout, this load also completes the transposition.