首页 > 代码库 > K-最近邻算法

K-最近邻算法

介绍

KNN算法全名为k-Nearest Neighbor,就是K最近邻的意思。KNN也是一种分类算法。但是与之前说的决策树分类算法相比,这个算法算是最简单的一个了。算法的主要过程为:

1、给定一个训练集数据,每个训练集数据都是已经分好类的。
2、设定一个初始的测试数据a,计算a到训练集所有数据的欧几里得距离,并排序。                       

3、选出训练集中离a距离最近的K个训练集数据。

4、比较k个训练集数据,选出里面出现最多的分类类型,此分类类型即为最终测试数据a的分类。

下面百度百科上的一张简图:

技术分享

KNN算法实现

首先测试数据需要2块,1个是训练集数据,就是已经分好类的数据,比如上图中的非绿色的点。还有一个是测试数据,就是上面的绿点,当然这里的测试数据不会是一个,而是一组。这里的数据与数据之间的距离用数据的特征向量做计算,特征向量可以是多维度的。通过计算特征向量与特征向量之间的欧几里得距离来推算相似度。定义训练集数据trainInput.txt:

a 1 2 3 4 5 
b 5 4 3 2 1 
c 3 3 3 3 3 
d -3 -3 -3 -3 -3 
a 1 2 3 4 4 
b 4 4 3 2 1 
c 3 3 3 2 4 
d 0 0 1 1 -2 
待测试数据testInput,只有特征向量值:

1 2 3 2 4 
2 3 4 2 1 
8 7 2 3 5 
-3 -2 2 4 0 
-4 -4 -4 -4 -4 
1 2 3 4 4 
4 4 3 2 1 
3 3 3 2 4 
0 0 1 1 -2 
下面是主程序:

package DataMing_KNN;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;

import org.apache.activemq.filter.ComparisonExpression;

/**
 * k最近邻算法工具类
 * 
 * @author lyq
 * 
 */
public class KNNTool {
	// 为4个类别设置权重,默认权重比一致
	public int[] classWeightArray = new int[] { 1, 1, 1, 1 };
	// 测试数据地址
	private String testDataPath;
	// 训练集数据地址
	private String trainDataPath;
	// 分类的不同类型
	private ArrayList<String> classTypes;
	// 结果数据
	private ArrayList<Sample> resultSamples;
	// 训练集数据列表容器
	private ArrayList<Sample> trainSamples;
	// 训练集数据
	private String[][] trainData;
	// 测试集数据
	private String[][] testData;

	public KNNTool(String trainDataPath, String testDataPath) {
		this.trainDataPath = trainDataPath;
		this.testDataPath = testDataPath;
		readDataFormFile();
	}

	/**
	 * 从文件中阅读测试数和训练数据集
	 */
	private void readDataFormFile() {
		ArrayList<String[]> tempArray;

		tempArray = fileDataToArray(trainDataPath);
		trainData = http://www.mamicode.com/new String[tempArray.size()][];>Sample样本数据类:

package DataMing_KNN;

/**
 * 样本数据类
 * 
 * @author lyq
 * 
 */
public class Sample implements Comparable<Sample>{
	// 样本数据的分类名称
	private String className;
	// 样本数据的特征向量
	private String[] features;
	//测试样本之间的间距值,以此做排序
	private Integer distance;
	
	public Sample(String[] features){
		this.features = features;
	}
	
	public Sample(String className, String[] features){
		this.className = className;
		this.features = features;
	}

	public String getClassName() {
		return className;
	}

	public void setClassName(String className) {
		this.className = className;
	}

	public String[] getFeatures() {
		return features;
	}

	public void setFeatures(String[] features) {
		this.features = features;
	}

	public Integer getDistance() {
		return distance;
	}

	public void setDistance(int distance) {
		this.distance = distance;
	}

	@Override
	public int compareTo(Sample o) {
		// TODO Auto-generated method stub
		return this.getDistance().compareTo(o.getDistance());
	}
	
}
测试场景类:

/**
 * k最近邻算法场景类型
 * @author lyq
 *
 */
public class Client {
	public static void main(String[] args){
		String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";
		String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testinput.txt";
		
		KNNTool tool = new KNNTool(trainDataPath, testDataPath);
		tool.knnCompute(3);
		
	}
	


}
执行的结果为:

测试数据特征:1 2 3 2 4 分类:a
测试数据特征:2 3 4 2 1 分类:c
测试数据特征:8 7 2 3 5 分类:b
测试数据特征:-3 -2 2 4 0 分类:a
测试数据特征:-4 -4 -4 -4 -4 分类:d
测试数据特征:1 2 3 4 4 分类:a
测试数据特征:4 4 3 2 1 分类:b
测试数据特征:3 3 3 2 4 分类:c
测试数据特征:0 0 1 1 -2 分类:d

程序的输出结果如上所示,如果不相信的话可以自己动手计算进行验证。

KNN算法的注意点:

1、knn算法的训练集数据必须要相对公平,各个类型的数据数量应该是平均的,否则当A数据由1000个B数据由100个,到时无论如何A数据的样本还是占优的。

2、knn算法如果纯粹凭借分类的多少做判断,还是可以继续优化的,比如近的数据的权重可以设大,最后根据所有的类型权重和进行比较,而不是单纯的凭借数量。

3、knn算法的缺点是计算量大,这个从程序中也应该看得出来,里面每个测试数据都要计算到所有的训练集数据之间的欧式距离,时间复杂度就已经为O(n*n),如果真实数据的n非常大,这个算法的开销的确态度,所以KNN不适合大规模数据量的分类。

KNN算法编码时遇到的困难:

按理来说这么简单的KNN算法本应该是没有多少的难度,但是在多欧式距离的排序上被深深的坑了一段时间,本人起初用Collections.sort(list)的方式进行按距离排序,也把Sample类实现了Compareable接口,但是排序就是不变,最后才知道,distance的int类型要改为Integer引用类型,在compareTo重载方法中调用distance的.CompareTo()方法就成功了,这个小细节平时没注意,难道属性的比较最终一定要调用到引用类型的compareTo()方法?这个小问题竟然花费了我一段时间,最后仔细的比较了一下网上的例子最后才发现......

K-最近邻算法