首页 > 代码库 > K-近邻算法(KNN)

K-近邻算法(KNN)

概述

简单地说,K-近邻算法(K-Nearest-Neighbors Classification)采用测量不同特征值之间的距离方法进行分类。

  • 优点:精度高、对异常值不敏感、无数据输入假定
  • 缺点:计算复杂度高、空间复杂度高
  • 使用数据范围:数值型和标称型
  • 工作原理:要确定测试样本属于哪一类,就寻找所有训练样本中与该测试样本“距离”最近的前K个样本,然后看这K个样本大部分属于哪一类,那么就认为这个测试样本也属于哪一类。简单的说就是让最相似的K个样本来投票决定。
  • 这里所说的距离,一般最常用的就是多维空间的欧式距离。这里的维度指特征维度,即样本有几个特征就属于几维。

近邻的距离度量表示法

  • 欧氏距离
  • 曼哈顿距离
  • 切比雪夫距离
  • 闵可夫斯基距离
  • 马氏距离
  • 巴氏距离

K值的选择

不要小看了这个K值选择问题,因为它对K近邻算法的结果会产生重大影响。如李航博士的一书「统计学习方法」上所说:

  • 如果选择较小的K值,就相当于用较小的领域中的训练实例进行预测,“学习”近似误差会减小,只有与输入实例较近或相似的训练实例才会对预测结果起作用,与此同时带来的问题是“学习”的估计误差会增大,换句话说,K值的减小就意味着整体模型变得复杂,容易发生过拟合;
  • 如果选择较大的K值,就相当于用较大领域中的训练实例进行预测,其优点是可以减少学习的估计误差,但缺点是学习的近似误差会增大。这时候,与输入实例较远(不相似的)训练实例也会对预测器作用,使预测发生错误,且K值的增大就意味着整体的模型变得简单。K=N,则完全不足取,因为此时无论输入实例是什么,都只是简单的预测它属于在训练实例中最多的累,模型过于简单,忽略了训练实例中大量有用信息。
  • 在实际应用中,K值一般取一个比较小的数值,例如采用交叉验证法(简单来说,就是一部分样本做训练集,一部分做测试集)来选择最优的K值。

KNN算法的优缺点

优点

  • 理论成熟,思想简单,即可以用来做分类也可以用来做回归;
  • 可用于非线性分类;
  • 训练时间复杂度为O(n);
  • 对数据没有假设,准确度高,对outlier不敏感。

缺点

  • 计算量大;
  • 样本不平衡问题(即有些类别的样本数量很多,而其他样本数量很少);
  • 需要大量内存。

具体应用案

案例一

这里有4组数据,且(1,1.1)和(1,1)定义为A类,(0,0)和(0,0.1)为B类。下面对(0.5,0.5)进行分类,判断其为A、B哪一类。
算法过程:
(1)计算已知类别数据集中的点与当前点之间的距离;
(2)按照距离递增次序排序;
(3)选取与当前点距离最小的K个点;
(4)确定前K个点所在类别的出现频率;
(5)返回前K个点出现频率最高的类别作为当前点的预测分类。

具体代码:

from numpy import *
import operator
from os import listdir

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5 #计算距离
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1#选择距离最小的k个点
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]#排序

def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = [‘A‘,‘A‘,‘B‘,‘B‘]
    return group, labels

group,labels=createDataSet()
classify0([0.5,0.5],group,labels,3)

输出:

‘B‘

案例二:

  • 数据集介绍
    iris虹膜数据集,是一种植物,通常被用来作为机器学习的案例。该数据集中有150个实例,每个实例包含了花的四个维度,分别为萼片长度、宽度,花瓣长度、宽度。该花共有3种分类:iris setosa、iris versicolor、iris virginica。
  • 利用python的机器学习库sklearn:SKLearnExample.py
    该库中包含了很多机器学习的算法,如KNN。接下来我们介绍如何调用KNN的算法:
from sklearn import neighbors#导入包含KNN算法模块
from sklearn import datasets#导入数据集模块

knn = neighbors.KNeighborsClassifier()#调用分类器方法

iris = datasets.load_iris()#导入数据

print iris#分类规则:iris setosa、iris versicolor、iris virginica分别为用0、1、2表示

knn.fit(iris.data, iris.target)#建立模型

predictedLabel = knn.predict([[0.1, 0.2, 0.3, 0.4]])#预测新的对象属于哪一类

