Deep Meta-Learning: Learning to Learn in the Concept Space

Now we’ll see how to learn to learn in the concept space using deep meta learning. First, how do we perform meta learning? We sample a batch of related tasks and some k data points in each task and train our meta learner. Instead of just training using our vanilla meta learning technique, we can combine the power of deep learning with meta learning. So, when we sample a batch of tasks and some k data points in each task, we learn the representations of each of the k data points using deep neural networks and then we’ll perform meta learning on those representations.

Our framework consists of three components:

  • Concept generator
  • Concept discriminator
  • Meta learner

The role of the concept generator is to extract the feature representations of each of the data points in our dataset, capturing its high-level concept, and the role of the concept discriminator is to recognize and classify the concepts generated by the concept generator, while the meta learner learns on the concepts generated by the concept generator. All of the previous components—that is, the concept generator, concept discriminator, and meta learner—learn together. So, we improve the vanilla meta learning by integrating meta learning with deep learning. Our concept generator evolves with new incoming data so we can view our framework as a lifelong learning system.

But what’s really going on here? Look at the following diagram; as you can see, we sample a set of tasks and feed them to our concept generator, which learns the concepts—that is, embeddings—and then feeds those concepts to the meta learner, which learns on these concepts and sends the loss back to the concept generator. Meanwhile, we also feed some external dataset to the concept generator, which learns the concepts for those inputs and sends those concepts to the concept discriminator. The concept discriminator predicts the labels for those concepts, calculates the loss, and sends the loss back to the concept generator. By doing so, we enhance our concept generator’s ability to generalize concepts:

Deep Meta-Learning: Learning to Learn in the Concept Space
Deep Meta-Learning: Learning to Learn in the Concept Space

But still, why are we doing this? Instead of performing meta learning on a raw dataset, we perform meta learning in the concept space. How do we learn these concepts? These concepts are generated by the concept generator by learning the embeddings of the input. So, we train the concept generator and meta learner on various related tasks; along with this, we improve the concept generator through the concept discriminator by feeding an external dataset to the concept generator so that it can learn the concepts better. This joint training process allows our concept generator to learn various concepts and perform better on related tasks; we feed the external dataset only to enhance the performance of our concept generator, which learns continuously when we feed a new set of inputs. So, it’s a lifelong learning system.

Key components

Now let’s see each of our components in detail.

Concept generator

As we know, a concept generator is used to extract features. We can use deep neural nets parameterized by some parameter \theta_G for generating the concepts. For examples, our concept generator can be a CNN if our input is an image.

Concept discriminator

It’s basically a classifier and is used to predict the labels for the concepts generated by the concept generator. So it can be any of the supervised learning algorithms such as SVM and decision trees, parameterized by \theta_D .

Meta learner

Our meta learner can be any meta learning algorithm, say MAML (model agnostic meta learning), Meta-SGD, or Reptile, parameterized by \theta_M .

Loss function

We use two sets of loss functions here:

  • Concept discrimination loss
  • Meta learning loss

Concept discrimination loss

We sample some data points (x,y) from our dataset D , feed them to the concept generator which learns the concepts and sends them to the concept discriminator, which tries to predict the classes for those concepts. So, the concept discriminator loss implies how good our concept discriminator is at predicting the classes and it can be represented as follows:

L_{(x, y)}\left(\theta_{D}, \theta_{G}\right)

Our loss function can be any loss function according to our task. For example, it can be a cross entropy loss if we’re performing classification tasks.

Meta learning loss

We sample some batch of tasks from the task distributions, learn their concepts via the concept generator, perform meta learning on those concepts, and then we compute the meta learning loss:

L_{T}\left(\theta_{M}, \theta_{G}\right)

Our meta learning loss varies depending upon what meta learner we use, such as MAML or Reptile.

Our final loss function is a combination of both of these, concept discrimination and meta learning loss:

loss =L_{T}\left(\theta_{M}, \theta_{G}\right)+\lambda L_{(x, y)}\left(\theta_{D}, \theta_{G}\right)

In the previous equation, lambda is a hyperparameter balancing between meta learning and concept discrimination loss. So, our objective becomes finding the optimal parameter that minimizes this loss:

\min {\theta{D}, \theta_{M}, \theta_{G}} \mathbb{E}{T \sim p(T),(x, y) \sim \mathbb{D}} J\left[L{T}\left(\theta_{M}, \theta_{G}\right), L_{(x, y)}\left(\theta_{D}, \theta_{G}\right)\right) ]

We minimize the loss by calculating gradients and update our model parameters:

\left(\theta_{D}, \theta_{M}, \theta_{G}\right)=\left(\theta_{D}, \theta_{M}, \theta_{G}\right)-\beta \nabla\left[J\left(L_{T}\left(\theta_{M}, \theta_{G}\right), L_{(x, y)}\left(\theta_{D}, \theta_{G}\right)\right)\right.


Now we’ll see how our algorithm works step by step:

Step 1:

Let us say we have a model f parameterized by a parameter \theta and we have a distribution over tasks p(T) . First, we randomly initialize the model parameter \theta .

Step 2:

We sample a batch of tasks from the task distributions and learn their concepts via the concept generator, perform meta learning on those concepts, and then compute the meta learning loss:

L_{T}\left(\theta_{M}, \theta_{G}\right)

Step 3: (inner loop)

We sample some data points (x,y) from our external dataset D , feed them to the concept generator to learn their concept, feed those concepts to the concept discriminator, which classifies them, and then we compute the concept discrimination loss:

L_{(x, y)}\left(\theta_{D}, \theta_{G}\right) )

Step 4: (outer loop)

We combine both of these losses and try to minimize the loss using SGD and get the updated model parameters:

\left(\theta_{D}, \theta_{M}, {\theta}{G}\right)=\left(\theta{D}, \theta_{M}, \theta_{G}\right)-\beta \nabla\left[J\left(L_{T}\left(\theta_{M}, \theta_{G}\right), L_{(x, y)}\left(\theta_{D}, \theta_{G}\right)\right)\right.

Step 5:

we repeat steps 2 to 5 for some n number of iterations.


Leave a Reply