corsix / amx

Apple AMX Instruction Set

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Using AMX for non-ML optimisation

sfjohnson opened this issue · comments

Hi again,

Thanks for your excellent research. I've been attempting to optimise OpenJPH, a JPEG2000 implementation, with AMX. Just starting off with one of the wavelet transform functions AMX is coming out significantly slower than the NEON instructions generated by the compiler.

First I wanted to ask about AMX_SET() / AMX_CLR(). I timed them in a tight loop and the average came to 7.24 nanoseconds for an AMX_SET() / AMX_CLR() pair, which sounds reasonable. I'm not sure what's happening though because when I go and actually put them around my test function its average time increases by about 2 milliseconds!

This is a rough sketch of how the test goes:

void compress_an_image() {
...
// Calls this 1000s of times
gen_irrev_horz_wvlt_fwd_tx(...)
...
}

// Call this 30 times and take the average
void trial () {
AMX_SET();
compress_an_image();
AMX_CLR();
}

My times are:

Original code, NEON generated by compiler:
Avg run time (ms): 49.935734
Avg run time (ms): 49.773602
Avg run time (ms): 49.785702
Avg run time (ms): 50.063000

Original code, with AMX_SET() / AMX_CLR() around trial(), AMX doing nothing useful:
Avg run time (ms): 52.630169
Avg run time (ms): 52.561501
Avg run time (ms): 52.559669
Avg run time (ms): 52.419498

Converting one stage of one of the wavelet transforms to AMX:
Avg run time (ms): 60.170498
Avg run time (ms): 60.380032
Avg run time (ms): 60.384933
Avg run time (ms): 60.366333

And it just gets worse the more code I convert to AMX.

Regardless of the actual AMX code I wrote, there's some weirdness around AMX_SET() / AMX_CLR(). If I put them around gen_irrev_horz_wvlt_fwd_tx() which gets called 1000s of times during image compression, it's much slower, around 70 ms.

I was wondering if you have any insights, I know there's lots going on that could be slowing it down like the kernel having to save and load the AMX state when context switching, or caching, or how the CPU cores share the AMX blocks (still trying to get my head around that).

I also wanted to show you some of the wavelet transform code if you'd like to have a look. There isn't much AMX code out there so I wrote in a way that I thought would give a reasonable speed increase, but so far no go.

The following only shows the first stage of the transform:

Original code: https://github.com/aous72/OpenJPH/blob/master/src/core/transform/ojph_transform.cpp#L357

void gen_irrev_horz_wvlt_fwd_tx(line_buf* line_src, line_buf *line_ldst, line_buf *line_hdst, ui32 width, bool even) {
    float *src = line_src->f32;
    float *ldst = line_ldst->f32, *hdst = line_hdst->f32;

    const ui32 L_width = (width + (even ? 1 : 0)) >> 1;
    const ui32 H_width = (width + (even ? 0 : 1)) >> 1;

    //extension
    src[-1] = src[1];
    src[width] = src[width-2];
    // predict
    float factor = LIFTING_FACTORS::steps[0];
    const float* sp = src + (even ? 1 : 0);
    float *dph = hdst;
    for (ui32 i = H_width; i > 0; --i, sp+=2)
      *dph++ = sp[0] + factor * (sp[-1] + sp[1]);
}

My AMX (de)optimisation:

#define _amx_ldx(srcBuf, destIndex, flags) AMX_LDX(((uint64_t)&*(srcBuf)) | ((uint64_t)(destIndex)<<56) | (flags))
#define _amx_ldy(srcBuf, destIndex, flags) AMX_LDY(((uint64_t)&*(srcBuf)) | ((uint64_t)(destIndex)<<56) | (flags))
#define _amx_stz(destBuf, srcIndex, flags) AMX_STZ(((uint64_t)&*(destBuf)) | ((uint64_t)(srcIndex)<<56) | (flags))

#define VECFP_MULTIPLE_2 0
#define VECFP_MULTIPLE_4 (1ull << 25)

