K-Nearest Neighbors (KNN)

Introduction

In this module, we will explore how the K-Nearest Neighbors (KNN) algorithm works and apply it to classify fruits using physical characteristics like mass, width, height, and color score.

KNN is a simple yet powerful algorithm for classification and regression that uses the proximity of data points to make predictions.


How KNN Works

📌 Core Steps

  1. Choose the number of neighbors K (usually an odd number like 3, 5, or 7).
  2. Calculate distance (e.g., Euclidean) from the new point to all existing data points.
  3. Identify the K closest neighbors.
  4. Predict by majority vote (for classification) or average value (for regression).

Example:
If you’re given data on the size and color of known apples, oranges, lemons, etc., and then introduced to a new “mystery fruit”, KNN can tell you what kind of fruit it most likely is — based on how similar it is to others.


Euclidean Distance

KNN relies heavily on distance. The Euclidean distance is most common:

For two points A = (x₁, y₁) and B = (x₂, y₂), \(d = \sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}\)

In multi-feature datasets, each feature is a dimension in the space.


KNN in Action: Example Codes

Install package once (if needed) install.packages("class")

library(class)

Training data (2D)

X_train <- matrix(c(1, 2, 2, 3, 3, 1, 6, 5, 7, 7), ncol = 2, byrow = TRUE)

y_train <- factor(c("A", "A", "A", "B", "B"))

New point to classify

X_test <- matrix(c(3, 2), ncol = 2)

Run KNN (k = 3)

predicted_label <- knn(train = X_train, test = X_test, cl = y_train, k = 3)

paste("Predicted label:", predicted_label)

🥝 Real Dataset: Fruit Sorting with KNN

Dataset Loading and Setup

library(readr)
library(dplyr)

Attaching package: 'dplyr'
The following objects are masked from 'package:stats':

    filter, lag
The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
library(ggplot2)
library(class)
library(caret)
Loading required package: lattice
fruit_data <- read.delim("fruit_data_with_colors.txt", sep="\t", header=TRUE)

fruit_data$fruit_label <- as.factor(fruit_data$fruit_label)
fruit_data$fruit_name <- as.factor(fruit_data$fruit_name)

Data Exploration

summary(fruit_data)
 fruit_label    fruit_name fruit_subtype           mass           width      
 1:19        apple   :19   Length:59          Min.   : 76.0   Min.   :5.800  
 2: 5        lemon   :16   Class :character   1st Qu.:140.0   1st Qu.:6.600  
 3:19        mandarin: 5   Mode  :character   Median :158.0   Median :7.200  
 4:16        orange  :19                      Mean   :163.1   Mean   :7.105  
                                              3rd Qu.:177.0   3rd Qu.:7.500  
                                              Max.   :362.0   Max.   :9.600  
     height        color_score    
 Min.   : 4.000   Min.   :0.5500  
 1st Qu.: 7.200   1st Qu.:0.7200  
 Median : 7.600   Median :0.7500  
 Mean   : 7.693   Mean   :0.7629  
 3rd Qu.: 8.200   3rd Qu.:0.8100  
 Max.   :10.500   Max.   :0.9300  
table(fruit_data$fruit_name)

   apple    lemon mandarin   orange 
      19       16        5       19 
ggplot(fruit_data, aes(x = width, y = height, color = fruit_name)) +
  geom_point(size = 3, alpha = 0.7) +
  labs(title = "Fruit Types by Size") +
  theme_minimal()

Feature Normalization and Splitting

features <- fruit_data[, c("mass", "width", "height", "color_score")]

preproc <- preProcess(features, method = c("center", "scale"))
features_scaled <- predict(preproc, features)

set.seed(123)
train_idx <- createDataPartition(fruit_data$fruit_label, p = 0.7, list = FALSE)
train_x <- features_scaled[train_idx, ]
test_x <- features_scaled[-train_idx, ]
train_y <- fruit_data$fruit_label[train_idx]
test_y <- fruit_data$fruit_label[-train_idx]

Model Evaluation: Choosing Best K

k_values <- seq(1, 15, 2)
accuracy_list <- sapply(k_values, function(k) {
  pred <- knn(train_x, test_x, train_y, k)
  mean(pred == test_y)
})

plot(k_values, accuracy_list, type = "b", col = "blue", pch = 19,
     xlab = "K", ylab = "Accuracy", main = "KNN Accuracy by K")

(best_k <- k_values[which.max(accuracy_list)])
[1] 1

Final Model and Confusion Matrix

final_pred <- knn(train_x, test_x, train_y, best_k)

cm <- confusionMatrix(final_pred, test_y)

# Precision, Recall, F1-score by class
metrics <- data.frame(
  Class = levels(test_y),
  Precision = cm$byClass[, "Pos Pred Value"],
  Recall = cm$byClass[, "Sensitivity"],
  F1 = 2 * ((cm$byClass[, "Pos Pred Value"] * cm$byClass[, "Sensitivity"]) /
            (cm$byClass[, "Pos Pred Value"] + cm$byClass[, "Sensitivity"]))
)

metrics
         Class Precision Recall F1
Class: 1     1         1      1  1
Class: 2     2         1      1  1
Class: 3     3         1      1  1
Class: 4     4         1      1  1
# label map
label_map <- levels(fruit_data$fruit_name)
names(label_map) <- levels(fruit_data$fruit_label)

