Digit Classification with Nearest Neighbors

Chris Tralie

Today we talked about supervised learning, which is the process by learning from a set of labeled examples, or examples that have been sorted into different classes beforehand. We're not always fortunate enough to have data that's labeled like this, but when we do, there are a variety of techniques we can use to learn models of the different classes. Actually, everything we've done in this class so far can be considered supervised learning, including learning from Markov chains trained on text to Naive Bayes bag of words and Gaussian Naive Bayes. In every application we looked at with these techniques, we trained on our labeled examples and then we tested on some new unseen data that wasn't included in the training set.

In this exercise today, we explore a new supervised learning technique known as nearest neighbors. If we have a way of measuring a distance between two different data points, then we can apply this technique. For example, let's suppose we had a labeled set of data points in two classes: red circles and blue squares. Then, let's say we wanted to guess which of the two classes some new data point was in. We'll depict this data point as a black triangle, as shown below

The K-nearest neighbors technique simply finds the K closest labeled examples, as measured by the distance, and uses them to vote on the class identity of this new point. In the above example, we choose K = 5 for the 5 nearest neighbors, and we happen to get 4 votes for a red circle and 1 vote for a blue square, so we would label this new data point as a red circle.

Overall, we can think of nearest neighbors as a supervised learning technique that memorizes examples. This means it's only as good as the examples, and it will do better with a higher number and variety of examples, which we don't always have access to. By contrast, other learning techniques will try to better generalize some knowledge to new examples. But we'll start with this "memorizer" first.

As simple as this technique may seem, it can work very well in practice. Below we'll show k-nearest neighbors on an example of 28x28 images of drawn digits, where the labeled examples are obtained from the MNIST database. In this case, there are 10 unique classes for the digits between 0 and 9, inclusive. Let's first load in our imports and load in all of the MNIST digits.

The digits are setup in a 2D array so that digits[i][j] gives the $j^{\text{th}}$ example of digit $i$. Each digit is itself a $28 \times 28$ 2D array of grayscale values between 0 and 1.

Next, let's try to think about how to define a distance between two digit images. First, let's look at the range of values in a digit. We'll pick out the first 0 as an example

Sam had the idea in class that we might try to compare which pixels where black in one image compared to another. To do this, we should only pick all pixels that are under a certain threshold. This is called binary quantization. For example, let's suppose the threshold is 0.4. Here's how we might do this for the image above

Below is a simpler way to accomplish thresholding with a single line of code using numpy broadcasting. Not only is the code shorter, but if you can stick to numpy, the code will run much faster because it's compiled C/Fortran

To compare two digits, we can threshold both of them and compute the Hamming distance between them, or the number of overlapping pixels that are different after thresholding. The Hamming distance between an 2D array $X$ and a 2D array $Y$ is defined as

$\sum_{i, j} |X[i][j] - Y[i][j]|$

Let's compare the first 1 to the first 7 this way, using numpy subtract broadcasting and np.sum to avoid a loop

Now we're finally ready to apply this to K-nearest neighbors. We'll define a function handle for this distance, as well as another distance that I had in mind, the Euclidean distance, which is defined as

$\sqrt{\sum_{i, j} (X[i, j] - Y[i, j])^2}$

You'll see that these each work pretty well if you try them out below, but the Euclidean distance has the advantage that we don't have to choose a threshold as a parameter. We'll compare them both below

We'll setup a little interactive canvas where we can draw digits and retrieve their K nearest neighbors. We'll use np.argsort to help us find the nearest neighbors. Try it out for yourself!

Finally, it's worth noting that the above approach is a brute force nearest neighbors approach that uses sorting. There are tons of ways to improve this. One of them is to use a data structure known as a KD Tree, which is able to hone in on the region of the space that contains nearest neighbors much more quickly without checking every example. It is roughly analagous to binary search performed spatially. Sadly, KD Trees to suffer from what's known as the "curse of dimensionality." Therefore, one often uses an approximate nearest neighbors scheme.