tensorflow学习三 多线程队列读取二进制文件
还没测试 先留着吧
import tensorflow as tf
import os
path = ""
train_times = 1
#num_epochs指的是将会训练多少轮
#第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,
# 如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。
# shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。
filename_queue = tf.train.string_input_producer(path, shuffle=True, num_epochs=train_times)
#采用读取固定长度二进制数据的读取器,一次读取两个数据类型的数据
reader = tf.FixedLengthRecordReader(record_bytes=2*4)
key, value = reader.read(path)
#将读入的数据按照float32的大小解码
decode_value = tf.decode_raw(value, tf.float32)
v1 = decode_value[0]
v2 = decode_value[1]
v_mul = tf.multiply(v1,v2)
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
#创建会话
sess = tf.Session()
#初始化变量
sess.run(init_op)
sess.run(local_init_op)
#输入数据进入队列
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
try:
while not coord.should_stop():
value1,value2,mul_result = sess.run([v1,v2,v_mul])
print(value1,value2,mul_result)
except tf.erros.OutOfRangeError:
print("Done training -- epoch limit reached")
finally:
coord.request_stop()
#等待线程结束
coord.join(threads)
sess.close()