View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.stat.clustering;
19  
20  import java.util.ArrayList;
21  import java.util.Collection;
22  import java.util.List;
23  import java.util.Random;
24  
25  /**
26   * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
27   * @param <T> type of the points to cluster
28   * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
29   * @version $Revision: 771076 $ $Date: 2009-05-03 12:28:48 -0400 (Sun, 03 May 2009) $
30   * @since 2.0
31   */
32  public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
33  
34      /** Random generator for choosing initial centers. */
35      private final Random random;
36  
37      /** Build a clusterer.
38       * @param random random generator to use for choosing initial centers
39       */
40      public KMeansPlusPlusClusterer(final Random random) {
41          this.random = random;
42      }
43  
44      /**
45       * Runs the K-means++ clustering algorithm.
46       * 
47       * @param points the points to cluster
48       * @param k the number of clusters to split the data into
49       * @param maxIterations the maximum number of iterations to run the algorithm
50       *     for.  If negative, no maximum will be used
51       * @return a list of clusters containing the points
52       */
53      public List<Cluster<T>> cluster(final Collection<T> points,
54                                      final int k, final int maxIterations) {
55          // create the initial clusters
56          List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
57          assignPointsToClusters(clusters, points);
58  
59          // iterate through updating the centers until we're done
60          final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 
61          for (int count = 0; count < max; count++) {
62              boolean clusteringChanged = false;
63              List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
64              for (final Cluster<T> cluster : clusters) {
65                  final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
66                  if (!newCenter.equals(cluster.getCenter())) {
67                      clusteringChanged = true;
68                  }
69                  newClusters.add(new Cluster<T>(newCenter));
70              }
71              if (!clusteringChanged) {
72                  return clusters;
73              }
74              assignPointsToClusters(newClusters, points);
75              clusters = newClusters;
76          }
77          return clusters;
78      }
79  
80      /**
81       * Adds the given points to the closest {@link Cluster}.
82       * 
83       * @param <T> type of the points to cluster
84       * @param clusters the {@link Cluster}s to add the points to
85       * @param points the points to add to the given {@link Cluster}s
86       */
87      private static <T extends Clusterable<T>> void
88          assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
89          for (final T p : points) {
90              Cluster<T> cluster = getNearestCluster(clusters, p);
91              cluster.addPoint(p);
92          }
93      }
94  
95      /**
96       * Use K-means++ to choose the initial centers.
97       * 
98       * @param <T> type of the points to cluster
99       * @param points the points to choose the initial centers from
100      * @param k the number of centers to choose
101      * @param random random generator to use
102      * @return the initial centers
103      */
104     private static <T extends Clusterable<T>> List<Cluster<T>>
105         chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
106 
107         final List<T> pointSet = new ArrayList<T>(points);
108         final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
109 
110         // Choose one center uniformly at random from among the data points.
111         final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
112         resultSet.add(new Cluster<T>(firstPoint));
113 
114         final double[] dx2 = new double[pointSet.size()];
115         while (resultSet.size() < k) {
116             // For each data point x, compute D(x), the distance between x and 
117             // the nearest center that has already been chosen.
118             int sum = 0;
119             for (int i = 0; i < pointSet.size(); i++) {
120                 final T p = pointSet.get(i);
121                 final Cluster<T> nearest = getNearestCluster(resultSet, p);
122                 final double d = p.distanceFrom(nearest.getCenter());
123                 sum += d * d;
124                 dx2[i] = sum;
125             }
126 
127             // Add one new data point as a center. Each point x is chosen with
128             // probability proportional to D(x)2
129             final double r = random.nextDouble() * sum;
130             for (int i = 0 ; i < dx2.length; i++) {
131                 if (dx2[i] >= r) {
132                     final T p = pointSet.remove(i);
133                     resultSet.add(new Cluster<T>(p));
134                     break;
135                 }
136             }
137         }
138 
139         return resultSet;
140 
141     }
142 
143     /**
144      * Returns the nearest {@link Cluster} to the given point
145      * 
146      * @param <T> type of the points to cluster
147      * @param clusters the {@link Cluster}s to search
148      * @param point the point to find the nearest {@link Cluster} for
149      * @return the nearest {@link Cluster} to the given point
150      */
151     private static <T extends Clusterable<T>> Cluster<T>
152         getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
153         double minDistance = Double.MAX_VALUE;
154         Cluster<T> minCluster = null;
155         for (final Cluster<T> c : clusters) {
156             final double distance = point.distanceFrom(c.getCenter());
157             if (distance < minDistance) {
158                 minDistance = distance;
159                 minCluster = c;
160             }
161         }
162         return minCluster;
163     }
164 
165 }