tinygrad / tinygrad

You like pytorch? You like micrograd? You love tinygrad! ❤️

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[DRAFT PROPOSAL] Outline for AMD >100TFLOPS matmul for 7900XTX bounty

flammit opened this issue · comments

For anyone considering working on this bounty, here's some background information that might be helpful.

The following is based on using the ROCM debugging tools to see how this was achieving >100TF performance:

GPU_DUMP_CODE_OBJECT=1 AMD_LOG_LEVEL=3 ./MIOpenDriver gemmfp16 --iter 1000 --time 1 --a_w 4096 --a_h 4096 --b_w 4096

Note this debugging was done a long time ago, so it might be different on a more recent version of ROCM. Will follow-up with results and more detailed instructions.

This pseudo-code can be translated to an appropriate UOps and Lazy graph once the appropriate structures exist in core tinygrad.

Here's psuedo-code for the rocblas/tensile generated kernel:

// AMD TENSILE GEMM PSUEDOCODE for 4096x4096 with ~100TFs
// Cijk_Ailk_Bljk_HHS_BH_MT128x128x16_MI16x16x16x1_SN_1LDSB0_APM1_ABV0_ACED0_AF0EM1_AF1EM1_AMAS3_ASE_ASGT_ASAE01_ASCE01_ASEM1_AAC0_BL1_BS1_DTL0_DTVA0_DVO0_ETSP_EPS1_FL0_GRVW8_GSU1_GSUASB_GLS0_ISA1100_IU1_K1_KLA_LBSPP128_LPA0_LPB8_LDL1_LRVW16_LWPMn1_LDW0_FMA_MIAV1_MDA2_NTA0_NTB0_NTC0_NTD0_NEPBS0_NLCA1_NLCB1_ONLL1_OPLV0_PK0_PAP0_PGR1_PLR1_RK0_SIA1_SS1_SU32_SUM0_SUS128_SCIUI1_SPO0_SRVW0_SSO0_SVW4_SNLL0_TT4_64_TLDS1_USFGROn1_VAW2_VSn1_VW4_WSGRA1_WSGRB1_WS32_WG32_4_1_WGM4

// GEMM problem params
#define M 4096
#define N 4096
#define K 4096

// tensor core params
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
#define WARP_SIZE 32

// post-TC OptOp.* params
#define M_UPCAST 4 // number WMMAs along M axis per warp/threadgroup per iteration (after OptOps.TC, OptOps.UPCAST remaining global M dim)
#define N_UPCAST 4 // number WMMAs along N axis per warp/threadgroup per iteration (after OptOps.TC, OptOps.UPCAST remaining global N dim)
#define K_UPCAST 1 // number WMMAs along K axis per warp/threadgroup per iteration (after OptOps.TC, OptOps.UNROLL remaining reduce K dim) - only 1 supported below
#define M_THREADGROUPS 2 // number threadgroups along M axis per kernel (after non-TC OptOps.UPCAST, OptOps.LOCAL remaining global M dim)
#define N_THREADGROUPS 2 // number threadgroups along M axis per kernel (after non-TC OptOps.UPCAST, OptOps.LOCAL remaining global N dim)
#define K_THREADGROUPS 1 // number threadgroups along K axis per kernel (after non-TC OptOps.UNROLL, OptOps.GROUP remaining reduce K dim) - only 1 supported below

// shared memory params
#define A_SMEM_SIZE ((M_THREADGROUPS * M_UPCAST * WMMA_M) * (WMMA_K) * sizeof(half))
#define B_SMEM_SIZE ((WMMA_K) * (N_THREADGROUPS * N_UPCAST * WMMA_N) * sizeof(half))
#define SMEM_SIZE (2 * (A_SMEM_SIZE + B_SMEM_SIZE)) // double-buffered

// iteration params
#define BLOCK_K (K/WMMA_K) // 256

// kernel launch params
#define GLOBAL_X (M/WMMA_M/M_UPCAST/M_THREADGROUPS) // 32
#define GLOBAL_Y (N/WMMA_N/N_UPCAST/N_THREADGROUPS) // 32
#define GLOBAL_Z 1
#define THREADS (WARP_SIZE*M_THREADGROUPS*N_THREADGROUPS) // 128 = 4 WARPs, each WARP doing 4x4 WMMA

