triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

int8 x bfloat16 matmul tests fail on 4090s due to numerical error

rationalism opened this issue · comments

Versions: Ubuntu 22.04, CUDA 12.4, PyTorch nightly 2024-06-08, 4090 GPU (CUDA Compute Capability 8.9/Ada)

All of the int8 x bfloat16 matmul tests fail on main due to numerical error. Eg., here is one example:

(lm_fun) alyssa@alyssa-desktop:~/lm_fun/triton/python$ pytest -vs test/unit/operators/test_matmul.py::test_op[32-32-32-1-1-2-None-None-None-False-False-int8-bfloat16-None-False-None-None]
============================================================================================= test session starts =============================================================================================
platform linux -- Python 3.10.14, pytest-7.4.3, pluggy-1.3.0 -- /home/alyssa/anaconda3/envs/lm_fun/bin/python
cachedir: .pytest_cache
rootdir: /home/alyssa/lm_fun/triton/python
plugins: xdist-3.6.1, anyio-3.7.1
collected 1 item                                                                                                                                                                                              

test/unit/operators/test_matmul.py::test_op[32-32-32-1-1-2-None-None-None-False-False-int8-bfloat16-None-False-None-None] FAILED

================================================================================================== FAILURES ===================================================================================================
____________________________________________________________ test_op[32-32-32-1-1-2-None-None-None-False-False-int8-bfloat16-None-False-None-None] ____________________________________________________________