print predictedLabel

结果:

[0]

以上就是如何使用python里面的sklearn库来进行KNN算法的调用。
接下来介绍适合通过自己写程序来实现KNN的算法。

案例三

基本步骤:

  • 加载数据集;
  • 计算距离;
  • 返回最近的K个邻居;
  • 用“少数服从多数”的原则进行归类划分;
  • 测算预测值的准确率。
import csv#用于读取数据
import random
import math
import operator

#导入数据
def loadDataset(filename, split, trainingSet=[] , testSet=[]):#加载数据集
    with open(filename, ‘rb‘) as csvfile:#将filename导入为csv格式的文件。(‘rb’读写模式)
        lines = csv.reader(csvfile)#读取文件行数
        dataset = list(lines)#转化为list的数据结构
        for x in range(len(dataset)-1):
            for y in range(4):
                dataset[x][y] = float(dataset[x][y])
            if random.random() < split:#将数据分为两部分,分别加到训练集和测试集中
                trainingSet.append(dataset[x])
            else:
                testSet.append(dataset[x])

#计算距离
def euclideanDistance(instance1, instance2, length):#传入两个实例及维度
    distance = 0
    for x in range(length):#所有维度距离的平方和
        distance += pow((instance1[x] - instance2[x]), 2)
    return math.sqrt(distance)

#返回最近的K个label
def getNeighbors(trainingSet, testInstance, k):#testInstance测试集中的一个数据
    distances = []#定义一个空的容器
    length = len(testInstance)-1
    for x in range(len(trainingSet)):#计算测试集(一个)到每一个训练集的距离
        dist = euclideanDistance(testInstance, trainingSet[x], length)
        distances.append((trainingSet[x], dist))#将所有的距离放在定义好的空容器diastances
    distances.sort(key=operator.itemgetter(1))#距离从小到大排序
    neighbors = []
    for x in range(k):
        neighbors.append(distances[x][0])
    return neighbors#返回最近的k个邻居

#对邻居进行分类,找出类别最多的
def getResponse(neighbors):
    classVotes = {}
    for x in range(len(neighbors)):
        response = neighbors[x][-1]
        if response in classVotes:
            classVotes[response] += 1
        else:
            classVotes[response] = 1
    sortedVotes = sorted(classVotes.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedVotes[0][0]

#计算正确率
def getAccuracy(testSet, predictions):
    correct = 0
    for x in range(len(testSet)):
        if testSet[x][-1] == predictions[x]:
            correct += 1
    return (correct/float(len(testSet))) * 100.0

def main():
    # prepare data
    trainingSet=[]#创建两个空的测试集和训练集
    testSet=[]
    split = 0.67#将2/3的数据划分为训练集,1/3划分为测试集
    loadDataset(r‘/home/duxu/exercise/iris.csv‘, split, trainingSet, testSet)
    print ‘Train set: ‘ + repr(len(trainingSet))
    print ‘Test set: ‘ + repr(len(testSet))
    # generate predictions
    predictions=[]
    k = 3
    for x in range(len(testSet)):
        neighbors = getNeighbors(trainingSet, testSet[x], k)
        result = getResponse(neighbors)
        predictions.append(result)
        print(‘> predicted=‘ + repr(result) + ‘, actual=‘ + repr(testSet[x][-1]))
    accuracy = getAccuracy(testSet, predictions)
    print(‘Accuracy: ‘ + repr(accuracy) + ‘%‘)
main()

输出:

Train set: 100
Test set: 49
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-setosa‘, actual=‘Iris-setosa‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-virginica‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-versicolor‘, actual=‘Iris-versicolor‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
> predicted=‘Iris-virginica‘, actual=‘Iris-virginica‘
Accuracy: 97.95918367346938%

从结果看出:训练集有100个实例,测试集有50个实例;接着打印出来了测试集的预测结果和实际分类;最后计算出了预测的正确率约为98%,比较理想。

<script type="text/javascript"> $(function () { $(‘pre.prettyprint code‘).each(function () { var lines = $(this).text().split(‘\n‘).length; var $numbering = $(‘
    ‘).addClass(‘pre-numbering‘).hide(); $(this).addClass(‘has-numbering‘).parent().append($numbering); for (i = 1; i <= lines; i++) { $numbering.append($(‘
  • ‘).text(i)); }; $numbering.fadeIn(1700); }); }); </script>

    K-近邻算法(KNN)