DeepTrackAI / DeepTrack2

DeepTrack2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Speeding up `PyTorchContinousGenerator` by copying multiple features into GPU memory

jacopoabramo opened this issue · comments

Greetings,

I'm currently using DeepTrack for a project to create a convolutional model for particle tracking. I started with the in-package model builder class, but I now need to create a similar model using PyTorch. Since I still want to use the feature system of DeepTrack I'm using the PyTorchContinousGenerator class to generate my data. So far this has worked quite well, unfortunately now I'm training a quantized version of this model with Brevitas and the training loop is much slower than the original version. My thoughts are that this is due to the fact that the generator does not make a copy directly in the GPU memory hence for each batch I need to copy the data from CPU to GPU memory. My question would be: would it be possible to create a new generator class or adapt the existing one so that feature data is directly copied into GPU memory after being generated?

Thank you.

Hi!

This should not be difficult to implement. In fact, I think it's something we should consider doing per default. The Pytorch generator is almost identical to the standard ContinuousGenerator, with the only difference being the getitem being overridden to convert the internal numpy arrays to torch tensors.

Instead, one could hook into the construct_datapoint function which is called immediately after generating the sample. I will have a quick look.

Not tested, but I imagine something like:

class PyTorchContinuousGenerator(ContinuousGenerator):
    """Extends the ContinuousGenerator to support PyTorch models.

    This class is used to generate batches of data for PyTorch models."""

    @staticmethod
    def image_to_torch(image):

        device = "gpu" if torch.cuda.is_available() else "cpu"

        array = image.to_numpy()._value
        torch_array = torch.from_numpy(array).to(device)
        torch_image = Image(torch_array, copy=False)

        torch_image.merge_properties_from(image)

        return torch_image

    def construct_datapoint(self, image: Image):

        if isinstance(image, list):
            image = [self.image_to_torch(x) for x in image]
        else:
            image = self.image_to_torch(x)

        return super().construct_datapoint(image)

    def __getitem__(self, idx):

        batch_size = self._batch_size

        subset = self.current_data[idx * batch_size : (idx + 1) * batch_size]

        for d in subset:
            d["usage"] += 1

        data = [self.batch_function(d["data"]) for d in subset]
        labels = [self.label_function(d["data"]) for d in subset]

        return torch.stack(data), torch.stack(labels)

should work.

Let me know if you need any more help with the implementation based on this.

Hi @BenjaminMidtvedt, thank you for the quick reply, really appreciate it!

I just tested the generator, a couple of things to correct:

  • device = "gpu" if torch.cuda.is_available() else "cpu" -> "gpu" should actually be "cuda"; on a more general case this is the RuntimeError I get: RuntimeError: Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mps, xla, lazy, vulkan, meta, hpu, privateuseone device type at start of device string: gpu; but usually cuda I think it's the most used one.
  • there's a typo in the construct_datapoint method: image = self.image_to_torch(x) -> this should image = self.image_to_torch(image)

After correcting these 2 things data is generated correctly, but then the __getitem__ method fails with this traceback:

TypeError: no implementation found for 'torch.stack' on types that implement __torch_function__: [<class 'deeptrack.image.Image'>]

I'm assuming that the Image feature requires the implementation of this method, I just don't know in detail what the implementation would required to be.

Right! Forgot that we don't have bindings for torch functions yet. Will be added.

You can simply do torch.stack([x._value for x in data]) to expose the underlying torch tensors. You no longer need the Image wrapper

You can simply do torch.stack([x._value for x in data]) to expose the underlying torch tensors. You no longer need the Image wrapper

Could you be more specific? I tried applying your suggested change but I get AttributeError: 'numpy.ndarray' object has no attribute '_value'. I'm sure I'm forgetting something but it's not clear what.

I see. You have a mixture of numpy arrays and torch tensors at this point?

You can do

data = [self.batch_function(dt.image.strip(d["data"])) for d in subset]
labels = [self.label_function(dt.image.strip(d["data"])) for d in subset]

return torch.stack(data), torch.stack(labels)

which should be more stable. strip removes the Image wrapper correctly and safely. Then you just need to ensure that the batch_function / label_function you provide accept and return torch tensors. This will depend on the specifics of your application.

Hi @BenjaminMidtvedt , sorry for the late reply. Your suggestion helped me figure out what to do. In the end this is the example code snippet I used:

from deeptrack.image import Image

class PyTorchContinuousGenerator(ContinuousGenerator):
    """Extends the ContinuousGenerator to support PyTorch models.

    This class is used to generate batches of data for PyTorch models."""

    @staticmethod
    def image_to_torch(image):

        device = "cuda" if torch.cuda.is_available() else "cpu"

        array = image.to_numpy()._value
        torch_array = torch.from_numpy(array).to(device)
        torch_image = Image(torch_array, copy=False)

        torch_image.merge_properties_from(image)

        return torch_image

    def construct_datapoint(self, image: Image):

        if isinstance(image, list):
            image = [self.image_to_torch(x) for x in image]
        else:
            image = self.image_to_torch(image)

        return super().construct_datapoint(image)

    def __getitem__(self, idx):

        batch_size = self._batch_size

        subset = self.current_data[idx * batch_size : (idx + 1) * batch_size]

        for d in subset:
            d["usage"] += 1

        data = [self.batch_function(dt.image.strip(d["data"])) for d in subset]
        labels = [self.label_function(d["data"]) for d in subset]

        return torch.stack(data), torch.stack(labels)

def get_normalized_torch_label(image: dt.Image) -> torch.Tensor:
    return torch.from_numpy(image.get_property("position") / IMAGE_SIZE)

torch_data_generator = PyTorchContinuousGenerator(
    feature = imaged_particle,
    label_function = get_normalized_torch_label,
    min_data_size = 10,
    max_data_size = 20,
    batch_size = batch_size,
    shuffle_batch = True
)

Training is now much faster, thank you very much. Unfortunately the implementation is a bit unbalanced but for what I need is more than enough. I guess that you now just need to implement the __torch_function__ in the Image feature to make it more easily adaptable for PyTorch generators. I'll close the issue.