Este1le / nparam

Calculate the number of model parameters given hyper-parameter settings.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nparam

For neural machine translation models built through Amazon Sockeye, this documentation and the scripts show a way to calculate the number of model parameters by hand given hyper-parameter settings.

Here are some examples of model parameter count with respect to different hyper-parameter settings: param_table.md.

For the following discussion, we will consider RNN and Transformer models seperately.

Usage

cd nparam
sh get_nparam.sh -p model.hpm

rnn.hpm is an example for the RNN hyper-parameter file, and transformer.hpm is an example for the Transformer hyperparameter file.

With the hyper-parameter settings specified, this script will print the name of parameters with the shapes, the number of parameters in subnetworks and the total number of model parameters.

Notice that --bpe-symbols-src and --bpe-symbols-trg in get_nparam.sh may not equal to the real vocab size used in the training. If you want the real total number of model parameters, you can specify the --exact flag, and path to the BPE'ed training data (e.g. ./data-bpe/train.bpe-30000.de, ./data-bpe/train.bpe-30000.en) should also be specified (--train-bpe-src, --train-bpe-trg). Otherwise, without the --exact flag, an approximation of the number of parameters will be calculated based on --bpe-symbols-src and --bpe-symbols-trg.

RNN

1. Hyper-parameters

  • bpe_symbols: Number of BPE operations for source and target side of training data.
  • rnn_cell_type: RNN cell type for encoder and decoder, including gru and lstm.
  • num_layers: Number of layers for encoder and decoder.
  • num_embed: Embedding size for source and target tokens.
  • rnn_num_hidden: Number of RNN hidden units for encoder and decoder.

2. Parameters

The parameters (parameter name: shape) for a RNN model trained with bpe_symbols=50000:50000, rnn_cell_type=lstm, num_embed=512:512, rnn_num_hidden=512, num_layers=2:2 are shown below.

enc2decinit

decoder_rnn_enc2decinit_0_bias: (512,), 
decoder_rnn_enc2decinit_0_weight: (512, 512), 
decoder_rnn_enc2decinit_1_bias: (512,), 
decoder_rnn_enc2decinit_1_weight: (512, 512),
decoder_rnn_enc2decinit_2_bias: (512,), 
decoder_rnn_enc2decinit_2_weight: (512, 512), 
decoder_rnn_enc2decinit_3_bias: (512,), 
decoder_rnn_enc2decinit_3_weight: (512, 512) 

hidden

decoder_rnn_hidden_bias: (512,), 
decoder_rnn_hidden_weight: (512, 1024)

decoder_lx

decoder_rnn_l0_h2h_bias: (2048,), 
decoder_rnn_l0_h2h_weight: (2048, 512), 
decoder_rnn_l0_i2h_bias: (2048,), 
decoder_rnn_l0_i2h_weight: (2048, 1024), 
decoder_rnn_l1_h2h_bias: (2048,), 
decoder_rnn_l1_h2h_weight: (2048, 512), 
decoder_rnn_l1_i2h_bias: (2048,), 
decoder_rnn_l1_i2h_weight: (2048, 512) 

birnn

encoder_birnn_forward_l0_h2h_bias: (1024,), 
encoder_birnn_forward_l0_h2h_weight: (1024, 256), 
encoder_birnn_forward_l0_i2h_bias: (1024,), 
encoder_birnn_forward_l0_i2h_weight: (1024, 512), 
encoder_birnn_reverse_l0_h2h_bias: (1024,), 
encoder_birnn_reverse_l0_h2h_weight: (1024, 256), 
encoder_birnn_reverse_l0_i2h_bias: (1024,), 
encoder_birnn_reverse_l0_i2h_weight: (1024, 512) 

encoder_lx

encoder_rnn_l0_h2h_bias: (2048,), 
encoder_rnn_l0_h2h_weight: (2048, 512), 
encoder_rnn_l0_i2h_bias: (2048,), 
encoder_rnn_l0_i2h_weight: (2048, 512)

io

source_embed_weight: (49410, 512), 
target_embed_weight: (42767, 512), 
target_output_bias: (42767,), 
target_output_weight: (42767, 512)

The total number of parameters is calculated as follows:

total_num_params = sum([reduce(lambda s1,s2: s1*s2, shape) for param_name, shape in params.items()])

For the model above, the total number of parameters is 79638799.

3. Influence of Hyper-parameters on Parameters

Now let's see how the changes on each hyper-parameter reflect on the shape and number of parameter matrices.

