首页 > 代码库 > tensorflow bilstm官方示例
tensorflow bilstm官方示例
1 ‘‘‘ 2 A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library. 3 This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/) 4 Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf 5 6 Author: Aymeric Damien 7 Project: https://github.com/aymericdamien/TensorFlow-Examples/ 8 ‘‘‘ 9 10 from __future__ import print_function 11 12 import tensorflow as tf 13 from tensorflow.contrib import rnn 14 import numpy as np 15 16 # Import MNIST data 17 from tensorflow.examples.tutorials.mnist import input_data 18 mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) 19 20 ‘‘‘ 21 To classify images using a bidirectional recurrent neural network, we consider 22 every image row as a sequence of pixels. Because MNIST image shape is 28*28px, 23 we will then handle 28 sequences of 28 steps for every sample. 24 ‘‘‘ 25 26 # Parameters 27 learning_rate = 0.001 28 29 # 可以理解为,训练时总共用的样本数 30 training_iters = 100000 31 32 # 每次训练的样本大小 33 batch_size = 128 34 35 # 这个是用来显示的。 36 display_step = 10 37 38 # Network Parameters 39 # n_steps*n_input其实就是那张图 把每一行拆到每个time step上。 40 n_input = 28 # MNIST data input (img shape: 28*28) 41 n_steps = 28 # timesteps 42 43 # 隐藏层大小 44 n_hidden = 128 # hidden layer num of features 45 n_classes = 10 # MNIST total classes (0-9 digits) 46 47 # tf Graph input 48 # [None, n_steps, n_input]这个None表示这一维不确定大小 49 x = tf.placeholder("float", [None, n_steps, n_input]) 50 y = tf.placeholder("float", [None, n_classes]) 51 52 # Define weights 53 weights = { 54 # Hidden layer weights => 2*n_hidden because of forward + backward cells 55 ‘out‘: tf.Variable(tf.random_normal([2*n_hidden, n_classes])) 56 } 57 biases = { 58 ‘out‘: tf.Variable(tf.random_normal([n_classes])) 59 } 60 61 62 def BiRNN(x, weights, biases): 63 64 # Prepare data shape to match `bidirectional_rnn` function requirements 65 # Current data input shape: (batch_size, n_steps, n_input) 66 # Required shape: ‘n_steps‘ tensors list of shape (batch_size, n_input) 67 68 # Unstack to get a list of ‘n_steps‘ tensors of shape (batch_size, n_input) 69 # 变成了n_steps*(batch_size, n_input) 70 x = tf.unstack(x, n_steps, 1) 71 72 # Define lstm cells with tensorflow 73 # Forward direction cell 74 lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) 75 # Backward direction cell 76 lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) 77 78 # Get lstm cell output 79 try: 80 outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, 81 dtype=tf.float32) 82 except Exception: # Old TensorFlow version only returns outputs not states 83 outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, 84 dtype=tf.float32) 85 86 # Linear activation, using rnn inner loop last output 87 return tf.matmul(outputs[-1], weights[‘out‘]) + biases[‘out‘] 88 89 pred = BiRNN(x, weights, biases) 90 91 # Define loss and optimizer 92 # softmax_cross_entropy_with_logits:Measures the probability error in discrete classification tasks in which the classes are mutually exclusive 93 # return a 1-D Tensor of length batch_size of the same type as logits with the softmax cross entropy loss. 94 # reduce_mean就是对所有数值(这里没有指定哪一维)求均值。 95 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 96 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) 97 98 # Evaluate model 99 correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1)) 100 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 101 102 # Initializing the variables 103 init = tf.global_variables_initializer() 104 105 # Launch the graph 106 with tf.Session() as sess: 107 sess.run(init) 108 step = 1 109 # Keep training until reach max iterations 110 while step * batch_size < training_iters: 111 batch_x, batch_y = mnist.train.next_batch(batch_size) 112 # Reshape data to get 28 seq of 28 elements 113 batch_x = batch_x.reshape((batch_size, n_steps, n_input)) 114 # Run optimization op (backprop) 115 sess.run(optimizer, feed_dict={x: batch_x, y: batch_y}) 116 if step % display_step == 0: 117 # Calculate batch accuracy 118 acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y}) 119 # Calculate batch loss 120 loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y}) 121 print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + 122 "{:.6f}".format(loss) + ", Training Accuracy= " + 123 "{:.5f}".format(acc)) 124 step += 1 125 print("Optimization Finished!") 126 127 # Calculate accuracy for 128 mnist test images 128 test_len = 128 129 test_data = http://www.mamicode.com/mnist.test.images[:test_len].reshape((-1, n_steps, n_input)) 130 test_label = mnist.test.labels[:test_len] 131 print("Testing Accuracy:", 132 sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
官方关于bilstm的例子写的很清楚了。因为是第一次看,还是要查许多东西。尤其是数据处理方面。
数据的处理(https://segmentfault.com/a/1190000008793389)
拼接
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
tf.stack([t1, t2], 0) ==> [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
tf.stack([t1, t2], 1) ==> [[[1, 2, 3], [7, 8, 9]], [[4, 5, 6], [10, 11, 12]]]
tf.stack([t1, t2], 2) ==> [[[1, 7], [2, 8], [3, 9]], [[4, 10], [5, 11], [6, 12]]]
从shape的角度看:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) # [2,3] + [2,3] ==> [4, 3]
tf.concat([t1, t2], 1) # [2,3] + [2,3] ==> [2, 6]
tf.stack([t1, t2], 0) # [2,3] + [2,3] ==> [2*,2,3]
tf.stack([t1, t2], 1) # [2,3] + [2,3] ==> [2,2*,3]
tf.stack([t1, t2], 2) # [2,3] + [2,3] ==> [2,3,2*]
抽取:
input = [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
[[5, 5, 5]]]
tf.gather(input, [0, 2]) ==> [[[1, 1, 1], [2, 2, 2]],
[[5, 5, 5], [6, 6, 6]]]
tensorflow bilstm官方示例
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。