Vanishing Gradient is a common problem in training deep neural networks, especially in very deep architectures. It makes it difficult for the model to learn from data during training.
What is Vanishing Gradient?
In deep learning, training happens through a method called backpropagation, where the model adjusts its weights using gradients (a kind of slope) of the loss function with respect to each weight. These gradients tell the model how much to change each weight to improve performance.
However, in deep neural networks (many layers), the gradients can get very small as they are propagated backward through the layers. This is called vanishing gradient.
As a result:
- Early layers (closer to the input) receive almost no updates.
- The network stops learning or learns very slowly.
When Does Vanishing Gradient Happen?
- Very Deep Networks: The more layers, the more chance gradients will shrink as they go back.
- Activation Functions:
- Sigmoid or tanh squish inputs into small ranges (e.g., between 0 and 1 for sigmoid).
- Their derivatives are also small.
- Multiplying many small numbers together makes them even smaller.
Example:
Let’s say you have 10 layers, and the gradient at each layer is around 0.5. After 10 layers:
Final gradient = 0.5^10 = 0.000976
This tiny number means almost no learning for the early layers.
How to Solve Vanishing Gradient?
1. Use ReLU Activation (or variants)
- ReLU (Rectified Linear Unit): f(x) = max(0, x)
- Its derivative is either 1 (for positive inputs) or 0 (for negative).
- No shrinking happens like with sigmoid.
Example in PyTorch:
Before:
x = torch.sigmoid(linear(x))
After:
x = torch.relu(linear(x))
Other ReLU variants: LeakyReLU, ELU, GELU(often used in transformers).
2. Use Batch Normalization
- This technique normalizes the inputs to each layer.
- Helps keep gradients in a stable range.
- Often improves both speed and performance.
Example in PyTorch:
nn.Sequential( nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU() )
3. Use Proper Weight Initialization
Some methods set initial weights in a way that keeps gradient flow stable.
- Xavier (Glorot) Initialization: Good for tanh
- He Initialization: Good for ReLU
Example in PyTorch:
torch.nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
4. Use Residual Connections (Skip Connections)
Used in ResNet and similar architectures.
- Skip connections let gradients flow directly across layers.
- Solves vanishing (and exploding) gradient issues.
Concept: Instead of computing:
x = F(x)
do:
x = x + F(x)
Example in PyTorch:
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
def forward(self, x):
return x + F.relu(self.linear(x))
5. Use Shorter Networks or Pretrained Models
If you don't need a very deep network:
- Use fewer layers.
- Or use a pretrained model (like ResNet or BERT) that has already solved this issue.
Reference
- Y. Bengio et al., "Learning long-term dependencies with gradient descent is difficult", IEEE Transactions on Neural Networks (1994)
- K. He et al., "Deep Residual Learning for Image Recognition", CVPR (2016)
- S. Ioffe and C. Szegedy, "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift", ICML (2015)
- PyTorch Official Document: https://pytorch.org/docs/stable/nn.html
Comments
Post a Comment