Suppose bpe_symbols=sb:tb, num_layers=sn:tn, num_embed=se:te, rnn_num_hidden=h, where s stands for source/encoder, t stands for target/decoder. And s1, s2 are the length of the first and second dimension of the parameter matrix.

  • bpe_symbols

    io: For source_embed_weight, s1=sb. For target_..., s1=tb.

  • rnn_cell_type

    enc2decinit: Lstm has decoder_rnn_enc2decinit_x_..., where x=0,...,2*tn-1; while for gru, x=0,...,tn-1.

    decoder_lx, encoder_lx: For lstm, s1=4h; while for gru, s1=3h.

    birnn: For lstm, s1=2h; while for gru, s1=3h/2.

  • num_layers

    enc2decinit: Lstm has decoder_rnn_enc2decinit_x_bias/weight, where x=0,...,2*tn-1; while for gru, x=0,...,tn-1.

    decoder_lx: decoder_rnn_lx_..., where x=0,...,tn-1.

    encoder_lx: encoder_rnn_lx_..., where x=0,...,sn-2. Notice when sn=1, these parameters do not exist.

  • num_embed

    decoder_lx: For decoder_rnn_l0_i2h_weight, s2=h+se.

    birnn: For i2h_weight, s2=se.

    io: For source_embed_weight, s2=se and for target_embed_weight, s2=te.

  • rnn_num_hidden

    enc2decinit: s1=h. For weight parameters, s2=h .

    hidden: s1=h. For weight parameters, s2=2h .

    decoder_lx: For lstm, s1=4h; while for gru, s1=3h. For weight parameters, s2=h, except for decoder_rnn_l0_i2h_weight, where s2=h+se.

    birnn: For lstm, s1=2h; while for gru, s1=3h/2. For h2h_weight parameters, s2=h/2.

    encoder_lx: For lstm, s1=4h; while for gru, s1=3h. For weight parameters, s2=h.

    io: For target_output_weight parameters, s2=h.

4. Parameters w.r.t. Hyper-parameters

From previous section, we can get the equation for calculating the number of parameters based on hyper-parameter settings.

Suppose bpe_symbols=sb:tb, num_layers=sn:tn, num_embed=se:te, rnn_num_hidden=h.

enc2decinit

decoder_rnn_enc2decinit_x_bias: (h,), 
decoder_rnn_enc2decinit_x_weight: (h, h)

where x=0,...,2*tn-1 for lstm; x=0,...,tn-1 for gru.

The total number of enc2decinit parameters can be calculated as follows:

if rnn_cell_type == lstm:
	nparam_enc2decinit = 2*tn*h(h+1)
elif rnn_cell_type == gru:
	nparam_enc2decinit = tn*h(h+1)

hidden

decoder_rnn_hidden_bias: (h,), 
decoder_rnn_hidden_weight: (h, 2h)

The total number of hidden parameters can be calculated as follows:

nparam_hidden = h(2h+1)

decoder_lx

decoder_rnn_lx_h2h_bias: (y,), 
decoder_rnn_lx_h2h_weight: (y, h), 
decoder_rnn_lx_i2h_bias: (y,), 
decoder_rnn_lx_i2h_weight: (y, z)

where x=0,...,tn-1. 
	  y=4h for lstm; y=3h for gru. 
	  z=h+se for x=0; z=h for x!=0.

The total number of decode_lx parameters can be calculated as follows:

if rnn_cell_type == lstm:
	nparam_decoder_lx = 4h(se+2*tn*(h+1))
elif rnn_cell_type == gru:
	nparam_decoder_lx = 3h(se+2*tn*(h+1))

birnn

encoder_birnn_forward_l0_h2h_bias: (y,), 
encoder_birnn_forward_l0_h2h_weight: (y, h/2), 
encoder_birnn_forward_l0_i2h_bias: (y,), 
encoder_birnn_forward_l0_i2h_weight: (y, se), 
encoder_birnn_reverse_l0_h2h_bias: (y,), 
encoder_birnn_reverse_l0_h2h_weight: (y, h/2), 
encoder_birnn_reverse_l0_i2h_bias: (y,), 
encoder_birnn_reverse_l0_i2h_weight: (y, se) 

where y=2h for lstm; y=3/2h for gru.

The total number of birnn parameters can be calculated as follows:

if rnn_cell_type == lstm:
	nparam_birnn = 2h(4+h+2*se)
elif rnn_cell_type == gru:
	nparam_birnn = 3/2h(4+h+2*se)

encoder_lx

encoder_rnn_lx_h2h_bias: (y,), 
encoder_rnn_lx_h2h_weight: (y, h), 
encoder_rnn_lx_i2h_bias: (y,), 
encoder_rnn_lx_i2h_weight: (y, h)

