K-Nearest Neighbors (KNN)

In this module, we will explore how the K-Nearest Neighbors (KNN) algorithm works and apply it to classify fruits using physical characteristics like mass, width, height, and color score.

🔎 Introduction to KNN

KNN is an algorithm that can be used to make sorting more efficient. It works by taking in data points on a subject and predicting the classification of a new subject based on the data that’s already given. For a certain data point, it looks at the “K” nearest points in the training data (K is a defined number by the user) and the point in question is assigned to the most common class or average value of the neighbors in the training set.

Instead of creating a complicated model, KNN stores all the data and waits until a prediction needs to be made. When a new, unknown data point is introduced, KNN calculates the distance between this point and every other point in the dataset. It then identifies the “K” closest points (K is a number chosen by the user) and uses those neighbors to predict the new point’s category or value.

The basic idea behind KNN is that similar things are found near each other. For a classification task, once the K nearest neighbors are found, the new point is assigned to the class that appears most frequently among them.

If it’s a regression task, meaning you want to predict a number instead of a category, the new point’s value is usually the average of the neighboring values. KNN is called a “lazy learner” because it doesn’t do any actual learning until it’s time to make a prediction. It relies entirely on the structure of the data it’s given.

This can be applied to different types of sorting, like fruit sorting. For example, if you input data on the characteristics of a certain set of fruits (height, weight, mass etc.), based on this data point you can use KNN to determine the distance from this point and a new “mystery fruit”. So, based on known data points on the characteristics of apples, you can determine if your new “mystery fruit” is an apple or an orange based on how it fits into your identified data points.

So, if you input measurements for apples, oranges, and bananas, and then measure a new “mystery fruit,” KNN can find the fruits most similar to it based on those characteristics. If, say, four of the five nearest fruits are apples, the mystery fruit would likely be classified as an apple too. In this way, KNN makes it easy to sort new items based on known examples without having to build a complex model ahead of time.

KNN is also called a lazy learner because it simply stores data can just make decisions when needed. That can make it easy to set up for predictions but not very efficient at making predictions for larger data sets.

It’s especially useful for problems where you expect that similar inputs should have similar outputs (like fruits!), and it’s flexible enough to be used in a variety of real-world situations. It’s often used by streaming services, like Amazon, Netflix, or Hulu, to gage what kind of shows users prefer. In healthcare, KNN can help predict medical diagnoses by comparing a patient’s symptoms to past cases by indentifying similar symptoms. It has also been used in various instances of image recognition, like identifying handwritten numbers or objects in photos, and in banking, where it can spot unusual transactions that might be fraud.


📌 How KNN Works? Core Steps of KNN

  1. Choose the number of neighbors K (usually an odd number like 3, 5, or 7).

    The first thing you do is pick how many neighbors you want the algorithm to consider when it makes a prediction. K is just a positive whole number, like 3, 5, or 7. Most people go with an odd number to avoid ties when you’re voting on the most common class, so you don’t end up with a tie as to what your result is. A popular trick is to use the square root of the total number of data points in your dataset to pick K. A smaller K means the model is more sensitive to little changes and might get thrown off by outliers. A bigger K value smooths things out a bit more and makes the model more stable, but if K is too large, it can start blending things together that maybe shouldn’t be. So picking the right K really depends on the data.

  2. Calculate distance (e.g., Euclidean) from the new point to all existing data points.

    Once you’ve chosen K, the next step is to figure out how close your new, unknown data point is to all the other points in your training data. This is where distance metrics come in. The most common one is Euclidean distance, which is just the straight-line distance between two points. You could also use Manhattan distance, which moves along a grid, or Minkowski distance, which is a more general formula. Essentially, you’re measuring how similar or different the new point is from everything else ㅁbased on the features you’re tracking, like weight, height, or whatever.

  3. Identify the K closest neighbors.

    Once all the distances are calculated, you sort the training data by how close each point is to your new one. Then you just take the K points that are closest. These are the neighbors that are going to help decide what label or value the new point gets. The idea is that points that are close together in this space are likely to be similar.

  4. Predict by majority vote (for classification) or average value (for regression).

    Now that you have your K neighbors, you use them to actually make a prediction. If you’re doing classification, you look at the labels of the K neighbors and go with the one that shows up the most. Like if you’re trying to classify a mystery fruit and 3 of the 5 nearest neighbors are apples, then you’ll probably call it an apple. If you’re doing regression, which is when you’re trying to predict a number, you just take the average of the neighbors’ values and use that as your prediction for the new point.

