首页 > 代码库 > TFboy养成记 CNN
TFboy养成记 CNN
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Sun Jul 2 19:59:43 2017 4 5 @author: Administrator 6 """ 7 import tensorflow as tf 8 import numpy as np 9 from tensorflow.examples.tutorials.mnist import input_data 10 11 def compute_accuracy(v_xs,v_ys): 12 global prediction 13 y_pre = sess.run(prediction,feed_dict = {xs:v_xs,ys:v_ys,keep_prob: 1}) 14 correct_prediction = tf.equal(tf.arg_max(y_pre,1),tf.arg_max(v_ys,1)) 15 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 16 result = sess.run(accuracy,feed_dict = {xs:v_xs,ys:v_ys}) 17 return result 18 19 def getWeights(shape): 20 Weights = tf.Variable(tf.truncated_normal(shape,stddev = 0.1)) 21 return Weights 22 def getBias(shape): 23 return tf.Variable(tf.constant(0.1,shape = shape)) 24 25 def conv2d(x,W): 26 return tf.nn.conv2d(x,W,strides= [1,1,1,1],padding = ‘SAME‘) 27 def maxpool(x): 28 return tf.nn.max_pool(x,ksize = [1,2,2,1], 29 strides = [1,2,2,1],padding=‘SAME‘) 30 31 mnist = input_data.read_data_sets("‘MNIST_data‘, one_hot=True") 32 xs = tf.placeholder(tf.float32,[None,28*28]) 33 ys = tf.placeholder(tf.float32,[None,10]) 34 35 keep_prob = tf.placeholder(tf.float32) 36 x_image = tf.reshape(xs,[-1,28,28,1]) 37 38 39 W_c1 = getWeights([5,5,1,32]) 40 b_c1 = getBias([32]) 41 h_c1 = tf.nn.relu(conv2d(x_image,W_c1)+b_c1) 42 h_p1 = maxpool(h_c1) 43 #这里注意的是maxpooling会将原来的28*28 变为14*14 44 45 W_c2 = getWeights([5,5,32,64]) 46 b_c2 = getBias([64]) 47 h_c2 = tf.nn.relu(conv2d(h_p1,W_c2)+b_c2) 48 h_p2 = maxpool(h_c2) 49 #经过这次maxpooling,这时候将变7*7*64 50 51 W_fc1 = getWeights([7*7*64,1024]) 52 b_fc1 = getBias([1024]) 53 h_p2_flat = tf.reshape(h_p2,[-1,7*7*64]) 54 h_fc1 = tf.nn.relu(tf.matmul(h_p2_flat,W_fc1)+b_fc1) 55 h_fc1_drop = tf.nn.drop(h_fc1,keep_prob) 56 57 W_fc2 = getWeights([1024,10]) 58 b_fc2 = getBias(10) 59 prediction = tf.nn.relu(tf.matmul(h_fc1_drop,W_fc2)+b_fc2) 60 61 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction), 62 reduction_indices= [1])) 63 train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy) 64 65 with tf.Session() as sess: 66 sess.run(tf.initialize_all_variables()) 67 for i in range(1000): 68 batch_xs,batch_ys = mnist.train.next_batch(100) 69 sess.run(train_step,feed_dict = {xs:batch_xs,ys:batch_ys,keep_prob:0.5}) 70 if i % 50: 71 print (compute_accuracy( mnist.test.images,mnist.test.labels)) 72
TFboy养成记 CNN
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。