Recurrent Batch Normalization

4 minute read



Previous works only apply batch normalization to the input-to-hidden transformation of RNNs, we demonstrate that it is both possible and beneficial to batch-normalize the hidden-to-hidden transition, thereby reducing internal covariate shift between time steps.

The Batch-Norm LSTM consistently leads to faster convergence and improved generalization on various sequential problems such as sequence classification, language modeling and question answering.

It’s well known that for deep feed-forward neural network, covariate shift degrades the efficiency of training. Covariate shift is a change in the distribution of the inputs to a model. This occurs continuously during training of feed-forward neural networks, where changing the parameters of a layer affects the distribution of the inputs to all layers above it. As a result, the upper layers are continually adapting to the shifting input distribution and unable to learn effectively. The internal covariate shift may play an especially important role in recurrent neural networks, which resemble very deep feed-forward networks.

Batch Normalization is a technique for controlling the distributions of feed-forward neural network activations, thereby reducing internal covariate shift. It involves standardizing the activations going into each layer, enforcing their means and variances to be invariant to changes in the parameters of underlying layers. This effectively decouples each layer’s parameters from those of other layers, leading to a better-conditioned optimizatoin problem. Indeed, dnn trained with bn converge significantly faster and generalize better.


It is proven to be difficult to apply in recurrent architectures (Lau-rent et al., 2016; Amodei et al., 2015). It has found limited use in stacked RNNs, where the nor-malization is applied “vertically”, i.e. to the input of each RNN, but not “horizontally” between timesteps.

The LSTM has an additional memory cell c_t whose update is nearly linear which allows the gradient to flow back through time more easily. the update of the LSTM cell is regulated by a set of gates. The forget gate ft determines the extent to which information is carried over from the previous timestep, and the input gate it controls the flow of information from the current input xt. The output gate ot allows the model to read from the cell. This carefully controlled interaction with the cell is what allows the LSTM to robustly retain information for long periods of time.

Batch Normalization

Batch Normalization (Ioffe & Szegedy, 2015) is a recently proposed network reparameterization which aims to reduce internal covariate shift. It does so by standardizing the activations using empirical estimates of their means and standard deviations.

At training time, the statistics E[h] and Var[h] are estimated by the sample mean and sample variance of the current minibatch. This allows for backpropagation through the statistics, preserving the convergence properties of stochastic gradient descent. During inference, the statistics are typically estimated based on the entire training set, so as to produce a deterministic prediction.

Batch-Normalized LSTM

The batch normalization transform relies on batch statistics to standardize the LSTM activations. It would seem natural to share the statistics that are used for normalization across time, just as recurrent neural networks share their parameters over time. However, we find that simply averaging statistics over time severely degrades performance. Although LSTM activations do converge to a stationary distribution, we observe that their statistics during the initial transient differ significantly (see Figure 5 in Appendix A). Consequently, we recommend using separate statistics for each timestep to preserve information of the initial transient phase in the activations.

For our experiments we estimate the population statistics separately for each timestep 1, . . . , Tmax where Tmax is the length of the longest training sequence. When at test time we need to generalize beyond Tmax, we use the population statistic of time Tmax for all time steps beyond it.

During training the author estimate the statistics across the minibatch, independently for each timestep. At test time they use estimates obtained by averaging the minibatch estimates over the training set.

Initializing gammar for Gradient FLow.

In fig 1.b, when the input standard deviation is low, the input tends to be close to the origin where the derivative is close to 1. As the standard deviation increases, the expected derivative decreases as the input is more likely to be in the saturation regime. At unit standard deviation, the expected derivative is much smaller than 1.

the Authors conjectured that this is what causes the gradient to vanish, and recommend initializing gamma to a small value.