Getting started with K-means clustering in Python

Dr J Rogel-Salazar2023-07-21 | 13 min read

Return to blog home

Imagine you are an accomplished marketeer establishing a new campaign for a product and want to find appropriate segments to target, or you are lawyer interested in grouping together different documents depending on their content, or you are analysing credit card transactions to identify similar patterns. In all those cases, and many more, data science can be used to help clustering your data. Clustering analysis is an important area of unsupervised learning that helps us group data together. We have discussed in this blog the difference between supervised and unsupervised learning in the past. As a reminder, we use unsupervised learning when labelled data is not available for our purposes but we want to explore common features in the data. In the examples above, as a marketeer we may find common demographic characteristics in our target audience, or as a lawyer we establish different common themes in the documents in question or, as a fraud analyst we establish common transactions that may highlight outliers in someone’s account.

In all those cases, clustering offers a hand at finding those common traces and there are a variety of clustering algorithms out there. In a previous post, we talked about density based clustering where we discussed its use in anomaly detection, similar to the credit card transactions use-case above. In that post we argued that other algorithms may be easier to understand and implement, for example k-means, and the aim of this post if to do exactly that.

We will first establish the notion of a cluster and determine an important part in the implementation of k-means: centroids. We will see how k-means approaches the issue of similarity and how the groups are updated on every iteration until a stopping condition is met. We will illustrate this with a Python implementation and will finish by looking at how to use this algorithm via the Scikit-learn library.

You can use the code in this post in your own machine, provided you have Python 3.x installed. Alternatively, you can just sign up for free trial access to the Domino MLOps platform and access the code used in this article directly (use the Access Project button at the end of the article) . Let's get started.

K-Means - what does it mean?

We mentioned that we are interested in finding out commonalities among our data observations. One way to determine that commonality or similarity is through a measure of distance among the data points. The shorter the distance, the more similar the observations are. There are different ways in which we can measure that distance and one that is very familiar to a lot of people is the Euclidean distance. That’s right! The same one we are taught when learning thePythagorean theorem. Let us take a look and consider two data observations over two attributes \(a\) and \(b\). Point \(p_1\) has coordinates \((a_1,b_1 )\) and point \(p_2=(a_2,b_2 )\).

K Means Distance

The distance \(\overline{p_1 p_2 }\) is given by:

$$\overline{p_1 p_2 }=\sqrt{(a_2-a_1 )^2+(b_2-b_1 )^2 }$$


The expression above can be extended to more than 2 attributes and the distance can be measured between any two points. For a dataset with \(n\) observations, we assume there are \(k\) groups or clusters, and our aim is to determine which observation corresponds to any of those \(k\) groups. This is an important point to emphasise: the algorithm will not give us the number of clusters, instead we need to define the number\(k\) in advance. We may be able to run the algorithm with different values for \(k\) and determine the best possible solution.

In a nutshell, \(k\)-means clustering tries to minimise the distances between the observations that belong to a cluster and maximise the distance between the different clusters. In that way, we have cohesion between the observations that belong to a group, while observations that belong to a different group are kept further apart. Please note that as we explained in this post, \(k\)-means is exhaustive in the sense that every single observation in the dataset will be forced to be part of one of the \(k\) clusters assumed.

It should now be clear where the \(k\) in \(k\)-means comes from, but what about the “means” part? Well, it turns out that as part of the algorithm we are also looking to identify the centre for each cluster. We call this a centroid, and as we assign observations to one cluster or the other, we update the position of the cluster centroid. This is done by taking the mean (average if you will) of all the data points that have been included in that cluster. Easy!

A recipe for k-means

The recipe for \(k\)-means is quite straightforward.

  1. Decide how many clusters you want, i.e. choose k
  2. Randomly assign a centroid to each of the k clusters
  3. Calculate the distance of all observation to each of the k centroids
  4. Assign observations to the closest centroid
  5. Find the new location of the centroid by taking the mean of all the observations in each cluster
  6. Repeat steps 3-5 until the centroids do not change position

Et voilà!

Take a look at the representation below where the steps are depicted in a schematic way for a 2-dimensional space. The same steps can be applied to more dimensions (i.e. more features or attributes). For simplicity, in the schematic we only show the distance measured to the closest centroid, but in practice all distances need to be considered.

