fauxneticien / E2E-language-diarization-transfer

Source code of paper <End-to-End Language Diarization for Bilingual Code-switching Speech>

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Modify training script to load pre-trained model

fauxneticien opened this issue · comments

Note in train_xsa_e2e.py a new model is randomly initialised:

    model = X_Transformer_E2E_LID(n_lang=args.lang,
                                dropout=0.1,
                                input_dim=args.dim,
                                feat_dim=256,
                                n_heads=4,
                                d_k=256,
                                d_v=256,
                                d_ff=2048,
                                max_seq_len=args.maxlength,
                                device=device)

And a model is saved to *.ckpt file:

torch.save(model.state_dict(), args.savedir + '{}.ckpt'.format(args.model))

For transfer learning, we will want to load a pre-trained model using a checkpoint file supplied by Hexin, so we'll need to figure out how to load this model (with torch.load?)

https://pytorch.org/docs/stable/generated/torch.load.html#torch.load