Nearest Neighbors

In this section, we will implement the KNN algorithm (k-Nearest-Neighbors) to classify data.

Let's start with an example!

I have collected and plotted (fictional) data on trees:

It appears that there is some association between the height of tree, diameter of the trunk, and the type of tree. Birch trees are shorter and have a smaller diameter, oak trees are taller and have a larger diameter, while sycamore trees are in the middle. We can see that not every tree we have measured follows this pattern, but there is a general trend. Now suppose that I have measured the height and diameter of another trees, but I don't know what type of tree it is. I have plotted this in green and circled it so you can spot it easily:

I want to guess what type of tree this is (in other words I want to classify the tree as either oak, sycamore, or birch). Looking at the plot I can see it is most similar in characteristic to the other birch trees I have already measured. Therefore, I would guess that this tree is a birch tree. Here I have implicitly implemented the nearest-neighbors algorithm. I have looked at the other data points closest to the tree I want to classify and used that to guess what type of tree this actually is.

We can do this pretty easily by looking at our plot with this data set, but we want to be able to have a computer classify objects for us - whether we want to classify types of trees, types of movies, if a person has a certain disease, etc.

As a more formal, mathematical explanation of what we did above to classify the tree we first plotted our data point. This data point has an x-coordinate and a y-coordinate. To predict what type of data point it was, we found the nearest points already on the scatter plot. We then predicted that the classification of our added point should be the same as the nearest points.

But suppose it was not so obvious what the point was:

How could we classify this point?

We will use the same idea and find the nearest neighbor:

Since this point is closest to a sycamore tree, using the nearest neighbor approach, we would classify it as a sycamore tree.

But we notice that some birch trees are also nearby neighbors to our mystery tree. We can consider more than one neighbor in our classification. Let's compare the 5 nearest neighbors to our mystery tree:

We can see that three of these points are birch trees and two are sycamore trees. Since there are more birch trees than sycamore trees in the five nearest points, we would classify this tree as birch. We can see that this algorithm does not always give the same result - it depends on the number of neighbors we choose to consider, but it can give us a reasonable estimation. We will quantify this more in the next section and use this algorithm to make predictions.

We chose 5 nearest neighbors. 5 is an odd number. Why did we choose an odd number?

KNN with Code

Now that we see the idea behind K-nearest neighbors, let's see how we can this using code!

In the following Jupyter Notebook, I have a data set where students either go to UC Berkeley or Stanford. In the data table, it has the coordinates of where they live during the school year. Below that table is a graph of the data.

In order to do our k-nearest neighbor classification, we need to do three things: 1. Split the data set 2. Find the Euclidean distances 3. Classifying

Split the Data Set

Next we want split the data into a test and training set so that we can create accurate predictions. The first thing we want to do is shuffle the data. We can do this using pandas' .sample() function that takes in the number of rows we want to shuffle. We would put 100 as the argument in sample because we want the entire table to be shuffled.

Then we have to take 3/4 of the data for our training set and the remaining 1/4 as our test set to see how accurate our predictor is!

Euclidean Distances

In order to find the distance between two pieces of data, we need to find the distance between them. Using the formula below, we can find the Euclidean distance between data points. A small distance means that the two points are similar while a larger distance means that the two points are more different. We want to find the distance from the test set example we are looking to classify from all the other training set points.

def distance(arr1, arr2):
    return sum((arr1-arr2)**2))**(1/2)

Classifying

Lastly, we need to create a function that takes our Euclidean distances and training sets and uses them to classify our test cases. Our classifying function should compute distance between an example (from the test set) and every row in the training set. Choose the row from the test set you want to classify Once we have these distances, we should sort the distances from smallest distance (the most similar training points to the test example) to biggest distance (the least similar training points to the test example). Sort your table by distance from your test set example from smallest to biggest distance (descending = False) After we have this, we can grab the k smallest distance points from our example and group the point by its location (either Berkeley or Stanford). Choose the top k rows from your sorted table Then we can group these k training cases by whether they are from Berkeley or Stanford and see which one of these classifications is more common. Since we always use an odd number of nearest neighbors, we will always have a winner

Last updated