Gradient agreement as an optimization objective for meta-learning

The gradient agreement algorithm is an interesting and recently introduced algorithm that acts as an enhancement to meta learning algorithms. In MAML (Model agnostic meta learning) and Reptile, we try to find a better model parameter that’s generalizable across several related tasks so that we can learn quickly with fewer data points. If we recollect what we’ve learned in the previous posts, we’ve seen that we randomly initialize the model parameter and then we sample a random batch of tasks, T_i from the task distribution, p(T) . For each of the sampled tasks, T_i , we minimize the loss by calculating gradients and we get the updated parameters, \theta_{i}^{\prime} , and that forms our inner loop:

\theta_{i}^{\prime}=\theta-\alpha \nabla_{\theta} L_{T_{i}}\left(f_{\theta}\right)

After calculating the optimal parameter for each of the sampled tasks, we perform meta optimization— that is, we perform meta optimization by calculating loss in a new set of tasks, we minimize loss by calculating gradients with respect to the optimal parameters \theta_{i}^{\prime} , which we obtained in the inner loop, and we update our initial model parameter \theta

\theta=\theta-\beta \nabla_{\theta} \sum_{T_{i} \sim p(T)} L_{T_{i}}\left(f_{\theta_{i}}\right)

What’s really going in the previous equation? If you closely examine this equation, you’ll notice that we’re merely taking an average of gradients across tasks and updating our model parameter \theta , which implies all tasks contribute equally in updating our model parameter.

But what’s wrong with this? Let’s say we’ve sampled four tasks, and three tasks have a gradient update in one direction, but one task has a gradient update in a direction that completely differs from the other tasks. This disagreement can have a serious impact on updating the model’s initial parameter since the gradient of all of the tasks contributes equally in updating the model parameter. As you can see in the following diagram, all tasks from T_1 to T_3 have a gradient in one direction but task T_4 has a gradient in a completely different direction compared to the other tasks:

gradient agreement as an optimization objective for meta-learning
gradient agreement as an optimization objective for meta-learning

So, what should we do now? How can we understand which task has a strong gradient agreement and which tasks have a strong disagreement? If we associate weights with the gradients, can we understand the importance? So, we rewrite our outer gradient update equation by adding the weights multiplied with each of the gradients as follows:

\theta=\theta-\beta \sum_{i} w_{i} \nabla L_{T_{i}}\left(f_{\theta_{i}}\right)

Okay, how do we calculate these weights? These weights are proportional to the inner product of the gradients of a task and an average of gradients of all of the tasks in the sampled batch of tasks. But what does this imply? It implies that if the gradient of a task is in the same direction as the average gradient of all tasks in a sampled batch of tasks, then we can increase its weights so that it’ll contribute more in updating our model parameter. Similarly, if the gradient of a task is in the direction that’s greatly different from the average gradient of all tasks in a sampled batch of tasks, then we can decrease its weights so that it’ll contribute less in updating our model parameter. We’ll see how exactly these weights are computed in the following section. We can not only apply our gradient agreement algorithm to MAML, but also to the Reptile algorithm. So, our Reptile update equation becomes as follows:

\theta=\theta+\alpha \sum_{i} w_{i}\left(\theta_{i}^{\prime}-\theta\right)

Weight calculation We’ve seen that, by associating weights with the gradients, we can understand which tasks have strong gradient agreement and which tasks have strong gradient disagreement. We know that these weights are proportional to the inner product of the gradients of a task and an average of gradients of all of the tasks in the sampled batch of tasks. How can we calculate these weights? The weights are calculated as follows:

w_{i}=\frac{\sum_{j \in T}\left(g_{i}^{T} g_{j}\right)}{\sum_{k \in T}\left|\sum_{j \in T}\left(g_{k}^{T} g^{j}\right)\right|}

Let’s say we sampled a batch of tasks. Then, for each task in a batch, we sample k data points, calculate loss, update the gradients, and find the optimal parameter \theta_{i}^{\prime} for each of the tasks. Along with this, we also store the gradient update vector of each task in g_i . It can be calculated as g_{i}=\theta-\theta_{i}^{\prime} .

So, the weights for an i^{th} task is the sum of the inner products g_i of g_j and divided by a normalization factor. The normalization factor is proportional to the inner product of g_i and g_{average} . Let’s better understand how these weights are calculated exactly by looking at the following code:

Now let us see step by step, how exactly MAML is used in supervised learning.

Step 1:

Let us say we have a model f parameterized by a parameter \theta and we have a distribution over tasks p(T) . First, we randomly initialize the model parameter \theta .

Step 2:

Now, we sample some batch of tasks T_i from a distribution of tasks i.e . T_i \sim p(T) Let us say we have sampled 2 tasks then T = {T_1, T_2 }

Step 3: (inner loop)

For each task (T_i) in tasks (T), we sample k data points and prepare our train and test dataset. i.e:

D_i^{train} = { (x_1,y_1),(x_2,y_2)…..(x_k,y_k) }

D_i^{test}= { (x_1,y_1),(x_2,y_2)…..(x_k,y_k) }

Wait. What are the train and test sets? So, we use the train set in the inner loop for finding the optimal parameters \theta_i' and test in the outer loop for finding the optimal parameter \theta. So our dataset basically contains x features and y labels.

Now we use any supervised learning algorithm on D_i^{train} and calculate loss and minimize the loss using gradient descent and get the optimal parameters \theta'_i

\theta'_i = \theta - \alpha \nabla{\theta} L_{T_i}(f_{\theta})

Along with this, we also store the gradient update vector as: g_{i}=\theta-\theta_{i}^{\prime}

So for each of the tasks, we sample k data points and minimize the loss on the train set D_i^{train} and get the optimal parameters \theta'_i. As we sampled three tasks we will have two optimal parameters { \theta_1', \theta_2' } and we’ll have a gradient update vector for each of these two tasks as g = { (\theta - \theta'_1), (\theta - \theta'_2) }

Step 4: (outer loop)

Now, before performing meta optimization, we’ll calculate the weights as follows:

w_{i}=\frac{\sum_{j \in T}\left(g_{i}^{T} g_{j}\right)}{\sum_{k \in T}\left|\sum_{j \in T}\left(g_{k}^{T} g^{j}\right)\right|}

After calculating the weights, we now perform meta optimization by associating the weights with the gradients. We minimize the loss in the test set D_i^{test} . We minimize the loss by by calculating gradients with respect to the parameters obtained in the previous step and multiply the gradients with the weights.

If our meta learning algorithm is MAML, then the update equation is as follows:

\theta=\theta-\beta \sum_{i} w_{i} \nabla L_{T_{i}}\left(f_{\theta_{i}}\right)

If our meta learning algorithm is Reptile, then the update equation is as follows:

\theta=\theta+\alpha \sum_{i} w_{i}\left(\theta^{\prime}-\theta\right)

\theta = \theta - \beta \nabla_{\theta} \sum_{T_i \sim p(T)} L_{T_i} (f_{\theta'_i})

Step 5:

we repeat steps 2 to 5 for some n number of iterations.


Leave a Reply