Generalisation in Deep Learning
Generalisation in deep learning asks the question of “how well will my AI/deep learning model do in the wild?”, or, “how much did my model actually learn about the real world, rather than just about the data it was trained on?”
In real-life scenarios, what counts as “good” generalisation can be very contextual and hard to measure. In my view, this is for a few reasons:
- We may not have any way to measure performance once the AI is out in the world (deployed).
- Real-world data may look very different to the training data we have available.
- The “correct answer” may be subjective for certain types of task (e.g. creative writing, which contemporary AIs can at least attempt).
The term generalisation is itself quite a general term, so in practise it can be useful to categorise into a few different kinds of generalisation:
- In-distribution generalisation: possibly the simplest and easiest kind of generalisation, referring to when the target data distribution is similar to the training data.*
- Out-of-distribution generalisation: a catch-all term referring to instances where the target data is somehow different from the training data.
- Domain generalisation: also known as out-of-distribution generalisastion, where the target data distribution is different from the training data distribution, often resulting in a drop in accuracy.
- Compositional generalisation: where familiar concepts are combined in unfamiliar ways in the target distribution.
- Few-shot generalisation: where we have only a handful (generally fewer than 20) examples from the target data distribution.
- Zero-shot generalisation: an even harder version of few-shot generalisation, in which we have zero examples from the target data distribution. In this scenario we often depend on having some description or representation of the target data distribution.
There are certainly more types of generalisation than I’ve mentioned here, but this list is meant to give an outline of why generalisation is a large and complex topic. Splitting it up into different types like this can help machine learning/deep learning practitioners think about and solve one problem at a time.
Different types of generalisation are connected
Although I just said that generalisation can be split up into (at least) the five categories I mentioned, the reality is that all types of generalisation are in some way connected. Let me explain.
For example, we can draw a formal connection between zero-shot generalisation and domain generalisation: in both instances, we have little to no information about the target domain/data distribution, except that they don’t appear in the training data.
Secondly, compositional generalisation can be seen as a special kind of domain generalisation, in which the target domain is constructed from unseen combinations of concepts that already occur in the training domain. Similarly, zero-shot generalisation and few-shot generalisation are often studied at the same time in the context of general-purpose language models, since we can easily tack on a task description in addition to some training examples, then think of the zero-shot case as an extreme of the few-shot learning paradigm.
Lastly, we have the fact that fundamentally, all generalisation can be considered in-domain generalisation from some perspective: we can view all possible types of data as belonging to to one big happy distribution of distributions. This is why I think generalisation can be a difficult concept to define and grapple with.
Inductive biases
The concept of generalisation in deep learning comes hand-in-hand with the concept of inductive biases. Inductive biases are, in a way, the assumptions that are somehow encoded in a model’s architecture (or pre-training data) that determine the ways in which it will generalise to unseen data. Often these are useful, deliberate design choices that are intended to make a model more efficient in terms of data or model size. For example:
- The convolutional neural network (CNN) architecture encodes the inductive bias that “if I see a particular pattern/texture, I should extract a similar representation regardless of its relative position in the image”.
- Recurrent neural network (RNN) architectures encode the inductive bias that “all I need at each time step is the next word/token and a vector that represents everything that came before”.
- Transformers have the inductive bias that “information should flow between tokens (words) via the attention mechanism”, which turns out to be an extremely succesful inductive bias in the context of today’s AI language models.
In each case, inductive biases are both a blessing and a curse for deep learning models: on the one hand, applying “just the right kind” of inductive bias for the task at hand will greatly improve generalisation performance in the right circumstances. On the other hand, inductive biases are still biases: they fundamentally limit the way an AI model “sees” the world, which can harm out-of-distribution generalisation (and, in catastrophic cases, can actually correspond to real-world, harmful kinds of bias.)
In general, the larger our deep learning architecture is in terms of parameters, the weaker its inductive biases, i.e. larger models make fewer assumptions about the world. In order to make the most use of this, we have to compensate with a large amount of training data. This is a large part of how contemporary, commercially available AI models (such as chatbots) work in pratice.
One model to rule them all
I talked in my previous post about scaling laws and how we can define “general intelligence” by an AI model’s ability to adapt (generalise) to new situations. Hypothetically, if we had, say, a multi-modal language model that had mastered all the various kinds of generalisation, we’d likely find it capable of adapting to practically any type of task we threw at it, given sufficient examples/details of the task.
Actually achieving this in practise may not be as simple as just pouring more and more of the types of data we have (e.g. the internet) at larger and larger amounts of compute. Even if the fundamental methodology stays the same (train on next-token prediction with cross-entropy loss), I think at the very least, AI researchers will have to start (or have already started) to recognise the importance of the composition and diversity of one’s training dataset.
On top of ever-improving and more generally capable models, there is also a growing body of research on how to get even more performance out of existing models (e.g. by forcing them to think step-by-step, or even constructing a tree of possible “thoughts”). These approaches are very interesting and I’m excited to see where this direction leads in the near future.
Conclusion
I’ve given an overview of the problem of generalisation in deep learning, along with some of my thoughts about how the different types of generalisation are connected. I also gave some of my thoughts on what generalisation means in the extreme case of the ever-approaching artificial general intelligence.
I hope you enjoyed reading this post as much as I enjoyed researching and writing it!
Take care!
Jamie