Written by Hendrik. Jul 27, 2021
22
7 min read
This semester, I took part in introductory courses for Artificial Intelligence and Natural Language Processing. Machine Learning, without having any prior knowledge of it, was used at every corner. Unfortunately, though, it was only ever explained very vaguely and we never looked into it very much. It's been something of a black box all semester. Merely a side note. However, this whole machine learning thing seems interesting and it's made me curious to learn more about it.
Pedro Domingos, author of The Master Algorithm, claims that overfitting is the most important problem in machine learning. That's quite a bold claim. We had only discussed overfitting once during the semester, so it didn't seem like that big of a deal. So, what's all the fuss about? What is overfitting, and how can it be avoided?
One of the three main types of learning for machines is called supervised learning. This means that the machine is fed a large set of labeled data (labeled meaning that it is known what this data resembles). This trains the model. The machine learns to make predictions as to what the data resembles and because the data is labeled (and thus the outcome is known), it can use that to fit the model's parameters until the prediction sufficiently resembles the actual result. A side effect of this is that the more data is used in training, the better the model will be at making predictions.
The goal of learning is to be able to take unlabeled data, where the outcome may not be known, and make accurate predictions as to what it resembles. Meaning, the machine has to be able to generalize based on the data it has seen.
As humans, it is easy to make generalizations. For example, a car salesman may lie to you in order to make a sale. From this, you might infer that all salesmen lie in order to sell you something. This is generalization, which computers have a hard time being able to do.
Since the machine is now viewing never-before-seen data, it must thus react to the new data by generalizing. How well the machine can do so determines how successful the model is.
Overfitting simple means that the machine is not able to make generalizations. This can happen because the model has memorized the training data rather than learn to generalize based on trends. So, it sees patterns in the data that are not actually there. It models the training data too closely – details and noise included.
A telltale sign that a model has been overfit is if it was able to make accurate predictions during the training phase with labeled data but unable to so when shown unlabeled data.
Models that are learning walk a fine line between being blind and hallucinating. It is easy for a model attempting to learn to be restricted and thus not identify any patterns. It is also easy for a model to be overly complex and recognize patters in data that do not exist, thus hallucinating.
Pedro Domingos describes the following:
"A good learner is forever walking the narrow path between blindness and hallucination."
This is the central problem in machine learning. It is difficult to not fall outside this narrow path.
Restricting what the model can learn is the only safe way to avoid overfitting. An example of this would be only allowing it to learn short, conjunctive concepts that relate to one another.
There are a few more ways to avoid overfitting:
Thanks for reading my post. I’d love to get feedback from you, so feel free to shoot me a tweet!
- Hendrik