原文链接: TensorFlow 队列 queue
上一篇: TensorFlow 图像处理函数
参考
https://blog.csdn.net/menghaocheng/article/details/79621482
RandomShuffleQueue: 按照随机顺序进行的队列;
FIFOQueue: 按照先进先出的顺序排序的队列;
PaddingFIFOQueue: 通过填充来支持可变大小疑是的先进先出队列;
PriorityQueu:按照优先顺序进行排序的队列;
FIFOQueue:先入先出队列;
QueueRunner:队列管理器;
Coordinator:协调器。
新方法 | 描述 |
tf.FIFOQueue(3,'float') | 创建一个长度为3的队列,数据类型为'float' |
q.enqueue(vals,name=None) | 将一个元素编入该队列,如果在执行该操作时队列已满,那么将会阻塞直到幸免于难编入队列之中 |
q.enqueue_many(vals,name=None) | 将零个或多个元素编入队列之中 |
q.dequeue(name=None) | 把元素从队列中移出,如果在执行该操作时队列已空,那么将会阻塞直到元素出列,返回出列的tensors的tuple |
q.dequeue_many(n,name=None) | 将一个或多个元素从队列中移出 注:本例中未使用到这个方法,但聪明的你应该已经联想到了。 |
简单例子
import tensorflow as tf
with tf.Session() as sess:
q = tf.FIFOQueue(3, 'float')
init = q.enqueue_many(([0.1, 0.2, 0.3],))
init2 = q.dequeue()
init3 = q.enqueue(4.)
print(sess.run(init))
print(sess.run(init2))
print(sess.run(init3))
quelen = sess.run(q.size())
for i in range(quelen):
print(sess.run(q.dequeue()))
放入没有返回值
None
0.1
None
0.2
0.3
4.0
先执行两次inc 这样队列中的数据变为 0.3, 1.1, 1.2, 再将数据打印出来
如果dequeue 时队列为空,则程序会一直等待,直到写入一个数据
import tensorflow as tf
# 创建的图:一个先入先出队列,以及初始化,出队,+1,入队操作
q = tf.FIFOQueue(10, "float")
init = q.enqueue_many(([0.1, 0.2, 0.3],))
x = q.dequeue()
y = x + 1
q_inc = q.enqueue([y])
# 开启一个session,session是会话,会话的潜在含义是状态保持,各种tensor的状态保持
with tf.Session() as sess:
sess.run(init)
for i in range(2):
sess.run(q_inc)
quelen = sess.run(q.size())
for i in range(quelen):
print(sess.run(q.dequeue()))
0.3
1.1
1.2
队列的操作是在主线程的对话中依次完成的,这样的操作会造成数据的读取和输入较慢,处理相对困难。因此我们引入TF的队列管理器QueueRunner
新方法 | 描述 |
qr=tf.train.QueuRunner(q,enqueue_ops=[...]*2) | 创建队列管理器:指定管理的是队列q;指定入列操作;用两个线程去完成此项任务 |
qq.create_threads(sess,start=True) | 启动线程,并指定会话(TF所有的操作都是在会话中执行的) |
import tensorflow as tf
with tf.Session() as sess:
q = tf.FIFOQueue(1000, 'float32')
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter, tf.constant(1.0))
enqueueData_op = q.enqueue(counter)
qr = tf.train.QueueRunner(q, enqueue_ops=[add_op, enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())
enqueue_threads = qr.create_threads(sess, start=True)
for i in range(10):
print(sess.run(q.dequeue()))
发现打印了10次输出后开始报错了。原因是主线程执行了10次出列操作,然后随着with上下文的退出而关闭了会话,而入列线程还在继续!我们改一下:
import tensorflow as tf
q = tf.FIFOQueue(1000, 'float32')
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter, tf.constant(1.0))
enqueueData_op = q.enqueue(counter)
sess = tf.Session()
qr = tf.train.QueueRunner(q, enqueue_ops=[add_op, enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())
enqueue_threads = qr.create_threads(sess, start=True)
for i in range(10):
print(sess.run(q.dequeue()))
只是去掉了with上下文管理,主线程在执行完10次后没有关闭会话。再次运行,可以看到此时的会话并没有看错,但是程序出没有结束,而是被挂起。造成这种情况的原因是add操作和入队操作没有同步,即TensorFlow在队列设计时为优化IO系统,队列的操作一般使用批处理,这样入队线程没有发送结束的信息而程序主纯种期望将程序结束,因此造成程序堵塞程序被挂起。
【Coordinator】(线程同步与停止):
新方法 | 描述 |
tf.train.Coordinator() | 创建线程协调器 |
coord.request_stop() | 主线程请求结束 |
coord.join(enqueue_threads) | 等待线程结束 |
之所以会有相同的数字,是线程读取的c的值相同了
import tensorflow as tf
q = tf.FIFOQueue(10, 'float32')
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter, tf.constant(1.0))
enqueueData_op = q.enqueue(counter)
sess = tf.Session()
qr = tf.train.QueueRunner(q, enqueue_ops=[add_op, enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
for i in range(10):
print(i, sess.run(q.dequeue()))
coord.request_stop()
coord.join(enqueue_threads)
0 1.0
1 1.0
2 2.0
3 3.0
4 4.0
5 6.0
6 6.0
7 6.0
8 7.0
9 7.0
现在只使用一个线程写入
enqueueData_op = q.enqueue(tf.assign_add(counter, tf.constant(1.0)))
qr = tf.train.QueueRunner(q, enqueue_ops=[enqueueData_op])
十个线程,每个线程执行add,然后放入队列
add_op = tf.assign_add(counter, tf.constant(1.0))
enqueueData_op = q.enqueue(counter)
# enqueueData_op = q.enqueue(tf.assign_add(counter, tf.constant(1.0)))
sess = tf.Session()
qr = tf.train.QueueRunner(q, enqueue_ops=[add_op, enqueueData_op] * 10)
import tensorflow as tf
q = tf.FIFOQueue(10, 'float32')
counter = tf.Variable(0.0)
enqueueData_op = q.enqueue(tf.assign_add(counter, tf.constant(1.0)))
sess = tf.Session()
qr = tf.train.QueueRunner(q, enqueue_ops=[enqueueData_op])
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
for i in range(100):
print(i, sess.run(q.dequeue()))
coord.request_stop()
coord.join(enqueue_threads)
第一个数字是0,是因为第一次执行的时候已经是加一后的值
0 1.0
1 2.0
2 3.0
3 4.0
。。。。
93 94.0
94 95.0
95 96.0
96 97.0
97 98.0
98 99.0
99 100.0