Model Agnostic Meta Learning (MAML) is one of the recently introduced and most popularly used meta learning algorithms and It has created a major breakthrough in the meta learning research. Learning to learn is the key focus of meta learning and we know that in meta learning, we learn from a various number of related tasks containing only a small number of data points and the meta learner produces a quick learner that can generalize well on a new related task even with a lesser number of training samples.
The basic idea of MAML is to find the better initial parameters so that with good initial parameters the model can learn quickly on new tasks with a lesser gradient steps.
So, what do we mean by that?
Let us say we are performing a classification task using a neural network. How do we train the network? We will start off with initializing random weights and train the network by minimizing the loss. How do we minimize the loss? We minimize the loss using gradient descent.
Okay, but how do we use gradient descent for minimizing the loss? We use gradient descent for finding the optimal weights that will give us the minimal loss. We take multiple gradient steps to find the optimal weights so that we can reach the convergence. In MAML we try to find this optimal weights by learning from the distribution of similar tasks. So for a new task, we don’t have to start with randomly initialized weights instead, we can start with optimal weights which will take lesser gradient steps to reach convergence and also it doesn’t require more data points for training.
Let us understand MAML in simple terms, let us say we have three related tasks T_1, T_2, and T_3. First, We randomly initialize our model parameter \theta. We train our network on task T_1. Then we try to minimize the loss L by gradient descent. We minimize the loss by finding the optimal parameter \theta_1'.
Similarly for tasks T_2 T_3 we will start off with a randomly initialized model parameter \theta, minimize the loss by finding the right set of parameters by gradient descent. Let us say \theta_2' and \theta_3' are the optimal parameters for the tasks T_2 and T_3 respectively. As you can see in the below figure, we start off each task with randomly initialized parameter \theta and minimize the loss by finding the optimal parameters \theta_1', \theta_2' and \theta_3' for each of the tasks T_1, T_2 and T_3 respectively.
But instead of initializing \theta in a random position i.e with random values if we initialize \theta in a position that is common to all three tasks, we don’t need to take more gradient steps and also it will take us less time for training. MAML exactly tries to do this. MAML tries to find this optimal parameter \theta that is common to many of the related tasks and so we can train a new task relatively quick with few data points without having to take many gradient steps by initializing \theta in an optimal position.
As shown in the below figure, we shift \theta to a position that is common to all different optimal \theta' values.
So, for a new related task say T_4 we don’t have to start with a randomly initialized parameter \theta. Instead, we can start with the optimal \theta value so that it will take lesser gradient steps to attain convergence.
So, in MAML, we try to find this optimal \theta value that is common to related tasks so that will help us in learning from fewer data points and minimizing our training time. MAML is model agnostic meaning that we can apply MAML for any models that are trainable with gradient descent. But how exactly MAML works? How do we shift the model parameters to an optimal position? We will explore in detail in the next blog post.
Got any questions? Feel free to ask me in the comments section below.