首页 > 代码库 > kNN算法python实现和简单数字识别

kNN算法python实现和简单数字识别

kNN算法

算法优缺点:

  • 优点:精度高、对异常值不敏感、无输入数据假定
  • 缺点:时间复杂度和空间复杂度都很高
  • 适用数据范围:数值型和标称型

算法的思路:

KNN算法(全称K最近邻算法),算法的思想很简单,简单的说就是物以类聚,也就是说我们从一堆已知的训练集中找出k个与目标最靠近的,然后看他们中最多的分类是哪个,就以这个为依据分类。 

函数解析:

库函数

  • tile()

    tile(A,n)就是将A重复n次

a = np.array([0, 1, 2])np.tile(a, 2)array([0, 1, 2, 0, 1, 2])np.tile(a, (2, 2))array([[0, 1, 2, 0, 1, 2],[0, 1, 2, 0, 1, 2]])np.tile(a, (2, 1, 2))array([[[0, 1, 2, 0, 1, 2]],[[0, 1, 2, 0, 1, 2]]])b = np.array([[1, 2], [3, 4]])np.tile(b, 2)array([[1, 2, 1, 2],[3, 4, 3, 4]])np.tile(b, (2, 1))array([[1, 2],[3, 4],[1, 2],[3, 4]])`

自己实现的函数

createDataSet()生成测试数组
kNNclassify(inputX, dataSet, labels, k)分类函数

  • inputX 输入的参数
  • dataSet 训练集
  • labels 训练集的标号
  • k 最近邻的数目
    1.  1 #coding=utf-8 2 from numpy import * 3 import operator 4  5 def createDataSet(): 6     group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]]) 7     labels = [A,A,B,B] 8     return group,labels 9 #inputX表示输入向量(也就是我们要判断它属于哪一类的)10 #dataSet表示训练样本11 #label表示训练样本的标签12 #k是最近邻的参数,选最近k个13 def kNNclassify(inputX, dataSet, labels, k):14     dataSetSize = dataSet.shape[0]#计算有几个训练数据15     #开始计算欧几里得距离16     diffMat = tile(inputX, (dataSetSize,1)) - dataSet17     18     sqDiffMat = diffMat ** 219     sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加20     distances = sqDistances ** 0.521     #欧几里得距离计算完毕22     sortedDistance = distances.argsort()23     classCount = {}24     for i in xrange(k):25         voteLabel = labels[sortedDistance[i]]26         classCount[voteLabel] = classCount.get(voteLabel,0) + 127     res = max(classCount)28     return res29 30 def main():31     group,labels = createDataSet()32     t = kNNclassify([0,0],group,labels,3)33     print t34     35 if __name__==__main__:36     main()37             

       


kNN应用实例

手写识别系统的实现

数据集:

两个数据集:training和test。分类的标号在文件名中。像素32*32的。数据大概这个样子:

方法:

kNN的使用,不过这个距离算起来比较复杂(1024个特征),主要是要处理如何读取数据这个问题的,比较方面直接调用就可以了。

速度:

速度还是比较慢的,这里数据集是:training 2000+,test 900+(i5的CPU)

k=3的时候要32s+

  1.  1 #coding=utf-8 2 from numpy import * 3 import operator 4 import os 5 import time 6  7 def createDataSet(): 8     group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]]) 9     labels = [A,A,B,B]10     return group,labels11 #inputX表示输入向量(也就是我们要判断它属于哪一类的)12 #dataSet表示训练样本13 #label表示训练样本的标签14 #k是最近邻的参数,选最近k个15 def kNNclassify(inputX, dataSet, labels, k):16     dataSetSize = dataSet.shape[0]#计算有几个训练数据17     #开始计算欧几里得距离18     diffMat = tile(inputX, (dataSetSize,1)) - dataSet19     #diffMat = inputX.repeat(dataSetSize, aixs=1) - dataSet20     sqDiffMat = diffMat ** 221     sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加22     distances = sqDistances ** 0.523     #欧几里得距离计算完毕24     sortedDistance = distances.argsort()25     classCount = {}26     for i in xrange(k):27         voteLabel = labels[sortedDistance[i]]28         classCount[voteLabel] = classCount.get(voteLabel,0) + 129     res = max(classCount)30     return res31 32 def img2vec(filename):33     returnVec = zeros((1,1024))34     fr = open(filename)35     for i in range(32):36         lineStr = fr.readline()37         for j in range(32):38             returnVec[0,32*i+j] = int(lineStr[j])39     return returnVec40     41 def handwritingClassTest(trainingFloder,testFloder,K):42     hwLabels = []43     trainingFileList = os.listdir(trainingFloder)44     m = len(trainingFileList)45     trainingMat = zeros((m,1024))46     for i in range(m):47         fileName = trainingFileList[i]48         fileStr = fileName.split(.)[0]49         classNumStr = int(fileStr.split(_)[0])50         hwLabels.append(classNumStr)51         trainingMat[i,:] = img2vec(trainingFloder+/+fileName)52     testFileList = os.listdir(testFloder)53     errorCount = 0.054     mTest = len(testFileList)55     for i in range(mTest):56         fileName = testFileList[i]57         fileStr = fileName.split(.)[0]58         classNumStr = int(fileStr.split(_)[0])59         vectorUnderTest = img2vec(testFloder+/+fileName)60         classifierResult = kNNclassify(vectorUnderTest, trainingMat, hwLabels, K)61         #print classifierResult,‘ ‘,classNumStr62         if classifierResult != classNumStr:63             errorCount +=164     print tatal error ,errorCount65     print error rate,errorCount/mTest66         67 def main():68     t1 = time.clock()69     handwritingClassTest(trainingDigits,testDigits,3)70     t2 = time.clock()71     print execute ,t2-t172 if __name__==__main__:73     main()74             

     


 



来自为知笔记(Wiz)



kNN算法python实现和简单数字识别