首页 > 代码库 > 数据挖掘经典算法——K-means算法
数据挖掘经典算法——K-means算法
算法描述
K-means算法是一种被广泛使用的基于划分的聚类算法,目的是将n个对象会分成k个簇。算法的具体描述如下:
随机选取k个对象作为簇中心;
Do
计算所有对象到这k个簇中心的距离,将距离最近的归入相应的簇;
重新计算每个簇的中心;
计算准则函数V;
While 准则函数的值稳定(或变化小于某个阈值)
其中准则函数V的定义如下:
其中,ui表示第i个簇Si的中心。最终经过T次迭代获取到最终的分类结果,对于第t+1次迭代之后得到的中心,有如下定义:
算法的优缺点
优点:
1) 算法描述简单,高效;
2) 适合用于处理大数据,得到的算法的复杂度大约为O(nkt),n表示对象的数量,k是划分的簇数量,t为迭代次数。通常情况下能够保证k<<n,算法的效率有一定的保障;
3) 算法比较适合处理簇之间划分明确的对象集合;
缺点:
1) k值必须手动的给出,选取k值就显得特别重要了;
2) 不同的初始对象会带来不同的划分结果;
3) 如果对象集合内部包含一些小范围孤立对象,这种基于局部最优的聚类划分算法可能会产生一些错误的划分;
4) 通常判断对象之间远近的依据是欧拉距离,可以尽快得到结果,但同时也带来了一些缺点,比如采用K-means算法处理一下非凸面的簇时。
Kmeans算法的Java实现
1 import java.awt.BorderLayout; 2 import java.awt.Canvas; 3 import java.awt.Color; 4 import java.awt.Dimension; 5 import java.awt.Frame; 6 import java.awt.Graphics; 7 import java.awt.Graphics2D; 8 import java.awt.Paint; 9 import java.awt.event.WindowAdapter; 10 import java.awt.event.WindowEvent; 11 import java.util.ArrayList; 12 import java.util.HashSet; 13 import java.util.List; 14 import java.util.Random; 15 import java.util.Set; 16 17 import ocr.algorithm.KMeans.Cluster; 18 import ocr.algorithm.KMeans.KMeansNode; 19 20 /** 21 * 这是一个简单的Kmeans算法实现 22 * 假设前提包括: 23 * 1.集合中每个对象都是一个位于二维平面的点 24 * 2.对象之间的距离判定以欧氏距离为标准 25 * 3.这只是一个样例程序,主要用于叙述K-means算法的主干部分,一些特殊的情况未曾考虑(数据溢出,性能优化) 26 * @author yahokuma 27 * @email Hazin.lu@gmail.com 28 */ 29 public class KMeans { 30 31 private List<KMeansNode> datas = new ArrayList<KMeans.KMeansNode>(); 32 33 public static class KMeansNode { 34 private double x; 35 private double y; 36 37 38 39 public double getX() { 40 return x; 41 } 42 43 public double getY() { 44 return y; 45 } 46 47 public KMeansNode(double x,double y){ 48 this.x = x; 49 this.y = y; 50 } 51 52 public double distance(KMeansNode n){ 53 return 54 Math.pow( x - n.x , 2 ) + Math.pow( y - n.y , 2 ); 55 } 56 57 } 58 59 public static class Cluster{ 60 private List<KMeansNode> nodes = new ArrayList<KMeans.KMeansNode>(); 61 private KMeansNode center = null; 62 public KMeansNode getCenter() { 63 return center; 64 } 65 public void addNode(KMeansNode n){ 66 this.nodes.add(n); 67 } 68 public Cluster(KMeansNode c){ 69 this.center = c; 70 } 71 public void calculateCenter(){ 72 double x = 0,y = 0; 73 for (KMeansNode n : nodes) { 74 x += n.x; 75 y += n.y; 76 } 77 this.center = new KMeansNode( x / nodes.size(), y / nodes.size()); 78 } 79 80 public double criterion(){ 81 double criterion = 0; 82 calculateCenter(); 83 for (KMeansNode n : nodes) { 84 criterion += center.distance(n); 85 } 86 87 return criterion; 88 } 89 90 public void clear(){ 91 this.nodes.clear(); 92 } 93 94 public List<KMeansNode> getNodes(){ 95 return this.nodes; 96 } 97 98 public void print(){ 99 System.out.println("Contains "+ nodes.size() + " Nodes !"); 100 System.out.println("Center Node is ( "+ getCenter().x + "," + getCenter().y + " )"); 101 } 102 } 103 104 public KMeans(List<KMeansNode> datas){ 105 this.datas = datas; 106 } 107 108 109 private List<KMeansNode> findRandNodes(int k){ 110 List<KMeansNode> rNodes = new ArrayList<KMeans.KMeansNode>(); 111 Set<Integer> rIndexes = new HashSet<Integer>(); 112 Random r = new Random(); 113 Integer rInt = null; 114 for (int i = 0; i < k; i++) { 115 rInt = r.nextInt(datas.size()); 116 while(rIndexes.contains(rInt)) 117 rInt = r.nextInt(datas.size()); 118 119 rIndexes.add(rInt); 120 rNodes.add( datas.get(rInt)); 121 } 122 123 return rNodes; 124 } 125 126 private double calculateCriterion(List<Cluster> clusters){ 127 double res = 0; 128 for (Cluster c : clusters) { 129 res += c.criterion(); 130 } 131 return res; 132 } 133 134 135 136 public List<Cluster> partition(int k){ 137 List<KMeansNode> centerNodes = findRandNodes(k); 138 List<Cluster> clusters = new ArrayList<KMeans.Cluster>(); 139 for (KMeansNode c : centerNodes) { 140 clusters.add(new Cluster(c)); 141 } 142 143 double minDistance = Double.MAX_VALUE , distance; 144 Cluster minCluster = null; 145 double lastCriterion , criterion= Double.MAX_VALUE; 146 147 do{ 148 for (Cluster c : clusters) { 149 c.clear(); 150 } 151 lastCriterion = criterion; 152 153 154 for (KMeansNode n : datas) { 155 minDistance = Double.MAX_VALUE; 156 for (Cluster c : clusters) { 157 distance = c.getCenter().distance(n); 158 if( distance < minDistance ){ 159 minDistance = distance; 160 minCluster = c; 161 } 162 } 163 minCluster.addNode(n); 164 } 165 criterion = calculateCriterion(clusters); 166 167 }while( criterion != lastCriterion); 168 169 return clusters; 170 171 } 172 173 /** 174 *随机生成了1000个平面点,将其划分为4个簇(k=4) 175 *由于点的坐标都是随机生成的,在空间上分布均匀; 176 *从结果中可以看出K-means对于处理这种边界不分明的对象集合时并不能很好的进行区分; 177 *但是一般情况,经过处理还是会将整个平面均匀得划分成四个部分 178 **/ 179 public static void main(String args[]){ 180 Random r = new Random(); 181 List<KMeansNode> nodes = new ArrayList<KMeans.KMeansNode>(); 182 for (int i = 0; i < 1000; i++) { 183 nodes.add(new KMeansNode(r.nextDouble() * 1000, r.nextDouble() * 1000)); 184 } 185 186 KMeans kmeans = new KMeans(nodes); 187 List<Cluster> clusters = kmeans.partition(4); 188 for( Cluster c : clusters){ 189 c.print(); 190 } 191 192 193 Frame frame = new Frame("K-means Test!"); 194 frame.addWindowListener(new WindowAdapter(){ 195 public void windowClosing(WindowEvent e) { 196 System.exit(0); 197 } 198 }); 199 frame.add(new KMeansCanvas(clusters),BorderLayout.CENTER); 200 frame.pack(); 201 frame.setVisible(true); 202 } 203 204 } 205 206 207 /** 208 * 209 * @author yahokuma 210 * @email Hazin.lu@gmail.com 211 */ 212 class KMeansCanvas extends Canvas { 213 public final static Paint[] PAINT_COLOR = {Color.BLUE,Color.RED, Color.ORANGE, Color.BLACK}; 214 215 private List<Cluster> clusters = null; 216 public KMeansCanvas(List<Cluster> clusters) { 217 this.setBackground(Color.WHITE); 218 this.clusters = clusters; 219 } 220 @Override 221 public void paint(Graphics g) { 222 drawKarel(g); 223 } 224 @Override 225 public Dimension getPreferredSize() { 226 return new Dimension(1000,1000); 227 } 228 229 230 private void drawKarel(Graphics g) { 231 232 233 Random r = new Random(); 234 int i = 0 ; 235 for (Cluster c : clusters) { 236 Graphics2D g2d= (Graphics2D) g; 237 g2d.setPaint(PAINT_COLOR[i++]); 238 for (KMeansNode n : c.getNodes()) { 239 g2d.drawRect((int)n.getX(), (int)n.getY() , 2, 2); 240 g2d.fillRect((int)n.getX(), (int)n.getY() , 2, 2); 241 } 242 } 243 } 244 private static final long serialVersionUID = 1L; 245 }
参考资料
http://en.wikipedia.org/wiki/K-means_clustering
http://blog.csdn.net/aladdina/article/details/4141177
http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006910.html
http://zh.wikipedia.org/wiki/K%E5%B9%B3%E5%9D%87%E7%AE%97%E6%B3%95