arogozhnikov / einops

Flexible and powerful tensor operations for readable and reliable code (for pytorch, jax, TF and others)

Home Page:https://einops.rocks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] get_backend is not thread-safe

tran-khoa opened this issue · comments

Describe the bug
Using einops with multiple threads can lead to a race condition, as the backend dictionary is updated while being iterated over in another thread.

Traceback (most recent call last):
  File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 995, in _bo
otstrap
    self._bootstrap_inner()
  File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 1038, in _b
ootstrap_inner
    self.run()
  File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/p/home/jusers/tran4/jureca/llfs/llfs/data/iterator.py", line 70, in _prefetch_thread
    self._buffer.put(self._obtain_sample())
                     ^^^^^^^^^^^^^^^^^^^^^
  File "/p/home/jusers/tran4/jureca/llfs/llfs/data/iterator.py", line 64, in _obtain_sample
    sample = self.jax_pipeline(sample)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/home/jusers/tran4/jureca/llfs/projs/ll_barlow/experiment.py", line 57, in __call__
    images = filter_vmap(self.tokenizer_fn)(images)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise
_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/api.py", line 1258, in vmap_f
    out_flat = batching.batch(
               ^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/linear_util.py", line 188, in call_wrapp
ed
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/equinox/_vmap_pmap.py", line 204, in _fun_wrapper
    _out = self._fun(*_args)
           ^^^^^^^^^^^^^^^^^
  File "/p/home/jusers/tran4/jureca/llfs/llfs/nn/transformers/vit.py", line 100, in __call__
    x = eo.rearrange(x, "c p -> p c")
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 483, in rearrange
    return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 412, in reduce
    return _apply_recipe(recipe, tensor, reduction_type=reduction)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 233, in _apply_recipe
    backend = get_backend(tensor)
              ^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/_backends.py", line 27, in get_backend
    for framework_name, backend in _backends.items():
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: dictionary changed size during iteration

Reproduction steps
Rerun multiple times

import einops as eo
import jax.numpy as jnp
import threading

x = jnp.ones((2, 2))
y = jnp.zeros((2, 2))

def thread(*args, **kwargs):
    global x
    x = eo.rearrange(x, "n c -> (n c)")
    print(x)

threading.Thread(target=thread, daemon=True).start()

y = eo.rearrange(y, "n c -> (n c)")
print(y)

Expected behavior
No race condition

Your platform
einops 0.6.1, python 3.11.3, jax v0.4.14

Hmmm, good point, does wrapping _backends.items() into list solves the problem?

I don't think so. Assume two threads T1 and T2 using the same backend cannot find the backend in the dict.
If T1 then imports the respective backend, T2 may encounter
if BackendSubclass.framework_name not in _backends: therefore not finding the backend that has already been imported by T1.

A simple solution would be to introduce a lock, something like

def get_backend(tensor) -> 'AbstractBackend':
    """
    Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
    If needed, imports package and creates backend
    """
    for framework_name, backend in _backends.items():
        if backend.is_appropriate_type(tensor):
            return backend

    with lock:
        # Try to find backend again
        for framework_name, backend in _backends.items():
            if backend.is_appropriate_type(tensor):
                return backend

        # Find backend subclasses recursively
        backend_subclasses = []
        backends = AbstractBackend.__subclasses__()
        while backends:
            backend = backends.pop()
            backends += backend.__subclasses__()
            backend_subclasses.append(backend)

        for BackendSubclass in backend_subclasses:
            if _debug_importing:
                print('Testing for subclass of ', BackendSubclass)
            if BackendSubclass.framework_name not in _backends:
                # check that module was already imported. Otherwise it can't be imported
                if BackendSubclass.framework_name in sys.modules:
                    if _debug_importing:
                        print('Imported backend for ', BackendSubclass.framework_name)
                    backend = BackendSubclass()
                    _backends[backend.framework_name] = backend
                    if backend.is_appropriate_type(tensor):
                        return backend

    raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))

If T1 then imports the respective backend, T2 may encounter
if BackendSubclass.framework_name not in _backends: therefore not finding the backend that has already been imported by T1.

True, but not an issue: rest of function is idempontent, and no problem if backend is created twice.