Explaining K Means Clustering

Getting Started

For the purposes of our implementation, we will take a look at some data with 2 attributes for simplicity. We will then look at an example with more dimension. To get us started we will use a dataset we have prepared and it is available here with the name kmeans_blobs.csv. The data set contains 4 columns with the following information:

  1. ID: A unique identifier for the observation
  2. x: Attribute corresponding to an x coordinate
  3. y: Attribute corresponding to a y coordinate
  4. Cluster: An identifier for the cluster the observation belongs to

We will discard column 4 for our analysis, but it may be useful to check the results of the application of \(k\)-means. We will do this in our second example later on. Let us start by reading the dataset:

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
%matplotlib inline

blobs = pd.read_csv('kmeans_blobs.csv')
colnames = list(blobs.columns[1:-1])
blobs.head()
First 5 rows of the K Means dataset

Let us look at the observations in the dataset. We will use the Cluster column to show the different groups that are present in the dataset. Our aim will be to see if the application of the algorithm reproduces closely the groupings.

customcmap = ListedColormap(["crimson", "mediumblue", "darkmagenta"])

fig, ax = plt.subplots(figsize=(8, 6))
plt.scatter(x=blobs['x'], y=blobs['y'], s=150,
            c=blobs['cluster'].astype('category'), 
            cmap = customcmap)
ax.set_xlabel(r'x', fontsize=14)
ax.set_ylabel(r'y', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()
Scatter plot visualization of k means clustering dataset

Let's now look at our recipe.

Steps 1 and 2 - Define \(k\) and initiate the centroids

First we need 1) to decide how many groups we have and 2) assign the initial centroids randomly. In this case let us consider \(k=3\), and as for the centroids, well, they have to be in the same range as the dataset itself. So one option is to randomly pick \(k\) observations and use their coordinates to initialise the centroids:

def initiate_centroids(k, dset):
    '''
    Select k data points as centroids
    k: number of centroids
    dset: pandas dataframe
    '''
    centroids = dset.sample(k)
    return centroids

np.random.seed(42)
k=3
df = blobs[['x','y']]
centroids = initiate_centroids(k, df)
centroids
K means centroids

Step 3 - Calculate distance

We now need to calculate the distance between each of the centroids and the data points. We will assign the data point to the centroid that gives us the minimum error. Let us create a function to calculate the root of square errors:

def rsserr(a,b):
    '''
    Calculate the root of sum of squared errors. 
    a and b are numpy arrays
    '''
    return np.square(np.sum((a-b)**2)) 

Let us pick a data point and calculate the error so we can see how this works in practice. We will use point , which is in fact one of the centroids we picked above. As such, we expect that the error for that point and the third centroid is zero. We therefore would assign that data point to the second centroid. Let’s take a look:

for i, centroid in enumerate(range(centroids.shape[0])):
    err = rsserr(centroids.iloc[centroid,:], df.iloc[36,:])
    print('Error for centroid {0}: {1:.2f}'.format(i, err))
Error for centroid 0: 384.22
Error for centroid 1: 724.64
Error for centroid 2: 0.00

Step 4 - Assign centroids

We can use the idea from Step 3 to create a function that helps us assign the data points to corresponding centroids. We will calculate all the errors associated to each centroid, and then pick the one with the lowest value for assignation:

def centroid_assignation(dset, centroids):
    '''
    Given a dataframe `dset` and a set of `centroids`, we assign each
    data point in `dset` to a centroid. 
    - dset - pandas dataframe with observations
    - centroids - pa das dataframe with centroids
    '''
    k = centroids.shape[0]
    n = dset.shape[0]
    assignation = []
    assign_errors = []

    for obs in range(n):
        # Estimate error
        all_errors = np.array([])
        for centroid in range(k):
            err = rsserr(centroids.iloc[centroid, :], dset.iloc[obs,:])
            all_errors = np.append(all_errors, err)

        # Get the nearest centroid and the error
        nearest_centroid =  np.where(all_errors==np.amin(all_errors))[0].tolist()[0]
        nearest_centroid_error = np.amin(all_errors)

        # Add values to corresponding lists
        assignation.append(nearest_centroid)
        assign_errors.append(nearest_centroid_error)

    return assignation, assign_errors

