memory_efficient_attention fw produce inconsistent results
ShijunK opened this issue · comments
❓ Questions and Help
memory_efficient_attention fw produce inconsistent results
not sure what was going on? incorrect built? some specific versions combinations?
for some combinations:
xformers torch CUDA GPU CUDA Compute Capacity Status
v0.0.20+1dc3d7a(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed
v0.0.20+1dc3d7a(built from source) 1.13 11.7 A100 8 Failed
v0.0.21+320b5ad(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed
v0.0.22+1e065bc(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed
v0.0.23+1254a16(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed
but passed for some:
v0.0.20+1dc3d7a(built from source) 1.13 11.7 RTX A6000 8.6 Passed
v0.0.24+f7e46d5(built from source) 2.2 11.8 A100 8 Passed
v0.0.24+f7e46d5(built from source) 2.2 11.8 RTX A6000 8.6 Passed
v0.0.24+f7e46d5(built from source) 2.2 12.1 RTX A6000 8.6 Passed
v0.0.24+f7e46d5(built from source) 2.2 12.1 H100 9 Passed
0.0.22.post7(pip install) 2.1 11.8 A100 8 Passed
0.0.23(pip install) 2.1.1 11.8 A100 8 Passed
Command
pytest test_simple.py -v
To Reproduce
Steps to reproduce the behavior:
( for the combination: v0.0.20+1dc3d7a(built from source) 1.13 11.7 A100 )
- git checkout v0.0.20
- git submodule update --init --recursive
- Install cuda-11.7 locally
- Install torch-1.13+cu117 in a venv
- python setup.py bdist_wheel
- install xformers in venv
- pytest test_simple.py -v
test code:
import sys
import pytest
import torch
import xformers.ops
from xformers import info
@pytest.mark.parametrize("batch_size", [(1), (4), (8)])
@pytest.mark.parametrize(
"seq_len",
[
(2**1),
(2**3),
(2**6),
(2**9),
],
)
@pytest.mark.parametrize(
"k_seq_len",
[
(2**1),
(2**3),
(2**6),
(2**9), # 512
],
)
@pytest.mark.parametrize("dim_model", [(128)])
@pytest.mark.parametrize(
"dtype,rtol,atol",
[
(torch.float32, 2e-5, 3e-4),
(torch.float16, 4e-4, 4e-3),
],
)
def test_mem_efficient(
batch_size,
seq_len,
k_seq_len,
dim_model,
dtype,
rtol,
atol,
):
dropout = 0.0
device = torch.device("cuda")
q = torch.randn(
(batch_size, seq_len, dim_model),
requires_grad=False,
device=device,
dtype=dtype,
)
k = v = torch.randn(
(batch_size, k_seq_len, dim_model),
requires_grad=False,
device=device,
dtype=dtype,
)
with torch.no_grad():
result_a = xformers.ops.memory_efficient_attention(q, k, v, p=dropout, op=(
xformers.ops.fmha.cutlass.FwOp,
xformers.ops.fmha.cutlass.BwOp,
))
result_b = xformers.ops.memory_efficient_attention(q, k, v, p=dropout, op=(
xformers.ops.fmha.cutlass.FwOp,
xformers.ops.fmha.cutlass.BwOp,
))
is_close = torch.isclose(
result_a,
result_b,
rtol=rtol,
atol=atol,
)
assert torch.all(is_close)
if __name__ == "__main__":
info.print_info()
sys.exit(pytest.main(["--color=yes", "-s", "-vv", __file__]))
output:
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-8] FAILED
Expected behavior
Expect xformers mem efficient attention could produce close enough forward results, when executed twice, across torch versions (1.13 and 2.2)
and CUDA versions (11.7, 11.8, 12.1), and GPU with different compute capabilities (7.5, 8.0, 8.6, 9.0), and different q, k seq length, batch size, data types.
Environment
Please copy and paste the output from the
environment collection script from PyTorch
(or fill out the checklist below manually).
You can run the script with:
# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env
- PyTorch Version (e.g., 1.0):
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
, source): - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
Additional context
one failed environment:
xFormers 0.0.20+1dc3d7a.d20240311
memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available
memory_efficient_attention.flshattF: available
memory_efficient_attention.flshattB: available
memory_efficient_attention.smallkF: available
memory_efficient_attention.smallkB: available
memory_efficient_attention.tritonflashattF: unavailable
memory_efficient_attention.tritonflashattB: unavailable
indexing.scaled_index_addF: available
indexing.scaled_index_addB: available
indexing.index_select: available
swiglu.dual_gemm_silu: available
swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available
is_triton_available: False
is_functorch_available: False
pytorch.version: 1.13.0+cu117
pytorch.cuda: available
gpu.compute_capability: 8.0
gpu.name: A100-SXM-80GB
build.info: available
build.cuda_version: 1107
build.python_version: 3.10.13
build.torch_version: 1.13.0+cu117
build.env.TORCH_CUDA_ARCH_LIST: None
build.env.XFORMERS_BUILD_TYPE: None
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
build.env.NVCC_FLAGS: None
build.env.XFORMERS_PACKAGE_FROM: None
build.nvcc_version: 11.7.64
source.privacy: open source
one success environment:
xFormers 0.0.23+cu118
memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available
memory_efficient_attention.decoderF: available
memory_efficient_attention.flshattF@v2.3.6: available
memory_efficient_attention.flshattB@v2.3.6: available
memory_efficient_attention.smallkF: available
memory_efficient_attention.smallkB: available
memory_efficient_attention.tritonflashattF: unavailable
memory_efficient_attention.tritonflashattB: unavailable
memory_efficient_attention.triton_splitKF: available
indexing.scaled_index_addF: available
indexing.scaled_index_addB: available
indexing.index_select: available
swiglu.dual_gemm_silu: available
swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available
is_triton_available: True
pytorch.version: 2.1.1+cu118
pytorch.cuda: available
gpu.compute_capability: 8.0
gpu.name: A100-SXM-80GB
build.info: available
build.cuda_version: 1108
build.python_version: 3.10.13
build.torch_version: 2.1.1+cu118
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX 9.0
build.env.XFORMERS_BUILD_TYPE: Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
build.env.NVCC_FLAGS: None
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.23
build.nvcc_version: 11.8.89
source.privacy: open source
all versions of xformers (v0.0.20+1dc3d7a
, v0.0.21+320b5ad
, v0.0.22+1e065bc
, v0.0.23+1254a16
, v0.0.24+f7e46d5
) are built from source, except 0.0.22.post7
and 0.0.23
Hi,
If you want deterministic (reproducible) results, you need to enable it in PyTorch:
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html