首页 > 代码库 > 【Java】K-means算法Java实现以及图像分割

【Java】K-means算法Java实现以及图像分割

1.K-means算法简述以及代码原型

数据挖掘中一个重要算法是K-means,我这里就不做详细介绍。如果感兴趣的话可以移步陈皓的博客:   

 http://www.csdn.net/article/2012-07-03/2807073-k-means 讲得很好

    总的来讲,k-means聚类需要以下几个步骤:

         ①.初始化数据

         ②.计算初始的中心点,可以随机选择

         ③.计算每个点到每个聚类中心的距离,并且划分到距离最短的聚类中心簇中

         ④.计算每个聚类簇的平均值,这个均值作为新的聚类中心,重复步骤3

         ⑤.如果达到最大循环或者是聚类中心不再变化或者聚类中心变化幅度小于一定范围时,停止循环。

    恩,原理就是这样,超级简单。但是Java算法实现起来代码量并不小。这个代码也不算是完全自己写的啦,也有些借鉴。我把k-means实现封装在了一个类里面,这样就可以随时调用了呢。

      

import java.util.ArrayList;
import java.util.Random;

public class kmeans {
	private int k;//簇数
	private int m;//迭代次数
	private int dataSetLength;//数据集长度
	private ArrayList<double[]> dataSet;//数据集合
	private ArrayList<double[]> center;//中心链表
	private ArrayList<ArrayList<double[]>> cluster;//簇
	private ArrayList<Float> jc;//误差平方和,这个是用来计算中心聚点的移动哦
	private Random random;
	
