to_tf_dataset: Visible devices cannot be modified after being initialized
logasja opened this issue · comments
Describe the bug
When trying to use to_tf_dataset with a custom data_loader collate_fn when I use parallelism I am met with the following error as many times as number of workers there were in num_workers
.
File "/opt/miniconda/envs/env/lib/python3.11/site-packages/multiprocess/process.py", line 314, in _bootstrap
self.run()
File "/opt/miniconda/envs/env/lib/python3.11/site-packages/multiprocess/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/opt/miniconda/envs/env/lib/python3.11/site-packages/datasets/utils/tf_utils.py", line 438, in worker_loop
tf.config.set_visible_devices([], "GPU") # Make sure workers don't try to allocate GPU memory
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda/envs/env/lib/python3.11/site-packages/tensorflow/python/framework/config.py", line 566, in set_visible_devices
context.context().set_visible_devices(devices, device_type)
File "/opt/miniconda/envs/env/lib/python3.11/site-packages/tensorflow/python/eager/context.py", line 1737, in set_visible_devices
raise RuntimeError(
RuntimeError: Visible devices cannot be modified after being initialized
Steps to reproduce the bug
- Download a dataset using HuggingFace load_dataset
- Define a function that transforms the data in some way to be used in the collate_fn argument
- Provide a
batch_size
andnum_workers
value in theto_tf_dataset
function - Either retrieve directly or use tfds benchmark to test the dataset
from datasets import load_datasets
import tensorflow_datasets as tfds
from keras_cv.layers import Resizing
def data_loader(examples):
x = Resizing(examples[0]['image'], 256, 256, crop_to_aspect_ratio=True)
return {X[0]: x}
ds = load_datasets("logasja/FDF", split="test")
ds = ds.to_tf_dataset(collate_fn=data_loader, batch_size=16, num_workers=2)
tfds.benchmark(ds)
Expected behavior
Use multiple processes to apply transformations from the collate_fn to the tf dataset on the CPU.
Environment info
datasets
version: 2.19.1- Platform: Linux-6.5.0-1023-oracle-x86_64-with-glibc2.35
- Python version: 3.11.8
huggingface_hub
version: 0.22.2- PyArrow version: 15.0.2
- Pandas version: 2.2.1
fsspec
version: 2024.2.0