首页 > 代码库 > 数据挖掘经典算法——K-means算法

数据挖掘经典算法——K-means算法

算法描述

  K-means算法是一种被广泛使用的基于划分的聚类算法,目的是将n个对象会分成k个簇。算法的具体描述如下:

随机选取k个对象作为簇中心;

Do

          计算所有对象到这k个簇中心的距离,将距离最近的归入相应的簇;

          重新计算每个簇的中心;

          计算准则函数V;

While 准则函数的值稳定(或变化小于某个阈值)

 

  其中准则函数V的定义如下:

         

  其中,ui表示第i个簇Si的中心。最终经过T次迭代获取到最终的分类结果,对于第t+1次迭代之后得到的中心,有如下定义:

           

算法的优缺点

  优点:

    1)         算法描述简单,高效;

    2)         适合用于处理大数据,得到的算法的复杂度大约为O(nkt),n表示对象的数量,k是划分的簇数量,t为迭代次数。通常情况下能够保证k<<n,算法的效率有一定的保障;

    3)         算法比较适合处理簇之间划分明确的对象集合;

  缺点:

    1)         k值必须手动的给出,选取k值就显得特别重要了;

    2)         不同的初始对象会带来不同的划分结果;

    3)         如果对象集合内部包含一些小范围孤立对象,这种基于局部最优的聚类划分算法可能会产生一些错误的划分;

    4)         通常判断对象之间远近的依据是欧拉距离,可以尽快得到结果,但同时也带来了一些缺点,比如采用K-means算法处理一下非凸面的簇时。

