Prototypical Networks in few-shot learning

Prototypical networks are yet another simple, efficient and most popularly used few shot learning algorithm. Like siamese networks, it tries to learn the metric space to perform classification. The basic idea of the prototypical network is to create a prototypical representation of each class and classify a query point (new point) based on the distance between the class prototype and the query point. Now we will understand how prototypical networks work by going through them in detail.

Let us say we have support set comprising images of lion, elephant, and dog as shown below.

So we have three classes {lion, elephant, dog}. Now we need to create a prototypical representation for each of these three class. How can we build the prototype of these three classes? First, we will learn the embeddings of each data point using some embedding function.

The embedding function f_{\phi}() can be any function which can be used to extract features. Since our input is an image we can use the convolutional network as our embedding function which will extract features from the input image.

Once we learn the embeddings of each data point, we take the mean embeddings of data points in each class and form the class prototype as shown below. So, a class prototype is basically the mean embeddings of data points in a class.

Similarly, when a new data point comes in i.e query point for which we want to predict the label, we will generate the embeddings for this new data point using the same embedding function that we used to create the class prototype that is, we generate the embeddings for our query point using the convolutional network.

Once we have the embedding for our query point, we compare the distance between class prototype and query point embeddings to find which class the query point belongs to. We can use Euclidean distance as a distance measure for finding the distance between the class prototype and query points embeddings as shown here:

After finding the distance between the class prototype and query point embeddings, we apply softmax to this distance and get the probabilities. Since we have three classes that is, lion, elephant, and dog. We will get three probabilities. So the class which has high probability will be the class of our query point.

Since we want our network to learn from few data points that is since we want to perform few shot learning, we train our network in the same way. So we use episodic training that is, for each episode, we randomly sample a few data points from each of the class in our dataset and we call that as support set and we train the network using only the support set instead of the whole dataset. Similarly, we randomly sample a point from the dataset as a query point and try to predict its class. So in this way, our network learns how to learn from the smaller data points.

The overall flow of the prototypical network is shown in the following diagram. As you can see, first we will generate the embeddings for all the data points in our support set and build the class prototype by taking the mean embeddings of data points in a class. We also generate the embeddings for our query point. Then we compute the distance between the class prototype and query point embeddings. We use Euclidean distance as a distance measure. Then we apply softmax to this distance and get the probabilities. As you can see in the following figure since our query point is a lion, the probability for lion is high which is 0.9:

Prototypical networks are not only used for one shot/few shot learning, but it is also used in zero-shot learning. Consider the case where you have no data points per each class but you have the meta information containing a high-level description of each class. So in those cases, we learn the embeddings of meta information of each class to form the class prototype and then perform classification with the class prototype.

Got any question? Ask me in the comment section below.

Leave a Reply