Tensor operations with mask for PyTorch.
Sometimes you need to perform operations on PyTorch tensors with the masked elements been ignored, for example:
>>> input = torch.tensor([1., 2., 3.])
>>> result = torch.sum(input)
>>> print(result)
tensor(6.)
>>> mask = torch.tensor([1, 1, 0]).byte()
>>> masked_result = torchmasked.masked_sum(input, mask)
>>> print(masked_result)
tensor(3.) # element input[2] is masked and ignored
Then this package could be helpful.
Tested on Python 3.6+ and PyTorch 1.4+.
From PyPI:
pip install torchmasked
From source:
pip install git+https://github.com/Renovamen/torchmasked.git --upgrade
# or
python setup.py install
The usage is the same as PyTorch's original functions. Refer to PyTorch documentation or the source code for details.
torchmasked.masked_max
(masked version oftorch.max
)torchmasked.masked_min
(torch.min
)torchmasked.masked_sum
(torch.sum
)torchmasked.masked_mean
(torch.mean
)torchmasked.masked_softmax
(torch.nn.functional.softmax
) /torchmasked.nn.MaskedSoftmax
(torch.nn.Softmax
)