首页 > 代码库 > tensorflow读取数据之CSV格式
tensorflow读取数据之CSV格式
tensorflow要想用起来,首先自己得搞定数据输入。官方文档中介绍了几种,1.一次性从内存中读取数据到矩阵中,直接输入;2.从文件中边读边输入,而且已经给设计好了多线程读写模型;3.把网络或者内存中的数据转化为tensorflow的专用格式tfRecord,存文件后再读取。
其中,从文件中边读边输入,官方文档举例是用的CSV格式文件。我在网上找了一份代码,修改了一下,因为他的比较简略,我就补充一下遇到的问题
先贴代码
#coding=utf-8import tensorflow as tf
import numpy as np
defreadMyFileFormat(fileNameQueue):
reader = tf.TextLineReader()
key, value = http://www.mamicode.com/reader.read(fileNameQueue)
record_defaults = [[1], [1], [1]]
col1, col2, col3 = tf.decode_csv(value, record_defaults = record_defaults)
features = tf.pack([col1, col2])
label = col3
return features, label
definputPipeLine(fileNames = ["1.csv","2.csv"], batchSize =4, numEpochs = None):
fileNameQueue = tf.train.string_input_producer(fileNames, num_epochs = numEpochs)
example, label = readMyFileFormat(fileNameQueue)
min_after_dequeue =8
capacity = min_after_dequeue +3 * batchSize
exampleBatch, labelBatch = tf.train.shuffle_batch([example, label], batch_size = batchSize, num_threads = 3, capacity = cap acity, min_after_dequeue = min_after_dequeue)
return exampleBatch, labelBatch
featureBatch, labelBatch = inputPipeLine(["1.csv","2.csv"], batchSize = 4)
with tf.Session() as sess: # Start populating the filename queue.coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Retrieve a single instance:try:#while not coord.should_stop():
whileTrue:
example, label = sess.run([featureBatch, labelBatch])print example
except tf.errors.OutOfRangeError:
print‘Done reading‘
finally:
coord.request_stop()
coord.join(threads)
sess.close()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
其中,record_defaults = [[1], [1], [1]] ,是用于指定矩阵格式以及数据类型的,CSV文件中的矩阵,是NXM的,则此处为1XM,[1]中的1 用于指定数据类型,比如矩阵中如果有小数,则为float,[1]应该变为[1.0]。
col1, col2, col3 = tf.decode_csv(value, record_defaults = record_defaults) , 矩阵中有几列,这里就要写几个参数,比如5列,就要写到col5,不管你到底用多少。否则报错。
tf.pack([col1, col2]) ,好像要求col1与col2是同一数据类型,否则报错。
我的测试数据
-0.76 | 15.67 | -0.12 | 15.67 |
-0.48 | 12.52 | -0.06 | 12.51 |
1.33 | 9.11 | 0.12 | 9.1 |
-0.88 | 20.35 | -0.18 | 20.36 |
-0.25 | 3.99 | -0.01 | 3.99 |
-0.87 | 26.25 | -0.23 | 26.25 |
-1.03 | 2.87 | -0.03 | 2.87 |
-0.51 | 7.81 | -0.04 | 7.81 |
-1.57 | 14.46 | -0.23 | 14.46 |
-0.1 | 10.02 | -0.01 | 10.02 |
-0.56 | 8.92 | -0.05 | 8.92 |
-1.2 | 4.1 | -0.05 | 4.1 |
-0.77 | 5.15 | -0.04 | 5.15 |
-0.88 | 4.48 | -0.04 | 4.48 |
-2.7 | 10.82 | -0.3 | 10.82 |
-1.23 | 2.4 | -0.03 | 2.4 |
-0.77 | 5.16 | -0.04 | 5.15 |
-0.81 | 6.15 | -0.05 | 6.15 |
-0.6 | 5.01 | -0.03 | 5 |
-1.25 | 4.75 | -0.06 | 4.75 |
-2.53 | 7.31 | -0.19 | 7.3 |
-1.15 | 16.39 | -0.19 | 16.39 |
-1.7 | 5.19 | -0.09 | 5.18 |
-0.62 | 3.23 | -0.02 | 3.22 |
-0.74 | 17.43 | -0.13 | 17.41 |
-0.77 | 15.41 | -0.12 | 15.41 |
0 | 47 | 0 | 47.01 |
0.25 | 3.98 | 0.01 | 3.98 |
-1.1 | 9.01 | -0.1 | 9.01 |
-1.02 | 3.87 | -0.04 | 3.87 |
tensorflow读取数据之CSV格式