In this ML Made Simple, we will clearly explain what is the k-nearest neighbors — or k-NN — algorithm, and how it works. You don't need to have read the part #1 on Linear Regression, but consider doing so.

The k-NN algorithm is a method that can be used for regression and classification. Meaning that you can use it to estimate a continuous value (e.g. the price of a house) or identifying group membership (e.g. apple or banana).

It's probably the easiest machine learning algorithm. There is nothing to "learn", there are no parameters to optimize, it directly uses the dataset to make a prediction.

Pros

  • Incredibly simple
  • Accurate (but doesn't compare with some others models)
  • Can be used for regression and classification

Cons

  • Computationally expensive
  • Sensitive to outliers

I will show you how k-NN works by using the famous Iris Dataset as an example, it looks like this (we will use 2 features for the sake of simplicity):

Petal length (cm) Petal width (cm) Iris type
1.4 0.2 Setosa
4.7 1.4 Versicolor
4.0 1.3 Versicolor
5.5 1.8 Virginica
4.4 1.2 Versicolor
1.7 0.8 Setosa
...

The goal is to predict the type of a iris given its features petal length and petal width.

Visualize the dataset

One of the first things to do in machine learning is to plot your data, because humans are really bad at visualizing a big table of number.

Let's take a look at our dataset:

Now what we want to do is way more clear : if I add a point on this plot, I want to know whether it's a red, green, or blue dot. As you may have guessed, we will do it by looking at the nearest neighbors (closest dots).

Let's say that I picked up an iris, with a petal length of 4 cm and a petal width of 1 cm, and I want to predict whether its a setosa, versicolor, or virginica. If I told you to guess, you will probably just look around the position (4, 1) on the plot and find which category appears the most. That's also what k-NN does.

How does an algorithm "looks" at what is around a position ? Well, it needs a function to evaluate the distance between two points.

Find nearest neighbors

This is the Euclidean distance — which is exactly what we usually mean by "distance" — computed by using the Pythagorean theorem :

$$d(A,B) = \sqrt{(A_1-B_1)^2 + (A_2-B_2)^2 + ... +(A_n-B_n)^2}$$

This is the Pythagorean theorem for a n-dimensional space, here is a neater way to write it :

$$d(A,B) = \sqrt{\sum_{i=1}^{n} (A_i-B_i)^2}$$

There are more interesting ways to estimate a distance, like the Minkowski distance or the Manhattan distance, but we won't talk about them here.

Now that we can evaluate the distance between any pair of points, it's easy to find the closest instances (a dot on the plot or a row in the table). The algorithm just compute the distances between the position of my iris and all the instances in the dataset, and chooses the nearest ones (those with the lowest distance value). The issue with that technique is that it's computationally expensive, but this is how k-NN works.

Wait. How many instances does k-NN needs to choose ?

Good question. Actually, I still didn't explain what is the k in this algorithm name. It's a parameter, and you choose its value yourself — there are some ways to choose it automatically, but we won't see them in this blog post. This k represent how many nearest neighbors are chosen. You may pick 1, 2, or any number that works well (usually between 1-10).

Compute the output

The only thing that's left is to use the nearest neighbors that we just found to figure out what probably is the category of my iris. The algorithm does it just like anyone would do, it counts which type of iris is the most common in the neighborhood, that's all.

I told you that k-NN apply on regression and classification problems, but we only talked about classification. It's because the only difference is in this step.

In k-NN regression, the output isn't a category, but a property value. This value is the mean of the nearest neighbors property value. For example, if you want to predict the height a flower based on the length and width of its petals, you have to compute the mean of the heights of the flowers with the closest petals size.

Congratulations, you now know how the k-NN algorithm works !


Decisions boundaries

This is a little bonus.

Decisions boundaries are the lines where your algorithm change its predicted category, it's where it "thinks" the separation between two classes is.

Drawing these lines is a nice way to know how your algorithm behave, on top of being surprisingly satisfying. This is what they look like for our example:

I love it.

You may want to take a look at this Python Notebook, an actual implementation of k-NN.