Example:
If you’re given data on the size and color of known apples, oranges, lemons, etc., and then introduced to a new “mystery fruit”, KNN can tell you what kind of fruit it most likely is — based on how similar it is to others.


📚 Euclidean Distance

The distance between points is a key part of KNN. This is essential because it helps us know who is “closet” so you can pick right neighbor. The most commonly used metric is Euclidean distance for continuous variables, but you could also use others like Manhattan distance.

The Euclidean distance is most commonly used because it is:

  1. Easy to compute
  2. Makes geometric sense for continuous variables (eg. height, weight, age, etc)
  3. Matches our intuitive notion of “nearness” in a regular space

The Euclidean Distance is the “straight-line” distance between two points in Euclidean space. In KNN and many ML algorithms, we treat data points as if they live in n-dimensional Euclidean space, where each feature (or variable) is a dimension. So if a dataset has 4 features (like height, weight, age, income), it lives in 4D Euclidean space. We can then calculate Euclidean distances between points to compare them.git config pull.rebase false

For two points A = (x₁, y₁) and B = (x₂, y₂), \(d = \sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}\), the Euclidean distance is: \(d(A,B) = \sqrt{(x_1 - x_2)^2 + (y_1 - y_2)^2 + ... + (n_1 - n_2)^2}\)

In multi-feature datasets, each feature is a dimension in the space.

Example Codes

# Load necessary package
# install.packages("class")  # run this once if not already installed
library(class)

# Training data (2D points)
X_train <- matrix(c(1, 2,
                    2, 3,
                    3, 1,
                    6, 5,
                    7, 7), 
                  ncol = 2, byrow = TRUE)

# Labels for training data
y_train <- factor(c("A", "A", "A", "B", "B"))

# New data point to classify
X_test <- matrix(c(3, 2), ncol = 2)

# Run KNN (k = 3)
predicted_label <- knn(train = X_train, test = X_test, cl = y_train, k = 3)

# Output the prediction
print(paste("Predicted label:", predicted_label))
[1] "Predicted label: A"

đŸ€– Choosing K, kNN Cross Validation

In KNN, choosing the right value of K (the number of neighbors) is extremely important because it controls the model’s complexity:

  • Small K (Risk of overfitting):

    1. The model becomes very sensitive to noise and small fluctuations in the data.
    2. It may fit outliers and mistakes in the training set.
    3. This can cause overfitting, where the model memorizes the training data too closely but performs poorly on new, unseen data.
  • Large K (Risk of underfitting.):

    1. The model looks at more neighbors and becomes more stable and smoother.

    2. It may ignore important local patterns by averaging too much.

    3. This can cause underfitting, where the model is too simple and cannot capture the complexity of the data.

Thus, choosing the best K is a balance between overfitting and underfitting. Cross-validation is often used to find the optimal K. 

Why use Cross-Validation to choose K?

  • You cannot just guess the best K value.

  • Cross-validation helps systematically test different K values on parts of the data not used for training, simulating how the model would perform on truly unseen data.

  • You typically try a range of K values (e.g., from 1 to 20) and select the one that gives the highest validation accuracy.

Therefore, try multiple K values and use cross-validation, this is the most common and reliable method.

Example codes

# Install packages if needed
# install.packages(“caret”)
# Example in R using iris dataset
library(class)
library(caret)
Loading required package: ggplot2
Loading required package: lattice
# Split the data
set.seed(123)
train_index <- createDataPartition(iris$Species, p = 0.7, list = FALSE)
train_data <- iris[train_index, ]
test_data <- iris[-train_index, ]

# Scale features
train_scaled <- scale(train_data[, 1:4])
test_scaled <- scale(test_data[, 1:4], center = attr(train_scaled, "scaled:center"), 
                                         scale = attr(train_scaled, "scaled:scale"))

