首页 > 代码库 > tensorflow读取数据之CSV格式

tensorflow读取数据之CSV格式

tensorflow要想用起来,首先自己得搞定数据输入。官方文档中介绍了几种,1.一次性从内存中读取数据到矩阵中,直接输入;2.从文件中边读边输入,而且已经给设计好了多线程读写模型;3.把网络或者内存中的数据转化为tensorflow的专用格式tfRecord,存文件后再读取。

 

其中,从文件中边读边输入,官方文档举例是用的CSV格式文件。我在网上找了一份代码,修改了一下,因为他的比较简略,我就补充一下遇到的问题

先贴代码

 

#coding=utf-8import tensorflow as tf

import numpy as np

defreadMyFileFormat(fileNameQueue):  

reader = tf.TextLineReader()  

key, value = http://www.mamicode.com/reader.read(fileNameQueue)

record_defaults = [[1], [1], [1]]  

col1, col2, col3 = tf.decode_csv(value, record_defaults = record_defaults)  

features = tf.pack([col1, col2])  

label = col3  

return features, label

definputPipeLine(fileNames = ["1.csv","2.csv"], batchSize =4, numEpochs = None):  

fileNameQueue = tf.train.string_input_producer(fileNames, num_epochs = numEpochs)  

example, label = readMyFileFormat(fileNameQueue)  

min_after_dequeue =8  

capacity = min_after_dequeue +3 * batchSize  

exampleBatch, labelBatch = tf.train.shuffle_batch([example, label], batch_size = batchSize, num_threads = 3, capacity = cap acity, min_after_dequeue = min_after_dequeue)  

return exampleBatch, labelBatch

featureBatch, labelBatch = inputPipeLine(["1.csv","2.csv"], batchSize = 4)

with tf.Session() as sess: # Start populating the filename queue.coord = tf.train.Coordinator()  

threads = tf.train.start_queue_runners(coord=coord)  

# Retrieve a single instance:try:#while not coord.should_stop():

whileTrue:  

example, label = sess.run([featureBatch, labelBatch])print example  

except tf.errors.OutOfRangeError:  

print‘Done reading‘  

finally:  

coord.request_stop()  


coord.join(threads)  

sess.close()


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

其中,record_defaults = [[1], [1], [1]] ,是用于指定矩阵格式以及数据类型的,CSV文件中的矩阵,是NXM的,则此处为1XM,[1]中的1 用于指定数据类型,比如矩阵中如果有小数,则为float,[1]应该变为[1.0]。


col1, col2, col3 = tf.decode_csv(value, record_defaults = record_defaults) , 矩阵中有几列,这里就要写几个参数,比如5列,就要写到col5,不管你到底用多少。否则报错。


tf.pack([col1, col2]) ,好像要求col1与col2是同一数据类型,否则报错。

我的测试数据

 

 

-0.7615.67-0.1215.67
-0.4812.52-0.0612.51
1.339.110.129.1
-0.8820.35-0.1820.36
-0.253.99-0.013.99
-0.8726.25-0.2326.25
-1.032.87-0.032.87
-0.517.81-0.047.81
-1.5714.46-0.2314.46
-0.110.02-0.0110.02
-0.568.92-0.058.92
-1.24.1-0.054.1
-0.775.15-0.045.15
-0.884.48-0.044.48
-2.710.82-0.310.82
-1.232.4-0.032.4
-0.775.16-0.045.15
-0.816.15-0.056.15
-0.65.01-0.035
-1.254.75-0.064.75
-2.537.31-0.197.3
-1.1516.39-0.1916.39
-1.75.19-0.095.18
-0.623.23-0.023.22
-0.7417.43-0.1317.41
-0.7715.41-0.1215.41
047047.01
0.253.980.013.98
-1.19.01-0.19.01
-1.023.87-0.043.87

tensorflow读取数据之CSV格式