cxxixi / Precipitation-Nowcasting

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Notes on ConvLSTM

cxxixi opened this issue · comments

commented

Prescription

Since I was working on implementing the ConvLSTM model in a precipitation estimation project and there were a couple of confusing points coming up when reading the original paper by Shi et al. and the code based on the paper. Here I present some notes emulating the underlying principles behind the code and how the code illustrates those points provided by Shi et al.

import tensorflow as tf

# A new class inherited from tf.nn.rnn_cell.RNNCell
class ConvLSTMCell(tf.nn.rnn_cell.RNNCell):
    
    def __init__(self,shape,filters,kernel,forget_bias=1.0,activation=tf.tanh,normalize=True, peehole=True, data_format='channel_last', reuse=None):
        
        super(ConvLSTMCell,self).__init__(_reuse=reuse) #???
        
        self._kernel = kernel
        self._filters = filters
        self._forget_bias = forget_bias
        self._activation = activation
        self._normalize = normalize
        self._peehole = peehole  # whether the previous layers' parameters are accessible 
   
        if data_format == 'channel_last':
# set the _size of the tensor as [spatial shape]+[num_filters].e.g, if every single input is 64*64 image, and the number of filters is 4, then the _size is [64,64,4]
            self._size = tf.TensorShape(shape + [self._filters])
# ndims return the rank of the tensor or the dimension of the rank. E.g, if it's a 3D tensor, the method will return 3.
            self._feature_axis = self._size.ndims 
            self._data_format = None`

        elif data_format == 'channel_first':
            self._size = tf.TensorShape(shape + [self._filters])
            self._feature_axis = 0
            self._data_format = 'NC'
        else:
            raise ValueError("Unknown data fromat")

According to the official documents
data_format : A string or None.
Specifies whether the channel dimension of the input and output is the last dimension (default, or if data_format does not start with "NC"), or the second dimension (if data_format starts with "NC"). For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For N=3, the valid values are "NDHWC" (default) and "NCDHW".
Returns:
A Tensor with the same type as input of shape
[batch_size] + output_spatial_shape + [out_channels] if data_format is None or does not start with "NC", or [batch_size, out_channels] + output_spatial_shape

# @property 
# override the properties inherited from the parent class(RNNCell). 
    def state_size(self):
        return tf.nn.rnn_cell.LSTMStateTuple(self._size,self._size)
    
    def output_size(self):
        return self._size
## rewrite the main method -- call 
    def call(self, x, state):#state, x 哪里来的
        
        c, h = state  # state is a tuple; c is the hidden state, h is the output of a whole bunches of cell operations
        x = tf.concat([x,h],axis=self._feature_axis)  
        n = x.shape[-1].value       # n: num_input_channels        
        m = 4* tf._filters if tf._filter>1 else 4     # m:num_output_channels;  since there are four state contributing to forming the new hidden state, we multiply the number of filters by 4.
        W = tf.get_variable('kernel',self._kernel+[n,m]) # here shape = [3,3,input channels, output_channels]
        # compute the sum of N -d comvolution, see more here https://www.tensorflow.org/versions/master/api_docs/python/tf/nn/convolution
        #  x: input, W: filters
        y = tf.nn.convolution(x,W,'SAME',data_format = self._data_format)

5
For f gate, input gate and output gate, you can see they have the similar structure and both take in the X and H(t-1) which is the previous hidden state, therefore, the author concates these two items and present it as a new X
Notice that tf.nn..convolution is the major change Shi et al. made to original LSTM model. This operation illustrates the main point of capturing both temporal and spatial information, which is proposed by Shi et al.
The only difference between the original LSTM and convLSTM has been demonstrated in the following picture.
6
Any operations between Weights W and input [X,H(t-1)] in the FC_LSTM have been altered as convolutional operations.

       # normalization 
        if not self._normalize:
            y += tf.get_variable("bias",[m],initializer=tf.zeros_initializer())##zero initializer

        # Splits a tensor into sub tensors.
        # the shape of y is [batch_size, out_channels]+ output_spatical_shape, therefore, were gonna split output_channels into four equal parts using the feature_axis which has been declared previously.
        j,i,f,o = tf.split(y, 4, axis=self._feature_axis) 
        #j: input contribution(hidden state); i: input_gate; f:forget_gate; o:output_gate 
        if self._peehole:
            i += tf.get_variable('W_ci',c.shape[1:])*c  # c: C(t-1), the previous cell state. 
            f += tf.get_variable('W_fi',c.shape[1:])*c  
            # c.shape[0] is the batch size dimension

If peehole is true, we can access to the previous cell state C(t-1).
Here, i and f update themselves by adding the corresponding item W_ci/W_fi * c
1

        # Adds a  Layer Normalization layer.
        if self._normalize:
            j = tf.contrib.layers.layer_norm(j)
            i = tf.contrib.layers.layer_norm(i)
            f = tf.contrib.layers.layer_norm(f)  # see more https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/layers/layer_norm

        f = tf.sigmoid(f+self._forget_bias) 
        i = tf.sigmoid(i)
        c = c*f + i*self._activation(j) 
        
        if self._peehole:
            o += tf.get_variable('W_oi',c.shape[1:])*c
        
        if self._normalize:
            o = tf.contrib.layers.layer_norm(o)
            c = tf.contrib.layers.layer_norm(c)
            
        o = tf.sigmoid(o)
        h = o*self._activation(c)
        
        state = tf.nn.rnn_cell.LSTMStateTuple(c,h)
        return h, state #output is the hidden state, not cell state

7

Reference

  1. Convolution LSTM Network: A Machine Learning Approach for Precipitation Nowcasting
  2. ConvLSTM-github
  3. LSTM introducton