Using AMX for non-ML optimisation
sfjohnson opened this issue · comments
Hi again,
Thanks for your excellent research. I've been attempting to optimise OpenJPH, a JPEG2000 implementation, with AMX. Just starting off with one of the wavelet transform functions AMX is coming out significantly slower than the NEON instructions generated by the compiler.
First I wanted to ask about AMX_SET()
/ AMX_CLR()
. I timed them in a tight loop and the average came to 7.24 nanoseconds for an AMX_SET()
/ AMX_CLR()
pair, which sounds reasonable. I'm not sure what's happening though because when I go and actually put them around my test function its average time increases by about 2 milliseconds!
This is a rough sketch of how the test goes:
void compress_an_image() {
...
// Calls this 1000s of times
gen_irrev_horz_wvlt_fwd_tx(...)
...
}
// Call this 30 times and take the average
void trial () {
AMX_SET();
compress_an_image();
AMX_CLR();
}
My times are:
Original code, NEON generated by compiler:
Avg run time (ms): 49.935734
Avg run time (ms): 49.773602
Avg run time (ms): 49.785702
Avg run time (ms): 50.063000
Original code, with AMX_SET() / AMX_CLR() around trial(), AMX doing nothing useful:
Avg run time (ms): 52.630169
Avg run time (ms): 52.561501
Avg run time (ms): 52.559669
Avg run time (ms): 52.419498
Converting one stage of one of the wavelet transforms to AMX:
Avg run time (ms): 60.170498
Avg run time (ms): 60.380032
Avg run time (ms): 60.384933
Avg run time (ms): 60.366333
And it just gets worse the more code I convert to AMX.
Regardless of the actual AMX code I wrote, there's some weirdness around AMX_SET()
/ AMX_CLR()
. If I put them around gen_irrev_horz_wvlt_fwd_tx()
which gets called 1000s of times during image compression, it's much slower, around 70 ms.
I was wondering if you have any insights, I know there's lots going on that could be slowing it down like the kernel having to save and load the AMX state when context switching, or caching, or how the CPU cores share the AMX blocks (still trying to get my head around that).
I also wanted to show you some of the wavelet transform code if you'd like to have a look. There isn't much AMX code out there so I wrote in a way that I thought would give a reasonable speed increase, but so far no go.
The following only shows the first stage of the transform:
Original code: https://github.com/aous72/OpenJPH/blob/master/src/core/transform/ojph_transform.cpp#L357
void gen_irrev_horz_wvlt_fwd_tx(line_buf* line_src, line_buf *line_ldst, line_buf *line_hdst, ui32 width, bool even) {
float *src = line_src->f32;
float *ldst = line_ldst->f32, *hdst = line_hdst->f32;
const ui32 L_width = (width + (even ? 1 : 0)) >> 1;
const ui32 H_width = (width + (even ? 0 : 1)) >> 1;
//extension
src[-1] = src[1];
src[width] = src[width-2];
// predict
float factor = LIFTING_FACTORS::steps[0];
const float* sp = src + (even ? 1 : 0);
float *dph = hdst;
for (ui32 i = H_width; i > 0; --i, sp+=2)
*dph++ = sp[0] + factor * (sp[-1] + sp[1]);
}
My AMX (de)optimisation:
#define _amx_ldx(srcBuf, destIndex, flags) AMX_LDX(((uint64_t)&*(srcBuf)) | ((uint64_t)(destIndex)<<56) | (flags))
#define _amx_ldy(srcBuf, destIndex, flags) AMX_LDY(((uint64_t)&*(srcBuf)) | ((uint64_t)(destIndex)<<56) | (flags))
#define _amx_stz(destBuf, srcIndex, flags) AMX_STZ(((uint64_t)&*(destBuf)) | ((uint64_t)(srcIndex)<<56) | (flags))
#define VECFP_MULTIPLE_2 0
#define VECFP_MULTIPLE_4 (1ull << 25)
// *M2 only* Multiple mode (bit 31=1), regular load (bit 53=0)
#define _amx_vecfp_multiple(xOffset, yOffset, zRow, xShuffle, yShuffle, bMode, laneWidthMode, aluMode, flags) \
AMX_VECFP( \
((uint64_t)(xOffset) << 10) | \
((uint64_t)(yOffset)) | \
((uint64_t)(zRow) << 20) | \
((uint64_t)(xShuffle) << 29) | \
((uint64_t)(yShuffle) << 27) | \
(1ull << 31) | \
((uint64_t)(bMode) << 32) | \
((uint64_t)(laneWidthMode) << 42) | \
((uint64_t)(aluMode) << 47) | \
(flags) \
)
void gen_irrev_horz_wvlt_fwd_tx(line_buf* line_src, line_buf *line_ldst, line_buf *line_hdst, ui32 width, bool even) {
static float amxScratch[32] __attribute__((aligned(128))) = { 0.0f };
// src, ldst, and hdst are aligned to 128 bytes
float *ldst = line_ldst->f32, *hdst = line_hdst->f32, *src = line_src->f32;
// even is always true
// 240 < width < 1920
const ui32 L_width = (width + (even ? 1 : 0)) >> 1;
const ui32 H_width = (width + (even ? 0 : 1)) >> 1;
//extension
src[-1] = src[1];
src[width] = src[width-2];
// predict
const float* sp = src + (even ? 1 : 0);
float *dph = hdst;
amxScratch[0] = LIFTING_FACTORS::steps[0];
_amx_ldy(&amxScratch[0], 0, LDST_MODE_SINGLE);
// Do ceil(H_width / 32) iterations
for (ui32 i = 0; i < (H_width + 31) >> 5; i++) {
// Process 64 floats from sp down to 32 floats to dph
_amx_ldx(&sp[-1], 0, LDST_MODE_QUAD);
// Extension
_amx_ldx(&sp[63], 4, LDST_MODE_SINGLE);
_amx_vecfp_multiple(0, 0, 0, 3, 0, 7, 4, 10, VECFP_MULTIPLE_4); // Z = S3(X) * Y
_amx_vecfp_multiple(8, 0, 0, 3, 0, 7, 4, 0, VECFP_MULTIPLE_4); // Z += (S3(X<<2)) * Y
_amx_vecfp_multiple(4, 0, 0, 3, 0, 0, 4, 11, VECFP_MULTIPLE_4); // Z += (S3(X<<1))
_amx_stz(&dph[0], 0, LDST_MODE_SINGLE);
_amx_stz(&dph[8], 16, LDST_MODE_SINGLE);
_amx_stz(&dph[16], 32, LDST_MODE_SINGLE);
_amx_stz(&dph[24], 48, LDST_MODE_SINGLE);
dph += 32;
sp += 64;
}
}
Disassembly:
**************************************************************
* ojph::local::gen_irrev_horz_wvlt_fwd_tx(ojph::line_buf*... *
**************************************************************
undefined __cdecl gen_irrev_horz_wvlt_fwd_tx(line_buf *
undefined w0:1 <RETURN>
line_buf * x0:8 param_1
line_buf * x1:8 param_2
line_buf * x2:8 param_3
uint w3:4 param_4
bool w4:1 param_5
__ZN4ojph5local26gen_irrev_horz_wvlt_fwd_txEPN XREF[2]: Entry Point(*),
ojph::local::gen_irrev_horz_wvlt_fwd_tx init_wavelet_transform_functions
0001ed6c 0a 08 40 f9 ldr x10,[param_1, #0x10]
0001ed70 48 08 40 f9 ldr x8,[param_3, #0x10]
0001ed74 89 00 00 52 eor w9,param_5,#0x1
0001ed78 29 01 03 0b add w9,w9,param_4
0001ed7c 40 05 40 bd ldr s0,[x10, #0x4]
0001ed80 40 c1 1f bc stur s0,[x10, #-0x4]
0001ed84 6b 08 00 51 sub w11,param_4,#0x2
0001ed88 40 59 6b bc ldr s0,[x10,w11, uxtw #2]
0001ed8c 40 59 23 bc str s0,[x10,param_4, uxtw #2]
0001ed90 8b 57 0c 10 adr x11,0x37880
0001ed94 1f 20 03 d5 nop
0001ed98 6c ce 80 52 mov w12,#0x673
0001ed9c 6c f9 b7 72 movk w12,#0xbfcb, LSL #16
0001eda0 6c 01 00 b9 str w12,[x11]=>ojph::local::gen_irrev_horz_wvlt_fw
0001eda4 2b 10 20 00 __amx_ldy x11
0001eda8 3f 09 00 71 cmp w9,#0x2
0001edac c3 04 00 54 b.cc LAB_0001ee44
0001edb0 29 7d 01 53 lsr w9,w9,#0x1
0001edb4 29 7d 00 11 add w9,w9,#0x1f
0001edb8 e9 17 49 4b neg w9,w9, LSR #0x5
0001edbc 4a 49 24 8b add x10,x10,param_5, UXTW #0x2
0001edc0 4a 11 00 d1 sub x10,x10,#0x4
0001edc4 0b 00 ea d2 mov x11,#0x5000000000000000
0001edc8 0c 40 bc d2 mov x12,#0xe2000000
0001edcc ec 00 c2 f2 movk x12,#0x1007, LSL #32
0001edd0 ac 00 e0 f2 movk x12,#0x5, LSL #48
0001edd4 0d 00 84 d2 mov x13,#0x2000
0001edd8 0d 40 bc f2 movk x13,#0xe200, LSL #16
0001eddc ed 00 c2 f2 movk x13,#0x1007, LSL #32
0001ede0 0e 00 82 d2 mov x14,#0x1000
0001ede4 0e 40 bc f2 movk x14,#0xe200, LSL #16
0001ede8 0e 00 d2 f2 movk x14,#0x9000, LSL #32
0001edec ae 00 e0 f2 movk x14,#0x5, LSL #48
LAB_0001edf0 XREF[1]: 0001ee40(j)
0001edf0 4f 01 0b aa orr x15,x10,x11
0001edf4 0f 10 20 00 __amx_ldx x15
0001edf8 4a 01 04 91 add x10,x10,#0x100
0001edfc 4f 01 46 b2 orr x15,x10,#0x400000000000000
0001ee00 0f 10 20 00 __amx_ldx x15
0001ee04 6c 12 20 00 __amx_ve x12
0001ee08 6d 12 20 00 __amx_ve x13
0001ee0c 6e 12 20 00 __amx_ve x14
0001ee10 a8 10 20 00 __amx_stz x8
0001ee14 0f 81 00 91 add x15,x8,#0x20
0001ee18 ef 01 44 b2 orr x15,x15,#0x1000000000000000
0001ee1c af 10 20 00 __amx_stz x15
0001ee20 0f 01 01 91 add x15,x8,#0x40
0001ee24 ef 01 43 b2 orr x15,x15,#0x2000000000000000
0001ee28 af 10 20 00 __amx_stz x15
0001ee2c 0f 81 01 91 add x15,x8,#0x60
0001ee30 ef 05 44 b2 orr x15,x15,#0x3000000000000000
0001ee34 af 10 20 00 __amx_stz x15
0001ee38 08 01 02 91 add x8,x8,#0x80
0001ee3c 29 05 00 31 adds w9,w9,#0x1
0001ee40 83 fd ff 54 b.cc LAB_0001edf0
LAB_0001ee44 XREF[1]: 0001edac(j)
0001ee44 c0 03 5f d6 ret
I'm not too good with understanding assembly but it looks like the compiler did an OK job taking all those mov
movk
instructions that setup the operands for the AMX instructions outside of the loop to speed things up. For NEON it should be outputting 4 floats per loop while my AMX code stores 32 floats each loop (well technically 40 but it overlaps), so theoretically it should be a lot faster even if my implementation wasn't ideal.
If you've got any ideas or info on cycle counts / pipelining etc they would be greatly appreciated!
AMX vector performance on M1/M2 is disappointing, especially in comparison to M1/M2 performance-core NEON. An M1 performance-core using NEON has theoretical maximum f32 performance of 102.4 GFLOPS (dispatch 4 FMA instructions per cycle, each FMA is 4 wide, each FMA counts as two ops, 3.2 GHz). Perfect four-way multithreading would get you up to 409.6 GFLOPS, and perfect eight-way would get you to 819.2 GLOPS. M2 should be slightly higher.
For comparison, AMX f32 vector FMAs on an M1 Max look like:
Z Accumulators | 1 Thread | 2 Threads | 3 Threads | 4 Threads | 5 Threads | 6 Threads |
---|---|---|---|---|---|---|
1 (64 bytes) per thread | 23.2 GFLOPS | 46.4 GFLOPS | 53.3 GFLOPS | 81.1 GFLOPS | 89.0 GFLOPS | 104.1 GFLOPS |
2 (128 bytes) per thread | 46.4 GFLOPS | 92.7 GFLOPS | 106.5 GFLOPS | 141.3 GFLOPS | 176.8 GFLOPS | 206.5 GFLOPS |
3 (192 bytes) per thread | 69.6 GFLOPS | 139.1 GFLOPS | 160.1 GFLOPS | 213.3 GFLOPS | 250.6 GFLOPS | 244.9 GFLOPS |
4 (256 bytes) per thread | 92.7 GFLOPS | 185.4 GFLOPS | 214.0 GFLOPS | 277.6 GFLOPS | 325.5 GFLOPS | 298.0 GFLOPS |
5 (320 bytes) per thread | 115.8 GFLOPS | 231.7 GFLOPS | 241.0 GFLOPS | 321.3 GFLOPS | 355.1 GFLOPS | 347.7 GFLOPS |
6 (384 bytes) per thread | 139.0 GFLOPS | 277.7 GFLOPS | 271.2 GFLOPS | 361.7 GFLOPS | 387.1 GFLOPS | 386.2 GFLOPS |
7 (448 bytes) per thread | 162.2 GFLOPS | 324.2 GFLOPS | 299.9 GFLOPS | 383.4 GFLOPS | 394.0 GFLOPS | 400.9 GFLOPS |
8 (512 bytes) per thread | 185.5 GFLOPS | 369.9 GFLOPS | 335.8 GFLOPS | 392.9 GFLOPS | 405.8 GFLOPS | 416.0 GFLOPS |
9 (576 bytes) per thread | 178.0 GFLOPS | 353.4 GFLOPS | 325.5 GFLOPS | 396.9 GFLOPS | 398.0 GFLOPS | 409.2 GFLOPS |
10 (640 bytes) per thread | 183.1 GFLOPS | 360.6 GFLOPS | 335.3 GFLOPS | 402.4 GFLOPS | 401.2 GFLOPS | 417.2 GFLOPS |
11 (704 bytes) per thread | 183.1 GFLOPS | 363.0 GFLOPS | 334.2 GFLOPS | 403.2 GFLOPS | 400.6 GFLOPS | 415.8 GFLOPS |
12 (768 bytes) per thread | 185.2 GFLOPS | 370.6 GFLOPS | 335.5 GFLOPS | 378.5 GFLOPS | 397.7 GFLOPS | 419.0 GFLOPS |
13 (832 bytes) per thread | 185.2 GFLOPS | 369.4 GFLOPS | 336.0 GFLOPS | 404.2 GFLOPS | 400.9 GFLOPS | 414.1 GFLOPS |
14 (896 bytes) per thread | 185.5 GFLOPS | 370.5 GFLOPS | 336.4 GFLOPS | 406.0 GFLOPS | 402.9 GFLOPS | 416.4 GFLOPS |
15 (960 bytes) per thread | 185.5 GFLOPS | 370.0 GFLOPS | 336.8 GFLOPS | 405.7 GFLOPS | 402.6 GFLOPS | 409.6 GFLOPS |
16 (1024 bytes) per thread | 185.4 GFLOPS | 370.4 GFLOPS | 336.3 GFLOPS | 406.0 GFLOPS | 399.7 GFLOPS | 405.3 GFLOPS |
A single thread can exceed 102.4 GFLOPS, potentially hitting 185.5 GFLOPS, but you need to be using 512 bytes of Z registers to get there. Four threads only get to 406.0 GFLOPS, which is less than the theoretical 409.6 achievable with NEON. More than four threads don't help AMX, but will help NEON.
The same thing on M2 looks like:
Z Accumulators | 1 Thread | 2 Threads | 3 Threads | 4 Threads | 5 Threads | 6 Threads |
---|---|---|---|---|---|---|
1 (64 bytes) per thread | 25.6 GFLOPS | 41.2 GFLOPS | 61.7 GFLOPS | 78.7 GFLOPS | 98.4 GFLOPS | 117.7 GFLOPS |
2 (128 bytes) per thread | 51.2 GFLOPS | 82.3 GFLOPS | 123.5 GFLOPS | 157.7 GFLOPS | 174.1 GFLOPS | 170.4 GFLOPS |
3 (192 bytes) per thread | 76.7 GFLOPS | 123.4 GFLOPS | 179.5 GFLOPS | 191.0 GFLOPS | 216.9 GFLOPS | 215.1 GFLOPS |
4 (256 bytes) per thread | 102.2 GFLOPS | 164.6 GFLOPS | 237.1 GFLOPS | 231.8 GFLOPS | 258.3 GFLOPS | 263.3 GFLOPS |
5 (320 bytes) per thread | 127.8 GFLOPS | 205.7 GFLOPS | 279.1 GFLOPS | 264.8 GFLOPS | 285.7 GFLOPS | 289.5 GFLOPS |
6 (384 bytes) per thread | 153.5 GFLOPS | 226.0 GFLOPS | 299.5 GFLOPS | 286.6 GFLOPS | 300.5 GFLOPS | 308.3 GFLOPS |
7 (448 bytes) per thread | 179.0 GFLOPS | 246.6 GFLOPS | 300.6 GFLOPS | 291.4 GFLOPS | 302.4 GFLOPS | 306.2 GFLOPS |
8 (512 bytes) per thread | 204.4 GFLOPS | 269.7 GFLOPS | 301.6 GFLOPS | 299.4 GFLOPS | 309.2 GFLOPS | 310.4 GFLOPS |
9 (576 bytes) per thread | 204.6 GFLOPS | 270.5 GFLOPS | 302.9 GFLOPS | 297.9 GFLOPS | 304.7 GFLOPS | 307.3 GFLOPS |
10 (640 bytes) per thread | 204.7 GFLOPS | 270.3 GFLOPS | 303.0 GFLOPS | 300.2 GFLOPS | 306.9 GFLOPS | 308.9 GFLOPS |
11 (704 bytes) per thread | 204.6 GFLOPS | 276.5 GFLOPS | 308.4 GFLOPS | 302.1 GFLOPS | 305.8 GFLOPS | 307.5 GFLOPS |
12 (768 bytes) per thread | 204.5 GFLOPS | 270.5 GFLOPS | 302.9 GFLOPS | 299.9 GFLOPS | 304.2 GFLOPS | 307.5 GFLOPS |
13 (832 bytes) per thread | 204.6 GFLOPS | 275.3 GFLOPS | 307.9 GFLOPS | 299.8 GFLOPS | 306.4 GFLOPS | 307.4 GFLOPS |
14 (896 bytes) per thread | 204.2 GFLOPS | 270.5 GFLOPS | 302.9 GFLOPS | 299.6 GFLOPS | 306.9 GFLOPS | 310.6 GFLOPS |
15 (960 bytes) per thread | 204.5 GFLOPS | 275.7 GFLOPS | 308.5 GFLOPS | 299.5 GFLOPS | 305.5 GFLOPS | 307.4 GFLOPS |
16 (1024 bytes) per thread | 204.6 GFLOPS | 270.5 GFLOPS | 302.8 GFLOPS | 299.8 GFLOPS | 306.9 GFLOPS | 307.4 GFLOPS |
The single-threaded numbers get 10% higher than M1 Max, which is consistent with clock speeds being 10% higher. M1 Max gets approximately double performance from 2 threads, and then only marginal improvement from subsequent threads. M2 only gets small performance improvements from additional threads, which is consistent with only having one P cpu cluster (versus two on M1 Max).
The four-at-a-time mode on M2 doesn't look much better:
Z Accumulators | 1 Thread | 2 Threads | 3 Threads | 4 Threads | 5 Threads | 6 Threads |
---|---|---|---|---|---|---|
1 (256 bytes) per thread | 102.3 GFLOPS | 164.9 GFLOPS | 247.0 GFLOPS | 187.5 GFLOPS | 250.9 GFLOPS | 305.7 GFLOPS |
2 (512 bytes) per thread | 204.5 GFLOPS | 326.9 GFLOPS | 351.4 GFLOPS | 208.2 GFLOPS | 326.6 GFLOPS | 323.5 GFLOPS |
3 (768 bytes) per thread | 204.6 GFLOPS | 326.8 GFLOPS | 351.4 GFLOPS | 211.6 GFLOPS | 320.5 GFLOPS | 324.9 GFLOPS |
4 (1024 bytes) per thread | 204.6 GFLOPS | 329.3 GFLOPS | 351.3 GFLOPS | 205.7 GFLOPS | 326.6 GFLOPS | 325.6 GFLOPS |
5 (1280 bytes) per thread | 204.6 GFLOPS | 329.1 GFLOPS | 351.2 GFLOPS | 205.4 GFLOPS | 326.5 GFLOPS | 322.4 GFLOPS |
6 (1536 bytes) per thread | 204.6 GFLOPS | 328.9 GFLOPS | 351.4 GFLOPS | 208.7 GFLOPS | 318.2 GFLOPS | 322.7 GFLOPS |
7 (1792 bytes) per thread | 204.6 GFLOPS | 329.2 GFLOPS | 351.4 GFLOPS | 205.9 GFLOPS | 326.4 GFLOPS | 324.0 GFLOPS |
8 (2048 bytes) per thread | 204.5 GFLOPS | 329.1 GFLOPS | 351.4 GFLOPS | 208.1 GFLOPS | 326.5 GFLOPS | 321.3 GFLOPS |
9 (2304 bytes) per thread | 204.6 GFLOPS | 329.2 GFLOPS | 351.4 GFLOPS | 207.3 GFLOPS | 323.6 GFLOPS | 326.9 GFLOPS |
10 (2560 bytes) per thread | 204.6 GFLOPS | 329.2 GFLOPS | 351.4 GFLOPS | 206.2 GFLOPS | 320.8 GFLOPS | 326.7 GFLOPS |
11 (2816 bytes) per thread | 204.5 GFLOPS | 326.9 GFLOPS | 346.5 GFLOPS | 208.2 GFLOPS | 326.3 GFLOPS | 321.5 GFLOPS |
12 (3072 bytes) per thread | 204.6 GFLOPS | 329.1 GFLOPS | 351.1 GFLOPS | 205.9 GFLOPS | 326.5 GFLOPS | 326.4 GFLOPS |
13 (3328 bytes) per thread | 204.5 GFLOPS | 329.3 GFLOPS | 351.4 GFLOPS | 206.6 GFLOPS | 323.5 GFLOPS | 323.4 GFLOPS |
14 (3584 bytes) per thread | 204.6 GFLOPS | 329.2 GFLOPS | 351.4 GFLOPS | 205.5 GFLOPS | 326.4 GFLOPS | 323.2 GFLOPS |
15 (3840 bytes) per thread | 204.6 GFLOPS | 329.2 GFLOPS | 351.0 GFLOPS | 205.8 GFLOPS | 326.5 GFLOPS | 322.5 GFLOPS |
16 (4096 bytes) per thread | 204.6 GFLOPS | 327.1 GFLOPS | 351.2 GFLOPS | 208.1 GFLOPS | 324.0 GFLOPS | 321.7 GFLOPS |
Furthermore, in order to hit these numbers, two things need to be noted:
- For vector operations, the Z accumulators you choose want to be maximally spread out over the permissible range, rather than contiguous (which is why the two-at-a-time / four-at-a-time on M2 use a Z step of 32 / 16).
- There's limited-to-no register renaming going on.
Because of point (2), you're in the "1 (256 bytes) per thread" row of the above table. Changing from "ceil(H_width / 32) iterations" to "ceil(H_width / 64) iterations" would help here, as each iteration could then address two different Z ranges, getting you to "2 (512 bytes) per thread".
Ok interesting, thanks for the data! I think I've got a better understanding of what's a good fit for AMX acceleration and how to get the utilisation as high as possible.