where x=0,...,sn-2 (encoder_lx parameters do not exist for sn=1).
	  y=4h for lsrm; y=3h for gru.

The total number of encoder_lx parameters can be calculated as follows:

if rnn_cell_type == lstm:
	nparam_encoder_lx = 4h(sn-1)(2+2h)
elif rnn_cell_type == gru:
	nparam_encoder_lx = 3h(sn-1)(2+2h)

io

source_embed_weight: (sb, se), 
target_embed_weight: (tb, te), 
target_output_bias: (tb,), 
target_output_weight: (tb, h)

The total number of io parameters can be calculated as follows:

nparam_io = sb*se+tb*(1+te+h)

We now can get the total number of all the parameters for an RNN model:

nparam = nparam_enc2decinit + nparam_hidden + nparam_decoder_lx + nparam_birnn + nparam_encoder_lx + nparam_io

if rnn_cell_type == lstm:
	nparam = h*(-4*h+8*se+(8*sn+10*tn)(1+h)+1)+(sb*se+tb*(1+te+h))
elif rnn_cell_type == gru:
	nparam = h*(-2.5*h+6*se+(6*sn+7*tn)(1+h)+1)+(sb*se+tb*(1+te+h))

Transformer

1. Hyper-parameters

  • bpe_symbols: Number of BPE operations for source and target side of training data.
  • num_layers: Number of layers for encoder and decoder.
  • num_embed: Embedding size for source and target tokens.
  • transformer_feed_forward_num_hidden: Number of hidden units in transformers feed forward layers.

2. Parameters

The parameters (parameter name: shape) for a Transformer model trained with bpe_symbols=30000:30000, num_layers=1:1, num_embed=512,transformer_feed_forward_num_hidden=300 are shown below.

decoder_att

decoder_transformer_0_att_enc_h2o_weight: (512, 512), 
decoder_transformer_0_att_enc_k2h_weight: (512, 512), 
decoder_transformer_0_att_enc_pre_norm_beta: (512,), 
decoder_transformer_0_att_enc_pre_norm_gamma: (512,), 
decoder_transformer_0_att_enc_q2h_weight: (512, 512), 
decoder_transformer_0_att_enc_v2h_weight: (512, 512),
decoder_transformer_0_att_self_h2o_weight: (512, 512), 
decoder_transformer_0_att_self_i2h_weight: (1536, 512), 
decoder_transformer_0_att_self_pre_norm_beta: (512,), 
decoder_transformer_0_att_self_pre_norm_gamma: (512,)

decoder_ff

decoder_transformer_0_ff_h2o_bias: (512,), 
decoder_transformer_0_ff_h2o_weight: (512, 300), 
decoder_transformer_0_ff_i2h_bias: (300,), 
decoder_transformer_0_ff_i2h_weight: (300, 512), 
decoder_transformer_0_ff_pre_norm_beta: (512,), 
decoder_transformer_0_ff_pre_norm_gamma: (512,)

decoder_final

decoder_transformer_final_process_norm_beta: (512,), 
decoder_transformer_final_process_norm_gamma: (512,)

encoder_att

encoder_transformer_0_att_self_h2o_weight: (512, 512), 
encoder_transformer_0_att_self_i2h_weight: (1536, 512), 
encoder_transformer_0_att_self_pre_norm_beta: (512,), 
encoder_transformer_0_att_self_pre_norm_gamma: (512,)

encoder_ff

encoder_transformer_0_ff_h2o_bias: (512,), 
encoder_transformer_0_ff_h2o_weight: (512, 300), 
encoder_transformer_0_ff_i2h_bias: (300,), 
encoder_transformer_0_ff_i2h_weight: (300, 512), 
encoder_transformer_0_ff_pre_norm_beta: (512,), 
encoder_transformer_0_ff_pre_norm_gamma: (512,)

encoder_final

encoder_transformer_final_process_norm_beta: (512,),
encoder_transformer_final_process_norm_gamma: (512,)

io

source_embed_weight: (29624, 512), 
target_embed_weight: (28059, 512), 
target_output_bias: (28059,), 
target_output_weight: (28059, 512)

For the model above, the total number of parameters is 49181083.

3. Influence of Hyper-parameters on Parameters

Now let's see how the changes on each hyper-parameter reflect on the shape and number of parameter matrices.

