dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

Home Page:https://dreamquark-ai.github.io/tabnet/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support for complex-valued datasets

SantaTitular opened this issue · comments

Feature request

Can it be possible for TabNet to take into account complex-valued datasets, e.g., such as radar data, complex-valued time signals, etc.

What is the expected behavior?
It should be exactly the same. Only changes in the output representation since we can visualize mag, phase, real or imag.

What is motivation or use case for adding/changing the behavior?
For EM datasets and radar where phase carries valid information that is often lost when using a transformation before

How should this be implemented in your opinion?
There are already lightweight implementations of complex NNs using pytorch, see complextorch . Just changes the way the weights are setup (one for real and another for complex).

Are you willing to work on this yourself?
yes

This does not feel like a very common issue. I guess you can already use the attention grouping strategy to represent complex numbers by grouping each rho and theta together. Would it to the trick?

Thanks for the fast reply!

Yes, I do agree that it is quite niche but I have found some papers showing its use for particular applications w/ physic related topics.

I'm still solving some environment dependencies to install TabNet but I would imagine that, theoretically, you can. However, I'm pretty sure that torch does not recognize cfloat as dtype and, thus, we would need to make a 2D convolution for every layer. Since you coded the TabNet to such a high-level I'm not sure whether it could be more simple (or possible) to just use a wrapper (like the one below)? Also, to train the NN with traditional classification with complex numbers you would typically need to apply a mag() at the output or just use a regression with MSE such as:

import torch.nn.functional as F

def complex_relu(input):
    return F.relu(input.real).type(torch.complex128)+1j*F.relu(input.imag).type(torch.complex128)
class ComplexReLU(nn.Module):

    def forward(self,input):
        return complex_relu(input)

class ComplexMSELoss(nn.Module):
    def __init__(self):
       super(ComplexMSELoss, self).__init__()

    def forward(self, inputs, targets):
        if inputs.is_complex():
            
            diff_real = (inputs.real - targets.real)
            diff_imag = (inputs.imag - targets.imag)
            sss = torch.mean(torch.square(torch.abs(diff_real + diff_imag)))
            return sss# Compute mean squared error between real and imaginary components
        else:
            return F.mse_loss(inputs , targets)

Edit: Forgot the wrapper*

Are you predicting complex numbers as well?

You can still preprocess your dataset in such a way that complex numbers are already transformed in imaginary and real parts. Make sure that you group the inputs so that attention is applied by complex inputs not separately on real and imaginary parts, for the outputs you can use TabNetRegressor to predict both the imaginary and real parts together.

That would be simplest way in my opinion, and that is what I would try before trying to rewrite the library. However, if you want to try it feel free to open a PR but it seems like a lot of work and big changes ahead to make this easily usable.

Depends on the strategy! I've read papers suggesting complex numbers as teacher signals or just applying a function to transform from the complex domain to the real one.

I do agree with you, I'll try using different strategies before openning a PR then! Finally managed to have a clean installation!
Edit=Removed unnecessary add-on