webbery / pytorch-transformer-chatbot

PyTorch v1.2에서 생긴 Transformer API 를 이용한 간단한 Chitchat 챗봇

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PyTorch_Transformer_Chatbot

Simple Chinese Generative Chatbot Implementation based on new PyTorch Transformer API (PyTorch v1.x / Python 3.x)

transformer_fig

ToDo

  • Dynamic Memory Networks
  • Beam Search
  • Search hyperparams
  • Attention Visualization
def forward(self, enc_input: torch.Tensor, dec_input: torch.Tensor) -> torch.Tensor:
    x_enc_embed = self.input_embedding(enc_input.long())
    x_dec_embed = self.input_embedding(dec_input.long())

    # Masking
    src_key_padding_mask = enc_input == self.vocab.PAD_ID # tensor([[False, False, False,  True,  ...,  True]])
    tgt_key_padding_mask = dec_input == self.vocab.PAD_ID
    memory_key_padding_mask = src_key_padding_mask
    tgt_mask = self.transfomrer.generate_square_subsequent_mask(dec_input.size(1))

    # einsum ref: https://pytorch.org/docs/stable/torch.html#torch.einsum
    # https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/
    x_enc_embed = torch.einsum('ijk->jik', x_enc_embed)
    x_dec_embed = torch.einsum('ijk->jik', x_dec_embed)


    # transformer ref: https://pytorch.org/docs/stable/nn.html#torch.nn.Transformer
    feature = self.transfomrer(src = x_enc_embed,
                               tgt = x_dec_embed,
                               src_key_padding_mask = src_key_padding_mask,
                               tgt_key_padding_mask = tgt_key_padding_mask,
                               memory_key_padding_mask=memory_key_padding_mask,
                               tgt_mask = tgt_mask.to(device)) # src: (S,N,E) tgt: (T,N,E)

    logits = self.proj_vocab_layer(feature)
    logits = torch.einsum('ijk->jik', logits)

    return logits

Experiments

실행순서

python build_vocab.py # 构建词典
python train.py # 训练seq2seq模型
python inference.py # 推理测试

Requirements

pip install mxnet
pip install gluonnlp
pip install konlpy
pip install python-mecab-ko
pip install chatspace
pip install tb-nightly
pip install future
pip install pathlib

Reference Repositories

About

PyTorch v1.2에서 생긴 Transformer API 를 이용한 간단한 Chitchat 챗봇


Languages

Language:Python 100.0%