Building fast k-means from scratch

Asad Ali
2 min readMar 8, 2022
Photo by Mel Poole on Unsplash

Clustering algorithms are a fundamental part of the unsupervised learning approach, and centriod-based algorithms are very commonly used in the industry. K-means is the most widely-used centroid-based clustering algorithm. The advantage of centroid-based algorithms is speed, efficiency and non-sensitivity to outliers. Fundamentally, clustering is all about grouping similar entities together and finding some pattern amongst the data variables as relevant to our interests.

Any K-means algorithm is going to do four things fundamentally

  1. Partition datasets into non-empty k subsets
  2. Compute seed points or centroids
  3. Based on some distance metrics, assign each centroid to the data cluster
  4. Repeat 2 and 3 until convergence

Although several packages are available for computing k-means in this article, I will start from absolute scratch using plain python and a simplified approch.

First, let us create a simple 2D array. Let us assume we are interested in 2 centroids.

Visualization of clusters, random centroids

Initialize Centeriods

First we will randomly sample two points using np permutations operations and pick two points.

Compute Distance

Now coming to step 3 of the above approach, I will compute the distance of my centroids to each of the points and then store the nearest centroids in a python list or NumPy array. Possible distance metrics are euclidean distance, Manhattan distance, or Chebychev. The below code is a simple nested “for” loop in which each index of the dataset and then computing the distance metrics for each of the possible centroids and simple assigning the nearest using “argmin”

Compute means of points of clusters

Now to compute the final centroid of a cluster, we will simply take the means of the data points assigned to that centroid. Obviously, we can replace the mean function with any other function e.g. mode, median as per the application requirement.

Putting it all together

This takes us to part 4, which is iterating over the dataset until convergence. Let us put everything into a python class to complete this part reliably. To speed up the distance calculation O(n2) part, we can use numba jit compiler built for numpy.

For the class, we have a constructor which takes in dataset and the number of clusters. The train functions perform the training part based on set number of maximum iterations. The train method is based on popular callbacks scheme, where we will call compute functions during training scheme.

--

--

Asad Ali

Data Science, Analytics, and Machine Learning Professional.