Kmeans算法的Java实现

  1 import java.awt.BorderLayout;
  2 import java.awt.Canvas;
  3 import java.awt.Color;
  4 import java.awt.Dimension;
  5 import java.awt.Frame;
  6 import java.awt.Graphics;
  7 import java.awt.Graphics2D;
  8 import java.awt.Paint;
  9 import java.awt.event.WindowAdapter;
 10 import java.awt.event.WindowEvent;
 11 import java.util.ArrayList;
 12 import java.util.HashSet;
 13 import java.util.List;
 14 import java.util.Random;
 15 import java.util.Set;
 16 
 17 import ocr.algorithm.KMeans.Cluster;
 18 import ocr.algorithm.KMeans.KMeansNode;
 19 
 20 /**
 21  * 这是一个简单的Kmeans算法实现
 22  * 假设前提包括:
 23  * 1.集合中每个对象都是一个位于二维平面的点
 24  * 2.对象之间的距离判定以欧氏距离为标准
 25  * 3.这只是一个样例程序,主要用于叙述K-means算法的主干部分,一些特殊的情况未曾考虑(数据溢出,性能优化)
 26  * @author yahokuma
 27  * @email Hazin.lu@gmail.com
 28  */
 29 public class KMeans {
 30     
 31     private List<KMeansNode> datas = new ArrayList<KMeans.KMeansNode>();
 32     
 33     public static class KMeansNode {
 34         private double x;
 35         private double y;
 36         
 37         
 38         
 39         public double getX() {
 40             return x;
 41         }
 42 
 43         public double getY() {
 44             return y;
 45         }
 46 
 47         public KMeansNode(double x,double y){
 48             this.x = x;
 49             this.y = y;
 50         }
 51         
 52         public double distance(KMeansNode n){
 53             return 
 54                 Math.pow( x - n.x , 2 ) + Math.pow( y - n.y , 2 );
 55         }
 56         
 57     }
 58     
 59     public static class Cluster{
 60         private List<KMeansNode> nodes = new ArrayList<KMeans.KMeansNode>();
 61         private KMeansNode center = null;
 62         public KMeansNode getCenter() {
 63             return center;
 64         }
 65         public void addNode(KMeansNode n){
 66             this.nodes.add(n);
 67         }
 68         public Cluster(KMeansNode c){
 69             this.center = c;
 70         }
 71         public void calculateCenter(){
 72             double x = 0,y = 0;
 73             for (KMeansNode n : nodes) {
 74                 x += n.x;
 75                 y += n.y;
 76             }
 77             this.center = new KMeansNode( x / nodes.size(), y / nodes.size());
 78         }
 79         
 80         public double criterion(){
 81             double criterion = 0;
 82             calculateCenter();
 83             for (KMeansNode n : nodes) {
 84                 criterion += center.distance(n);
 85             }
 86             
 87             return criterion;
 88         }
 89         
 90         public void clear(){
 91             this.nodes.clear();
 92         }
 93         
 94         public List<KMeansNode> getNodes(){
 95             return this.nodes;
 96         }
 97         
 98         public void print(){
 99             System.out.println("Contains "+ nodes.size() + " Nodes !");
100             System.out.println("Center Node is ( "+ getCenter().x + "," + getCenter().y + " )");
101         }
102     }
103     
104     public KMeans(List<KMeansNode> datas){
105         this.datas = datas;
106     }
107     
108     
109     private List<KMeansNode> findRandNodes(int k){
110         List<KMeansNode> rNodes = new ArrayList<KMeans.KMeansNode>();
111         Set<Integer> rIndexes = new HashSet<Integer>(); 
112         Random r = new Random();
113         Integer rInt = null;
114         for (int i = 0; i < k; i++) {
115             rInt = r.nextInt(datas.size());
116             while(rIndexes.contains(rInt))
117                 rInt = r.nextInt(datas.size());
118             
119             rIndexes.add(rInt);
120             rNodes.add( datas.get(rInt));
121         }
122         
123         return rNodes;
124     }
125     
126     private double calculateCriterion(List<Cluster> clusters){
127         double res = 0;
128         for (Cluster c : clusters) {
129             res += c.criterion();
130         }
131         return res;
132     }
133     
134     
135     
136     public List<Cluster> partition(int k){
137         List<KMeansNode> centerNodes = findRandNodes(k);
138         List<Cluster> clusters = new ArrayList<KMeans.Cluster>();
139         for (KMeansNode c : centerNodes) {
140             clusters.add(new Cluster(c));
141         }
142         
143         double minDistance = Double.MAX_VALUE , distance;
144         Cluster minCluster = null;
145         double lastCriterion , criterion= Double.MAX_VALUE;
146         
147         do{
148             for (Cluster c : clusters) {
149                 c.clear();
150             }
151             lastCriterion = criterion;
152             
153             
154             for (KMeansNode n : datas) {
155                 minDistance = Double.MAX_VALUE;
156                 for (Cluster c : clusters) {
157                     distance = c.getCenter().distance(n);
158                     if( distance < minDistance ){
159                         minDistance = distance;
160                         minCluster = c;
161                     }
162                 }
163                 minCluster.addNode(n);
164             }
165             criterion = calculateCriterion(clusters);
166             
167         }while( criterion != lastCriterion);
168         
169         return clusters;
170         
171     }
172     
173     /**
174     *随机生成了1000个平面点,将其划分为4个簇(k=4)
175     *由于点的坐标都是随机生成的,在空间上分布均匀;
176     *从结果中可以看出K-means对于处理这种边界不分明的对象集合时并不能很好的进行区分;
177     *但是一般情况,经过处理还是会将整个平面均匀得划分成四个部分
178     **/
179     public static void main(String args[]){
180         Random r = new Random();
181         List<KMeansNode> nodes = new ArrayList<KMeans.KMeansNode>();
182         for (int i = 0; i < 1000; i++) {
183             nodes.add(new KMeansNode(r.nextDouble() * 1000, r.nextDouble() * 1000));
184         }
185         
186         KMeans kmeans = new KMeans(nodes);
187         List<Cluster> clusters = kmeans.partition(4);
188         for( Cluster c :  clusters){
189             c.print();
190         }
191         
192         
193          Frame frame = new Frame("K-means Test!");  
194          frame.addWindowListener(new WindowAdapter(){  
195              public void windowClosing(WindowEvent e) {  
196                  System.exit(0);  
197              }  
198          });  
199          frame.add(new KMeansCanvas(clusters),BorderLayout.CENTER);  
200          frame.pack();  
201          frame.setVisible(true);  
202     }
203     
204 }
205 
206 
207 /**
208  *  
209  * @author yahokuma
210  * @email Hazin.lu@gmail.com
211  */
212 class KMeansCanvas extends Canvas {  
213     public final static Paint[] PAINT_COLOR = {Color.BLUE,Color.RED, Color.ORANGE, Color.BLACK};
214     
215     private List<Cluster> clusters = null;
216     public KMeansCanvas(List<Cluster> clusters) {  
217         this.setBackground(Color.WHITE); 
218         this.clusters = clusters;
219     }  
220     @Override  
221     public void paint(Graphics g) {  
222         drawKarel(g);  
223     }  
224     @Override  
225     public Dimension getPreferredSize() {  
226         return new Dimension(1000,1000);  
227     }  
228     
229     
230     private void drawKarel(Graphics g) {  
231         
232     
233         Random r = new Random();
234         int i = 0 ;
235         for (Cluster c : clusters) {
236             Graphics2D g2d= (Graphics2D) g;  
237             g2d.setPaint(PAINT_COLOR[i++]);
238             for (KMeansNode n : c.getNodes()) {
239                 g2d.drawRect((int)n.getX(), (int)n.getY() , 2, 2);
240                 g2d.fillRect((int)n.getX(), (int)n.getY() , 2, 2);
241             }
242         }
243     }  
244     private static final long serialVersionUID = 1L;  
245 }  

 参考资料

http://en.wikipedia.org/wiki/K-means_clustering

http://blog.csdn.net/aladdina/article/details/4141177

http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006910.html

http://zh.wikipedia.org/wiki/K%E5%B9%B3%E5%9D%87%E7%AE%97%E6%B3%95