1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.stat.clustering;
19
20 import java.util.ArrayList;
21 import java.util.Collection;
22 import java.util.List;
23 import java.util.Random;
24
25
26
27
28
29
30
31
32 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
33
34
35 private final Random random;
36
37
38
39
40 public KMeansPlusPlusClusterer(final Random random) {
41 this.random = random;
42 }
43
44
45
46
47
48
49
50
51
52
53 public List<Cluster<T>> cluster(final Collection<T> points,
54 final int k, final int maxIterations) {
55
56 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
57 assignPointsToClusters(clusters, points);
58
59
60 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
61 for (int count = 0; count < max; count++) {
62 boolean clusteringChanged = false;
63 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
64 for (final Cluster<T> cluster : clusters) {
65 final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
66 if (!newCenter.equals(cluster.getCenter())) {
67 clusteringChanged = true;
68 }
69 newClusters.add(new Cluster<T>(newCenter));
70 }
71 if (!clusteringChanged) {
72 return clusters;
73 }
74 assignPointsToClusters(newClusters, points);
75 clusters = newClusters;
76 }
77 return clusters;
78 }
79
80
81
82
83
84
85
86
87 private static <T extends Clusterable<T>> void
88 assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
89 for (final T p : points) {
90 Cluster<T> cluster = getNearestCluster(clusters, p);
91 cluster.addPoint(p);
92 }
93 }
94
95
96
97
98
99
100
101
102
103
104 private static <T extends Clusterable<T>> List<Cluster<T>>
105 chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
106
107 final List<T> pointSet = new ArrayList<T>(points);
108 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
109
110
111 final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
112 resultSet.add(new Cluster<T>(firstPoint));
113
114 final double[] dx2 = new double[pointSet.size()];
115 while (resultSet.size() < k) {
116
117
118 int sum = 0;
119 for (int i = 0; i < pointSet.size(); i++) {
120 final T p = pointSet.get(i);
121 final Cluster<T> nearest = getNearestCluster(resultSet, p);
122 final double d = p.distanceFrom(nearest.getCenter());
123 sum += d * d;
124 dx2[i] = sum;
125 }
126
127
128
129 final double r = random.nextDouble() * sum;
130 for (int i = 0 ; i < dx2.length; i++) {
131 if (dx2[i] >= r) {
132 final T p = pointSet.remove(i);
133 resultSet.add(new Cluster<T>(p));
134 break;
135 }
136 }
137 }
138
139 return resultSet;
140
141 }
142
143
144
145
146
147
148
149
150
151 private static <T extends Clusterable<T>> Cluster<T>
152 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
153 double minDistance = Double.MAX_VALUE;
154 Cluster<T> minCluster = null;
155 for (final Cluster<T> c : clusters) {
156 final double distance = point.distanceFrom(c.getCenter());
157 if (distance < minDistance) {
158 minDistance = distance;
159 minCluster = c;
160 }
161 }
162 return minCluster;
163 }
164
165 }