Building and training siamese network with triplet loss using Keras with Tensorflow 2.0
Building and training siamese network with triplet loss using Keras with Tensorflow 2.0
This appraoch is taken from the popular FaceNet paper.
We have a CNN model called EmbeddingModel
:
The basic idea in Siamese Networks is that we find similarity between inputs. We have a sample, a positive sample and a negative sample.
When we teach someone how to identify an image, the idea is to show them a similar image and an image which is different and show them this is how they are supposed to differentiate.
“A Siamese Neural Network is a class of neural network architectures that contain two or more identical subnetworks. ‘identical’ here means, they have the same configuration with the same parameters and weights. Parameter updating is mirrored across both sub-networks.”
These networks are generally used in verification systems.
We use three images for each training example:
person1_image1.jpg
(Anchor Example, represented below in green)person1_image2.jpg
(Positive Example, in blue)person2_image1.jpg
(Negative Example, in red).All the three images of an example pass through the model, and we get the three Embeddings: One for the Anchor Example, one for the Positive Example, and one for the Negative Example.
The three instances of the EmbeddingModel
shown above are not different instances. It’s the same, shared model instance - i.e. the parameters are shared, and are updated for all the three paths simultaneously
If you need a deeper insight, refer to the articles in the reference section to read more.
A loss function that tries to pull the Embeddings of Anchor and Positive Examples closer, and tries to push the Embeddings of Anchor and Negative Examples away from each other.
It is explained with equations in the notebook
A step by step series of examples that tell you how to get a development env running
Download the prerequisites from the requirements.txt
file.
pip install -r requirements.txt
Open jupyter notebook
on your local host or you can use Google Colab too.
Follow the steps in the notebook if you want to train your own or you can simply run the notebook in this repo.
Importing the Libraries and Helper Functions.
Importing the Data (MNIST Dataset)
Reshaping and Normalizing the Examples.
Creating a function to plot triplets and generate triplet examples.
Creating an Embedding Model, this is a simple Neural Network.
Using the Embedding Model to create a Siamese Network.
Implementing the Triplet Loss function and the custom loss function.
Creating a small test set.
Compiling the Siamese Network with Triplet Loss.
Training the Siamese Network.