myownskyW7 / CARAFE

This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Not working in half precision

tetelias opened this issue · comments

Installed correctly, grad_check runs without errors as does the sample code. When trying to use either native torch amp or original NVidia one, I receive the same error:

  File "/home/someone/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/someone/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/home/someone/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/someone/anaconda3/lib/python3.7/site-packages/carafe/carafe.py", line 250, in forward
    x = self.feature_reassemble(x, mask)
  File "/home/someone/anaconda3/lib/python3.7/site-packages/carafe/carafe.py", line 242, in feature_reassemble
    x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
  File "/home/someone/anaconda3/lib/python3.7/site-packages/carafe/carafe.py", line 114, in forward
    group_size, scale_factor, routput, output)
RuntimeError: expected scalar type Half but found Float

I had the same question, do you solve it?

I have solved it, because the type of output is torch.float16, but the type of masks and rmasks is torch.float 32, so transform the type

if features.is_cuda: masks = masks.type(torch.half) rmasks = rmasks.type(torch.half) carafe_ext.forward(features, rfeatures, masks, rmasks, kernel_size, group_size, scale_factor, routput, output)

I update the code for the issue, it worked for me.

        if features.is_cuda:
            if features.type() == 'torch.cuda.HalfTensor':
                masks = masks.type(torch.half)
                rmasks = rmasks.type(torch.half)
            carafe_ext.forward(features, rfeatures, masks, rmasks, kernel_size,
                               group_size, scale_factor, routput, output)