K-means Clustering in MySQL

Clustering is about finding data points that are grouped together. K-means clustering a fairly simple clustering algorithm. In the most basic version you pick a number of clusters (K), assign random “centroids” to the them, and iterate these two steps until convergence:

  1. Cluster assignment: assign data points to the cluster with closest centroid.
  2. Cluster update: for each cluster, set the centroid to the mean of the data points assigned to it.

While this algorithm sometimes produces suboptimal clusterings, it is fast and really easy to to implement in SQL. Suppose we have some address data, already geocoded into latitude-longitude pairs, and we want to find clusters of addresses that lie close to each other. We’ll use two tables, km_data to store the data and the cluster assigned to each point, and km_clusters for the cluster centers:

create table km_data (id int primary key, cluster_id int,
    lat double, lng double);
create table km_clusters (id int auto_increment primary key,
    lat double, lng double
);

The K-means algorithm can now be implemented with the following procedure.

DELIMITER //
CREATE PROCEDURE kmeans(v_K int)
BEGIN
TRUNCATE km_clusters;
-- initialize cluster centers
INSERT INTO km_clusters (lat, lng) SELECT lat, lng FROM km_data LIMIT v_K;
REPEAT
    -- assign clusters to data points
    UPDATE km_data d SET cluster_id = (SELECT id FROM km_clusters c 
        ORDER BY POW(d.lat-c.lat,2)+POW(d.lng-c.lng,2) ASC LIMIT 1);
    -- calculate new cluster center
    UPDATE km_clusters C, (SELECT cluster_id, 
        AVG(lat) AS lat, AVG(lng) AS lng 
        FROM km_data GROUP BY cluster_id) D 
    SET C.lat=D.lat, C.lng=D.lng WHERE C.id=D.cluster_id;
UNTIL ROW_COUNT() = 0 END REPEAT;
END//

… and a necessary sample of the output, based on address data from an NGO:

K-means algorithm output

K-means algorithm output on geocoded address data

Discussion

The procedure initializes the cluster centers as the first K data points. This works sufficiently well if the data points are in no particular order, although better methods such as K-means++ exist.

The K-means algorithm is run until it has fully converged, that is, the cluster centers no longer move at all. A more robust algorithm might use a different stopping condition, for example detecting when the cluster centers are hardly moving.

2 Comments

  1. Michael McClennen
    Posted 2013/07/12 at 22:19 | Permalink

    Thank you so much for posting this! You saved me a lot of time looking for a good SQL implementation of the k-means clustering algorithm.

    I played around with it a bit, and found that for small data sets (in my case, ~14000 points and ~400 clusters) I got about 30% better performance on the cluster-updating step by defining an auxiliary table as follows:

    CREATE TABLE km_aux (id int primary key,
    cluster_id int);

    and then using the following SQL statements to compute the closest cluster for each point:

    INSERT INTO km_aux (id, cluster_id)
    SELECT d.id, c.id
    FROM km_data as d JOIN km_clusters as c
    ORDER BY POW(d.lat-c.lat,2)+POW(d.lng-c.lng,2) ASC;

    UPDATE km_data as d JOIN km_aux as a using (id) SET d.cluster_id = a.cluster_id;

    If the number of data points is n and the number of clusters is k, then this implementation is O(kn * log(kn)) whereas yours is O(kn). But I am hypothesizing that the constant overhead on each invocation of the subquery in your implementation outweighs this difference for small data sets.

    I found this post originally via a Duck-Duck-Go search, but will start reading your blog regularly now :-)

  2. Thiago Mata
    Posted 2015/10/09 at 19:11 | Permalink

    Based on your work I have created a version of it into postgresql. https://gist.github.com/thiagomata/a9737c3455d6248bef9f