# Try different K values and record accuracy
accuracies <- sapply(1:20, function(k) {
  pred <- knn(train = train_scaled, test = test_scaled, cl = train_data$Species, k = k)
  mean(pred == test_data$Species)
})

# Plot accuracy vs K
plot(1:20, accuracies, type = "b", col = "blue", pch = 19,
     xlab = "K", ylab = "Accuracy", main = "KNN Accuracy for Different K")

This plot illustrates how the accuracy of the K-Nearest Neighbors (KNN) algorithm varies with different values of K—the number of nearest neighbors considered during classification. As K increases, the model becomes more generalized, averaging the behavior of more neighboring points. Conversely, smaller K values make the model more sensitive to local patterns and noise in the data.


đŸ„ Real Dataset: Fruit Sorting with KNN

Now we’re going to apply the KNN algorithm to a real dataset. We’ve learned the theoretical aspects so far, but now let’s see how it actually works in practice! We’ll walk through the code step-by-step, interpret outputs like scatterplots, accuracy plots, and confusion matrices, and answer key questions like: What’s a confusion matrix? Why normalize data? Why split it? And why compare models?

Dataset Loading and Setup

library(readr)
library(dplyr)

Attaching package: 'dplyr'
The following objects are masked from 'package:stats':

    filter, lag
The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
library(ggplot2)
library(class)
library(caret)
library(curl)
Using libcurl 8.11.1 with OpenSSL/3.3.2 

Attaching package: 'curl'
The following object is masked from 'package:readr':

    parse_date
f <- curl("https://raw.githubusercontent.com/msbrendakim/AN588-KNN-FruitsClassification/master/fruit_data_with_colors.txt")
fruit_data<-read.delim(f, sep="\t", header=TRUE)

fruit_data$fruit_label <- as.factor(fruit_data$fruit_label)
fruit_data$fruit_name <- as.factor(fruit_data$fruit_name)

We start by loading our fruit dataset, stored in fruit_data_with_colors.txt, a tab-separated file with measurements like mass and color score. Here, we convert fruit_label and fruit_name to factors, which tells R these are categories (e.g., apple, orange), not numbers. This is crucial for KNN, a classification algorithm that assigns categories based on nearby points.

Data Exploration

summary(fruit_data)
 fruit_label    fruit_name fruit_subtype           mass           width      
 1:19        apple   :19   Length:59          Min.   : 76.0   Min.   :5.800  
 2: 5        lemon   :16   Class :character   1st Qu.:140.0   1st Qu.:6.600  
 3:19        mandarin: 5   Mode  :character   Median :158.0   Median :7.200  
 4:16        orange  :19                      Mean   :163.1   Mean   :7.105  
                                              3rd Qu.:177.0   3rd Qu.:7.500  
                                              Max.   :362.0   Max.   :9.600  
     height        color_score    
 Min.   : 4.000   Min.   :0.5500  
 1st Qu.: 7.200   1st Qu.:0.7200  
 Median : 7.600   Median :0.7500  
 Mean   : 7.693   Mean   :0.7629  
 3rd Qu.: 8.200   3rd Qu.:0.8100  
 Max.   :10.500   Max.   :0.9300  
table(fruit_data$fruit_name)

   apple    lemon mandarin   orange 
      19       16        5       19 
ggplot(fruit_data, aes(x = width, y = height, color = fruit_name)) +
  geom_point(size = 3, alpha = 0.7) +
  labs(title = "Fruit Types by Size") +
  theme_minimal()

Before we model, we need to understand our data. The summary(fruit_data) output shows stats for our features. For example, mass ranges from 76 to 362 grams, and color score is between 0.55 and 0.93. This tells us our features have different scales, which we’ll need to fix later for KNN.

The table(fruit_data$fruit_name) output shows we have 19 apples, 16 lemons, 5 mandarins, and 19 oranges. Notice mandarins are underrepresented with only 5 samples, which might make them harder to classify accurately compared to apples or oranges.

The scatterplot plots width versus height, with each point colored differently. When you look at this plot, you might see mandarins clustering in one area (shorter and narrower) and lemons in another (taller and narrower). If the fruit types form distinct clusters, that’s a great sign for KNN, because it relies on nearby points being similar. But if points overlap a lot, like apples blending with oranges, KNN might struggle to separate them. This plot gives us a first look at whether our features (width and height) are good for distinguishing fruits, and it hints that mass and color score might add even more separation.

