How to Prevent Overfitting in Machine Learning Models

Very deep neural networks with a massive number of parameters are very robust machine learning systems. But, in this type of huge network, overfitting is a common serious problem. Learning how to deal with overfitting is essential to mastering machine learning. The fundamental issue in machine learning is the tension between optimization and generalization. Optimization refers to adjusting a model to get the best performance possible on the training data (the learning in machine learning).

In contrast, generalization refers to how well the trained model performs on the data it has never seen before (test set). The goal of the game is to get a good generalization. But, you don’t control generalization; you can only adjust the model based on its training data.

How do you know whether a model is overfitting?

The clear sign of overfitting is when the model accuracy is high in the training set, but the accuracy drops significantly with new data or in the test set. This means the model knows the training data very well but can not generalize. This case makes your model useless in production or AB tests in most domains.

How to prevent overfitting?

Okay, now let's say you found that your model overfits. But what to do now to prevent your model from overfitting? Fortunately, there are many ways you can try to prevent your model from overfitting. Below I have described a few of the most widely used solutions for overfitting.

1. Reduce the network size

The simplest way to prevent overfitting is to reduce the model's size: the number of learnable parameters in the model (which is determined by the number of layers and units per layer).

2. Cross-Validation

In cross-validation, the initial training data is used as small train-test splits. Then, these splits are used to tune the model. The most popular form of cross-validation is K-fold cross-validation, and K represents the number of folds. Here is a short video from Udacity which explains K-fold cross-validation very well.

3. Add weight regularization

Given two explanations for something, the explanation most likely to be correct is the simplest one — the one that makes fewer assumptions. This idea also applies to the models learned by neural networks: given some training data and network architecture, multiple sets of weight values could explain the data. Simpler models are less likely to overfit than complex ones. A simple model in this context is a model where the distribution of parameter values has less entropy (or a model with fewer parameters). Thus a common way to mitigate overfitting is to put constraints on the complexity of a network by forcing its weights to take only small values, which makes the distribution of weight values more regular. This is called weight regularization, and it’s done by adding to the loss function of the network a cost associated with having large weights. This cost comes in two flavors:

L1 regularization — The cost added is proportional to the absolute value of the weight coefficients.

L2 regularization — The cost added is proportional to the square of the value of the weight coefficients. L2 regularization is also called weight decay in the context of neural networks.

4. Remove irrelevant features

Improve the data by removing irrelevant features. A dataset may contain many features that do not contribute much to the prediction, and removing those less important features can improve accuracy and reduce overfitting. You can use the scikit-learn feature selection module for this purpose.

5. Add dropout layer

Dropout, applied to a layer, consists of randomly dropping out(setting to zero) several output features of the layer during training. A given layer typically returns a vector [0.2, 0.5, 1.3, 0.8, 1.1] for a given input sample during training. After applying dropout, this vector will have a few zero entries distributed randomly: for example, [0, 0.5, 1.3, 0, 1.1].

6. Data Augmentation

The simplest way to reduce overfitting is to increase the training data size. Let's consider we are dealing with images. In this case, there are a few ways of improving the training data size — rotating the image, flipping, scaling, shifting, etc. This technique is known as data augmentation, which usually provides a giant leap in improving the model's accuracy.

This blog was originally published at Medium

  1. Deep Learning with Python — Book by François Chollet
  2. https://elitedatascience.com/overfitting-in-machine-learning
  3. https://www.quora.com/How-do-we-know-whether-a-model-is-overfitting
  4. https://www.analyticsvidhya.com/
comments powered by Disqus