It is recommended to begin your machine learning journey through the k-Nearest Neighbors (kNN) algorithm since it’s a very easy machine learning algorithm to understand and implement (and one of the most fun!). kNN is a supervised classification algorithm in that it allows us to determine which category a given input belongs to.
k-Nearest Neighbors is also called a Lazy Learning algorithm. Whereas many machine learning algorithms will develop generalizations about the data as its being loaded (also known as Eager Learning), kNN delays any computation related to the data until we ask the model to classify a new input. This makes the algorithm relatively easy to implement, but also computationally expensive if we need fast queries. Regardless, kNN is a good algorithm to begin with since the intuition behind it is fairly easy to grasp conceptually.
At a super high level, when we input a new datapoint to be classified, the k-Nearest Neighbor algorithm works by taking the most similar datapoints to . Whichever class the datapoints are a part of will ultimately be what class belongs to.
Although the algorithm is easy to implement, the data wrangling/analysis is often still a difficult task with kNN. Let’s take an example by looking one of my favourite philosophical questions: what differentiates pants from shorts? Assume we have a labelled dataset of pants and shorts (with descriptions about each instance such as price, color, length, thickness etc).
Existential datapoint questions
Like many of us, our datapoint is having an existential crisis: “where do I belong?”. Thankfully, through Euclidean geometry, we can help answer her question by seeing where the datapoints she is most similar to belong.
There are many ways to calculate similarity between two peices of information. For example, we can use the Hamming Distance (the number of different positions between two bit sequences) or the Levenshtein Distance (the number of edits it takes to change one piece of information into another). However, the distance calculate we are probably most used to using is Euclidean Distance (the distance between two elements on a plane).
In 2 dimensions, we can calculate the distance between two points and using the Pythagorean Theorem, or more formally as where 1 and 2 represent the first two dimensions of and (in this case and ). When we change this to 3 or more dimensions we can rewrite this as where there are dimensions. If this math scares you, don’t worry - we’ll stick to 2 dimensions for this guide so you can rely on our friend Pythagoras.
At a slightly lower-level (using Euclidean geometry), given a new datapoint , if we were to plot our labeled data in some coordinate system, we could retrieve the closest datapoints to where would be plotted. Intuitively, whichever class the most datapoints belong to will be indicative of what class belongs to.
The obvious answer to my above question about differentiating between pants and shorts is
"length" (with an implicit ‘duh’). Unsuprisingly, this is actually really good insight for thinking about how we will plot our dataset of pants and shorts.
We want to plot all the pants in our dataset using axes that correspond to the different attributes (or features) that tell us the most about whether a given garment is either pants or shorts. To keep us thinking in 2 dimensions, we are going to select two features (resulting in 2 axes): length and price.
Now, we can plot the pants and shorts accordingly on the graph by designating length to be one axis and price to be the other. Using the Euclidean distance between and the other datapoints, we can determine the nearest neighbors. In python, the distance function will look something like this:
math.sqrt((length1 - length2) ** 2 + (price1 - price2) ** 2)
After calculating distances by inputting the length and price of , we can obtain the nearest neighbors by sorting the list of datapoints by the distance to and taking the first elements:
neighbors = datapoints[:k]
Finally, we can iterate over
neighbors and count the number of pants to see if is a pair of pants - a fairly short(s) algorithm!
Why do we take the nearest neighbors to instead of the nearest neighbor?
Taking the nearest neighbor to is often going to result in outliars in the data to dictate the classification, which is why taking neighbors results in a more “democratic” selection.
How can we find the optimal value for ?
Although is a hyperparameter that we can set, certain values are obviously better than others. A general rule of thumb is often to set where there are instances of data. However, there exist more advanced techniques to determine optimal values for .
This has been a very brief introduction to the math and intuition behind kNN. There is still a lot more to the algorithm that I haven’t touched upon such as normalizing parameters and eliminating bias, but I’d start by actually implenting the algorithm. For a more complete understanding check out these links:
- Scikit Learn (try using kNN for yourself)
- Wikipedia (actually a very comprehensive resource)
- Stanford NLP (math-y but very to-the-point)