justinchiu / hmmlm-jax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

hmmlm (Jax version)

Dependencies

pip install -r requirements.txt
bash install_jax_gpu.sh

For testing with pytorch, install pytorch and TVM.

Pytorch

WANDB_MODE=dryrun python main.py --lr 0.01 --model factoredhmm --assignment brown --states_per_word 256 --train_spw 128 --num_clusters 128 --num_classes 32768 --bsz 16 --eval_bsz 16 --bptt 32 --dataset ptb --iterator bptt --reset_eos 1 --no_shuffle_train 0 --optimizer adamw --state fac --tw slIlrIrp

About


Languages

Language:Python 99.6%Language:Shell 0.4%