Vanishing Gradient Problem and How to Fix it
What is Vanishing Gradient Problem
Neural networks are trained using stochastic gradient descent. This involves first calculating the prediction error made by the model and using the error to estimate a gradient used to update each weight in the network so that less error is made next time. This error gradient is propagated backward through the network from the output layer to the input layer.
As the backpropagation algorithm advances downwards (or backward) from the output layer towards the input layer, the gradients often get smaller and smaller and approach zero, eventually leaving the weights of the initial or lower layers nearly unchanged. As a result, the gradient descent never converges to the optimum, known as the vanishing gradients problem.
Why does it happen?
Certain activation functions, like the sigmoid function, squishes the sample input space into a small input space between 0 and 1. Therefore, a significant change in the input of the sigmoid function will cause a slight change in the output. Hence, the derivative becomes small.
However, when 'n' hidden layers use an activation like the sigmoid function, 'n' small derivatives are multiplied. Thus, the gradient decreases exponentially as we propagate down to the initial layers.
How to fix it?
1. Use non-saturating activation function
because of the nature of the sigmoid activation function, it starts saturating for larger inputs (negative or positive) came out to be a major reason behind the vanishing of gradients, thus making it non-recommendable to use in the hidden layers of the network.
So to tackle the issue regarding the saturation of activation functions like sigmoid and tanh, we must use some other non-saturating functions like ReLU
and its alternatives.
2. Proper weight initialization
There are different ways to initialize weights, for example, Xavier / Glorot initialization, Kaiming initializer etc. Keras API has a default weight initializer for each type of layers. For example, see the available initializers for tf.keras in Keras doc.
You can get the weights of a layer like below:
1# tf.Keras
2model.layers[1].get_weights()
Using Xavier normal initializer with Keras:
1initializer = tf.keras.initializers.GlorotNormal()
2layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
This option does not guarantee that you will resolve these issues, but it makes your network more robust when combined with other methods. Residual networks.
3. Residual networks
If you are using Convolutional Neural Networks, for example, and you are suffering from vanishing / exploding gradients, it might make sense to move to a new architecture like ResNETs. Compared to other networks, these structures connect different layers, i.e., the so-called skip connections, acting as gradient highways, allowing the gradient to flow between the different layers unhindered.
4. Batch normalization (BN)
BN layers can also resolve the issue. As stated before, the problem arises when a large input space is mapped to a small one, causing the derivatives to disappear. Batch normalization reduces this problem by simply normalizing the input, so it doesn’t reach the outer edges of the sigmoid function. Example of using BN with TensorFlow
1from keras.layers.normalization import BatchNormalization
2
3# instantiate model
4model = Sequential()
5
6# The general use case is to use BN between the linear and non-linear layers in your network,
7# because it normalizes the input to your activation function,
8# though, it has some considerable debate about whether BN should be applied before
9# non-linearity of current layer or works best after the activation function.
10
11model.add(Dense(64, input_dim=14, init='uniform')) # linear layer
12model.add(BatchNormalization()) # BN
13model.add(Activation('tanh')) # non-linear layer
Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1.
Reference:
- https://datascience.stackexchange.com/a/72352/136830
- https://towardsdatascience.com/the-vanishing-gradient-problem-69bf08b15484
- https://keras.io/api/layers/initializers/#usage-of-initializers
- https://machinelearningmastery.com/how-to-fix-vanishing-gradients-using-the-rectified-linear-activation-function/