facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.

Home Page:https://facebookresearch.github.io/xformers/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

memory_efficient_attention fw produce inconsistent results

ShijunK opened this issue · comments

❓ Questions and Help

memory_efficient_attention fw produce inconsistent results

not sure what was going on? incorrect built? some specific versions combinations?

for some combinations:
xformers torch CUDA GPU CUDA Compute Capacity Status
v0.0.20+1dc3d7a(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed
v0.0.20+1dc3d7a(built from source) 1.13 11.7 A100 8 Failed
v0.0.21+320b5ad(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed
v0.0.22+1e065bc(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed
v0.0.23+1254a16(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed

but passed for some:
v0.0.20+1dc3d7a(built from source) 1.13 11.7 RTX A6000 8.6 Passed
v0.0.24+f7e46d5(built from source) 2.2 11.8 A100 8 Passed
v0.0.24+f7e46d5(built from source) 2.2 11.8 RTX A6000 8.6 Passed
v0.0.24+f7e46d5(built from source) 2.2 12.1 RTX A6000 8.6 Passed
v0.0.24+f7e46d5(built from source) 2.2 12.1 H100 9 Passed
0.0.22.post7(pip install) 2.1 11.8 A100 8 Passed
0.0.23(pip install) 2.1.1 11.8 A100 8 Passed

Command

pytest test_simple.py -v

To Reproduce

Steps to reproduce the behavior:
( for the combination: v0.0.20+1dc3d7a(built from source) 1.13 11.7 A100 )

  1. git checkout v0.0.20
  2. git submodule update --init --recursive
  3. Install cuda-11.7 locally
  4. Install torch-1.13+cu117 in a venv
  5. python setup.py bdist_wheel
  6. install xformers in venv
  7. pytest test_simple.py -v

test code:

import sys
import pytest
import torch

import xformers.ops
from xformers import info


@pytest.mark.parametrize("batch_size", [(1), (4), (8)])
@pytest.mark.parametrize(
    "seq_len",
    [
        (2**1),
        (2**3),
        (2**6),
        (2**9),
    ],
)
@pytest.mark.parametrize(
    "k_seq_len",
    [
        (2**1),
        (2**3),
        (2**6),
        (2**9),  # 512
    ],
)
@pytest.mark.parametrize("dim_model", [(128)])
@pytest.mark.parametrize(
    "dtype,rtol,atol",
    [
        (torch.float32, 2e-5, 3e-4),
        (torch.float16, 4e-4, 4e-3),
    ],
)
def test_mem_efficient(
    batch_size,
    seq_len,
    k_seq_len,
    dim_model,
    dtype,
    rtol,
    atol,
):
    dropout = 0.0
    
    device = torch.device("cuda")

    q = torch.randn(
        (batch_size, seq_len, dim_model),
        requires_grad=False,
        device=device,
        dtype=dtype,
    )

    k = v = torch.randn(
        (batch_size, k_seq_len, dim_model),
        requires_grad=False,
        device=device,
        dtype=dtype,
    )

    with torch.no_grad():
        result_a = xformers.ops.memory_efficient_attention(q, k, v, p=dropout, op=(
            xformers.ops.fmha.cutlass.FwOp,
            xformers.ops.fmha.cutlass.BwOp,
        ))

        result_b = xformers.ops.memory_efficient_attention(q, k, v, p=dropout, op=(
            xformers.ops.fmha.cutlass.FwOp,
            xformers.ops.fmha.cutlass.BwOp,
        ))
    is_close = torch.isclose(
        result_a,
        result_b,
        rtol=rtol,
        atol=atol,
    )
    assert torch.all(is_close)

if __name__ == "__main__":
   info.print_info()
   sys.exit(pytest.main(["--color=yes", "-s", "-vv", __file__]))

output:

memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-8] FAILED


Expected behavior

Expect xformers mem efficient attention could produce close enough forward results, when executed twice, across torch versions (1.13 and 2.2)
and CUDA versions (11.7, 11.8, 12.1), and GPU with different compute capabilities (7.5, 8.0, 8.6, 9.0), and different q, k seq length, batch size, data types.

Environment

Please copy and paste the output from the
environment collection script from PyTorch
(or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env
  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

one failed environment:

xFormers 0.0.20+1dc3d7a.d20240311
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.flshattF:               available
memory_efficient_attention.flshattB:               available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        unavailable
memory_efficient_attention.tritonflashattB:        unavailable
indexing.scaled_index_addF:                        available
indexing.scaled_index_addB:                        available
indexing.index_select:                             available
swiglu.dual_gemm_silu:                             available
swiglu.gemm_fused_operand_sum:                     available
swiglu.fused.p.cpp:                                available
is_triton_available:                               False
is_functorch_available:                            False
pytorch.version:                                   1.13.0+cu117
pytorch.cuda:                                      available
gpu.compute_capability:                            8.0
gpu.name:                                          A100-SXM-80GB
build.info:                                        available
build.cuda_version:                                1107
build.python_version:                              3.10.13
build.torch_version:                               1.13.0+cu117
build.env.TORCH_CUDA_ARCH_LIST:                    None
build.env.XFORMERS_BUILD_TYPE:                     None
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              None
build.env.XFORMERS_PACKAGE_FROM:                   None
build.nvcc_version:                                11.7.64
source.privacy:                                    open source

one success environment:

xFormers 0.0.23+cu118
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.decoderF:               available
memory_efficient_attention.flshattF@v2.3.6:        available
memory_efficient_attention.flshattB@v2.3.6:        available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        unavailable
memory_efficient_attention.tritonflashattB:        unavailable
memory_efficient_attention.triton_splitKF:         available
indexing.scaled_index_addF:                        available
indexing.scaled_index_addB:                        available
indexing.index_select:                             available
swiglu.dual_gemm_silu:                             available
swiglu.gemm_fused_operand_sum:                     available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
pytorch.version:                                   2.1.1+cu118
pytorch.cuda:                                      available
gpu.compute_capability:                            8.0
gpu.name:                                          A100-SXM-80GB
build.info:                                        available
build.cuda_version:                                1108
build.python_version:                              3.10.13
build.torch_version:                               2.1.1+cu118
build.env.TORCH_CUDA_ARCH_LIST:                    5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX 9.0
build.env.XFORMERS_BUILD_TYPE:                     Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              None
build.env.XFORMERS_PACKAGE_FROM:                   wheel-v0.0.23
build.nvcc_version:                                11.8.89
source.privacy:                                    open source

all versions of xformers (v0.0.20+1dc3d7a, v0.0.21+320b5ad, v0.0.22+1e065bc, v0.0.23+1254a16, v0.0.24+f7e46d5) are built from source, except 0.0.22.post7 and 0.0.23

Hi,
If you want deterministic (reproducible) results, you need to enable it in PyTorch:
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html