首页 > 代码库 > 【机器学习】手写数字识别算法

【机器学习】手写数字识别算法

1.数据准备

样本数据获取忽略,实际上就是将32*32的图片上数字格式化成一个向量,如下:

 

技术分享

本demo所有样本数据都是基于这种格式的

训练数据:将图片数据转成1*1024的数组,作为一个训练数据。

训练数据集:https://github.com/zimuqi/machine_Learning/tree/master/ch02/trainingDigits

测试数据集:https://github.com/zimuqi/machine_Learning/tree/master/ch02/testDigits

样本的文件名格式为:真实值_xxx.txt

转换代码:

1 def img2vector(filename):
2     returnVect=zeros((1,1024))
3     fr=open(filename)
4     for i in range(32):
5         lineStr=fr.readline()
6         for j in range(32):
7             returnVect[0,32*i+j]=int(lineStr[j])
8     return returnVect

 

2.测试算法

 1 def handwritingClassTest():
 2     hwLabels=[]    # 训练样本的标签数组
 3     traningFileList=listdir("trainingDigits")    # 获取所有的训练样本目录下的文件名
 4     m=len(traningFileList)
 5     traningMat=zeros((m,1024))    # 初始化训练样本数列
 6 
 7     for i in range(m):
 8         fileNameStr=traningFileList[i]    # 获取文件名
 9         fileStr=fileNameStr.split(".")[0]   
10         clasNumStr=int(fileStr.split("_")[0])    # 获取样本的实际值 放入标签数组
11         hwLabels.append(clasNumStr)
12         traningMat[i,:]=img2vector("trainingDigits/{}".format(fileNameStr))    # 将样本转化成1*1024的行放入训练样本数列
13 
14     testFileList=listdir("testDigits")    # 测试样本目录
15     error=0
16     mtest=len(testFileList)
17     for i in range(mtest):
18         fileNameStr=testFileList[i]
19         fileStr=fileNameStr.split(".")[0]
20         clasNumStr=int(fileStr.split("_")[0])
21         testMat=img2vector("testDigits/{}".format(fileNameStr))
22         res=classify(testMat,traningMat,hwLabels,3)     # 使用分类器分类
23         print "came bank with:{} the real anwser is:{}".format(clasNumStr,res)
24         if clasNumStr!=res:    # 对比与真实的结果 计算错误率
25             error+=1
26 
27     print "total:{}".format(mtest)
28     print "error:{}".format(error)
29     print "error:{}".format(float(error/mtest))

这个案例中 算法的识别率为:98.84%

classify是分类器 上上一篇文章中有写到,具体了解可以点击这里

 

【机器学习】手写数字识别算法