IntelLabs / bayesian-torch

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning extending the core of PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

dnn_to_bnn() with LSTM network

jeiglsperger opened this issue · comments

I have a problem with the dnn_to_bnn() transformation of my lstm network.
I define my lstm network the following:

model = []
n_layers = self.suggest_hyperparam_to_optuna('n_layers')
p = self.suggest_hyperparam_to_optuna('dropout_rate')
n_feature = self.n_features
lstm_hidden_dim = self.suggest_hyperparam_to_optuna('lstm_hidden_dim')

model.append(PrepareForlstm())
model.append(torch.nn.LSTM(input_size=n_feature, hidden_size=lstm_hidden_dim, num_layers=n_layers,
                           dropout=p))
model.append(GetOutputZero())
model.append(PrepareForDropout())
model.append(torch.nn.Dropout(p))
model.append(torch.nn.Linear(in_features=lstm_hidden_dim, out_features=self.n_outputs))
model = torch.nn.Sequential(*model)

The classes PrepareForlstm, GetOutputZero, and PrepareForDropout are defined as follows:

# Class to only get the first output of a lstm layer
class GetOutputZero(torch.nn.Module):
    def __init__(self):
        super(GetOutputZero, self).__init__()

    def forward(self, x):
        lstm_out, (hn, cn) = x
        return lstm_out


# Class to reshape the data suitable for lstm layer
class PrepareForlstm(torch.nn.Module):
    def __init__(self):
        super(PrepareForlstm, self).__init__()

    def forward(self, x):
        return x.view(x.shape[0], x.shape[1], -1)


# Class to reshape data suitable for dropout layer
class PrepareForDropout(torch.nn.Module):
    def __init__(self):
        super(PrepareForDropout, self).__init__()

    def forward(self, lstm_out):
        return lstm_out[:, -1, :]

This network works fine for me. Now I wanted to try out the dnn_to_bnn() API and transform it to a bayesian lstm with:

bnn_prior_parameters = {
            "prior_mu": self.suggest_hyperparam_to_optuna('prior_mu'),
            "prior_sigma": self.suggest_hyperparam_to_optuna('prior_sigma'),
            "posterior_mu_init": self.suggest_hyperparam_to_optuna('posterior_mu_init'),
            "posterior_rho_init": self.suggest_hyperparam_to_optuna('posterior_rho_init'),
            "type": self.suggest_hyperparam_to_optuna('type'),
            "moped_enable": self.suggest_hyperparam_to_optuna('moped_enable')
        }
dnn_to_bnn(model, bnn_prior_parameters)

When executing my code, I get the following traceback:

File "/home/josef/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/josef/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/home/josef/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/josef/.local/lib/python3.8/site-packages/bayesian_torch/layers/variational_layers/rnn_variational.py", line 126, in forward
    ff_i, kl_i = self.ih(x_t)
  File "/home/josef/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/josef/.local/lib/python3.8/site-packages/bayesian_torch/layers/variational_layers/linear_variational.py", line 164, in forward
    out = F.linear(input, weight, bias)
  File "/home/josef/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1 and 23x180)

So the size of the input does not match the size of weight in out = F.linear(input, weight, bias) in line 164 in linear_variational.py. I tried to trace back how this mats came about, but have not clue why the network does not work anymore. Maybe someone of you has a clue?

I recognized I had some mistakes in the architecture of my framework (e.g. the creation of the sequences) and the selection of the hyperparameter. After debugging those things it works, therefore I will close this issue.