// *M2 only* Multiple mode (bit 31=1), regular load (bit 53=0)
#define _amx_vecfp_multiple(xOffset, yOffset, zRow, xShuffle, yShuffle, bMode, laneWidthMode, aluMode, flags) \
  AMX_VECFP( \
    ((uint64_t)(xOffset) << 10) | \
    ((uint64_t)(yOffset)) | \
    ((uint64_t)(zRow) << 20) | \
    ((uint64_t)(xShuffle) << 29) | \
    ((uint64_t)(yShuffle) << 27) | \
    (1ull << 31) | \
    ((uint64_t)(bMode) << 32) | \
    ((uint64_t)(laneWidthMode) << 42) | \
    ((uint64_t)(aluMode) << 47) | \
    (flags) \
  )

void gen_irrev_horz_wvlt_fwd_tx(line_buf* line_src, line_buf *line_ldst, line_buf *line_hdst, ui32 width, bool even) {
  static float amxScratch[32] __attribute__((aligned(128))) = { 0.0f };

  // src, ldst, and hdst are aligned to 128 bytes
  float *ldst = line_ldst->f32, *hdst = line_hdst->f32, *src = line_src->f32;

  // even is always true
  // 240 < width < 1920

  const ui32 L_width = (width + (even ? 1 : 0)) >> 1;
  const ui32 H_width = (width + (even ? 0 : 1)) >> 1;

  //extension
  src[-1] = src[1];
  src[width] = src[width-2];
  // predict
  const float* sp = src + (even ? 1 : 0);
  float *dph = hdst;

  amxScratch[0] = LIFTING_FACTORS::steps[0];
  _amx_ldy(&amxScratch[0], 0, LDST_MODE_SINGLE);
  // Do ceil(H_width / 32) iterations
  for (ui32 i = 0; i < (H_width + 31) >> 5; i++) {
    // Process 64 floats from sp down to 32 floats to dph
    _amx_ldx(&sp[-1], 0, LDST_MODE_QUAD);
    // Extension
    _amx_ldx(&sp[63], 4, LDST_MODE_SINGLE);

    _amx_vecfp_multiple(0, 0, 0, 3, 0, 7, 4, 10, VECFP_MULTIPLE_4); // Z =  S3(X) * Y
    _amx_vecfp_multiple(8, 0, 0, 3, 0, 7, 4, 0, VECFP_MULTIPLE_4); // Z += (S3(X<<2)) * Y
    _amx_vecfp_multiple(4, 0, 0, 3, 0, 0, 4, 11, VECFP_MULTIPLE_4); // Z += (S3(X<<1))

    _amx_stz(&dph[0], 0, LDST_MODE_SINGLE);
    _amx_stz(&dph[8], 16, LDST_MODE_SINGLE);
    _amx_stz(&dph[16], 32, LDST_MODE_SINGLE);
    _amx_stz(&dph[24], 48, LDST_MODE_SINGLE);

    dph += 32;
    sp += 64;
  }
}