Suppose bpe_symbols=sb:tb, num_layers=sn:tn, num_embed=e, transformer_feed_forward_num_hidden=f. And s1, s2 are the length of the first and second dimension of the parameter matrix.

  • bpe_symbols:

    io: For source_embed_weight, s1=sb. For target_..., s1=tb.

  • num_layers:

    decoder_att, decoder_ff: ..._transformer_x_..., where x=0,...,tn-1. , encoder_att, encoder_ff: ..._transformer_x_..., where x=0,...,sn-1.

  • num_embed:

    all. Please see section 4 below for more details.

  • transformer_feed_forward_num_hidden:

    decoder_ff, encoder_ff: For ..._i2h_... matrices, s1=f; for ..._h2o_weight matrices, s2=f.

4. Parameters w.r.t. Hyper-parameters

From previous section, we can get the equation for calculating the number of parameters based on hyper-parameter settings.

Suppose bpe_symbols=sb:tb, num_layers=sn:tn, num_embed=e, transformer_feed_forward_num_hidden=f.

decoder_att

decoder_transformer_x_att_enc_h2o_weight: (e, e), 
decoder_transformer_x_att_enc_k2h_weight: (e, e), 
decoder_transformer_x_att_enc_pre_norm_beta: (e,), 
decoder_transformer_x_att_enc_pre_norm_gamma: (e,), 
decoder_transformer_x_att_enc_q2h_weight: (e, e), 
decoder_transformer_x_att_enc_v2h_weight: (e, e),
decoder_transformer_x_att_self_h2o_weight: (e, e), 
decoder_transformer_x_att_self_i2h_weight: (3*e, e), 
decoder_transformer_x_att_self_pre_norm_beta: (e,), 
decoder_transformer_x_att_self_pre_norm_gamma: (e,)

where x=0,...,tn-1.

The total number of decoder_att parameters can be calculated as follows:

nparam_decoder_att = tn*4e*(2e+1)

decoder_ff

decoder_transformer_x_ff_h2o_bias: (e,), 
decoder_transformer_x_ff_h2o_weight: (e, f), 
decoder_transformer_x_ff_i2h_bias: (f,), 
decoder_transformer_x_ff_i2h_weight: (f, e), 
decoder_transformer_x_ff_pre_norm_beta: (e,), 
decoder_transformer_x_ff_pre_norm_gamma: (e,)

where x=0,...,tn-1.

The total number of decoder_ff parameters can be calculated as follows:

nparam_decoder_ff = tn*(2ef+3e+f)

decoder_final

decoder_transformer_final_process_norm_beta: (e,), 
decoder_transformer_final_process_norm_gamma: (e,)

The total number of decoder_final parameters can be calculated as follows:

nparam_decoder_final = 2e

encoder_att

encoder_transformer_x_att_self_h2o_weight: (e, e), 
encoder_transformer_x_att_self_i2h_weight: (3*e, e), 
encoder_transformer_x_att_self_pre_norm_beta: (e,), 
encoder_transformer_x_att_self_pre_norm_gamma: (e,)

where x=0,...,sn-1.

The total number of encoder_att parameters can be calculated as follows:

nparam_encoder_att = sn*2e*(2e+1)

encoder_ff

encoder_transformer_x_ff_h2o_bias: (e,), 
encoder_transformer_x_ff_h2o_weight: (e, f), 
encoder_transformer_x_ff_i2h_bias: (f,), 
encoder_transformer_x_ff_i2h_weight: (f, e), 
encoder_transformer_x_ff_pre_norm_beta: (e,), 
encoder_transformer_x_ff_pre_norm_gamma: (e,)

where x=0,...,sn-1.

The total number of encoder_ff parameters can be calculated as follows:

nparam_encoder_ff = sn*(2ef+3e+f)

encoder_final

encoder_transformer_final_process_norm_beta: (e,),
encoder_transformer_final_process_norm_gamma: (e,)

The total number of encoder_final parameters can be calculated as follows:

nparam_encoder_final = 2e

io

source_embed_weight: (sb, e), 
target_embed_weight: (tb, e), 
target_output_bias: (tb,), 
target_output_weight: (tb, e)

The total number of io parameters can be calculated as follows:

nparam_io = sb*e+tb*(2e+1)

We now can get the total number of all the parameters for a Transformer model:

nparam = nparam_decoder_att + nparam_decoder_ff + nparam_decoder_final + nparam_encoder_att + nparam_encoder_ff + nparam_encoder_final + nparam_io

nparam = tn*(8e*e+7e+2ef+f)+sn*(4e*e+5e+2ef+f)+4e+(sb*e+tb*(2e+1))

About

Calculate the number of model parameters given hyper-parameter settings.


Languages

Language:Jupyter Notebook 63.6%Language:Python 33.8%Language:Shell 2.7%