Allow using a custom Process class
ShakedDovrat opened this issue · comments
Thank you for creating this great package.
I would like to create a pipeline where some of the stages use PyTorch (with GPU usage). PyTorch cannot access the GPU from inside a multiprocessing.Process
subprocess. For that reason PyTorch includes a torch.multiprocessing.Process
class which has the same API as multiprocessing.Process
.
I would like the ability to use a custom Process
class instead of the default multiprocessing.Process
, so I can use PyTorch in the pipeline. Without it I'm afraid pypeln is unusable to me.
For instance, add an optional process_class
arguement to map
(and other functions) with a default value multiprocessing.Process
.
Alternatively, maybe there's a walkaround for what I need that I'm unaware of. In that case, please let me know.
Hey @ShakedDovrat! I do believe we can expose a config option to let users specify which Process
class they wants to use. Currently there is a use_threads
flag which changes a multiprocessing.Process
to a multithreading.Thread
since they follow the same API, maybe if we add a worker_class
option we could make it more general. I see two ways of doing this:
1. Add the option to all API functions
The workers are initialized here:
pypeln/pypeln/process/worker.py
Line 224 in f4160d0
They this is called by the
start
method here: pypeln/pypeln/process/worker.py
Line 131 in f4160d0
To get the information you need to add the
worker_class
field here: pypeln/pypeln/process/worker.py
Lines 54 to 66 in f4160d0
and here:
pypeln/pypeln/process/stage.py
Lines 15 to 24 in f4160d0
After that you have add this to all public functions that want to use this.
2. Add a context manager
This simplifies a lot of stuff since during start_workers
you would just have to check that a global variable is set or not and use it. The API could look like this:
with pl.process.config(worker_class=torch.multiprocessing.Process):
# run your pipeline here
Option 2 sounds way easier to implement but sets all workers to the same class (which I think is probably what you want 99% of the time), the other method is more general but requires the user to specify the class per stage which can be tedious.
Thank you @cgarciae! I will look into it.
BTW: you might want to check https://github.com/pytorch/data