# confusion matrix transfer to heatmap
cm_df <- as.data.frame(cm$table)
cm_df$Reference <- factor(cm_df$Reference, levels = names(label_map), labels = label_map)
cm_df$Prediction <- factor(cm_df$Prediction, levels = names(label_map), labels = label_map)

# heatmap visualization
ggplot(cm_df, aes(x = Reference, y = Prediction, fill = Freq)) +
  geom_tile(color = "white") +
  geom_text(aes(label = Freq), color = "black", size = 6) +
  scale_fill_gradient(low = "white", high = "steelblue") +
  labs(title = "KNN Confusion Matrix (Fruit Names)", x = "True Label", y = "Predicted Label") +
  theme_minimal()

🔀 Model Comparison: KNN vs Logistic Regression

library(nnet)  # for multinom()
library(e1071) # confusionMatrix

# using the same features in data
train_df <- data.frame(train_x)
train_df$label <- train_y
test_df <- data.frame(test_x)
test_df$label <- test_y

# Multinomial logistic regression
logit_model <- multinom(label ~ ., data = train_df)
# weights:  24 (15 variable)
initial  value 60.996952 
iter  10 value 13.831555
iter  20 value 12.482299
iter  30 value 12.435228
final  value 12.434969 
converged
# prediction
logit_pred <- predict(logit_model, newdata = test_df)

cm_logit <- confusionMatrix(logit_pred, test_y)

# comparison
cm_logit_df <- as.data.frame(cm_logit$table)
cm_logit_df$Reference <- factor(cm_logit_df$Reference, levels = names(label_map), labels = label_map)
cm_logit_df$Prediction <- factor(cm_logit_df$Prediction, levels = names(label_map), labels = label_map)

ggplot(cm_logit_df, aes(x = Reference, y = Prediction, fill = Freq)) +
  geom_tile(color = "white") +
  geom_text(aes(label = Freq), color = "black", size = 6) +
  scale_fill_gradient(low = "white", high = "tomato") +
  labs(title = "Logistic Regression Confusion Matrix (Fruit Names)", x = "True Label", y = "Predicted Label") +
  theme_minimal()

Model Comparison Table

comparison_df <- data.frame(
  Model = c("KNN", "Logistic Regression"),
  Accuracy = c(cm$overall["Accuracy"], cm_logit$overall["Accuracy"]),
  Kappa = c(cm$overall["Kappa"], cm_logit$overall["Kappa"])
)

comparison_df
                Model Accuracy     Kappa
1                 KNN      1.0 1.0000000
2 Logistic Regression      0.8 0.7151899

🥭 Classifying a Mystery Fruit

new_fruit <- data.frame(mass = 160, width = 7.2, height = 7.5, color_score = 0.76)
new_scaled <- predict(preproc, new_fruit)

mystery_pred <- knn(train_x, new_scaled, train_y, k = best_k, prob = TRUE)

cat("Prediction:", mystery_pred)
Prediction: 3
cat("Confidence:", attr(mystery_pred, "prob"))
Confidence: 1

KNN vs. K-Means

Feature KNN K-Means
Supervised? Yes No
Used For Classification / Regression Clustering
Distance Use Classify based on neighbors Assign to closest centroid
Label Needed? Yes No
Type Lazy Learner (no training phase) Iterative optimization algorithm

Strengths and Limitations

Advantages Limitations
  • Easy to understand and implement

  • No training phase needed

  • Flexible for multi-class problems

  • Expensive at prediction time

  • Sensitive to irrelevant features

  • Struggles in high-dimensional space


Conclusion

In this module, we explored the K-Nearest Neighbors (KNN) algorithm from its theoretical foundation to practical implementation using a fruit classification task. Through both toy examples and real datasets, we gained a deeper understanding of how similarity-based classification works.

Key Takeaways:

  • KNN is a non-parametric, instance-based learning method that predicts new data points by comparing their distance to known labeled data.
  • Normalization of features was critical, especially since Euclidean distance is sensitive to scale.
  • Model performance varied with the value of K, with K = 1 showing perfect classification for this particular dataset. However, this could lead to overfitting in noisier or more complex data.
  • Visualization using confusion matrices helped us interpret model results at a glance, especially when labeled with meaningful class names (like apple, orange, etc.).
  • We compared KNN to Logistic Regression, a probabilistic model, and found that while KNN achieved perfect accuracy in this specific task, logistic regression showed slightly lower accuracy and Kappa. This demonstrates the potential of KNN for small, structured, and well-separated datasets.

When to Use KNN:

KNN is particularly useful when: - You want a simple baseline model - You have a small to medium-sized dataset - Interpretability and flexibility are more important than speed

However, its limitations in terms of: - Computational cost (especially with large datasets) - Sensitivity to irrelevant features or high dimensionality should be kept in mind.

Future Directions:

  • Try changing the distance metric (e.g., Manhattan, cosine) and evaluate impact.
  • Apply dimensionality reduction techniques like PCA to see how performance and interpretability are affected.
  • Compare KNN with other classifiers such as Decision Trees, Random Forests, or SVMs for larger datasets or more complex feature spaces.

In summary, KNN provides a solid foundation for understanding proximity-based machine learning and offers valuable intuition for feature-space dynamics in classification tasks.


References