BLOCK_M = 32, BLOCK_N = 32, BLOCK_K = 32, SPLIT_K = 1, NWARP = 1, NSTAGE = 2, M = 32, N = 32, K = 32, AT = False, BT = False, ADTYPE = 'int8', BDTYPE = 'bfloat16', INPUT_PRECISION = None
F8_FASTACCUM = False, ACC_DTYPE = None, OUTPUT_DTYPE = None

    @pytest.mark.parametrize(
        "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE",
        itertools.chain(
            *[[
                # 1 warp
                (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                # 2 warp
                (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                # 4 warp
                (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                # 8 warp
                (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None),
                # variable input
                (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, None, True, None, None),
            ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]],
            # n-stage
            *[[
                (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, None, True, None, None),
                (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, None, True, None, None),
            ]
              for DTYPE in ["float16", "bfloat16", "float32"]
              for AT in [False, True]
              for BT in [False, True]
              for STAGES in [4]],
            # tf32x3
            *[[
                (16, 16, 16, 1, 1, 2, 32, 32, 80, AT, BT, "float32", "float32", "tf32x3", True, None, None),
                (64, 32, 64, 1, 2, 2, 128, 64, 128, AT, BT, "float32", "float32", "tf32x3", True, None, None),
                (128, 64, 16, 1, 4, 2, 256, 128, 80, AT, BT, "float32", "float32", "tf32x3", True, None, None),
                (256, 128, 32, 1, 8, 2, 512, 256, 160, AT, BT, "float32", "float32", "tf32x3", True, None, None),
                (128, 128, 32, 1, 4, 2, 256, 256, 160, AT, BT, "float32", "float32", "tf32x3", True, None, None),
            ] for AT in [False, True] for BT in [False, True]],
            # mixed-precision
            *[[
                (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None),
                (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None),
                (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None),
            ] for ADTYPE, BDTYPE in [
                ("float8e4nv", "float8e5"),
                ("float8e4nv", "float8e4nv"),
                ("float8e5", "float8e4nv"),
                ("float8e5", "float8e5"),
                ("float8e4b15", "float8e4b15"),
                ("float8e4nv", "float16"),
                ("float16", "float8e5"),
                ("int8", "bfloat16"),
                ("float16", "int8"),
                ("float16", "float32"),
                ("float32", "float16"),
                ("bfloat16", "float32"),
                ("float32", "bfloat16"),
            ] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]],
            # mixed-precision block layout
            *[[
                (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, True, None, None),
                (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, True, None, None),
                (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, None, True, None, None),
            ] for ADTYPE, BDTYPE in [
                ("float8e4nv", "float16"),
                ("float16", "float8e5"),
                ("float16", "float32"),
                ("float32", "float16"),
                ("bfloat16", "float32"),
                ("float32", "bfloat16"),
            ] for AT in [False, True] for BT in [False, True]],
            # acc-out-dtype and output_dtype
            *[[
                (32, 32, 32, 1, 1, 2, None, None, None, False, False, "float16", "float16", None, True, ACC_DTYPE,
                 OUTPUT_DTYPE),
                (128, 256, 32, 1, 8, 2, None, None, None, False, False, "float16", "float16", None, True, ACC_DTYPE,
                 OUTPUT_DTYPE),
            ] for ACC_DTYPE in [None, "float16", "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"]],
        ),
    )
    def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION,
                F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE):
        capability = torch.cuda.get_device_capability()
        if capability[0] < 7:
            pytest.skip("Only test tl.dot() on devices with sm >= 70")
        if capability[0] < 8 and (ADTYPE == "bfloat16" or BDTYPE == "bfloat16"):
            pytest.skip("Only test bfloat16 on devices with sm >= 80")
        if capability[0] < 9 and capability[1] < 9 and (ADTYPE == "float8e4nv" or BDTYPE == "float8e4nv"):
            pytest.skip("Only test float8e4nv on devices with sm >= 89")
        if (ADTYPE == "bfloat16" or BDTYPE == "bfloat16") and SPLIT_K != 1:
            pytest.skip("bfloat16 matmuls don't allow split_k for now")
        torch.manual_seed(0)
        # nuke kernel decorators -- will set meta-parameters manually
        kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
        pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
        configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
        kernel = triton.ops._matmul.kernel
        kernel.configs = configs
        # kernel.run = kernel.run.run.run
    
        # get matrix shape
        M = BLOCK_M if M is None else M
        N = BLOCK_N if N is None else N
        K = BLOCK_K * SPLIT_K if K is None else K
    
        def is_fp8(dtype):
            return "float8" in dtype
    
        def f8_to_f16(x, dtype):
    
            @triton.jit
            def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):
                pid = tl.program_id(0)
                offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
                mask = offs < N
                x = tl.load(X + offs, mask=mask)
                tl.store(Y + offs, x, mask=mask)
    
            ret = torch.empty_strided(x.shape, x.stride(), dtype=torch.float16, device=x.device)
            grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), )
            dtype = getattr(tl, dtype)
            kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024)
            return ret
    
        def upcast_if_fp8(x, dtype):
            if is_fp8(dtype):
                return f8_to_f16(x, dtype)
            return x
    
        def init_input(m, n, dtype, acc_dtype):
            if 'float8' in dtype:
                ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype]
                sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128
                val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth
                return sign | val
            if dtype == "int8":
                return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8)
            # Use small range of values to prevent numerical issues.
            min_exp = -4 if acc_dtype == "float16" else -10
            exponents = torch.randint(min_exp, 0, size=(m, n))
            ret = (2.**exponents).to(getattr(torch, dtype)).to("cuda")
            return ret
    
        if is_hip():
            if INPUT_PRECISION == 'tf32x3' or is_fp8(ADTYPE) or is_fp8(BDTYPE):
                pytest.skip("fp8 inputs or tf32x3 precison does not have native support on hip")
        # allocate/transpose inputs
        a = init_input(M, K, ADTYPE, ACC_DTYPE)
        b = init_input(K, N, BDTYPE, ACC_DTYPE)
        a = a if not AT else a.T.contiguous().T
        b = b if not BT else b.T.contiguous().T
        # run test
        th_a = upcast_if_fp8(a, ADTYPE)
        th_b = upcast_if_fp8(b, BDTYPE)
        ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype)
        acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype
        output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype
        th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype))
        try:
            if is_fp8(ADTYPE):
                a = triton.reinterpret(a, getattr(tl, ADTYPE))
            if is_fp8(BDTYPE):
                b = triton.reinterpret(b, getattr(tl, BDTYPE))
            tt_c = triton.ops.matmul(a, b, acc_dtype if ACC_DTYPE else None, INPUT_PRECISION, F8_FASTACCUM, output_dtype)
>           torch.testing.assert_close(th_c, tt_c)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 29 / 1024 (2.8%)
E           Greatest absolute difference: 0.3125 at index (20, 12) (up to 1e-05 allowed)
E           Greatest relative difference: 0.09423828125 at index (20, 21) (up to 0.016 allowed)

test/unit/operators/test_matmul.py:199: AssertionError