// kernel calculates a M=128 x N=128 block of GEMM output using 128 threads and gridDims=[32,32,1]
__device__ void amd_fp16_gemm_kernel(const __device__ half *a, const __device__ half *b, __device__ half *d) {
    extern __shared__ char shared[SMEM_SIZE];

    const extern __shared__ half *low_temp_a  = (__shared__ half *)(shared);
    const extern __shared__ half *low_temp_b  = (__shared__ half *)(current_temp_a + A_MEM_SIZE);
    const extern __shared__ half *high_temp_a = (__shared__ half *)(current_temp_b + B_MEM_SIZE);
    const extern __shared__ half *high_temp_b = (__shared__ half *)(next_temp_a    + A_MEM_SIZE);

    // initialize first smem pointers to low as current, high as next
    extern __shared__ half* current_temp_a = low_temp_a;
    extern __shared__ half* current_temp_b = low_temp_b;
    extern __shared__ half* next_temp_a    = high_temp_a;
    extern __shared__ half* next_temp_a    = high_temp_a;

    // initialize the A WMMA fragments per WARP (M_UPCAST);
    half16 a_frag_0 = half16(0.0);
    half16 a_frag_1 = half16(0.0);
    half16 a_frag_2 = half16(0.0);
    half16 a_frag_3 = half16(0.0);

    // initialize the B WMMA fragments per WARP (N_UPCAST);
    half16 b_frag_0 = half16(0.0);
    half16 b_frag_1 = half16(0.0);
    half16 b_frag_2 = half16(0.0);
    half16 b_frag_3 = half16(0.0);

    // initialize the accumulators for the 4x4=16 WMMAs per thread (M_UPCAST*N_UPCAST)
    float8 wmma0_0 = float8(0.0);
    float8 wmma0_1 = float8(0.0);
    float8 wmma0_2 = float8(0.0);
    float8 wmma0_3 = float8(0.0);
    float8 wmma1_0 = float8(0.0);
    float8 wmma1_1 = float8(0.0);
    float8 wmma1_2 = float8(0.0);
    float8 wmma1_3 = float8(0.0);
    float8 wmma2_0 = float8(0.0);
    float8 wmma2_1 = float8(0.0);
    float8 wmma2_2 = float8(0.0);
    float8 wmma2_3 = float8(0.0);
    float8 wmma3_0 = float8(0.0);
    float8 wmma3_1 = float8(0.0);
    float8 wmma3_2 = float8(0.0);
    float8 wmma3_3 = float8(0.0);

    // calc global load offsets
    /*
    v_mov_b32_e32 v5, v2                                       // 000000279794: 7E0A0302
    v_add_co_u32 v6, vcc_lo, 16, v5                            // 000000279798: D7006A06 00020A90
    v_mov_b32_e32 v7, v1                                       // 0000002797A0: 7E0E0301
    v_add_co_u32 v8, vcc_lo, 2, v7                             // 0000002797A4: D7006A08 00020E82
    ...
    v_mul_lo_u32 v10, s50, v7                                  // 0000002797C8: D72C000A 00020E32
    v_add_co_u32 v197, vcc_lo, v4, v10                         // 0000002797D0: D7006AC5 00021504
    v_add_nc_u32_e32 v197, 8, v197                             // 0000002797D8: 4B8B8A88
    v_lshlrev_b32_e32 v197, 1, v197                            // 0000002797DC: 318B8A81
    v_mul_lo_u32 v10, s50, v8                                  // 0000002797E0: D72C000A 00021032
    v_add_co_u32 v198, vcc_lo, v4, v10                         // 0000002797E8: D7006AC6 00021504
    v_add_nc_u32_e32 v198, 8, v198                             // 0000002797F0: 4B8D8C88
    v_lshlrev_b32_e32 v198, 1, v198                            // 0000002797F4: 318D8C81
    */
    size_t block_a_off = ...;
    /*
    v_mul_lo_u32 v10, s52, v5                                  // 0000002797F8: D72C000A 00020A34
    v_add_co_u32 v199, vcc_lo, v9, v10                         // 000000279800: D7006AC7 00021509
    v_add_nc_u32_e32 v199, 8, v199                             // 000000279808: 4B8F8E88
    v_lshlrev_b32_e32 v199, 1, v199                            // 00000027980C: 318F8E81
    v_mul_lo_u32 v10, s52, v6                                  // 000000279810: D72C000A 00020C34
    v_add_co_u32 v200, vcc_lo, v9, v10                         // 000000279818: D7006AC8 00021509
    v_add_nc_u32_e32 v200, 8, v200                             // 000000279820: 4B919088
    v_lshlrev_b32_e32 v200, 1, v200                            // 000000279824: 31919081
    */
    size_t block_b_off = ...;

    size_t store_shared_a_off = ...;
    size_t store_shared_b_off = ...;
    size_t load_shared_a_off = ...;
    size_t load_shared_b_off = ...;
    // TODO: figure out exact alias layout (probably optimized for buffer_load_b128/ds_store_b128 from global, swizzles in the ds_load_16 for B buffer)

    // load first iteration of global A to registers - 16 FP16 per thread (WMMA_M * M_UPCAST * M_THREADGROUPS * WMMA_K / THREADS)
    /*
    buffer_load_b128 v[202:205], v197, s[8:11], 0 offen        // 000000279988: E05C0000 8042CAC5
    buffer_load_b128 v[206:209], v198, s[8:11], 0 offen        // 000000279990: E05C0000 8042CEC6
    */
    half8 global_a1 = (half8 *)(block_a + block_a_off    );
    half8 global_a2 = (half8 *)(block_a + block_a_off + 4); // 2 from 000000279798 LSHL'd 1

    // load first iteration of global B to registers - 16 FP16 per thread (WMMA_M * M_UPCAST * M_THREADGROUPS * WMMA_K / THREADS)
    /*
    buffer_load_b128 v[210:213], v199, s[12:15], 0 offen       // 000000279998: E05C0000 8043D2C7
    buffer_load_b128 v[214:217], v200, s[12:15], 0 offen       // 0000002799A0: E05C0000 8043D6C8
    */
    half8 global_b1 = (half8 *)(block_b + block_b_off);
    half8 global_b2 = (half8 *)(block_b + block_b_off + 32); // 16 from 000000279798 LSHL'd 1.  TODO: this doesn't seem right

    // wait for vmcnt(0), then store first iteration of the A global loads to shared memory -
    /*
    ...
    s_waitcnt vmcnt(0)                                         // 000000279CA0: BF8903F7
    ds_store_b128 v195, v[202:205]                             // 000000279CA4: DB7C0000 0000CAC3
    ds_store_b128 v195, v[206:209] offset:512                  // 000000279CAC: DB7C0200 0000CEC3
    */
    (half8 *)(current_temp_a + store_shared_a_off + 0)   = global_a1;
    (half8 *)(current_temp_a + store_shared_a_off + 512) = global_a2; // TODO: is that offset right??

    // store first iteration of the B global loads to shared memory
    /*
    ds_store_b128 v196, v[210:213]                             // 000000279CB4: DB7C0000 0000D2C4
    ds_store_b128 v196, v[214:217] offset:576                  // 000000279CBC: DB7C0240 0000D6C4
    */
    (half8 *)(current_temp_b + store_shared_b_off + 0)   = global_b1;
    (half8 *)(current_temp_b + store_shared_b_off + 576) = global_b2; // TODO: is that offset right??

    // do all but the last K iterations (last one won't need next iteration loads)
    for (int k = 0; k < BLOCK_K-1; k++) {
        // wait for lgkmcnt(0) -- wait for current iteration stores to shared memory to finish, and barrier
        /*
        ...
        s_waitcnt lgkmcnt(0)                                       // 000000279CCC: BF89FC07
        s_waitcnt_lgkmcnt null, 0x0                                // 000000279CD0: BDFC0000
        s_barrier                                                  // 000000279CD4: BFBD0000
        */
        __syncthreads();

        // load from shared memory the first 2 columns (2x16 FP16) of B for the current iteration: use 32 ds_load_u16 each at different offsets
        /*
        ds_load_u16 v130, v218                                     // 000000279CD8: D8F00000 820000DA
        ds_load_u16_d16_hi v130, v218 offset:256                   // 000000279CE0: DA9C0100 820000DA
        ds_load_u16 v131, v218 offset:512                          // 000000279CE8: D8F00200 830000DA
        ds_load_u16_d16_hi v131, v218 offset:768                   // 000000279CF0: DA9C0300 830000DA
        ds_load_u16 v132, v218 offset:1024                         // 000000279CF8: D8F00400 840000DA
        ds_load_u16_d16_hi v132, v218 offset:1280                  // 000000279D00: DA9C0500 840000DA
        ds_load_u16 v133, v218 offset:1536                         // 000000279D08: D8F00600 850000DA
        ds_load_u16_d16_hi v133, v218 offset:1792                  // 000000279D10: DA9C0700 850000DA
        ds_load_u16 v134, v218 offset:2048                         // 000000279D18: D8F00800 860000DA
        ds_load_u16_d16_hi v134, v218 offset:2304                  // 000000279D20: DA9C0900 860000DA
        ds_load_u16 v135, v218 offset:2560                         // 000000279D28: D8F00A00 870000DA
        ds_load_u16_d16_hi v135, v218 offset:2816                  // 000000279D30: DA9C0B00 870000DA
        ds_load_u16 v136, v218 offset:3072                         // 000000279D38: D8F00C00 880000DA
        ds_load_u16_d16_hi v136, v218 offset:3328                  // 000000279D40: DA9C0D00 880000DA
        ds_load_u16 v137, v218 offset:3584                         // 000000279D48: D8F00E00 890000DA
        ds_load_u16_d16_hi v137, v218 offset:3840                  // 000000279D50: DA9C0F00 890000DA
        ds_load_u16 v138, v218 offset:2                            // 000000279D58: D8F00002 8A0000DA
        ds_load_u16_d16_hi v138, v218 offset:258                   // 000000279D60: DA9C0102 8A0000DA
        ds_load_u16 v139, v218 offset:514                          // 000000279D68: D8F00202 8B0000DA
        ds_load_u16_d16_hi v139, v218 offset:770                   // 000000279D70: DA9C0302 8B0000DA
        ds_load_u16 v140, v218 offset:1026                         // 000000279D78: D8F00402 8C0000DA
        ds_load_u16_d16_hi v140, v218 offset:1282                  // 000000279D80: DA9C0502 8C0000DA
        ds_load_u16 v141, v218 offset:1538                         // 000000279D88: D8F00602 8D0000DA
        ds_load_u16_d16_hi v141, v218 offset:1794                  // 000000279D90: DA9C0702 8D0000DA
        ds_load_u16 v142, v218 offset:2050                         // 000000279D98: D8F00802 8E0000DA
        ds_load_u16_d16_hi v142, v218 offset:2306                  // 000000279DA0: DA9C0902 8E0000DA
        ds_load_u16 v143, v218 offset:2562                         // 000000279DA8: D8F00A02 8F0000DA
        ds_load_u16_d16_hi v143, v218 offset:2818                  // 000000279DB0: DA9C0B02 8F0000DA
        ds_load_u16 v144, v218 offset:3074                         // 000000279DB8: D8F00C02 900000DA
        ds_load_u16_d16_hi v144, v218 offset:3330                  // 000000279DC0: DA9C0D02 900000DA
        ds_load_u16 v145, v218 offset:3586                         // 000000279DC8: D8F00E02 910000DA
        ds_load_u16_d16_hi v145, v218 offset:3842                  // 000000279DD0: DA9C0F02 910000DA
        ... // note not sure why they load an extra 4 elements of the 3rd column
        ds_load_u16 v146, v218 offset:4                            // 000000279DD8: D8F00004 920000DA
        ds_load_u16_d16_hi v146, v218 offset:260                   // 000000279DE0: DA9C0104 920000DA
        ds_load_u16 v147, v218 offset:516                          // 000000279DE8: D8F00204 930000DA
        ds_load_u16_d16_hi v147, v218 offset:772                   // 000000279DF0: DA9C0304 930000DA
        */
        // load first two 16-element columns of B
        b_frag_0 = half16(current_temp_b[load_shared_b_off+0], current_temp_b[load_shared_b_off+256], current_temp_b[load_shared_b_off+512], ...);
        b_frag_1 = half16(current_temp_b[load_shared_b_off+2], current_temp_b[load_shared_b_off+258], current_temp_b[load_shared_b_off+514], ...);

        // load from global memory the next 16 FP16 of A per thread for the next iteration
        /*
        buffer_load_b128 v[202:205], v197, s[8:11], 0 offen        // 000000279DF8: E05C0000 8042CAC5
        buffer_load_b128 v[206:209], v198, s[8:11], 0 offen        // 000000279E00: E05C0000 8042CEC6
        */
        global_a1 = (half8 *)(block_a + block_a_off    );
        global_a2 = (half8 *)(block_a + block_a_off + 4); // 2 from 000000279798 LSHL'd 1

        // load from global memory the next 16 FP16 of B per thread for the next iteration
        /*
        buffer_load_b128 v[210:213], v199, s[12:15], 0 offen       // 000000279E08: E05C0000 8043D2C7
        buffer_load_b128 v[214:217], v200, s[12:15], 0 offen       // 000000279E10: E05C0000 8043D6C8
        */
        global_b1 = (half8 *)(block_b + block_b_off);
        global_b2 = (half8 *)(block_b + block_b_off + 32); // 16 from 000000279798 LSHL'd 1.  TODO: this doesn't seem right

        // load from shared memory the last 2 columns (2x16 FP16) of B for the current iteration: use 32 ds_load_u16 each at different offsets
        /*
        ...
        ds_load_u16 v148, v218 offset:1028                         // 000000279E60: D8F00404 940000DA
        ds_load_u16_d16_hi v148, v218 offset:1284                  // 000000279E68: DA9C0504 940000DA
        ds_load_u16 v149, v218 offset:1540                         // 000000279E70: D8F00604 950000DA
        ds_load_u16_d16_hi v149, v218 offset:1796                  // 000000279E78: DA9C0704 950000DA
        ds_load_u16 v150, v218 offset:2052                         // 000000279E80: D8F00804 960000DA
        ds_load_u16_d16_hi v150, v218 offset:2308                  // 000000279E88: DA9C0904 960000DA
        ds_load_u16 v151, v218 offset:2564                         // 000000279E90: D8F00A04 970000DA
        ds_load_u16_d16_hi v151, v218 offset:2820                  // 000000279E98: DA9C0B04 970000DA
        ds_load_u16 v152, v218 offset:3076                         // 000000279EA0: D8F00C04 980000DA
        ds_load_u16_d16_hi v152, v218 offset:3332                  // 000000279EA8: DA9C0D04 980000DA
        ds_load_u16 v153, v218 offset:3588                         // 000000279EB0: D8F00E04 990000DA
        ds_load_u16_d16_hi v153, v218 offset:3844                  // 000000279EB8: DA9C0F04 990000DA
        ds_load_u16 v154, v218 offset:6                            // 000000279EC0: D8F00006 9A0000DA
        ds_load_u16_d16_hi v154, v218 offset:262                   // 000000279EC8: DA9C0106 9A0000DA
        ds_load_u16 v155, v218 offset:518                          // 000000279ED0: D8F00206 9B0000DA
        ds_load_u16_d16_hi v155, v218 offset:774                   // 000000279ED8: DA9C0306 9B0000DA
        ds_load_u16 v156, v218 offset:1030                         // 000000279EE0: D8F00406 9C0000DA
        ds_load_u16_d16_hi v156, v218 offset:1286                  // 000000279EE8: DA9C0506 9C0000DA
        ds_load_u16 v157, v218 offset:1542                         // 000000279EF0: D8F00606 9D0000DA
        ds_load_u16_d16_hi v157, v218 offset:1798                  // 000000279EF8: DA9C0706 9D0000DA
        ds_load_u16 v158, v218 offset:2054                         // 000000279F00: D8F00806 9E0000DA
        ds_load_u16_d16_hi v158, v218 offset:2310                  // 000000279F08: DA9C0906 9E0000DA
        ds_load_u16 v159, v218 offset:2566                         // 000000279F10: D8F00A06 9F0000DA
        ds_load_u16_d16_hi v159, v218 offset:2822                  // 000000279F18: DA9C0B06 9F0000DA
        ds_load_u16 v160, v218 offset:3078                         // 000000279F20: D8F00C06 A00000DA
        ds_load_u16_d16_hi v160, v218 offset:3334                  // 000000279F28: DA9C0D06 A00000DA
        ds_load_u16 v161, v218 offset:3590                         // 000000279F30: D8F00E06 A10000DA
        ds_load_u16_d16_hi v161, v218 offset:3846                  // 000000279F38: DA9C0F06 A10000DA
        */
        b_frag_2 = half16(current_temp_b[load_shared_b_off+4], current_temp_b[load_shared_b_off+260], current_temp_b[load_shared_b_off+516], ...);
        b_frag_3 = half16(current_temp_b[load_shared_b_off+6], current_temp_b[load_shared_b_off+262], current_temp_b[load_shared_b_off+518], ...);

        // load from shared memory 64 FP16 of A for the current iteration: 8 x ds_load_b128s
        /*
        ds_load_b128 v[163:166], v219                              // 000000279F40: DBFC0000 A30000DB
        ds_load_b128 v[167:170], v219 offset:16                    // 000000279F48: DBFC0010 A70000DB
        ds_load_b128 v[171:174], v219 offset:1152                  // 000000279F50: DBFC0480 AB0000DB
        ds_load_b128 v[175:178], v219 offset:1168                  // 000000279F58: DBFC0490 AF0000DB
        ds_load_b128 v[179:182], v219 offset:2304                  // 000000279F60: DBFC0900 B30000DB
        ds_load_b128 v[183:186], v219 offset:2320                  // 000000279F68: DBFC0910 B70000DB
        ds_load_b128 v[187:190], v219 offset:3456                  // 000000279F70: DBFC0D80 BB0000DB
        ds_load_b128 v[191:194], v219 offset:3472                  // 000000279F78: DBFC0D90 BF0000DB
        */
        a_frag_0 = (half16 *)(current_temp_a[load_shared_a_off +    0]); // assume the 16 FP16 load will be broken to two contiguous 8 FP16/b128 loads
        a_frag_1 = (half16 *)(current_temp_a[load_shared_a_off + 1152]); // TODO: why is this offset??
        a_frag_2 = (half16 *)(current_temp_a[load_shared_a_off + 2304]); // TODO: why is this offset??
        a_frag_2 = (half16 *)(current_temp_a[load_shared_a_off + 3456]); // TODO: why is this offset??

        // store the next iteration of the A global loads to shared memory
        /*
        s_waitcnt vmcnt(3)                                         // 000000279F80: BF890FF7
        ds_store_b128 v195, v[202:205] offset:16384                // 000000279F84: DB7C4000 0000CAC3
        s_waitcnt vmcnt(2)                                         // 000000279F8C: BF890BF7
        ds_store_b128 v195, v[206:209] offset:16896                // 000000279F90: DB7C4200 0000CEC3
        */
        (half8 *)(next_temp_a + shared_a_off + 0)   = next_global_a1;
        (half8 *)(next_temp_a + shared_a_off + 512) = next_global_a2; // TODO: is that offset right??

        // store the next iteration of the B global loads to shared memory
        /*
        s_waitcnt vmcnt(1)                                         // 000000279F98: BF8907F7
        ds_store_b128 v196, v[210:213] offset:16384                // 000000279F9C: DB7C4000 0000D2C4
        s_waitcnt vmcnt(0)                                         // 000000279FA4: BF8903F7
        ds_store_b128 v196, v[214:217] offset:16960                // 000000279FA8: DB7C4240 0000D6C4
        */
        (half8 *)(next_temp_b + shared_b_off + 0)   = next_global_b1;
        (half8 *)(next_temp_b + shared_b_off + 576) = next_global_b2; // TODO: is that offset right?? maybe LDSPad for bank conflict

        // wait for all the shared memory loads to finish (but NOT the last 4 shared memory stores)
        /*
        s_waitcnt lgkmcnt(4)                                       // 000000279FB0: BF89FC47
        s_nop 1                                                    // 000000279FB4: BF800001
        */

        // execute all 16 WMMAs
        /*
        v_wmma_f32_16x16x16_f16 v[0:7], v[163:170], v[130:137], v[0:7]// 000000279FB8: CC404000 1C0305A3
        v_wmma_f32_16x16x16_f16 v[8:15], v[163:170], v[138:145], v[8:15]// 000000279FC0: CC404008 1C2315A3
        v_wmma_f32_16x16x16_f16 v[16:23], v[163:170], v[146:153], v[16:23]// 000000279FC8: CC404010 1C4325A3
        v_wmma_f32_16x16x16_f16 v[24:31], v[163:170], v[154:161], v[24:31]// 000000279FD0: CC404018 1C6335A3
        v_wmma_f32_16x16x16_f16 v[32:39], v[171:178], v[130:137], v[32:39]// 000000279FD8: CC404020 1C8305AB
        v_wmma_f32_16x16x16_f16 v[40:47], v[171:178], v[138:145], v[40:47]// 000000279FE0: CC404028 1CA315AB
        v_wmma_f32_16x16x16_f16 v[48:55], v[171:178], v[146:153], v[48:55]// 000000279FE8: CC404030 1CC325AB
        v_wmma_f32_16x16x16_f16 v[56:63], v[171:178], v[154:161], v[56:63]// 000000279FF0: CC404038 1CE335AB
        v_wmma_f32_16x16x16_f16 v[64:71], v[179:186], v[130:137], v[64:71]// 000000279FF8: CC404040 1D0305B3
        v_wmma_f32_16x16x16_f16 v[72:79], v[179:186], v[138:145], v[72:79]// 00000027A000: CC404048 1D2315B3
        v_wmma_f32_16x16x16_f16 v[80:87], v[179:186], v[146:153], v[80:87]// 00000027A008: CC404050 1D4325B3
        v_wmma_f32_16x16x16_f16 v[88:95], v[179:186], v[154:161], v[88:95]// 00000027A010: CC404058 1D6335B3
        v_wmma_f32_16x16x16_f16 v[96:103], v[187:194], v[130:137], v[96:103]// 00000027A018: CC404060 1D8305BB
        v_wmma_f32_16x16x16_f16 v[104:111], v[187:194], v[138:145], v[104:111]// 00000027A020: CC404068 1DA315BB
        v_wmma_f32_16x16x16_f16 v[112:119], v[187:194], v[146:153], v[112:119]// 00000027A028: CC404070 1DC325BB
        v_wmma_f32_16x16x16_f16 v[120:127], v[187:194], v[154:161], v[120:127]// 00000027A030: CC404078 1DE335BB
        */
        wmma0_0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_0, b_frag_0, wmma0_0, false);
        wmma0_1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_0, b_frag_1, wmma0_1, false);
        wmma0_2 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_0, b_frag_2, wmma0_2, false);
        wmma0_3 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_0, b_frag_3, wmma0_3, false);
        wmma1_0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_1, b_frag_0, wmma1_0, false);
        wmma1_1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_1, b_frag_1, wmma1_1, false);
        wmma1_2 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_1, b_frag_2, wmma1_2, false);
        wmma1_3 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_1, b_frag_3, wmma1_3, false);
        wmma2_0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_2, b_frag_0, wmma2_0, false);
        wmma2_1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_2, b_frag_1, wmma2_1, false);
        wmma2_2 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_2, b_frag_2, wmma2_2, false);
        wmma2_3 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_2, b_frag_3, wmma2_3, false);
        wmma3_0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_3, b_frag_0, wmma3_0, false);
        wmma3_1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_3, b_frag_1, wmma3_1, false);
        wmma3_2 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_3, b_frag_2, wmma3_2, false);
        wmma3_3 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_3, b_frag_3, wmma3_3, false);

        // swap smem buffers for the next iteration
        const int next_is_high = (k % 2);
        current_temp_a = next_is_high ? high_temp_a : low_temp_a;
        current_temp_b = next_is_high ? high_temp_b : low_temp_b;
        next_temp_a    = next_is_high ? low_temp_a  : high_temp_a;
        next_temp_b    = next_is_high ? low_temp_b  : high_temp_b;
    }

    // wait for the last iteration shared memory loads to complete, and sync
    /*
    s_waitcnt lgkmcnt(0)                                       // 00000027A424: BF89FC07
    s_waitcnt_lgkmcnt null, 0x0                                // 00000027A428: BDFC0000
    s_barrier                                                  // 00000027A42C: BFBD0000
    */
    __syncthreads();

    // load from shared memory 4 columns (4x16 FP16) of B for the last iteration: 64 x ds_load_16s
    /*
	ds_load_u16 v130, v218                                     // 00000027A430: D8F00000 820000DA
	ds_load_u16_d16_hi v130, v218 offset:256                   // 00000027A438: DA9C0100 820000DA
	ds_load_u16 v131, v218 offset:512                          // 00000027A440: D8F00200 830000DA
	ds_load_u16_d16_hi v131, v218 offset:768                   // 00000027A448: DA9C0300 830000DA
	ds_load_u16 v132, v218 offset:1024                         // 00000027A450: D8F00400 840000DA
	ds_load_u16_d16_hi v132, v218 offset:1280                  // 00000027A458: DA9C0500 840000DA
	ds_load_u16 v133, v218 offset:1536                         // 00000027A460: D8F00600 850000DA
	ds_load_u16_d16_hi v133, v218 offset:1792                  // 00000027A468: DA9C0700 850000DA
	ds_load_u16 v134, v218 offset:2048                         // 00000027A470: D8F00800 860000DA
	ds_load_u16_d16_hi v134, v218 offset:2304                  // 00000027A478: DA9C0900 860000DA
	ds_load_u16 v135, v218 offset:2560                         // 00000027A480: D8F00A00 870000DA
	ds_load_u16_d16_hi v135, v218 offset:2816                  // 00000027A488: DA9C0B00 870000DA
	ds_load_u16 v136, v218 offset:3072                         // 00000027A490: D8F00C00 880000DA
	ds_load_u16_d16_hi v136, v218 offset:3328                  // 00000027A498: DA9C0D00 880000DA
	ds_load_u16 v137, v218 offset:3584                         // 00000027A4A0: D8F00E00 890000DA
	ds_load_u16_d16_hi v137, v218 offset:3840                  // 00000027A4A8: DA9C0F00 890000DA
	ds_load_u16 v138, v218 offset:2                            // 00000027A4B0: D8F00002 8A0000DA
	ds_load_u16_d16_hi v138, v218 offset:258                   // 00000027A4B8: DA9C0102 8A0000DA
	ds_load_u16 v139, v218 offset:514                          // 00000027A4C0: D8F00202 8B0000DA
	ds_load_u16_d16_hi v139, v218 offset:770                   // 00000027A4C8: DA9C0302 8B0000DA
	ds_load_u16 v140, v218 offset:1026                         // 00000027A4D0: D8F00402 8C0000DA
	ds_load_u16_d16_hi v140, v218 offset:1282                  // 00000027A4D8: DA9C0502 8C0000DA
	ds_load_u16 v141, v218 offset:1538                         // 00000027A4E0: D8F00602 8D0000DA
	ds_load_u16_d16_hi v141, v218 offset:1794                  // 00000027A4E8: DA9C0702 8D0000DA
	ds_load_u16 v142, v218 offset:2050                         // 00000027A4F0: D8F00802 8E0000DA
	ds_load_u16_d16_hi v142, v218 offset:2306                  // 00000027A4F8: DA9C0902 8E0000DA
	ds_load_u16 v143, v218 offset:2562                         // 00000027A500: D8F00A02 8F0000DA
	ds_load_u16_d16_hi v143, v218 offset:2818                  // 00000027A508: DA9C0B02 8F0000DA
	ds_load_u16 v144, v218 offset:3074                         // 00000027A510: D8F00C02 900000DA
	ds_load_u16_d16_hi v144, v218 offset:3330                  // 00000027A518: DA9C0D02 900000DA
	ds_load_u16 v145, v218 offset:3586                         // 00000027A520: D8F00E02 910000DA
	ds_load_u16_d16_hi v145, v218 offset:3842                  // 00000027A528: DA9C0F02 910000DA
	ds_load_u16 v146, v218 offset:4                            // 00000027A530: D8F00004 920000DA
	ds_load_u16_d16_hi v146, v218 offset:260                   // 00000027A538: DA9C0104 920000DA
	ds_load_u16 v147, v218 offset:516                          // 00000027A540: D8F00204 930000DA
	ds_load_u16_d16_hi v147, v218 offset:772                   // 00000027A548: DA9C0304 930000DA
	ds_load_u16 v148, v218 offset:1028                         // 00000027A550: D8F00404 940000DA
	ds_load_u16_d16_hi v148, v218 offset:1284                  // 00000027A558: DA9C0504 940000DA
	ds_load_u16 v149, v218 offset:1540                         // 00000027A560: D8F00604 950000DA
	ds_load_u16_d16_hi v149, v218 offset:1796                  // 00000027A568: DA9C0704 950000DA
	ds_load_u16 v150, v218 offset:2052                         // 00000027A570: D8F00804 960000DA
	ds_load_u16_d16_hi v150, v218 offset:2308                  // 00000027A578: DA9C0904 960000DA
	ds_load_u16 v151, v218 offset:2564                         // 00000027A580: D8F00A04 970000DA
	ds_load_u16_d16_hi v151, v218 offset:2820                  // 00000027A588: DA9C0B04 970000DA
	ds_load_u16 v152, v218 offset:3076                         // 00000027A590: D8F00C04 980000DA
	ds_load_u16_d16_hi v152, v218 offset:3332                  // 00000027A598: DA9C0D04 980000DA
	ds_load_u16 v153, v218 offset:3588                         // 00000027A5A0: D8F00E04 990000DA
	ds_load_u16_d16_hi v153, v218 offset:3844                  // 00000027A5A8: DA9C0F04 990000DA
	ds_load_u16 v154, v218 offset:6                            // 00000027A5B0: D8F00006 9A0000DA
	ds_load_u16_d16_hi v154, v218 offset:262                   // 00000027A5B8: DA9C0106 9A0000DA
	ds_load_u16 v155, v218 offset:518                          // 00000027A5C0: D8F00206 9B0000DA
	ds_load_u16_d16_hi v155, v218 offset:774                   // 00000027A5C8: DA9C0306 9B0000DA
	ds_load_u16 v156, v218 offset:1030                         // 00000027A5D0: D8F00406 9C0000DA
	ds_load_u16_d16_hi v156, v218 offset:1286                  // 00000027A5D8: DA9C0506 9C0000DA
	ds_load_u16 v157, v218 offset:1542                         // 00000027A5E0: D8F00606 9D0000DA
	ds_load_u16_d16_hi v157, v218 offset:1798                  // 00000027A5E8: DA9C0706 9D0000DA
	ds_load_u16 v158, v218 offset:2054                         // 00000027A5F0: D8F00806 9E0000DA
	ds_load_u16_d16_hi v158, v218 offset:2310                  // 00000027A5F8: DA9C0906 9E0000DA
	ds_load_u16 v159, v218 offset:2566                         // 00000027A600: D8F00A06 9F0000DA
	ds_load_u16_d16_hi v159, v218 offset:2822                  // 00000027A608: DA9C0B06 9F0000DA
	ds_load_u16 v160, v218 offset:3078                         // 00000027A610: D8F00C06 A00000DA
	ds_load_u16_d16_hi v160, v218 offset:3334                  // 00000027A618: DA9C0D06 A00000DA
	ds_load_u16 v161, v218 offset:3590                         // 00000027A620: D8F00E06 A10000DA
	ds_load_u16_d16_hi v161, v218 offset:3846                  // 00000027A628: DA9C0F06 A10000DA
    */
    b_frag_0 = half16(current_temp_b[load_shared_b_off+0], current_temp_b[load_shared_b_off+256], current_temp_b[load_shared_b_off+512], ...);
    b_frag_1 = half16(current_temp_b[load_shared_b_off+2], current_temp_b[load_shared_b_off+258], current_temp_b[load_shared_b_off+514], ...);
    b_frag_2 = half16(current_temp_b[load_shared_b_off+4], current_temp_b[load_shared_b_off+260], current_temp_b[load_shared_b_off+516], ...);
    b_frag_3 = half16(current_temp_b[load_shared_b_off+6], current_temp_b[load_shared_b_off+262], current_temp_b[load_shared_b_off+518], ...);

    // load from shared memory 4 rows (4x16 FP16) of A for the last iteration: 8 x ds_load_128s
    /*
    ds_load_b128 v[163:166], v219                              // 00000027A630: DBFC0000 A30000DB
	ds_load_b128 v[167:170], v219 offset:16                    // 00000027A638: DBFC0010 A70000DB
	ds_load_b128 v[171:174], v219 offset:1152                  // 00000027A640: DBFC0480 AB0000DB
	ds_load_b128 v[175:178], v219 offset:1168                  // 00000027A648: DBFC0490 AF0000DB
	ds_load_b128 v[179:182], v219 offset:2304                  // 00000027A650: DBFC0900 B30000DB
	ds_load_b128 v[183:186], v219 offset:2320                  // 00000027A658: DBFC0910 B70000DB
	ds_load_b128 v[187:190], v219 offset:3456                  // 00000027A660: DBFC0D80 BB0000DB
	ds_load_b128 v[191:194], v219 offset:3472                  // 00000027A668: DBFC0D90 BF0000DB
    */
    a_frag_0 = (half16 *)(current_temp_a[load_shared_a_off +    0]);
    a_frag_1 = (half16 *)(current_temp_a[load_shared_a_off + 1152]);
    a_frag_2 = (half16 *)(current_temp_a[load_shared_a_off + 2304]);
    a_frag_2 = (half16 *)(current_temp_a[load_shared_a_off + 3456]);

    // wait for the shared memory loads to complete
    /*
    s_waitcnt lgkmcnt(0)                                       // 00000027A670: BF89FC07
    s_nop 1
    */

    // execute the last iteration of 16 WMMAs
    /*
    v_wmma_f32_16x16x16_f16 v[0:7], v[163:170], v[130:137], v[0:7]// 00000027A678: CC404000 1C0305A3
    v_wmma_f32_16x16x16_f16 v[8:15], v[163:170], v[138:145], v[8:15]// 00000027A680: CC404008 1C2315A3
    v_wmma_f32_16x16x16_f16 v[16:23], v[163:170], v[146:153], v[16:23]// 00000027A688: CC404010 1C4325A3
    v_wmma_f32_16x16x16_f16 v[24:31], v[163:170], v[154:161], v[24:31]// 00000027A690: CC404018 1C6335A3
    v_wmma_f32_16x16x16_f16 v[32:39], v[171:178], v[130:137], v[32:39]// 00000027A698: CC404020 1C8305AB
    v_wmma_f32_16x16x16_f16 v[40:47], v[171:178], v[138:145], v[40:47]// 00000027A6A0: CC404028 1CA315AB
    v_wmma_f32_16x16x16_f16 v[48:55], v[171:178], v[146:153], v[48:55]// 00000027A6A8: CC404030 1CC325AB
    v_wmma_f32_16x16x16_f16 v[56:63], v[171:178], v[154:161], v[56:63]// 00000027A6B0: CC404038 1CE335AB
    v_wmma_f32_16x16x16_f16 v[64:71], v[179:186], v[130:137], v[64:71]// 00000027A6B8: CC404040 1D0305B3
    v_wmma_f32_16x16x16_f16 v[72:79], v[179:186], v[138:145], v[72:79]// 00000027A6C0: CC404048 1D2315B3
    v_wmma_f32_16x16x16_f16 v[80:87], v[179:186], v[146:153], v[80:87]// 00000027A6C8: CC404050 1D4325B3
    v_wmma_f32_16x16x16_f16 v[88:95], v[179:186], v[154:161], v[88:95]// 00000027A6D0: CC404058 1D6335B3
    v_wmma_f32_16x16x16_f16 v[96:103], v[187:194], v[130:137], v[96:103]// 00000027A6D8: CC404060 1D8305BB
    v_wmma_f32_16x16x16_f16 v[104:111], v[187:194], v[138:145], v[104:111]// 00000027A6E0: CC404068 1DA315BB
    v_wmma_f32_16x16x16_f16 v[112:119], v[187:194], v[146:153], v[112:119]// 00000027A6E8: CC404070 1DC325BB
    v_wmma_f32_16x16x16_f16 v[120:127], v[187:194], v[154:161], v[120:127]// 00000027A6F0: CC404078 1DE335BB
    */
    wmma0_0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_0, b_frag_0, wmma0_0, false);
    wmma0_1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_0, b_frag_1, wmma0_1, false);
    wmma0_2 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_0, b_frag_2, wmma0_2, false);
    wmma0_3 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_0, b_frag_3, wmma0_3, false);
    wmma1_0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_1, b_frag_0, wmma1_0, false);
    wmma1_1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_1, b_frag_1, wmma1_1, false);
    wmma1_2 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_1, b_frag_2, wmma1_2, false);
    wmma1_3 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_1, b_frag_3, wmma1_3, false);
    wmma2_0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_2, b_frag_0, wmma2_0, false);
    wmma2_1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_2, b_frag_1, wmma2_1, false);
    wmma2_2 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_2, b_frag_2, wmma2_2, false);
    wmma2_3 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_2, b_frag_3, wmma2_3, false);
    wmma3_0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_3, b_frag_0, wmma3_0, false);
    wmma3_1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_3, b_frag_1, wmma3_1, false);
    wmma3_2 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_3, b_frag_2, wmma3_2, false);
    wmma3_3 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag_3, b_frag_3, wmma3_3, false);

    // store final accumulations to D using buffer_store_b64 (8 D elems * 16 WMMAs = 128 elems, 4 FP16 at a time for 32 b64 stores)
    /*
    ...
    buffer_store_b64 v[140:141], v136, s[16:19], 0 offen       // 00000027A92C: E06C0000 80448C88
    ...
    buffer_store_b64 v[144:145], v136, s[16:19], 0 offen       // 00000027A970: E06C0000 80449088
    ...
    buffer_store_b64 v[148:149], v136, s[16:19], 0 offen       // 00000027A9B4: E06C0000 80449488
    ...
    */
    (half4 *)(d + ... ) = half4(...);
    // 32 times
    (half4 *)(d + ... ) = half4(...);
}

Fixed up the mix-up between A and B loads from shared memory.

Fundamentally, this is our standard GEMM with the following applied_opts:

[
  Opt(op=OptOps.TC, axis=0, amt=2),
  Opt(op=OptOps.UPCAST, axis=0, amt=4),
  Opt(op=OptOps.UPCAST, axis=1, amt=4),
  Opt(op=OptOps.LOCAL, axis=0, amt=2),
  Opt(op=OptOps.LOCAL, axis=1, amt=2)
]

The key difference from the current situation is that shared memory is used to double buffer the inputs, and then the loading from global memory, loading from shared memory, storing to shared memory and the WMMAs are all interleaved to hide the latency.