首页 > 代码库 > TensorFlow读取CSV数据(批量)

TensorFlow读取CSV数据(批量)

直接上代码:

# -*- coding:utf-8 -*-import tensorflow as tfdef read_data(file_queue):    reader = tf.TextLineReader(skip_header_lines=1)    key, value = reader.read(file_queue)    defaults = [[0], [0.], [0.], [0.], [0.], [‘‘]]    Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species = tf.decode_csv(value, defaults)    #因为使用的是鸢尾花数据集,这里需要对y值做转换    preprocess_op = tf.case({        tf.equal(Species, tf.constant(Iris-setosa)): lambda: tf.constant(0),        tf.equal(Species, tf.constant(Iris-versicolor)): lambda: tf.constant(1),        tf.equal(Species, tf.constant(Iris-virginica)): lambda: tf.constant(2),    }, lambda: tf.constant(-1), exclusive=True)    return tf.stack([SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm]), preprocess_opdef create_pipeline(filename, batch_size, num_epochs=None):    file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)    example, label = read_data(file_queue)    min_after_dequeue = 1000    capacity = min_after_dequeue + batch_size    example_batch, label_batch = tf.train.shuffle_batch(        [example, label], batch_size=batch_size, capacity=capacity,        min_after_dequeue=min_after_dequeue    )    return example_batch, label_batchx_train_batch, y_train_batch = create_pipeline(Iris-train.csv, 50, num_epochs=1000)x_test, y_test = create_pipeline(Iris-test.csv, 60)init_op = tf.global_variables_initializer()local_init_op = tf.local_variables_initializer()  # local variables like epoch_num, batch_sizewith tf.Session() as sess:    sess.run(init_op)    sess.run(local_init_op)    # Start populating the filename queue.    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    # Retrieve a single instance:    try:        #while not coord.should_stop():        while True:            example, label = sess.run([x_train_batch, y_train_batch])            print (example)            print (label)    except tf.errors.OutOfRangeError:        print (Done reading)    finally:        coord.request_stop()    coord.join(threads)    sess.close()

 

数据集是鸢尾花数据集,大家自行下载吧,下面给个示例:

Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species21,5.4,3.4,1.7,0.2,Iris-setosa22,5.1,3.7,1.5,0.4,Iris-setosa23,4.6,3.6,1.0,0.2,Iris-setosa24,5.1,3.3,1.7,0.5,Iris-setosa25,4.8,3.4,1.9,0.2,Iris-setosa26,5.0,3.0,1.6,0.2,Iris-setosa27,5.0,3.4,1.6,0.4,Iris-setosa28,5.2,3.5,1.5,0.2,Iris-setosa29,5.2,3.4,1.4,0.2,Iris-setosa30,4.7,3.2,1.6,0.2,Iris-setosa31,4.8,3.1,1.6,0.2,Iris-setosa32,5.4,3.4,1.5,0.4,Iris-setosa33,5.2,4.1,1.5,0.1,Iris-setosa34,5.5,4.2,1.4,0.2,Iris-setosa35,4.9,3.1,1.5,0.1,Iris-setosa36,5.0,3.2,1.2,0.2,Iris-setosa37,5.5,3.5,1.3,0.2,Iris-setosa

 

TensorFlow读取CSV数据(批量)