Let us add some columns to our data containing the centroid assignations and the error incurred. Furthermore, we can use this to update our scatter plot showing the centroids (denoted with squares) and we colour the observations according to the centroid they have been assigned to:

df['centroid'], df['error'] = centroid_assignation(df, centroids)
df.head()
K means centroids assigned in the dataframe
fig, ax = plt.subplots(figsize=(8, 6))
plt.scatter(df.iloc[:,0], df.iloc[:,1],  marker = 'o', 
            c=df['centroid'].astype('category'), 
            cmap = customcmap, s=80, alpha=0.5)
plt.scatter(centroids.iloc[:,0], centroids.iloc[:,1],  
            marker = 's', s=200, c=[0, 1, 2], 
            cmap = customcmap)
ax.set_xlabel(r'x', fontsize=14)
ax.set_ylabel(r'y', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()
K means cluster scatter plot visualization with centroids

Let us see the total error by adding all the contributions. We will take a look at this error as a measure of convergence. In other words, if the error does not change, we can assume that the centroids have stabilised their location and we can terminate our iterations. In practice, we need to be mindful of having found a local minimum (outside the scope of this post).

print("The total error is {0:.2f}".format(df['error'].sum()))
The total error is 11927659.01

Step 5 - Update centroid location

Now that we have a first attempt at defining our clusters, we need to update the position of the k centroids. We do this by calculating the mean of the position of the observations assigned to each centroid. Let take a look:

centroids = df.groupby('centroid').agg('mean').loc[:, colnames].reset_index(drop = True)
centroids
K means centroid locations

We can verify that the position has been updated. Let us look again at our scatter plot:

fig, ax = plt.subplots(figsize=(8, 6))
plt.scatter(df.iloc[:,0], df.iloc[:,1],  marker = 'o', 
            c=df['centroid'].astype('category'), 
            cmap = customcmap, s=80, alpha=0.5)
plt.scatter(centroids.iloc[:,0], centroids.iloc[:,1],  
            marker = 's', s=200,
            c=[0, 1, 2], cmap = customcmap)
ax.set_xlabel(r'x', fontsize=14)
ax.set_ylabel(r'y', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()

Step 6 - Repeat steps 3-5

Updated k means cluster scatter plot visualization with centroids

Now we go back to calculate the distance to each centroid, assign observations and update the centroid location. This calls for a function to encapsulate the loop:

def kmeans(dset, k=2, tol=1e-4):
    '''
    K-means implementationd for a 
    `dset`:  DataFrame with observations
    `k`: number of clusters, default k=2
    `tol`: tolerance=1E-4
    '''
    # Let us work in a copy, so we don't mess the original
    working_dset = dset.copy()
    # We define some variables to hold the error, the 
    # stopping signal and a counter for the iterations
    err = []
    goahead = True
    j = 0
    
    # Step 2: Initiate clusters by defining centroids 
    centroids = initiate_centroids(k, dset)

    while(goahead):
        # Step 3 and 4 - Assign centroids and calculate error
        working_dset['centroid'], j_err = centroid_assignation(working_dset, centroids) 
        err.append(sum(j_err))
        
        # Step 5 - Update centroid position
        centroids = working_dset.groupby('centroid').agg('mean').reset_index(drop = True)

        # Step 6 - Restart the iteration
        if j>0:
            # Is the error less than a tolerance (1E-4)
            if err[j-1]-err[j]<=tol:
                goahead = False
        j+=1

    working_dset['centroid'], j_err = centroid_assignation(working_dset, centroids)
    centroids = working_dset.groupby('centroid').agg('mean').reset_index(drop = True)
    return working_dset['centroid'], j_err, centroids

OK, we are now ready to apply our function. We will clean our dataset first and let the algorithm run:

np.random.seed(42)
df['centroid'], df['error'], centroids =  kmeans(df[['x','y']], 3)
df.head()
K_Means_Fig7

Let us see the location of the final centroids:

centroids
Centroid locations in k means clustering dataframe

And in a graphical way, let us see our clusters:

fig, ax = plt.subplots(figsize=(8, 6))
plt.scatter(df.iloc[:,0], df.iloc[:,1],  marker = 'o', 
            c=df['centroid'].astype('category'), 
            cmap = customcmap, s=80, alpha=0.5)
plt.scatter(centroids.iloc[:,0], centroids.iloc[:,1],  
            marker = 's', s=200, c=[0, 1, 2], 
            cmap = customcmap)
ax.set_xlabel(r'x', fontsize=14)
ax.set_ylabel(r'y', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()

As we can see the three groups have been obtained. In this particular example the data is such that the distinction between the groups is clear. However, we may not be as lucky in every case. So the question about how many groups there are still remains. We can use a screen plot to help us with the error minimisation by looking at running the algorithm with a sequence \(k=1,2,3,4,...\) and look for the “elbow” in the plot indicating a good number of clusters to use:

K means cluster with updated centroids
err_total = []
n = 10

df_elbow = blobs[['x','y']]

for i in range(n):
    _, my_errs, _ = kmeans(df_elbow, i+1)
    err_total.append(sum(my_errs))
fig, ax = plt.subplots(figsize=(8, 6))
plt.plot(range(1,n+1), err_total, linewidth=3, marker='o')
ax.set_xlabel(r'Number of clusters', fontsize=14)
ax.set_ylabel(r'Total error', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()

We can now apply the “elbow rule” which is a heuristic to help us determine the number of clusters. If we think of the line shown above as depicting an arm, then the “elbow” is the point of inflection. In this case the “elbow” is located between 2 and 4 clusters, giving us an indication that choosing 3 is a good fit.

Elbow rule for k means clustering

Using Scikit-learn

We have seen how to make an initial implementation of the algorithm, but in many cases you may want to stand on the shoulders of giants and use other tried and tested modules to help you with your machine learning work. In this case, Scikit-learn is a good choice and it has a very nice implementation for \(k\)-means. If you want to know more about the algorithm and its evaluation you can take a look at Chapter 5 of Data Science and Analytics with Python where I use a wine dataset for the discussion.

In this case we will show how k-means can be implemented in a couple of lines of code using the well-known Iris dataset. We can load it directly from Scikit-learn and we will shuffle the data to ensure the points are not listed in any particular order.

from sklearn.cluster import KMeans
from sklearn import datasets
from sklearn.utils import shuffle

# import some data to play with
iris = datasets.load_iris()
X = iris.data 
y = iris.target
names = iris.feature_names
X, y = shuffle(X, y, random_state=42)

We can call the KMeans implementation to instantiate a model and fit it. The parameter n_clusters is the number of clusters \(k\). In the example below we request 3 clusters:

model = KMeans(n_clusters=3, random_state=42) 
iris_kmeans = model.fit(X)

That is it! We can look at the labels that the algorithm has provided as follows:

iris_kmeans.labels_
array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 1,
       0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0,
       0, 1, 1, 2, 1, 2, 1, 2, 1, 0, 2, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0,
       1, 2, 0, 1, 1, 0, 1, 1, 1, 1, 2, 1, 0, 1, 2, 0, 0, 1, 2, 0, 1, 0,
       0, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 2, 2, 0, 1,
       1, 1, 1, 2, 0, 2, 1, 2, 1, 1, 1, 0, 1, 1, 0, 1, 2, 2, 0, 1, 2, 2,
       0, 2, 0, 2, 2, 2, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 1, 2], dtype=int32)

In order to make a comparison, let us reorder the labels:

y = np.choose(y, [1, 2, 0]).astype(int)
y
array([2, 1, 0, 2, 2, 1, 2, 0, 2, 2, 0, 1, 1, 1, 1, 2, 0, 2, 2, 0, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 1, 1, 0, 2, 1, 1, 1, 0, 2, 2, 1,
       1, 2, 0, 0, 2, 0, 2, 0, 2, 1, 0, 2, 1, 1, 1, 2, 0, 1, 1, 1, 2, 1,
       2, 0, 1, 2, 0, 1, 0, 0, 2, 2, 0, 2, 1, 2, 0, 1, 1, 2, 2, 1, 0, 1,
       1, 2, 2, 0, 2, 0, 0, 2, 1, 1, 0, 0, 1, 1, 1, 2, 0, 1, 0, 0, 1, 2,
       2, 0, 2, 0, 1, 0, 2, 0, 2, 2, 2, 1, 2, 2, 1, 2, 0, 0, 1, 2, 0, 0,
       1, 0, 1, 2, 0, 0, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1, 2, 0])

We can check how many observations were correctly assigned. We do this with the help of a confusion matrix:

from sklearn.metrics import confusion_matrix
conf_matrix=confusion_matrix(y, iris_kmeans.labels_) 

fig, ax = plt.subplots(figsize=(7.5, 7.5))
ax.matshow(conf_matrix, cmap=plt.cm.Blues, alpha=0.3)
for i in range(conf_matrix.shape[0]):
    for j in range(conf_matrix.shape[1]):
        ax.text(x=j, y=i,s=conf_matrix[i, j], va='center', 
                ha='center', size='xx-large')
 
plt.xlabel('Predictions', fontsize=18)
plt.ylabel('Actuals', fontsize=18)
plt.title('Confusion Matrix', fontsize=18)
plt.show()

As we can see, most of the observations were correctly identified. In particular those for cluster 1 seem to all have been captured.

K means cluster confusion matrix

Let us look at the location of the final clusters:

iris_kmeans.cluster_centers_
array([[5.006     , 3.428     , 1.462     , 0.246     ],
       [5.9016129 , 2.7483871 , 4.39354839, 1.43387097],
       [6.85      , 3.07368421, 5.74210526, 2.07105263]])

And we can look at some 3D graphics. In this case we will plot the following features:

• Petal width
• Sepal length
• Petal length

As you can see, there are a few observations that differ in colour between the two plots.

fig = plt.figure(figsize=(20, 10))
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
ax1.scatter(X[:, 3], X[:, 0], X[:, 2], 
            c=iris_kmeans.labels_.astype(float), 
           edgecolor="k", s=150, cmap=customcmap)
ax1.view_init(20, -50)
ax1.set_xlabel(names[3], fontsize=12)
ax1.set_ylabel(names[0], fontsize=12)
ax1.set_zlabel(names[2], fontsize=12)
ax1.set_title("K-Means Clusters for the Iris Dataset", fontsize=12)

ax2 = fig.add_subplot(1, 2, 2, projection='3d')

for label, name in enumerate(['virginica','setosa','versicolor']):
    ax2.text3D(
        X[y == label, 3].mean(),
        X[y == label, 0].mean(),
        X[y == label, 2].mean() + 2,
        name,
        horizontalalignment="center",
        bbox=dict(alpha=0.2, edgecolor="w", facecolor="w"),
    )

ax2.scatter(X[:, 3], X[:, 0], X[:, 2], 
            c=y, edgecolor="k", s=150, 
            cmap=customcmap)
ax2.view_init(20, -50)
ax2.set_xlabel(names[3], fontsize=12)
ax2.set_ylabel(names[0], fontsize=12)
ax2.set_zlabel(names[2], fontsize=12)
ax2.set_title("Actual Labels for the Iris Dataset", fontsize=12)
fig.show()
K Means clusters for iris data set vs the actual labels in the dataset

Summary

In this post we have explained the ideas behind the \(k\)-means algorithm and provided a simple implementation of these ideas in Python. I hope you agree that it is a very straightforward algorithm to understand and in case you want to use a more robust implementation, Scikit-learn has us covered. Given its simplicity, \(k\)-means is a very popular choice for starting up with clustering analysis.

To get access to the code/data from this article, please use the Access Project button below.

Dr Jesus Rogel-Salazar is a Research Associate in the Photonics Group in the Department of Physics at Imperial College London. He obtained his PhD in quantum atom optics at Imperial College in the group of Professor Geoff New and in collaboration with the Bose-Einstein Condensation Group in Oxford with Professor Keith Burnett. After completion of his doctorate in 2003, he took a posdoc in the Centre for Cold Matter at Imperial and moved on to the Department of Mathematics in the Applied Analysis and Computation Group with Professor Jeff Cash.