ValueError: NOT FOUND when trying to save train state in docker container
hayden-donnelly opened this issue · comments
I'm getting the following error when I try to save my train state from within a docker container:
Traceback (most recent call last):
File "/project/test.py", line 32, in <module>
checkpointer.save(os.path.abspath('checkpoints/checkpoint1'), state)
File "/usr/local/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py", line 81, in save
self._handler.save(tmpdir, item, *args, **kwargs)
File "/usr/local/lib/python3.9/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 756, in save
asyncio.run(async_save(directory, item, *args, **kwargs))
File "/usr/local/lib/python3.9/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/usr/local/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete
return future.result()
File "/usr/local/lib/python3.9/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 754, in async_save
future.result() # Block on result.
ValueError: NOT_FOUND: Error opening "cast" driver: Error opening "zarr" driver: Error writing "step/.zarray" in OCDBT database at local file "/project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/": Error writing local file "/project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/manifest.ocdbt": Error getting file info: /project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/manifest.ocdbt.__lock [OS error: No such file or directory] [tensorstore_spec[1]='{\"base\":{\"create\":true,\"driver\":\"zarr\",\"dtype\":\"int64\",\"kvstore\":{\"base\":{\"driver\":\"file\",\"path\":\"/project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/\"},\"cache_pool\":\"cache_pool#ocdbt\",\"config\":{\"max_decoded_node_bytes\":100000000,\"max_inline_value_bytes\":1024},\"driver\":\"ocdbt\",\"experimental_read_coalescing_threshold_bytes\":1000000,\"path\":\"step/\"},\"metadata\":{\"chunks\":[],\"compressor\":{\"id\":\"zstd\",\"level\":1},\"shape\":[]},\"open\":true,\"recheck_cached_data\":false,\"recheck_cached_metadata\":false},\"context\":{\"cache_pool\":{},\"cache_pool#ocdbt\":{\"total_bytes_limit\":100000000},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"ocdbt_coordinator\":{}},\"driver\":\"cast\",\"dtype\":\"int64\"}'] [source locations='tensorstore/kvstore/kvstore.cc:268\ntensorstore/kvstore/kvstore.cc:268\ntensorstore/driver/driver.cc:114\ntensorstore/driver/driver.cc:114'] [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"cache_pool#ocdbt\":{\"total_bytes_limit\":100000000},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"ocdbt_coordinator\":{}},\"create\":true,\"driver\":\"zarr\",\"dtype\":\"int64\",\"kvstore\":{\"base\":{\"driver\":\"file\",\"path\":\"/project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/\"},\"cache_pool\":\"cache_pool#ocdbt\",\"config\":{\"max_decoded_node_bytes\":100000000,\"max_inline_value_bytes\":1024},\"driver\":\"ocdbt\",\"experimental_read_coalescing_threshold_bytes\":1000000,\"path\":\"step/\"},\"metadata\":{\"chunks\":[],\"compressor\":{\"id\":\"zstd\",\"level\":1},\"shape\":[]},\"open\":true,\"recheck_cached_data\":false,\"recheck_cached_metadata\":false}']
Here's the code to reproduce:
import flax.linen as nn
from flax.training import train_state
import optax
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp
import os
def create_train_state(module, rng):
x = (jnp.ones([1, 256, 256, 1]))
variables = module.init(rng, x)
params = variables['params']
tx = optax.adam(1e-3)
ts = train_state.TrainState.create(
apply_fn=module.apply, params=params, tx=tx
)
return ts
class TestModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(4, kernel_size=(3, 3))(x)
return x
if __name__ == '__main__':
init_rng = jax.random.PRNGKey(0)
model = TestModel()
state = create_train_state(model, init_rng)
del init_rng
checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=True))
checkpointer.save(os.path.abspath('checkpoints/checkpoint1'), state)
And here's my docker setup:
Dockerfile:
FROM python:3.9.17-slim-bullseye
WORKDIR /project
COPY requirements.txt requirements.txt
RUN python -m pip install --upgrade pip
RUN python -m pip install jupyterlab flax orbax-checkpoint jax
EXPOSE 8888
ENTRYPOINT ["jupyter", "lab", "--ip=0.0.0.0", "--allow-root", "--no-browser", "--NotebookApp.token=''", "--NotebookApp.password=''"]
docker-compose.yaml:
services:
test:
build: .
ports:
- 8888:8888
volumes:
- .:/project
deploy:
resources:
reservations:
devices:
- capabilities: [gpu]
I use the following commands to build and enter my docker container:
docker-compose build
docker-compose up -d
docker-compose exec test bash
Then I create the checkpoint directory:
mkdir checkpoints
From here you can run the reproduction code.
I've been able to reproduce this error in a couple of different docker environments, but this one is the simplest. For some reason it does not reproduce in Colab.
I was able to run your codes without errors from a simple virtual env without docker. I think it's something do file permissions of your project
& checkpoints
folders. You may want to create the checkpoints folders in docker / python script instead of mkdir from cli. The checkpoint handler needs to have full permissions to the checkpoint path. Lmk if you can fix it.
I just tried running it in ubuntu without docker and it worked, so it seems docker is the problem. Creating the checkpoints folder with python doesn't seem to change anything. I'll keep experimenting.
So the problem was the ownership of my mounted volume. Because I mounted /projects
from my host machine, I couldn't control its ownership from inside the container. I solved this by adding RUN mkdir /checkpoints
to my dockerfile in order to create a separate directory that is owned by the root user, then saving checkpoints to that directory instead of the one inside /project
. The one problem with this is that everything inside of /checkpoints
will be lost when I stop the container. I got around this by only using /checkpoints
as a temp directory and immediately copying any checkpoints to /project/checkpoints
after they're created.
Here's the updated dockerfile:
FROM python:3.9.17-slim-bullseye
RUN mkdir /checkpoints
WORKDIR /project
COPY requirements.txt requirements.txt
RUN python -m pip install --upgrade pip
RUN python -m pip install requirements.txt
EXPOSE 8888
ENTRYPOINT ["jupyter", "lab", "--ip=0.0.0.0", "--allow-root", "--no-browser", "--NotebookApp.token=''", "--NotebookApp.password=''"]
So the root directory of my container now looks like this:
(usual linux stuff)
/checkpoints
/project
/checkpoints
/test.py
...
And /checkpoints
is owned by the root user:
root@12b24ced75c4:/# ls -ld checkpoints
drwxr-xr-x 2 root root 4096 Aug 2 20:18 checkpoints
Whereas /project
and by extension /project/checkpoints
are not:
root@12b24ced75c4:/# ls -ld project
drwxrwxrwx 1 1000 1000 4096 Aug 2 20:17 project
Finally, here's the updated python file with the copying technique that I mentioned:
import flax.linen as nn
from flax.training import train_state
import optax
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp
import os
import shutil
def create_train_state(module, rng):
x = (jnp.ones([1, 256, 256, 1]))
variables = module.init(rng, x)
params = variables['params']
tx = optax.adam(1e-3)
ts = train_state.TrainState.create(
apply_fn=module.apply, params=params, tx=tx
)
return ts
class TestModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(4, kernel_size=(3, 3))(x)
return x
if __name__ == '__main__':
init_rng = jax.random.PRNGKey(0)
model = TestModel()
state = create_train_state(model, init_rng)
del init_rng
checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=True))
# Save to root owned checkpoints dir.
checkpointer.save(os.path.abspath('../checkpoints/checkpoint1'), state)
# Copy from root owned checkpoints dir, to checkpoints dir in mounted volume.
shutil.copytree('../checkpoints/checkpoint1', 'checkpoints/checkpoint1')
# Restore from checkpoints dir in mounted volume.
state = checkpointer.restore(os.path.abspath('checkpoints/checkpoint1'))
@ChromeHearts thanks for pointing me in the right direction.
This is certainly very weird. Your docker was actually running as root so it shouldn't have issues directly saving checkpoints to the checkpoint
folder. In addition, if you can copy the checkpoints from temp to mounted folder, it meant your python script had no issues writing as well. It's definitely something to do with your docker or local mount. I re-did your setup in docker and was able to checkpoint directly to local mount folder.
sudo docker-compose exec test bash
root@e8a981978d40:/project# ls -l
total 16
-rw-r--r-- 1 1003 1003 268 Aug 3 02:26 Dockerfile
-rw-r--r-- 1 1003 1003 105 Aug 3 02:27 docker-compose.yaml
-rw-r--r-- 1 1003 1003 940 Aug 3 02:07 main.py
drwxr-xr-x 5 1003 1003 4096 Aug 3 02:08 py39
root@e8a981978d40:/project# mkdir checkpoints
root@e8a981978d40:/project# python main.py
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
save_path='/project/checkpoints/checkpoint1'
root@e8a981978d40:/project# ls -l
total 20
-rw-r--r-- 1 1003 1003 268 Aug 3 02:26 Dockerfile
drwxr-xr-x 3 root root 4096 Aug 3 02:35 checkpoints
-rw-r--r-- 1 1003 1003 105 Aug 3 02:27 docker-compose.yaml
-rw-r--r-- 1 1003 1003 940 Aug 3 02:07 main.py
drwxr-xr-x 5 1003 1003 4096 Aug 3 02:08 py39
root@e8a981978d40:/project# find checkpoints
checkpoints
checkpoints/checkpoint1
checkpoints/checkpoint1/checkpoint
checkpoints/checkpoint1/d
checkpoints/checkpoint1/d/ce942794f70ea11a64fa0742f009b653
checkpoints/checkpoint1/d/2b248e926ce267f2604fe6215090a51b
checkpoints/checkpoint1/d/7d4d3dd57291b20fd8866db6035f8025
checkpoints/checkpoint1/manifest.ocdbt
root@e8a981978d40:/project#
The main.py is simply your python script (1st version without the copy). I managed to save checkpoint without issues. I suggest avoid the copy from temp to mounted volume. Docker temp folders are not meant for storing large dataset. They are slow and have limited storage size.
It seems that we have at least one solution, so closing this issue.