Image Similarity with Python Part II: Nearest Neighbor Search

main_graphic.jpg

I've had several questions about my previous post, Determining how similar two images are with Python + Perceptual Hashing.

I mentioned using a BallTree data structure to compare a lot of images, and will go over that in this post.


To quickly summarize the previous post
  • I demonstrated how to hash an image using a perceptual hashing algorithm, which you can then use to compare the similarity of two images.

  • I used the imagehash library for this.

  • The hamming distance between two hashes is used to calculate similarity by comparing them at each index, and increasing by one for each index where they are different.

So a hamming distance of zero means that they are the same.


Comparing lots of images

If you want to compare a lot of images, it is not efficient to compare each image to all of the others, every time you want to do something.

My fashion app idea

Around 2015 I had an app idea.

Users would:

  • Upload an image of a clothing item.
  • Crop it closely around the item of interest.
  • View similar items and hopefully identify the specific item in the uploaded image.

I had a collection of images of clothing items, and the idea was that the app could help users identify an item of clothing(the brand, etc) that they were interested in, or maybe identify the item and give them other similar items at various price points.

At the center of that idea was the need to frequently compare a lot of images to each other and rank how similar they are.

Nearest Neighbors Search

A BallTree data structure, like this one from Scikit-learn, allows for efficient nearest neighbors search, and this is what I attemped to use.

However, there was one problem with using it that I had forgotten about until I revisited my code from then as I was trying to write up this post.

The input data gets converted into floats.

The hamming distance is calculated by iterating through two bit sequences of the same length and comparing them at each index.

Since the data gets converted into floats once you pass it to the BallTree object, the distance measures will be wrong.

I will detail this a bit more at the end of the post, but for now I am just going to get into what I DID do that worked!

Annoy from Spotify

What I did was switch to another library!

Annoy - Approximate Nearest Neighbors Oh Yeah is used at Spotify for their music recommendations.

You can read more about it in the docs, but one thing it does is build a forest of trees to store your data.

It was pretty straightforward to get up and running.

  • Start off with an activated virtual environment.
  • Install any packages you don't have.

Import all of the necessary packages.

import os
import random
import imagehash # pip install imagehash
import numpy as np # pip install numpy
from PIL import Image #pip install pillow
from annoy import AnnoyIndex #pip install annoy

You can install all of these packages with pip.

Annoy input data
  1. An identifier, that must be an integer.
  2. The vector, which for us will be the bit array of our image hash.

Converting the image hash to a bit array

First, here's how you can convert the perceptual hash returned from the imagehash library into a bit array, which we will use with Annoy.

image_one = 'example_image_file.jpg'
img = Image.open(image_one)
img_hash = imagehash.whash(img)

#img_hash.hash returns a boolean array, which we will convert into 0s and 1s
hash_array = img_hash.hash.astype('int').flatten()

It looks like this:

hash_array 

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0])

Annoy example

Now for the real example.

I have a directory called images, with around 900 images of handbags, that I'm going to use here.

images_dir = "images"
images_list = [img for img in os.listdir(images_dir)]

So I just used a list comprehension to put them all in a list.

Compile the training data

As mentioned, we need an integer id for each image, along with its corresponding vector of hash data.

vector_length = 0
id_to_vec = {}
for count,f in enumerate(images_list):
    img = Image.open(''.join([images_dir,'/',f]))
    img_hash = imagehash.whash(img)
    hash_array = img_hash.hash.astype('int').flatten();
    vector_length = hash_array.shape[0]
    id_to_vec[count] = hash_array

Now we've got a dictionary id_to_vec with integer ids for the data, from the count, as well as the vector of hash data in 0s and 1s.

The next step is to set up the AnnoyIndex and add this data to it. You could do this all in one step, but I wanted to separate it out for this post.

f = vector_length
dist_function = "hamming"

t = AnnoyIndex(f, dist_function)
for key,value in id_to_vec.items():
    t.add_item(key,value)

So we iterated through the id_to_vec dictionary and added the keys and values to t.

t.add_item(key,value)

Annoy is a C++ library with Python bindings, and the vector length tells it how much memory to allocate for each vector.

Now we can build the trees. I just picked 200 for the number of trees as I was playing around with different results.

num_trees = 200

t.build(num_trees)

Now the forest is ready, and we can query it.

Query for nearest neighbors

I'm going to find the nearest neighbors of this image, img-24.jpg from my images_list.

tree query image

And since the input data maps an integer id to the image vector data, we need to query with the id of the query image.

In this case, I just incremented the id for each iteration, so it is the same as the index of the image in the original images_list.

query_index = images_list.index('img-24.jpg')
num_neighbors = 9

neighbors = t.get_nns_by_item(query_index,num_neighbors,include_distances=True)

So here I've queried the tree with the method .get_nns_by_item() and passed:

  1. The index id of the query image.
  2. Number of nearest neighbors to return.
  3. including_distances - setting this to true will return the distances of each of these neighbors to the query image.

In the code above, as well as in the previous post, I used the wavelet hash

    img_hash = imagehash.whash(img)

And one thing it does is convert the image to grayscale, so the color data is not preserved in the hash.

Now there is also a colorhash option, which I think is relatively new.

    img_hash = imagehash.colorhash(img)

I will show the nearest neighbors using both of these hashes(separately) below.

Wavelet hash

Here are the 9 nearest neighbors of the query image, when using imagehash.whash.

wave hash nearest neighbors

Random Sample of Images

Here is a random sample of 81 images from the dataset overall for comparison.

There are 900+ images in the directory.

random sample of images

Colorhash

And here are the 9 nearest neighbors using imagehash.colorhash.

colorhash nearest neighbors

And there are a few other hashing options in the imagehash library as well, so I would definitely recommend trying them all out to see the different results.

Scikit-learn BallTree

Just a few comments on what I tried with this.

I am not really sure when the hamming distance metric could be used out of the box with nearest neighbor search in Scikit-learn, because all of the inputs seem to get converted into floats, and there is nothing you can do about it.

Also, with the way that the BallTree works, even writing your own custom metric(instead of using the built-in hamming distance) doesn't really work, because the metric is not only run against the input data.

The BallTree also creates its own tree node bounds and calculates the metric between those nodes and your data as well.

So I couldn't find a meaningful way to represent the image hash data that didn't have problems when calculating the distance between the node bounds.

If anyone reading this has any insight on use cases of scikit-learn's BallTree with the hamming distance, I would love to hear it!

Thanks for reading!

This is one solution for how compare a large dataset of images for similarity and retrieve the most similar images for a particular query image.

I really liked the Annoy library and appreciate the ease of using the hamming distance with bit arrays.

Hopefully this helped some of you.

Let me know what kind of image search projects you're working on!

blog comments powered by Disqus

Recent Posts

abstract_tree.png
Solving the Lowest Common Ancestor Problem in Python
May 9, 2023

Finding the Lowest Common Ancestor of a pair of nodes in a tree can be helpful in a variety of problems in areas such as information retrieval, where it is used with suffix trees for string matching. Read on for the basics of this in Python.

Read More
rectangles_cover.png
How to write a custom fragment shader in GLSL and use it with three.js
April 16, 2023

This blog post walks through the process of writing a fragment shader in GLSL, and using it within the three.js library for working with WebGL. We will render a visually appealing grid of rotating rectangles that can be used as a website background.

Read More
streaming data
Streaming data with Flask and Fetch + the Streams API
April 10, 2023

Streaming can be a great way to transfer and process large amounts of data. It can help save space and/or time, if the data uses a lot of memory, or if you want to start processing or visualizing the data as it comes in.

Read More
Get the latest posts as soon as they come out!