Modify training script to load pre-trained model
fauxneticien opened this issue · comments
Nay San commented
Note in train_xsa_e2e.py
a new model is randomly initialised:
model = X_Transformer_E2E_LID(n_lang=args.lang,
dropout=0.1,
input_dim=args.dim,
feat_dim=256,
n_heads=4,
d_k=256,
d_v=256,
d_ff=2048,
max_seq_len=args.maxlength,
device=device)
And a model is saved to *.ckpt
file:
torch.save(model.state_dict(), args.savedir + '{}.ckpt'.format(args.model))
For transfer learning, we will want to load a pre-trained model using a checkpoint file supplied by Hexin, so we'll need to figure out how to load this model (with torch.load
?)
https://pytorch.org/docs/stable/generated/torch.load.html#torch.load