Disassembly:

                             **************************************************************
                             * ojph::local::gen_irrev_horz_wvlt_fwd_tx(ojph::line_buf*... *
                             **************************************************************
                             undefined __cdecl gen_irrev_horz_wvlt_fwd_tx(line_buf * 
             undefined         w0:1           <RETURN>
             line_buf *        x0:8           param_1
             line_buf *        x1:8           param_2
             line_buf *        x2:8           param_3
             uint              w3:4           param_4
             bool              w4:1           param_5
                             __ZN4ojph5local26gen_irrev_horz_wvlt_fwd_txEPN  XREF[2]:     Entry Point(*), 
                             ojph::local::gen_irrev_horz_wvlt_fwd_tx                      init_wavelet_transform_functions
        0001ed6c 0a 08 40 f9     ldr        x10,[param_1, #0x10]
        0001ed70 48 08 40 f9     ldr        x8,[param_3, #0x10]
        0001ed74 89 00 00 52     eor        w9,param_5,#0x1
        0001ed78 29 01 03 0b     add        w9,w9,param_4
        0001ed7c 40 05 40 bd     ldr        s0,[x10, #0x4]
        0001ed80 40 c1 1f bc     stur       s0,[x10, #-0x4]
        0001ed84 6b 08 00 51     sub        w11,param_4,#0x2
        0001ed88 40 59 6b bc     ldr        s0,[x10,w11, uxtw #2]
        0001ed8c 40 59 23 bc     str        s0,[x10,param_4, uxtw #2]
        0001ed90 8b 57 0c 10     adr        x11,0x37880
        0001ed94 1f 20 03 d5     nop
        0001ed98 6c ce 80 52     mov        w12,#0x673
        0001ed9c 6c f9 b7 72     movk       w12,#0xbfcb, LSL #16
        0001eda0 6c 01 00 b9     str        w12,[x11]=>ojph::local::gen_irrev_horz_wvlt_fw
        0001eda4 2b 10 20 00     __amx_ldy  x11
        0001eda8 3f 09 00 71     cmp        w9,#0x2
        0001edac c3 04 00 54     b.cc       LAB_0001ee44
        0001edb0 29 7d 01 53     lsr        w9,w9,#0x1
        0001edb4 29 7d 00 11     add        w9,w9,#0x1f
        0001edb8 e9 17 49 4b     neg        w9,w9, LSR #0x5
        0001edbc 4a 49 24 8b     add        x10,x10,param_5, UXTW  #0x2
        0001edc0 4a 11 00 d1     sub        x10,x10,#0x4
        0001edc4 0b 00 ea d2     mov        x11,#0x5000000000000000
        0001edc8 0c 40 bc d2     mov        x12,#0xe2000000
        0001edcc ec 00 c2 f2     movk       x12,#0x1007, LSL #32
        0001edd0 ac 00 e0 f2     movk       x12,#0x5, LSL #48
        0001edd4 0d 00 84 d2     mov        x13,#0x2000
        0001edd8 0d 40 bc f2     movk       x13,#0xe200, LSL #16
        0001eddc ed 00 c2 f2     movk       x13,#0x1007, LSL #32
        0001ede0 0e 00 82 d2     mov        x14,#0x1000
        0001ede4 0e 40 bc f2     movk       x14,#0xe200, LSL #16
        0001ede8 0e 00 d2 f2     movk       x14,#0x9000, LSL #32
        0001edec ae 00 e0 f2     movk       x14,#0x5, LSL #48
                             LAB_0001edf0                                    XREF[1]:     0001ee40(j)  
        0001edf0 4f 01 0b aa     orr        x15,x10,x11
        0001edf4 0f 10 20 00     __amx_ldx  x15
        0001edf8 4a 01 04 91     add        x10,x10,#0x100
        0001edfc 4f 01 46 b2     orr        x15,x10,#0x400000000000000
        0001ee00 0f 10 20 00     __amx_ldx  x15
        0001ee04 6c 12 20 00     __amx_ve   x12
        0001ee08 6d 12 20 00     __amx_ve   x13
        0001ee0c 6e 12 20 00     __amx_ve   x14
        0001ee10 a8 10 20 00     __amx_stz  x8
        0001ee14 0f 81 00 91     add        x15,x8,#0x20
        0001ee18 ef 01 44 b2     orr        x15,x15,#0x1000000000000000
        0001ee1c af 10 20 00     __amx_stz  x15
        0001ee20 0f 01 01 91     add        x15,x8,#0x40
        0001ee24 ef 01 43 b2     orr        x15,x15,#0x2000000000000000
        0001ee28 af 10 20 00     __amx_stz  x15
        0001ee2c 0f 81 01 91     add        x15,x8,#0x60
        0001ee30 ef 05 44 b2     orr        x15,x15,#0x3000000000000000
        0001ee34 af 10 20 00     __amx_stz  x15
        0001ee38 08 01 02 91     add        x8,x8,#0x80
        0001ee3c 29 05 00 31     adds       w9,w9,#0x1
        0001ee40 83 fd ff 54     b.cc       LAB_0001edf0
                             LAB_0001ee44                                    XREF[1]:     0001edac(j)  
        0001ee44 c0 03 5f d6     ret

I'm not too good with understanding assembly but it looks like the compiler did an OK job taking all those mov movk instructions that setup the operands for the AMX instructions outside of the loop to speed things up. For NEON it should be outputting 4 floats per loop while my AMX code stores 32 floats each loop (well technically 40 but it overlaps), so theoretically it should be a lot faster even if my implementation wasn't ideal.

If you've got any ideas or info on cycle counts / pipelining etc they would be greatly appreciated!

AMX vector performance on M1/M2 is disappointing, especially in comparison to M1/M2 performance-core NEON. An M1 performance-core using NEON has theoretical maximum f32 performance of 102.4 GFLOPS (dispatch 4 FMA instructions per cycle, each FMA is 4 wide, each FMA counts as two ops, 3.2 GHz). Perfect four-way multithreading would get you up to 409.6 GFLOPS, and perfect eight-way would get you to 819.2 GLOPS. M2 should be slightly higher.

For comparison, AMX f32 vector FMAs on an M1 Max look like:

Z Accumulators 1 Thread 2 Threads 3 Threads 4 Threads 5 Threads 6 Threads
1 (64 bytes) per thread 23.2 GFLOPS 46.4 GFLOPS 53.3 GFLOPS 81.1 GFLOPS 89.0 GFLOPS 104.1 GFLOPS
2 (128 bytes) per thread 46.4 GFLOPS 92.7 GFLOPS 106.5 GFLOPS 141.3 GFLOPS 176.8 GFLOPS 206.5 GFLOPS
3 (192 bytes) per thread 69.6 GFLOPS 139.1 GFLOPS 160.1 GFLOPS 213.3 GFLOPS 250.6 GFLOPS 244.9 GFLOPS
4 (256 bytes) per thread 92.7 GFLOPS 185.4 GFLOPS 214.0 GFLOPS 277.6 GFLOPS 325.5 GFLOPS 298.0 GFLOPS
5 (320 bytes) per thread 115.8 GFLOPS 231.7 GFLOPS 241.0 GFLOPS 321.3 GFLOPS 355.1 GFLOPS 347.7 GFLOPS
6 (384 bytes) per thread 139.0 GFLOPS 277.7 GFLOPS 271.2 GFLOPS 361.7 GFLOPS 387.1 GFLOPS 386.2 GFLOPS
7 (448 bytes) per thread 162.2 GFLOPS 324.2 GFLOPS 299.9 GFLOPS 383.4 GFLOPS 394.0 GFLOPS 400.9 GFLOPS
8 (512 bytes) per thread 185.5 GFLOPS 369.9 GFLOPS 335.8 GFLOPS 392.9 GFLOPS 405.8 GFLOPS 416.0 GFLOPS
9 (576 bytes) per thread 178.0 GFLOPS 353.4 GFLOPS 325.5 GFLOPS 396.9 GFLOPS 398.0 GFLOPS 409.2 GFLOPS
10 (640 bytes) per thread 183.1 GFLOPS 360.6 GFLOPS 335.3 GFLOPS 402.4 GFLOPS 401.2 GFLOPS 417.2 GFLOPS
11 (704 bytes) per thread 183.1 GFLOPS 363.0 GFLOPS 334.2 GFLOPS 403.2 GFLOPS 400.6 GFLOPS 415.8 GFLOPS
12 (768 bytes) per thread 185.2 GFLOPS 370.6 GFLOPS 335.5 GFLOPS 378.5 GFLOPS 397.7 GFLOPS 419.0 GFLOPS
13 (832 bytes) per thread 185.2 GFLOPS 369.4 GFLOPS 336.0 GFLOPS 404.2 GFLOPS 400.9 GFLOPS 414.1 GFLOPS
14 (896 bytes) per thread 185.5 GFLOPS 370.5 GFLOPS 336.4 GFLOPS 406.0 GFLOPS 402.9 GFLOPS 416.4 GFLOPS
15 (960 bytes) per thread 185.5 GFLOPS 370.0 GFLOPS 336.8 GFLOPS 405.7 GFLOPS 402.6 GFLOPS 409.6 GFLOPS
16 (1024 bytes) per thread 185.4 GFLOPS 370.4 GFLOPS 336.3 GFLOPS 406.0 GFLOPS 399.7 GFLOPS 405.3 GFLOPS

A single thread can exceed 102.4 GFLOPS, potentially hitting 185.5 GFLOPS, but you need to be using 512 bytes of Z registers to get there. Four threads only get to 406.0 GFLOPS, which is less than the theoretical 409.6 achievable with NEON. More than four threads don't help AMX, but will help NEON.

The same thing on M2 looks like:

Z Accumulators 1 Thread 2 Threads 3 Threads 4 Threads 5 Threads 6 Threads
1 (64 bytes) per thread 25.6 GFLOPS 41.2 GFLOPS 61.7 GFLOPS 78.7 GFLOPS 98.4 GFLOPS 117.7 GFLOPS
2 (128 bytes) per thread 51.2 GFLOPS 82.3 GFLOPS 123.5 GFLOPS 157.7 GFLOPS 174.1 GFLOPS 170.4 GFLOPS
3 (192 bytes) per thread 76.7 GFLOPS 123.4 GFLOPS 179.5 GFLOPS 191.0 GFLOPS 216.9 GFLOPS 215.1 GFLOPS
4 (256 bytes) per thread 102.2 GFLOPS 164.6 GFLOPS 237.1 GFLOPS 231.8 GFLOPS 258.3 GFLOPS 263.3 GFLOPS
5 (320 bytes) per thread 127.8 GFLOPS 205.7 GFLOPS 279.1 GFLOPS 264.8 GFLOPS 285.7 GFLOPS 289.5 GFLOPS
6 (384 bytes) per thread 153.5 GFLOPS 226.0 GFLOPS 299.5 GFLOPS 286.6 GFLOPS 300.5 GFLOPS 308.3 GFLOPS
7 (448 bytes) per thread 179.0 GFLOPS 246.6 GFLOPS 300.6 GFLOPS 291.4 GFLOPS 302.4 GFLOPS 306.2 GFLOPS
8 (512 bytes) per thread 204.4 GFLOPS 269.7 GFLOPS 301.6 GFLOPS 299.4 GFLOPS 309.2 GFLOPS 310.4 GFLOPS
9 (576 bytes) per thread 204.6 GFLOPS 270.5 GFLOPS 302.9 GFLOPS 297.9 GFLOPS 304.7 GFLOPS 307.3 GFLOPS
10 (640 bytes) per thread 204.7 GFLOPS 270.3 GFLOPS 303.0 GFLOPS 300.2 GFLOPS 306.9 GFLOPS 308.9 GFLOPS
11 (704 bytes) per thread 204.6 GFLOPS 276.5 GFLOPS 308.4 GFLOPS 302.1 GFLOPS 305.8 GFLOPS 307.5 GFLOPS
12 (768 bytes) per thread 204.5 GFLOPS 270.5 GFLOPS 302.9 GFLOPS 299.9 GFLOPS 304.2 GFLOPS 307.5 GFLOPS
13 (832 bytes) per thread 204.6 GFLOPS 275.3 GFLOPS 307.9 GFLOPS 299.8 GFLOPS 306.4 GFLOPS 307.4 GFLOPS
14 (896 bytes) per thread 204.2 GFLOPS 270.5 GFLOPS 302.9 GFLOPS 299.6 GFLOPS 306.9 GFLOPS 310.6 GFLOPS
15 (960 bytes) per thread 204.5 GFLOPS 275.7 GFLOPS 308.5 GFLOPS 299.5 GFLOPS 305.5 GFLOPS 307.4 GFLOPS
16 (1024 bytes) per thread 204.6 GFLOPS 270.5 GFLOPS 302.8 GFLOPS 299.8 GFLOPS 306.9 GFLOPS 307.4 GFLOPS

The single-threaded numbers get 10% higher than M1 Max, which is consistent with clock speeds being 10% higher. M1 Max gets approximately double performance from 2 threads, and then only marginal improvement from subsequent threads. M2 only gets small performance improvements from additional threads, which is consistent with only having one P cpu cluster (versus two on M1 Max).

The four-at-a-time mode on M2 doesn't look much better:

Z Accumulators 1 Thread 2 Threads 3 Threads 4 Threads 5 Threads 6 Threads
1 (256 bytes) per thread 102.3 GFLOPS 164.9 GFLOPS 247.0 GFLOPS 187.5 GFLOPS 250.9 GFLOPS 305.7 GFLOPS
2 (512 bytes) per thread 204.5 GFLOPS 326.9 GFLOPS 351.4 GFLOPS 208.2 GFLOPS 326.6 GFLOPS 323.5 GFLOPS
3 (768 bytes) per thread 204.6 GFLOPS 326.8 GFLOPS 351.4 GFLOPS 211.6 GFLOPS 320.5 GFLOPS 324.9 GFLOPS
4 (1024 bytes) per thread 204.6 GFLOPS 329.3 GFLOPS 351.3 GFLOPS 205.7 GFLOPS 326.6 GFLOPS 325.6 GFLOPS
5 (1280 bytes) per thread 204.6 GFLOPS 329.1 GFLOPS 351.2 GFLOPS 205.4 GFLOPS 326.5 GFLOPS 322.4 GFLOPS
6 (1536 bytes) per thread 204.6 GFLOPS 328.9 GFLOPS 351.4 GFLOPS 208.7 GFLOPS 318.2 GFLOPS 322.7 GFLOPS
7 (1792 bytes) per thread 204.6 GFLOPS 329.2 GFLOPS 351.4 GFLOPS 205.9 GFLOPS 326.4 GFLOPS 324.0 GFLOPS
8 (2048 bytes) per thread 204.5 GFLOPS 329.1 GFLOPS 351.4 GFLOPS 208.1 GFLOPS 326.5 GFLOPS 321.3 GFLOPS
9 (2304 bytes) per thread 204.6 GFLOPS 329.2 GFLOPS 351.4 GFLOPS 207.3 GFLOPS 323.6 GFLOPS 326.9 GFLOPS
10 (2560 bytes) per thread 204.6 GFLOPS 329.2 GFLOPS 351.4 GFLOPS 206.2 GFLOPS 320.8 GFLOPS 326.7 GFLOPS
11 (2816 bytes) per thread 204.5 GFLOPS 326.9 GFLOPS 346.5 GFLOPS 208.2 GFLOPS 326.3 GFLOPS 321.5 GFLOPS
12 (3072 bytes) per thread 204.6 GFLOPS 329.1 GFLOPS 351.1 GFLOPS 205.9 GFLOPS 326.5 GFLOPS 326.4 GFLOPS
13 (3328 bytes) per thread 204.5 GFLOPS 329.3 GFLOPS 351.4 GFLOPS 206.6 GFLOPS 323.5 GFLOPS 323.4 GFLOPS
14 (3584 bytes) per thread 204.6 GFLOPS 329.2 GFLOPS 351.4 GFLOPS 205.5 GFLOPS 326.4 GFLOPS 323.2 GFLOPS
15 (3840 bytes) per thread 204.6 GFLOPS 329.2 GFLOPS 351.0 GFLOPS 205.8 GFLOPS 326.5 GFLOPS 322.5 GFLOPS
16 (4096 bytes) per thread 204.6 GFLOPS 327.1 GFLOPS 351.2 GFLOPS 208.1 GFLOPS 324.0 GFLOPS 321.7 GFLOPS

Furthermore, in order to hit these numbers, two things need to be noted:

  1. For vector operations, the Z accumulators you choose want to be maximally spread out over the permissible range, rather than contiguous (which is why the two-at-a-time / four-at-a-time on M2 use a Z step of 32 / 16).
  2. There's limited-to-no register renaming going on.

Because of point (2), you're in the "1 (256 bytes) per thread" row of the above table. Changing from "ceil(H_width / 32) iterations" to "ceil(H_width / 64) iterations" would help here, as each iteration could then address two different Z ranges, getting you to "2 (512 bytes) per thread".

Ok interesting, thanks for the data! I think I've got a better understanding of what's a good fit for AMX acceleration and how to get the utilisation as high as possible.