第6章 识别手写字体

第6章 识别手写字体















    MNIST网站: http://yann.lecun.com/exdb/mnist/










package com.luoxq.ann;

import java.util.Arrays;
import java.util.Random;

public class MnistTest {

public static void main(String... args) {
int[] shape = {28 * 28, 10};
NeuralNetwork nn = new NeuralNetwork(shape);
Mnist mnist = new Mnist();
System.out.println("Shape: " + Arrays.toString(shape));
System.out.println("Initial correct rate: " + test(nn, mnist));
int epochs = 1000;
double rate = 0.5;
System.out.println("Learning rate: " + rate);
long time = System.currentTimeMillis();
Mnist.Data[] data = http://www.mamicode.com/mnist.getTrainingSlice(0, 60000);
for (int epoch = 1; epoch <= epochs; epoch++) {
for (int sample = 0; sample < data.length; sample++) {
nn.train(data[sample].input, data[sample].output, rate);
long seconds = (System.currentTimeMillis() - time) / 1000;
System.out.println(epoch + ", " + seconds + ", " +
test(nn, mnist));

private static int test(NeuralNetwork nn, Mnist mnist) {
int correct = 0;
Mnist.Data[] data = http://www.mamicode.com/mnist.getTestSlice(0, 10000);
for (int sample = 0; sample < data.length; sample++) {
if (max(nn.f(data[sample].input)) == data[sample].label) {
return correct;

private static int max(double[] d) {
double max = d[0];
int idx = 0;
for (int i = 1; i < d.length; i++) {
if (max < d[i]) {
max = d[i];
idx = i;
return idx;


    我们先用一个10个神经元的单层神经网络试试看。结果出乎意外的好。我们很快就获得了超过90%的正确率。单层网络几乎就是对每个数字的像素分布做简单统计。能获得如此高的识别率,还是很神奇的。 在达到90%之后再训练已经效果不大,达到饱和了。我们必须换一种方法来做了。 


Shape: [784, 10]

Initial correct rate: 1373

Learning rate: 0.5



1, 4, 6429

2, 8, 7663

3, 13, 8963

4, 17, 9029

5, 22, 9016

6, 27, 9062

7, 31, 9063

8, 36, 9066

9, 41, 9072

10, 45, 9057

11, 50, 9084

12, 55, 9072

13, 61, 9062

14, 66, 9050

15, 70, 9077

16, 75, 9052

17, 79, 9068

18, 84, 9055

19, 88, 9060

20, 93, 9064





Shape: [784, 50, 10]

Initial correct rate: 944

Learning rate: 1.0



1, 24, 7459

2, 59, 9232

3, 99, 9313

4, 131, 9379

5, 153, 9412

6, 176, 9443

7, 200, 9412

8, 226, 9447

9, 248, 9462

10, 269, 9461

11, 290, 9465

12, 314, 9493

13, 343, 9477

14, 368, 9499

15, 392, 9502

16, 420, 9509

17, 447, 9482

18, 472, 9508

19, 496, 9491

20, 518, 9536

21, 545, 9523

22, 569, 9549

23, 593, 9527

24, 618, 9527

25, 643, 9520

26, 667, 9513

27, 689, 9507

28, 712, 9527

29, 734, 9501

30, 758, 9521

31, 781, 9508

32, 804, 9534

33, 827, 9534

34, 850, 9550

35, 875, 9569








package com.luoxq.ann;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;

* Created by luoxq on 17/4/15.
public class Mnist {

static class Data {
public byte[] data;
public int label;
public double[] input;
public double[] output;

public static void main(String... args) throws Exception {
Mnist mnist = new Mnist();
System.out.println("Data loaded.");
Random rand = new Random(System.nanoTime());
for (int i = 0; i < 20; i++) {
int idx = rand.nextInt(60000);
Data d = mnist.getTrainingData(idx);
BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
img.setRGB(x, y, toRgb(d.data[y * 28 + x]));
File output = new File(i + "_" + d.label + ".png");
if (!output.exists()) {
ImageIO.write(img, "png", output);

static int toRgb(byte bb) {
int b = (255 - (0xff & bb));
return (b << 16 | b << 8 | b) & 0xffffff;

Data[] trainingSet;
Data[] testSet;

public void shuffle() {
Random rand = new Random();
for (int i = 0; i < trainingSet.length; i++) {
int x = rand.nextInt(trainingSet.length);
Data d = trainingSet[i];
trainingSet[i] = trainingSet[x];
trainingSet[x] = trainingSet[i];

public Data getTrainingData(int idx) {
return trainingSet[idx];

public Data[] getTrainingSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(trainingSet, start, ret, 0, count);
return ret;

public Data getTestData(int idx) {
return testSet[idx];

public Data[] getTestSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(testSet, start, ret, 0, count);
return ret;

public void load() {
trainingSet = load("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz");
testSet = load("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz");
if (trainingSet.length != 60000 || testSet.length != 10000) {
throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);

private Data[] load(String imgFile, String labelFile) {
byte[][] images = loadImages(imgFile);
byte[] labels = loadLabels(labelFile);
if (images.length != labels.length) {
throw new RuntimeException("Images and label doesn‘t match: " + imgFile + " " + labelFile);
int len = images.length;
Data[] data = http://www.mamicode.com/new Data[len];
for (int i = 0; i < len; i++) {
data[i] = new Data();
data[i].data = http://www.mamicode.com/images[i];
data[i].label = 0xff & labels[i];
data[i].input = dataToInput(images[i]);
data[i].output = labelToOutput(labels[i]);
return data;

private double[] labelToOutput(byte label) {
double[] o = new double[10];
o[label] = 1;
return o;

private double[] dataToInput(byte[] b) {
double[] d = new double[b.length];
for (int i = 0; i < b.length; i++) {
d[i] = (b[i] & 0xff) / 255.0;
return d;

private byte[][] loadImages(String imgFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {
int magic = in.readInt();
if (magic != 0x00000803) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
int count = in.readInt();
int rows = in.readInt();
int cols = in.readInt();
if (rows != 28 || cols != 28) {
throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);
byte[][] data = http://www.mamicode.com/new byte[count][rows * cols];
for (int i = 0; i < count; i++) {
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + imgFile, ex);

private byte[] loadLabels(String labelFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {
int magic = in.readInt();
if (magic != 0x00000801) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
int count = in.readInt();
byte[] data = http://www.mamicode.com/new byte[count];
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + labelFile, ex);