	//设置原始数据集合
	public void setDataSet(ArrayList<double[]> dataSet){
		this.dataSet=dataSet;
	}
	//获得簇分组
	public  ArrayList<ArrayList<double[]>> getCluster(){
		return this.cluster;
	}
	//构造函数,传入要分的簇的数量
	public kmeans(int k){
		if(k<=0)
			k=1;
		this.k=k;
	}
	//初始化
	private void init(){
		m=0;
		random=new Random();
		if(dataSet==null||dataSet.size()==0)
			initDataSet();
		dataSetLength=dataSet.size();
		if(k>dataSetLength)
			k=dataSetLength;
		center=initCenters();
		cluster=initCluster();
		jc=new ArrayList<Float>();
	}
	//初始化数据集合
	private void initDataSet(){
		dataSet=new ArrayList<double[]>();
		double[][] dataSetArray=new double[][]{{8,2},{3,4},{2,5},{4,2},
				{7,3},{6,2},{4,7},{6,3},{5,3},{6,3},{6,9},
				{1,6},{3,9},{4,1},{8,6}};
		for(int i=0;i<dataSetArray.length;i++)
			dataSet.add(dataSetArray[i]);
	}
	//初始化中心链表,分成几簇就有几个中心
	private ArrayList<double[]> initCenters(){
		ArrayList<double[]> center= new ArrayList<double[]>();
		//生成一个随机数列,
		int[] randoms=new int[k];
		boolean flag;
		int temp=random.nextInt(dataSetLength);
		randoms[0]=temp;
		for(int i=1;i<k;i++){
			flag=true;
			while(flag){
				temp=random.nextInt(dataSetLength);
				int j=0;
				while(j<i){
					if(temp==randoms[j])
						break;
					j++;
				}
				if(j==i)
					flag=false;
			}
			randoms[i]=temp;
		}
		for(int i=0;i<k;i++)
			center.add(dataSet.get(randoms[i]));
		return center;
	}
	//初始化簇集合
	private ArrayList<ArrayList<double[]>> initCluster(){
		ArrayList<ArrayList<double[]>> cluster=
				new ArrayList<ArrayList<double[]>>();
		for(int i=0;i<k;i++)
			cluster.add(new ArrayList<double[]>());
		return cluster;
	}
	//计算距离
	private double distance(double[] element,double[] center){
		double distance=0.0f;
		double x=element[0]-center[0];
		double y=element[1]-center[1];
		double z=element[2]-center[2];
		double sum=x*x+y*y+z*z;
		distance=(double)Math.sqrt(sum);
		return distance;
	}
	//计算最短的距离
	private int minDistance(double[] distance){
		double minDistance=distance[0];
		int minLocation=0;
		for(int i=0;i<distance.length;i++){
			if(distance[i]<minDistance){
				minDistance=distance[i];
				minLocation=i;
			}else if(distance[i]==minDistance){
				if(random.nextInt(10)<5){
					minLocation=i;
				}
			}
		}
		return minLocation;
	}
	//每个点分类
	private void clusterSet(){
		double[] distance=new double[k];
		for(int i=0;i<dataSetLength;i++){
			//计算到每个中心店的距离
			for(int j=0;j<k;j++)
				distance[j]=distance(dataSet.get(i),center.get(j));
			//计算最短的距离
			int minLocation=minDistance(distance);
			//把他加到聚类里
			cluster.get(minLocation).add(dataSet.get(i));
		}
	}
	//计算新的中心
	private void setNewCenter(){
		for(int i=0;i<k;i++){
			int n=cluster.get(i).size();
			if(n!=0){
				double[] newcenter={0,0};
				for(int j=0;j<n;j++){
					newcenter[0]+=cluster.get(i).get(j)[0];
					newcenter[1]+=cluster.get(i).get(j)[1];
				}
				newcenter[0]=newcenter[0]/n;
				newcenter[1]=newcenter[1]/n;
				center.set(i, newcenter);
			}
		}
	}
	//求2点的误差平方
	private double errosquare(double[] element,double[] center){
		double x=element[0]-center[0];
		double y=element[1]-center[1];
		double errosquare=x*x+y*y;
		return errosquare;
	}
	//计算误差平方和准则函数
	private void countRule(){
		float jcf=0;
		for(int i=0;i<cluster.size();i++){
			for(int j=0;j<cluster.get(i).size();j++)
				jcf+=errosquare(cluster.get(i).get(j),center.get(i));
		jc.add(jcf);
		}
	}
	//核心算法
	private void Kmeans(){
		//初始化各种变量,随机选定中心,初始化聚类
		init();
		//开始循环
		while(true){
			//把每个点分到聚类中去
			clusterSet();
			//计算目标函数
			countRule();
			//检查误差变化,因为我规定的计算循环次数为50次,所以就不用计算这个啦,你要愿意用也可以,就是慢一点
			/*
			if(m!=0){
				if(jc.get(m)-jc.get(m-1)==0)
					break;
			}*/
			if(m>=50)
				break;
			//否则继续生成新的中心
			setNewCenter();
			m++;
			cluster.clear();
			cluster=initCluster();

		}
	}
    //只暴露一个接口给外部类
	public void execute(){
		System.out.print("start kmeans\n");
		Kmeans();
		System.out.print("kmeans end\n");
	}
        //用来在外面打印出来已经分好的聚类
	public void printDataArray(ArrayList<double[]> data,String dataArrayName){
		for(int i=0;i<data.size();i++){
			System.out.print("print:"+dataArrayName+"["+i+"]={"+data.get(i)[0]+","+data.get(i)[1]+"}\n");
		}
		System.out.print("==========================");
	}
}
  嗯,代码就是这样。注释写的很详细,也都能看得懂。下面我给一个测试例子。

import java.util.ArrayList;