Feature Normalization and Splitting

features <- fruit_data[, c("mass", "width", "height", "color_score")]

preproc <- preProcess(features, method = c("center", "scale"))
features_scaled <- predict(preproc, features)

set.seed(123)
train_idx <- createDataPartition(fruit_data$fruit_label, p = 0.7, list = FALSE)
train_x <- features_scaled[train_idx, ]
test_x <- features_scaled[-train_idx, ]
train_y <- fruit_data$fruit_label[train_idx]
test_y <- fruit_data$fruit_label[-train_idx]

Here, we’re prepping our data for KNN. We select our features—mass, width, height, and color score—into features. Since KNN calculates distances, we need to normalize these features because they’re on different scales (e.g., mass in grams vs. color score from 0 to 1). The preProcess function with method = c(“center”, “scale”) standardizes each feature to have a mean of 0 and a standard deviation of 1, and predict(preproc, features) applies this to get features_scaled.

Then, we split the data into 70% training and 30% testing sets using createDataPartition. This tests how well our model generalizes to unseen data, mimicking real-world scenarios where new fruits arrive. The set.seed(123) ensures our split is reproducible, and the split keeps the fruit classes balanced. train_x and train_y are our training features and labels, while test_x and test_y are for testing. This lets us train our model and test it on unseen data to check its generalization.

Model Evaluation: Choosing Best K

k_values <- seq(1, 15, 2)
accuracy_list <- sapply(k_values, function(k) {
  pred <- knn(train_x, test_x, train_y, k)
  mean(pred == test_y)
})

plot(k_values, accuracy_list, type = "b", col = "blue", pch = 19,
     xlab = "K", ylab = "Accuracy", main = "KNN Accuracy by K")

max_accuracy <- max(accuracy_list)
threshold <- 0.02
good_k <- k_values[accuracy_list >= (max_accuracy - threshold)]
best_k <- max(good_k)
cat("Best K (within 0.02 of max accuracy):", best_k, "\n")
Best K (within 0.02 of max accuracy): 7 

We test K values (1, 3, 5, 
, 15) to find the best number of neighbors. For each K, knn predicts test set labels, and we compute accuracy. The plot shows accuracy versus K. We pick the absolute highest accuracy (K=7, 1.0), with an accuracy of ~0.98.

The plot also shows accuracy peaking at K=1 (1.0), but K=1 risks overfitting. K=7, with an accuracy of ~0.98, is a better choice for a fruit sorting system, as it uses 7 neighbors, reducing sensitivity to noise while maintaining high accuracy. This means we might misclassify 1 out of ~18 test fruits, but the model will generalize better to new fruits.

Final Model and Confusion Matrix

final_pred <- knn(train_x, test_x, train_y, best_k)

cm <- confusionMatrix(final_pred, test_y)

# Precision, Recall, F1-score by class
metrics <- data.frame(
  Class = levels(test_y),
  Precision = cm$byClass[, "Pos Pred Value"],
  Recall = cm$byClass[, "Sensitivity"],
  F1 = 2 * ((cm$byClass[, "Pos Pred Value"] * cm$byClass[, "Sensitivity"]) /
            (cm$byClass[, "Pos Pred Value"] + cm$byClass[, "Sensitivity"]))
)

metrics
         Class Precision Recall F1
Class: 1     1         1      1  1
Class: 2     2         1      1  1
Class: 3     3         1      1  1
Class: 4     4         1      1  1
# label map
label_map <- levels(fruit_data$fruit_name)
names(label_map) <- levels(fruit_data$fruit_label)

# confusion matrix transfer to heatmap
cm_df <- as.data.frame(cm$table)
cm_df$Reference <- factor(cm_df$Reference, levels = names(label_map), labels = label_map)
cm_df$Prediction <- factor(cm_df$Prediction, levels = names(label_map), labels = label_map)

# heatmap visualization
ggplot(cm_df, aes(x = Reference, y = Prediction, fill = Freq)) +
  geom_tile(color = "white") +
  geom_text(aes(label = Freq), color = "black", size = 6) +
  scale_fill_gradient(low = "white", high = "steelblue") +
  labs(title = "KNN Confusion Matrix (Fruit Names)", x = "True Label", y = "Predicted Label") +
  theme_minimal()

