cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Multiple GPU usage not working

argideritzalpea opened this issue · comments

I am attempting to use this by moving my model into the torch DataParallel class:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = [[[CreateModel()]]]

model= nn.DataParallel(model)
model.to(device)

and ensuring that attributes are accessible:

class MyDataParallel(nn.DataParallel):
    def __getattr__(self, name):
        return getattr(self.module, name)

I get runtime errors of tensor sizes when I use the CLS explainer. Any plans to enable multiple GPU utilization?

Hi there, I currently don't have any plans to optimize for multiple GPU's as Captum on which this package relies on does not support this either. If Captum added support in the future I would definitely add feature parity in this package.