首页 > 代码库 > 13 Tensorflow机制(翻译)

13 Tensorflow机制(翻译)

    代码: tensorflow/examples/tutorials/mnist/

    本文的目的是来展示如何使用Tensorflow训练和评估手写数字识别问题。本文的观众是那些对使用Tensorflow进行机器学习感兴趣的人。

    本文的目的并不是讲解机器学习。

    请确认您已经安装了Tensorflow。

 

    教程文件

文件 作用
mnist.py 用来创建一个完全连接的MNIST模型。
fully_connected_feed.py 使用下载的数据集训练模型。

    运行fully_connected_feed.py文件开始训练。

python fully_connected_feed.py

 

    准备数据

    MNIST是机器学习的一个经典问题。这个问题是识别28*28像素图片上的数字,从0到9。

技术分享

    更多信息,请参考Yann LeCun‘s MNIST page 或者 Chris Olah‘s visualizations of MNIST。

 

    数据下载

    在run_training()方法之前,input_data.read_data_sets()方法可以让数据下载到本机训练文件夹,解压数据并返回一个DataSet实例。

data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)

    注意:fake_data是用来进行单元测试的,读者可以忽略。

数据集 作用
data_sets.train 55000图片和标签,用来训练。
data_sets.validation 5000图片和标签,用来在迭代中校验模型准确度。
data_sets.test 10000图片和标签,用来测试训练模型准确度。

   

    输入和占位符

    placeholder_inputs()函数创建两个tf.placeholder,用来定义输入的形状,包括fetch_size。

images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))

    在训练循环中,图片和标签数据集会被切分成batch_size大小,跟占位符匹配,然后通过feed_dict参数传递到sess.run()方法中。

 

    创建图

    创建占位符后,mnist.py文件中会通过三个步骤来创建图:inference(), loss(), 和training()。

  1. inference() - 运行网络来进行预测。
  2. loss() - 用来计算损失值。
  3. training() - 计算梯度。

技术分享

    inference层

    inference()函数创建图,返回预测结果。

    它把图片占位符当作输入,并在上面构建一对完全连接的层,使用ReLU激活后,连接一个10个节点的线性层。

    每一层都位于tf.name_scope声明的命名空间中。

with tf.name_scope(hidden1):

    在该命名空间中,权重和偏置会产生tf.Variable实例,并具有所需的形状。

weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units], stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), name=weights)
biases = tf.Variable(tf.zeros([hidden1_units]), name=biases)

 

    待续...

    原文:《TensorFlow Mechanics 101》:https://www.tensorflow.org/get_started/mnist/mechanics

 

   

 

13 Tensorflow机制(翻译)