We use best_k (K=7) to run our final KNN model with knn, making predictions on the test set (final_pred). The confusionMatrix compares these predictions to test_y, giving us a table of correct and incorrect classifications. We also compute precision, recall, and F1-score per class. Precision is how often our “apple” predictions are correct; recall is how many actual apples we identified; F1 balances both. Finally, we create a heatmap of the confusion matrix, mapping numeric labels (1, 2, 3, 4) to fruit names (apple, lemon, mandarin, orange) for clarity.

The metrics table shows perfect scores (1.0 for precision, recall, and F1) for all classes. This looks amazing, but it’s suspicious! Perfect scores often mean the test set is too small (only ~18 fruits with 30% of 59) or the data is too clean, making the work flawlessly. In a real fruit sorting system, new fruits might have more variability (e.g., bruised apples), so this perfection might not hold. The heatmap confirms this: the diagonal (correct predictions, like apple predicted as apple) has all the counts, and off-diagonal cells (mistakes) are zero. This suggests our model nailed the test set, but we should be cautious—it might not generalize to new fruits. For example, mandarins, with only 5 samples, might be overfitted, as there’s little data to learn from.

🔀 Model Comparison: KNN vs Logistic Regression

A model comparison table helps us evaluate which algorithm is best suited for our task.

library(nnet)  # for multinom()
library(e1071) # confusionMatrix

# using the same features in data
train_df <- data.frame(train_x)
train_df$label <- train_y
test_df <- data.frame(test_x)
test_df$label <- test_y

# Multinomial logistic regression
logit_model <- multinom(label ~ ., data = train_df)
# weights:  24 (15 variable)
initial  value 60.996952 
iter  10 value 13.831555
iter  20 value 12.482299
iter  30 value 12.435228
final  value 12.434969 
converged
# prediction
logit_pred <- predict(logit_model, newdata = test_df)

cm_logit <- confusionMatrix(logit_pred, test_y)

# comparison
cm_logit_df <- as.data.frame(cm_logit$table)
cm_logit_df$Reference <- factor(cm_logit_df$Reference, levels = names(label_map), labels = label_map)
cm_logit_df$Prediction <- factor(cm_logit_df$Prediction, levels = names(label_map), labels = label_map)

ggplot(cm_logit_df, aes(x = Reference, y = Prediction, fill = Freq)) +
  geom_tile(color = "white") +
  geom_text(aes(label = Freq), color = "black", size = 6) +
  scale_fill_gradient(low = "white", high = "tomato") +
  labs(title = "Logistic Regression Confusion Matrix (Fruit Names)", x = "True Label", y = "Predicted Label") +
  theme_minimal()

Model Comparison Table

comparison_df <- data.frame(
  Model = c("KNN", "Logistic Regression"),
  Accuracy = c(cm$overall["Accuracy"], cm_logit$overall["Accuracy"]),
  Kappa = c(cm$overall["Kappa"], cm_logit$overall["Kappa"])
)

comparison_df
                Model Accuracy     Kappa
1                 KNN      1.0 1.0000000
2 Logistic Regression      0.8 0.7151899

Seeing the model comparison table, KNN, with an accuracy of 1.0, got all test fruits correct, while logistic regression, with 0.8 accuracy, got 14–15 out of 18 correct.

Kappa measures agreement between predicted and actual labels, adjusted for chance. A Kappa of 1.0 (KNN) means perfect agreement, while 0.7151 (logistic regression) indicates good but not perfect agreement, likely due to errors in smaller classes like mandarins, where chance agreement is lower.