public class Test {
	public static void main(String[] args){
		kmeans k=new kmeans(2);
		ArrayList<double[]> dataSet=new ArrayList<double[]>();
		dataSet.add(new double[]{2,2,2});
		dataSet.add(new double[]{1,2,2});
		dataSet.add(new double[]{2,1,2});
		dataSet.add(new double[]{1,3,2});
		dataSet.add(new double[]{3,1,2});
		dataSet.add(new double[]{-2,-2,-2});
		dataSet.add(new double[]{-1,-2,-2});
		dataSet.add(new double[]{-2,-1,-2});
		dataSet.add(new double[]{-3,-1,-2});
		dataSet.add(new double[]{-1,-3,-2});


		k.setDataSet(dataSet);
		k.execute();
		ArrayList<ArrayList<double[]>> cluster=k.getCluster();
		for(int i=0;i<cluster.size();i++){
			k.printDataArray(cluster.get(i), "cluster["+i+"]");
		}
	}
}
   没啥难度,也就是输入写初始数据,然后执行k-means在进行分类,最后打印一下。这个原型代码很粗糙,没有添加聚类个数以及循环次数的变量,这些需要自己动手啦。

2.k-means应用图像分割

  我们可以把k-means聚类放在图像分割上,也就是说把一个颜色的像素分为一类,然后再涂一个颜色。像这样。技术分享技术分享
左边就是聚类之前的,右边是聚类之后的,看起来还是满炫酷的。其实聚类算法也是很容易扩展到这里的。
有下面四个提示(因为是作业,我决定先不放马,不然到时候作业雷同我的学分就咖喱gaygay了):
   ①.上面的原型代码是对二维的数据进行分类,那我们也知道,一个颜色有RGB三种原色构成,也就是说我们只需要 在二维的基础上,加上一维数据就吼啦。很简单有木有,改变下数组结构,在距离计算编程三维欧式距离就吼。
   ②.Java有自带的图像处理类,所以读取数据敲击方便。我给一点代码提示哦
//读取指定目录的图片数据,并且写入数组,这个数据要继续处理
	private int[][] getImageData(String path){
		BufferedImage bi=null;
		try{
			bi=ImageIO.read(new File(path));
		}catch (IOException e){
			e.printStackTrace();
		}
		int width=bi.getWidth();
		int height=bi.getHeight();
		int [][] data=http://www.mamicode.com/new int[width][height];>
          //介货是用来输出图像的
<pre name="code" class="java">           private void ImagedataOut(String path){
		Color c0=new Color(255,0,0);
		Color c1=new Color(0,255,0);
		Color c2=new Color(0,0,255);
		Color c3=new Color(128,128,128);
		BufferedImage nbi=new BufferedImage(source.length,source[0].length,BufferedImage.TYPE_INT_RGB);
		for(int i=0;i<source.length;i++){
			for(int j=0;j<source[0].length;j++){
				if(source[i][j].group==0)
					nbi.setRGB(i, j, c0.getRGB());
				else if(source[i][j].group==1)
					nbi.setRGB(i, j, c1.getRGB());
				else if(source[i][j].group==2)
					nbi.setRGB(i, j, c2.getRGB());
				else if (source[i][j].group==3)
					nbi.setRGB(i, j, c3.getRGB());
				//Color c=new Color((int)center[source[i][j].group].r,
				//		(int)center[source[i][j].group].g,(int)center[source[i][j].group].b);
				//nbi.setRGB(i, j, c.getRGB());
			}
		}
		try{
			ImageIO.write(nbi, "jpg", new File(path));
		}catch(IOException e){
			e.printStackTrace();
			}
	}

    很舒爽,你问我dataItem是啥?等我交完作业我就告诉你。
    ③.有一点不同的是,注意数据格式。胖胖开始用的就是int类型,结果在计算新的聚类中心的时候溢出了呢。。。所幸鹏鹏改成了double,但是鹏鹏在计算距离的时候又写错了,最后还是机智的胖胖鹏解决掉了所有的bug。
    ④.注意读取图片的时候保护好数据的顺序,也就是用一个二维数组来存储,这样在写的时候就不用记录像素点的位置,输出的时候也很方便。
   就是这些。。。。等我作业交完就来一次完整的代码讲解!

【Java】K-means算法Java实现以及图像分割