I would like to apply layer normalization to a recurrent neural network using tf.keras. In TensorFlow 2.0, there is a LayerNormalization class in tf.layers.experimental, but it's unclear how to use it within a recurrent layer like LSTM, at each time step (as it was designed to be used). Should I create a custom cell, or is there a simpler way?

For example, applying dropout at each time step is as easy as setting the recurrent_dropout argument when creating an LSTM layer, but there is no recurrent_layer_normalization argument.

2 Answers 2


You can create a custom cell by inheriting from the SimpleRNNCell class, like this:

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.activations import get as get_activation
from tensorflow.keras.layers import SimpleRNNCell, RNN, Layer
from tensorflow.keras.layers.experimental import LayerNormalization

class SimpleRNNCellWithLayerNorm(SimpleRNNCell):
    def __init__(self, units, **kwargs):
        self.activation = get_activation(kwargs.get("activation", "tanh"))
        kwargs["activation"] = None
        super().__init__(units, **kwargs)
        self.layer_norm = LayerNormalization()
    def call(self, inputs, states):
        outputs, new_states = super().call(inputs, states)
        norm_out = self.activation(self.layer_norm(outputs))
        return norm_out, [norm_out]

This implementation runs a regular SimpleRNN cell for one step without any activation, then it applies layer norm to the resulting output, then it applies the activation. Then you can use it like that:

model = Sequential([
    RNN(SimpleRNNCellWithLayerNorm(20), return_sequences=True,
        input_shape=[None, 20]),

model.compile(loss="mse", optimizer="sgd")
X_train = np.random.randn(100, 50, 20)
Y_train = np.random.randn(100, 5)
history = model.fit(X_train, Y_train, epochs=2)

For GRU and LSTM cells, people generally apply layer norm on the gates (after the linear combination of the inputs and states, and before the sigmoid activation), so it's a bit trickier to implement. Alternatively, you can probably get good results by just applying layer norm before applying activation and recurrent_activation, which would be easier to implement.


In tensorflow addons, there's a pre-built LayerNormLSTMCell out of the box.

See this doc for more details. You may have to install tensorflow-addons before you can import this cell.

pip install tensorflow-addons

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.