luopeixiang / named_entity_recognition

中文命名实体识别(包括多种模型:HMM,CRF,BiLSTM,BiLSTM+CRF的具体实现)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

代码是否多余?

WeiYangBin opened this issue · comments

    def train_step(self, batch_sents, batch_tags, word2id, tag2id):
        self.model.train()
        self.step += 1
        # 准备数据
        tensorized_sents, lengths = tensorized(batch_sents, word2id)
        tensorized_sents = tensorized_sents.to(self.device)
        targets, lengths = tensorized(batch_tags, tag2id)
        targets = targets.to(self.device)

        # forward
        scores = self.model(tensorized_sents, lengths)

        # 计算损失 更新参数
        self.optimizer.zero_grad()
        loss = self.cal_loss_func(scores, targets, tag2id).to(self.device)
        loss.backward()
        self.optimizer.step()

        return loss.item()

以上代码来自与model/bilstm_crf.py文件
我想问下train_step函数下的self.model.train()起到一个什么作用

self.model.train()可以让model变成训练模式,一些操作例如 dropout和batch normalization的只有在model.train()的情况下才会生效。

我刚刚试着将这行代码删除,print的结果是几乎一样的