TMA selection of get_best_configs in gemm.py
joyeamd opened this issue · comments
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]:
best_tma_multicast_config = (2, i == 'A')
break
when best_block_m <= best_block_n, I think need to use n instead of m to choose. Need to change to below:
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
dim_size = m if best_block_m > best_block_n else n
if dim_size >= 512 and is_multicast_legal[i]:
best_tma_multicast_config = (2, i == 'A')
break
Sorry for my later reply. m >= 512
is just a heuristic (tested but not theoretical). For now, I am not sure whether your change can bring a speedup. If you test and verify its better performance, a PR is welcome!