首页 > 代码库 > 『TensorFlow』第四弹_classification分类学习_拨云见日

『TensorFlow』第四弹_classification分类学习_拨云见日

本节是以mnist手写数字识别为例,实现了分类网络(又双叒叕见mnist......) ,不过这个分类器实现很底层,能学到不少tensorflow的用法习惯,语言不好描述,值得注意的地方使用大箭头标注了(原来tensorflow的向前传播的调用是这么实现的...之前自己好蠢):

 1 import tensorflow as tf
 2 from tensorflow.examples.tutorials.mnist import input_data
 3 
 4 ‘‘‘数据下载‘‘‘
 5 # one_hot标签:[0,0...1,0...]
 6 mnist = input_data.read_data_sets(Mnist_data,one_hot=True)
 7 
 8 ‘‘‘生成层函数‘‘‘
 9 def add_layer(input,in_size,out_size,n_layer=layer,activation_function=None):
10     layer_name = layer%s % n_layer
11     with tf.name_scope(layer_name):
12         with tf.name_scope(weights):
13             Weights = tf.Variable(tf.random_normal([in_size,out_size]),name=W)
14             tf.summary.histogram(layer_name+/weights,Weights)
15         with tf.name_scope(biases):
16             biases = tf.Variable(tf.zeros([1,out_size]) + 0.1)
17             tf.summary.histogram(layer_name + /biases, biases)
18         with tf.name_scope(Wx_plus_b):
19             # [in]*[[out]*in]+[out]
20             Wx_plus_b = tf.matmul(input,Weights) + biases
21         if activation_function is None:
22             outputs = Wx_plus_b
23         else:
24             outputs = activation_function(Wx_plus_b)
25         tf.summary.histogram(layer_name + /outputs, outputs)
26         return outputs
27 
28 ‘‘‘准确率‘‘‘
29 def compute_accuracy(v_xs,v_ys):
30     global prediction
31     y_pre = sess.run(prediction,feed_dict={xs:v_xs}) #<------------------------
32     # tf.equal()对比预测值的索引和实际label的索引是否一样,一样返回True,不一样返回False
33     correct_prediction = tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
34     # 将pred即True或False转换为1或0,并对所有的判断结果求均值
35     accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
36     result = sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys}) #<------------------------
37     return result
38 
39 
40 ‘‘‘占位符‘‘‘
41 xs = tf.placeholder(tf.float32,[None,784])
42 ys = tf.placeholder(tf.float32,[None,10])
43 
44 ‘‘‘添加层‘‘‘
45 prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)  # softmax:分类用激活函数
46 
47 ‘‘‘计算loss‘‘‘
48 # 交叉熵损失函数,参数分别为预测值pred和实际label值y,reduce_mean为求batch平均loss
49 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) #<------------------------
50 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 梯度下降优化器
51 
52 ‘‘‘会话生成‘‘‘
53 with tf.Session() as sess:
54     sess.run(tf.global_variables_initializer())
55     for i in range(1000):
56         batch_xs,batch_ys = mnist.train.next_batch(100) # 逐个batch的去取数据
57         sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
58         if i % 50 == 0:
59             print(compute_accuracy(mnist.test.images,mnist.test.labels))

调试&深入理解:

添加print(y_pre,v_ys):

1 ‘‘‘准确率‘‘‘
2 def compute_accuracy(v_xs,v_ys):
3     global prediction
4     y_pre = sess.run(prediction,feed_dict={xs:v_xs}) 
5     correct_prediction = tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
6     accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
7     result = sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys}) 
8     print(y_pre,v_ys) #<------------------------
9     return result

 

 

 可以看出来网络传出batch*10的矩阵:

[[  4.28935193e-10   1.35160810e-14   1.13027321e-08 ...,   9.99981165e-01
    2.55498271e-08   8.07754896e-06]
 [  1.76165886e-05   4.70404247e-16   9.99964714e-01 ...,   1.73039527e-16
    1.82767259e-08   7.25102686e-11]
 [  2.34255353e-08   9.98970389e-01   1.14979991e-06 ...,   1.53422152e-04
    5.52863348e-04   1.47934190e-06]
 ..., 
 [  5.98009420e-11   2.45282070e-08   4.45539960e-08 ...,   6.09112252e-03
    4.90780873e-03   1.23983808e-01]
 [  5.55463521e-05   6.45160526e-02   2.27339001e-06 ...,   1.37123177e-04
    6.74951375e-01   1.96006240e-05]
 [  2.16443485e-08   7.36915445e-15   4.69437926e-11 ...,   3.93592785e-17
    3.30396951e-13   3.78093485e-15]] [[ 0.  0.  0. ...,  1.  0.  0.]
 [ 0.  0.  1. ...,  0.  0.  0.]
 [ 0.  1.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]

 一直很好奇tf.argmax参数的1是什么,添加print(tf.argmax(v_ys,1).eval())后输出:

[7 2 1 ..., 4 5 6]

 

 查看源代码

1 def argmax(input, axis=None, name=None, dimension=None):

 

也就是batch*1的预测值,1指维度,渍渍渍,还以为会是个很高深的函数,感觉好失望......

 

 

 

整体输出如下:

/home/hellcat/anaconda2/envs/python3_6/bin/python /home/hellcat/PycharmProjects/data_analysis/TensorFlow/classification.py
Extracting Mnist_data/train-images-idx3-ubyte.gz
Extracting Mnist_data/train-labels-idx1-ubyte.gz
Extracting Mnist_data/t10k-images-idx3-ubyte.gz
Extracting Mnist_data/t10k-labels-idx1-ubyte.gz
2017-05-19 10:46:49.230735: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn‘t compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-05-19 10:46:49.230758: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn‘t compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-05-19 10:46:49.230764: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn‘t compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
0.0894
0.6471
0.7355
0.7808
0.8029
0.8222
0.8311
0.8393
0.8458
0.8493
0.8508
0.8589
0.8531
0.8592
0.864
0.8648
0.8694
0.8696
0.8717
0.8742

Process finished with exit code 0

 

 

总结:

虽然之前我也成功的运行过甚至抄写过更为复杂的生经网络结构,但是自己实际的调试了一个这种比较完备的神经网络,核实了每一步的输出之后,对于框架整体还是有了更为深刻的理解,很多地方都豁然开朗,拨云见日。

『TensorFlow』第四弹_classification分类学习_拨云见日