lonePatient / Bert-Multi-Label-Text-Classification

This repo contains a PyTorch implementation of a pretrained BERT model for multi-label text classification.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: __init__() takes 1 positional argument but 2 were given

FayeXXX opened this issue · comments

class BertForMultiLable(BertPreTrainedModel):
def init(self, config):

    super(BertForMultiLable, self).__init__(config)
    self.bert = BertModel(config)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    self.apply(self.init_weights)

when we try to run the bert model by your steps, everything is perfect, however there is always a type error showing that the init_weights() function in the line 11

self.apply(self.init_weights)

of bert_for_multi_label.py has something wrong. we tried to modify it to

self.apply(self._init_weights)

class BertForMultiLable(BertPreTrainedModel):
def init(self, config):

    super(BertForMultiLable, self).__init__(config)
    self.bert = BertModel(config)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    self.apply(self._init_weights)

then it looks like okay. I am wondering if you also meet this problem and if the problem is caused by the version of python or the inheritance of parent class?

when you use the torch 1.0.0, you may meet this problem
and 1.1.0 is ok

I have tried 1.1.0, but it doesn't work. Have you solved this problem?

@FayeXXX just check your pytorch version. if it is >= 1.2.0, do the following change to your init():
modify self.apply(self.init_weights) to self.init_weights()

    super(BertForMultiLable, self).__init__(config)
    self.bert = BertModel(config)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    self.init_weights()

@lonePatient has already explained it in the readme of the file.

note: pytorch_transformers>=1.2.0, modify self.apply(self.init_weights) to self.init_weights() in pybert/model/nn/bert_for_multi_label.py file .

Try this fix