Train your first Neural Network for Image Similarity Search
Traditional approach to supervised learning is to pick up a model with labelled X
to y
mapped data and beginner tutorials on Deep Learning generally take on analysis on MNIST dataset with a MLP model. While these approaches work perfectly for one-to-one labelled dataset (Supervised Learning), it is not as much effective for weakly supervised learning.
What is weakly supervised learning?
Weak supervision is a branch of machine learning where noisy, limited, imprecise sources or to a set of data points with supervision only at the tuple level (typically pairs, triplets, or quadruplets of data points) are used to provide supervision signal for labelling large amounts of training data in a supervised learning setting. This approach alleviates the burden of obtaining hand-labeled data sets, which can be costly or impractical.
There are three typical types of weak supervision: incomplete supervision when only a subset of training data are labelled; inexact supervision when the training data are given with labels but not as exact as desired and inaccurate supervision when in the training data there are some labels with mistakes.
In this tutorial, instead of classifying CIFAR-10 image directly to their most probable labels, we will train a CNN model to search images closest to the sample image passed on Kaggle with GPU acceleration enabled. We will be using Metric Learning for this tutorial.
Metric Learning
Distance metric learning (or simply, metric learning) aims at automatically constructing task-specific distance metrics from weakly supervised data, in a machine learning manner. The learned distance metric can then be used to perform various tasks (e.g., k-NN classification, clustering, information retrieval).

When referring to the images in this pair we'll use the common metric learning names of the anchor
(a randomly chosen image) and the positive
(another randomly chosen image of the same class). For this tutorial, we are using the simplest approach to training; a batch will consist of (anchor, positive) pairs spread across the classes. The goal of learning will be to move the anchor and positive pairs closer together and further away from other instances in the batch. In this case the batch size will be dictated by the number of classes. (10 for CIFAR-10)
Let's get into the tutorial.
Step 1: Load the dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
y_train = np.squeeze(y_train)
x_test = x_test.astype("float32") / 255.0
y_test = np.squeeze(y_test)
CIFAR-10 data set has ten-labels: "Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck".
Step 2: Make helper function/class to...
- Show a collage of random examples - Function
- Make batch of
anchor
andpositive
pairs - Class - Convert or enumerate images labels from their respective
id
height_width = 32
num_classes = 10
def show_collage(examples):
box_size = height_width + 2
num_rows, num_cols = examples.shape[:2]
collage = Image.new(
mode="RGB",
size=(num_cols * box_size, num_rows * box_size),
color=(250, 250, 250),
)
for row_idx in range(num_rows):
for col_idx in range(num_cols):
array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)
collage.paste(
Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
)
# Double size for visualisation.
collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))
return collage
class AnchorPositivePairs(keras.utils.Sequence):
def __init__(self, num_batchs):
self.num_batchs = num_batchs
def __len__(self):
return self.num_batchs
def __getitem__(self, _idx):
x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
for class_idx in range(num_classes):
examples_for_class = class_idx_to_train_idxs[class_idx]
anchor_idx = random.choice(examples_for_class)
positive_idx = random.choice(examples_for_class)
while positive_idx == anchor_idx:
positive_idx = random.choice(examples_for_class)
x[0, class_idx] = x_train[anchor_idx]
x[1, class_idx] = x_train[positive_idx]
return x
class_idx_to_train_idxs = defaultdict(list)
for y_train_idx, y in enumerate(y_train):
class_idx_to_train_idxs[y].append(y_train_idx)
class_idx_to_test_idxs = defaultdict(list)
for y_test_idx, y in enumerate(y_test):
class_idx_to_test_idxs[y].append(y_test_idx)
Step 3: The Model
We define a custom model with a train_step that first embeds both anchors and positives and then uses their pairwise dot products as logits
for a softmax
. This model simply consists of a sequence of 2d convolutions followed by global pooling with a final linear projection to an embedding space. As is common in metric learning we normalise the embeddings so that we can use simple dot products to measure similarity.
class EmbeddingModel(keras.Model):
def train_step(self, data):
# Note: Workaround for open issue, to be removed.
if isinstance(data, tuple):
data = data[0]
anchors, positives = data[0], data[1]
with tf.GradientTape() as tape:
# Run both anchors and positives through model.
anchor_embeddings = self(anchors, training=True)
positive_embeddings = self(positives, training=True)
# Calculate cosine similarity between anchors and positives. As they have
# been normalised this is just the pair wise dot products.
similarities = tf.einsum(
"ae,pe->ap", anchor_embeddings, positive_embeddings
)
# Since we intend to use these as logits we scale them by a temperature.
# This value would normally be chosen as a hyper parameter.
temperature = 0.2
similarities /= temperature
# We use these similarities as logits for a softmax. The labels for
# this call are just the sequence [0, 1, 2, ..., num_classes] since we
# want the main diagonal values, which correspond to the anchor/positive
# pairs, to be high. This loss will move embeddings for the
# anchor/positive pairs together and move all other pairs apart.
sparse_labels = tf.range(num_classes)
loss = self.compiled_loss(sparse_labels, similarities)
# Calculate gradients and apply via optimizer.
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update and return metrics (specifically the one for the loss value).
self.compiled_metrics.update_state(sparse_labels, similarities)
return {m.name: m.result() for m in self.metrics}
inputs = layers.Input(shape=(height_width, height_width, 3))
x = layers.Conv2D(filters=16, kernel_size=3, strides=2, activation="relu")(inputs)
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(x)
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x)
embeddings = layers.Dense(units=8, activation=None)(x)
embeddings = tf.nn.l2_normalize(embeddings, axis=-1)
model = EmbeddingModel(inputs, embeddings)
Step 4: Train the model
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
history = model.fit(AnchorPositivePairs(num_batchs=1000), epochs=20)
Step 5: Testing model performance
We embed the test set and calculate all near neighbours. Recall that since the embeddings are unit length we can calculate cosine similarity via dot products. We will also print a collage of 5 randomly selected images and their nearest neighbours. The first column of the image below is a randomly selected image, the following 10 columns show the nearest neighbours in order of similarity.
near_neighbours_per_example = 10
embeddings = model.predict(x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
num_collage_examples = 5
examples = np.empty(
(
num_collage_examples,
near_neighbours_per_example + 1,
height_width,
height_width,
3,
),
dtype=np.float32,
)
for row_idx in range(num_collage_examples):
examples[row_idx, 0] = x_test[row_idx]
anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])
for col_idx, nn_idx in enumerate(anchor_near_neighbours):
examples[row_idx, col_idx + 1] = x_test[nn_idx]
show_collage(examples)

[Optional]
Step 6: Confusion Matrix
We get a quantified view of the performance by considering the correctness of near neighbours in terms of a confusion matrix.
confusion_matrix = np.zeros((num_classes, num_classes))
# For each class.
for class_idx in range(num_classes):
# Consider 10 examples.
example_idxs = class_idx_to_test_idxs[class_idx][:10]
for y_test_idx in example_idxs:
# And count the classes of its near neighbours.
for nn_idx in near_neighbours[y_test_idx][:-1]:
nn_class_idx = y_test[nn_idx]
confusion_matrix[class_idx, nn_class_idx] += 1
# Display a confusion matrix.
labels = [
"Airplane",
"Automobile",
"Bird",
"Cat",
"Deer",
"Dog",
"Frog",
"Horse",
"Ship",
"Truck",
]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
plt.show()

[Bonus]
Find the public notebook below:

You can learn more about Metric Learning and its advantages from this tutorial video by Mat Kelcey!
Cheers!