首页 > 代码库 > 【算法】Kmeans

【算法】Kmeans

package com.pachira.d;import java.util.ArrayList;import java.util.HashMap;import java.util.LinkedHashMap;import java.util.List;public class Kmeans {    /**     *  Kmeans聚类算法     *  基本思想:     *  以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。     *       *  过程描述:     *  输入:k, data[n], eps;     * (1) 选择k个初始中心点,例如c[0]=data[0],…c[k-1]=data[k-1];     * (2) 对于data[0]….data[n], 分别与c[0]…c[k-1]比较,假定与c[i]差值最少,就标记为i;     * (3) 对于所有标记为i点,重新计算c[i]={ 所有标记为i的data[j]之和}/标记为i的个数;     * (4) 重复(2)(3),直到所有c[i]值的变化小于给定阈值。     *       *  其他说明:     *  1、Kmeans的变种,其距离计算不是欧基米德距离,有可能会出现问题;     *  2、海量数据聚类,欧基米德距离要比余弦相似性好(Inderjit S.Dhillon James FAN 和 Yuqiang Guan论文)     *       *  data[n]的每个元素往往是一个向量;     *       */    /**     * 初始化聚类中心点      * @param k 中心点数     * @param data 带聚类的数据集     * @return 中心点集     */    public static double[] getPoints(int k, int[] data){        double[] points = new double[k];        for (int i = 0; i < k; i++) {            points[i] = (double)data[i];        }        return points;    }    /**     * 计算元素和每个中心点的距离,将该元素归为最小距离的中心点中     * @param points 中心点集     * @param data 元素集     * @return 聚类结果     */    public static LinkedHashMap<Double, List<Integer>> culcate(double[] points, int[] data){        LinkedHashMap<Double, List<Integer>> map = new LinkedHashMap<Double, List<Integer>>();        for (int i = 0; i < data.length; i++) {            //get one point to culcate the distance            int d = data[i];            double minDistance = Double.MAX_VALUE;            double key = -1;            for(int j = 0; j < points.length; j++){                //欧基米德距离                double tmp = Math.sqrt(Math.pow((d - points[j]), 2));                if(tmp < minDistance){                    minDistance = tmp;                    key = points[j];                }            }//            System.out.println(key);            if(map.containsKey(key)){                List<Integer> cus = map.get(key);                cus.add(d);            }else{                List<Integer> cus = new ArrayList<Integer>();                cus.add(d);                map.put(key, cus);            }        }        return map;    }    /**     * 重置中心点     * @param 聚类结果     * @return 重置后的中心点集     */    public static double[] resetPoint(HashMap<Double, List<Integer>> map){        double[] tmp = new double[map.keySet().size()];        int index = 0;        for(double key: map.keySet()){            List<Integer> val = map.get(key);            double total = 0;            for (int i = 0; i < val.size(); i++) {                total += val.get(i);            }            if(val.size() == 0){                tmp[index++] = key;            }else{                key = total / val.size();                tmp[index++] = key;            }        }        return tmp;    }    /**     * Kmeans     * @param data 待聚类元素集合     * @param k 类别数目(中心点数)     * @param eps 收敛阈值     * @return 聚类结果     */    public static LinkedHashMap<Double, List<Integer>> kmeans(int[] data, int k, double eps){        double[] points = getPoints(k, data);        LinkedHashMap<Double, List<Integer>> tmp = null;        while(true){            tmp = culcate(points, data);            show(tmp);            double[] tpoints = resetPoint(tmp);            boolean flag = true;            for (int i = 0; i < tpoints.length; i++) {                if(Math.abs(points[i] - tpoints[i]) > eps){                    flag = false;                    break;                }            }            if(flag)break;            points = tpoints;        }        return null;    }    /**     * 显示聚类结果     * @param map     */    public static void show(LinkedHashMap<Double, List<Integer>> map){        for (double key: map.keySet()) {            System.out.println(String.format("%.2f", key) + "\t" + map.get(key));        }        System.out.println("=================================");    }    public static void main(String[] args) {        int k = 10;        double eps = 0.001;        int[] data = http://www.mamicode.com/{45, 26, 45, 65, 49, 27, 44, 26, 40, 63, 35, 63, 47, 24, 65, 62, 38, 8, 43, 65, 34, 36, 80, 34, 62, 60, 54, 66, 86, 47, 73, 15, 40, 7, 12, 35, 88, 5, 9, 20, 94, 28, 70, 78, 87, 78, 43, 80, 25, 88, 46, 21, 52, 49, 36, 64, 52, 59, 24, 56, 54, 10, 81, 78, 66, 28, 53, 48, 2, 89, 44, 79, 16, 55, 27, 6, 0, 46, 76, 87, 30, 90, 40, 51, 98, 97, 55, 72, 32, 79, 61, 39, 74, 58, 55, 58, 32, 4, 76, 19};        kmeans(data, k, eps);    }}

 

【算法】Kmeans