You might be wondering, “How did KNN get a perfect score like this?”

  • KNN classifies fruits by finding the 7 nearest neighbors in feature space (mass, width, height, color score). Our dataset is small (59 fruits, ~18 in the test set), and the scatterplot showed distinct clusters (e.g., apples vs. lemons). With K=7, KNN likely found clear majorities for each test fruit, leading to perfect predictions. However, the earlier accuracy plot showed ~0.98 for K=7, so this perfect 1.0 suggests either a different random split or a very small test set where K=7 happened to align perfectly.

  • Logistic regression assumes linear boundaries between classes. If mandarins and oranges overlap in size or color score, it might misclassify them, leading to errors (e.g., 3–4 misclassifications for 0.8 accuracy). Its Kappa of 0.7151 shows it’s still decent but struggles with non-linear patterns, which KNN handles better.

  • With only ~18 test fruits, small changes in predictions have a big impact. KNN getting 18/18 correct gives 1.0, while logistic regression getting 14/18 gives 0.8—a difference of just 4 fruits but a large gap in percentage. This highlights the risk of overfitting with KNN, especially on a small dataset.


đŸ„­ Classifying a Mystery Fruit

new_fruit <- data.frame(mass = 160, width = 7.2, height = 7.5, color_score = 0.76)
new_scaled <- predict(preproc, new_fruit)

mystery_pred <- knn(train_x, new_scaled, train_y, k = best_k, prob = TRUE)

cat("Prediction:", mystery_pred)
Prediction: 3
cat("Confidence:", attr(mystery_pred, "prob"))
Confidence: 0.8571429

Finally, we classify a mystery fruit with mass=160g, width=7.2, height=7.5, and color_score=0.76. We normalize it using the same preproc object to match our training data’s scale, then run knn with best_k=7.

The prediction is 3, which maps to “mandarin” (based on label_map). The confidence is approximately 0.86, meaning the closest neighbor (since K=7) is 86% a mandarin. In a real-world scenario, like a fruit sorting machine, we’d want a higher K for reliability, as one odd fruit could throw us off. This result shows KNN’s power for quick predictions but also highlights the risk of overfitting with a small K.


🆚 KNN vs. K-Means

Feature KNN K-Means
Supervised? Yes No
Used For Classification / Regression Clustering
Distance Use Classify based on neighbors Assign to closest centroid
Label Needed? Yes No
Type Lazy Learner (no training phase) Iterative optimization algorithm

❗ Strengths and Limitations

Advantages Limitations
  • Easy to understand and implement

  • No training phase needed

  • Flexible for multi-class problems

  • Expensive at prediction time

  • Sensitive to irrelevant features

  • Struggles in high-dimensional space


✅ Conclusion

In this module, we explored the K-Nearest Neighbors (KNN) algorithm from its theoretical foundation to practical implementation using a fruit classification task. Through both toy examples and real datasets, we gained a deeper understanding of how similarity-based classification works.

Key Takeaways:

  • KNN is a non-parametric, instance-based learning method that predicts new data points by comparing their distance to known labeled data.
  • Normalization of features was critical, especially since Euclidean distance is sensitive to scale.
  • Model performance varied with the value of K, with K = 1 showing perfect classification for this particular dataset. However, this could lead to overfitting in noisier or more complex data.
  • Visualization using confusion matrices helped us interpret model results at a glance, especially when labeled with meaningful class names (like apple, orange, etc.).
  • We compared KNN to Logistic Regression, a probabilistic model, and found that while KNN achieved perfect accuracy in this specific task, logistic regression showed slightly lower accuracy and Kappa. This demonstrates the potential of KNN for small, structured, and well-separated datasets.

When to Use KNN:

KNN is particularly useful when: - You want a simple baseline model - You have a small to medium-sized dataset - Interpretability and flexibility are more important than speed

However, its limitations in terms of: - Computational cost (especially with large datasets) - Sensitivity to irrelevant features or high dimensionality should be kept in mind.

Future Directions:

  • Try changing the distance metric (e.g., Manhattan, cosine) and evaluate impact.

    Distance Type When to use
    Euclidean Distance Continuous, real-valued features (e.g., height, weight, income)
    Manhattan Distance Grid-like spaces (e.g., city blocks)
    Cosine Similarity When you care about direction more than distance (e.g., text data)
    Hamming Distance Categorical data (e.g., bit strings)
  • Apply dimensionality reduction techniques like PCA to see how performance and interpretability are affected.

  • Compare KNN with other classifiers such as Decision Trees, Random Forests, or SVMs for larger datasets or more complex feature spaces.

In summary, KNN provides a solid foundation for understanding proximity-based machine learning and offers valuable intuition for feature-space dynamics in classification tasks.


References