[BUG] get_backend is not thread-safe
tran-khoa opened this issue · comments
Describe the bug
Using einops with multiple threads can lead to a race condition, as the backend dictionary is updated while being iterated over in another thread.
Traceback (most recent call last):
File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 995, in _bo
otstrap
self._bootstrap_inner()
File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 1038, in _b
ootstrap_inner
self.run()
File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 975, in run
self._target(*self._args, **self._kwargs)
File "/p/home/jusers/tran4/jureca/llfs/llfs/data/iterator.py", line 70, in _prefetch_thread
self._buffer.put(self._obtain_sample())
^^^^^^^^^^^^^^^^^^^^^
File "/p/home/jusers/tran4/jureca/llfs/llfs/data/iterator.py", line 64, in _obtain_sample
sample = self.jax_pipeline(sample)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/p/home/jusers/tran4/jureca/llfs/projs/ll_barlow/experiment.py", line 57, in __call__
images = filter_vmap(self.tokenizer_fn)(images)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise
_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/api.py", line 1258, in vmap_f
out_flat = batching.batch(
^^^^^^^^^^^^^^^
File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/linear_util.py", line 188, in call_wrapp
ed
ans = self.f(*args, **dict(self.params, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/equinox/_vmap_pmap.py", line 204, in _fun_wrapper
_out = self._fun(*_args)
^^^^^^^^^^^^^^^^^
File "/p/home/jusers/tran4/jureca/llfs/llfs/nn/transformers/vit.py", line 100, in __call__
x = eo.rearrange(x, "c p -> p c")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 483, in rearrange
return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 412, in reduce
return _apply_recipe(recipe, tensor, reduction_type=reduction)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 233, in _apply_recipe
backend = get_backend(tensor)
^^^^^^^^^^^^^^^^^^^
File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/_backends.py", line 27, in get_backend
for framework_name, backend in _backends.items():
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: dictionary changed size during iteration
Reproduction steps
Rerun multiple times
import einops as eo
import jax.numpy as jnp
import threading
x = jnp.ones((2, 2))
y = jnp.zeros((2, 2))
def thread(*args, **kwargs):
global x
x = eo.rearrange(x, "n c -> (n c)")
print(x)
threading.Thread(target=thread, daemon=True).start()
y = eo.rearrange(y, "n c -> (n c)")
print(y)
Expected behavior
No race condition
Your platform
einops 0.6.1, python 3.11.3, jax v0.4.14
Hmmm, good point, does wrapping _backends.items()
into list solves the problem?
I don't think so. Assume two threads T1 and T2 using the same backend cannot find the backend in the dict.
If T1 then imports the respective backend, T2 may encounter
if BackendSubclass.framework_name not in _backends:
therefore not finding the backend that has already been imported by T1.
A simple solution would be to introduce a lock, something like
def get_backend(tensor) -> 'AbstractBackend':
"""
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
If needed, imports package and creates backend
"""
for framework_name, backend in _backends.items():
if backend.is_appropriate_type(tensor):
return backend
with lock:
# Try to find backend again
for framework_name, backend in _backends.items():
if backend.is_appropriate_type(tensor):
return backend
# Find backend subclasses recursively
backend_subclasses = []
backends = AbstractBackend.__subclasses__()
while backends:
backend = backends.pop()
backends += backend.__subclasses__()
backend_subclasses.append(backend)
for BackendSubclass in backend_subclasses:
if _debug_importing:
print('Testing for subclass of ', BackendSubclass)
if BackendSubclass.framework_name not in _backends:
# check that module was already imported. Otherwise it can't be imported
if BackendSubclass.framework_name in sys.modules:
if _debug_importing:
print('Imported backend for ', BackendSubclass.framework_name)
backend = BackendSubclass()
_backends[backend.framework_name] = backend
if backend.is_appropriate_type(tensor):
return backend
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))
If T1 then imports the respective backend, T2 may encounter
if BackendSubclass.framework_name not in _backends: therefore not finding the backend that has already been imported by T1.
True, but not an issue: rest of function is idempontent, and no problem if backend is created twice.