在tensorflow中session.run()用来将数据传入计算图,计算并返回出给定变量/placeholder的结果。
在看论文代码的时候遇到一段复杂的feed_dict, 本文记录了对sess.run()的复习。
1.tensorflow Session.run()
session.run()
的函数定义如下,可以在交互式python中sess = tf.Session; ?sess.run
,也可以在源码 line846中查看到。首先来看函数的参数定义:
run(self, fetches, feed_dict=None, options=None, run_metadata=None)
其中常用的fetches
和feed_dict
就是常用的传入参数。fetches主要指从计算图中取回计算结果进行放回的那些placeholder和变量,而feed_dict
则是将对应的数据传入计算图中占位符,它是字典数据结构只在调用方法内有效。
参考这个例子额的解释,最下面的fetch和feed,原始定义 在make_callable
下面让我们来看看官方代码中对run()
函数的解释:
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
"""Runs operations and evaluates tensors in `fetches`.
运行操作和对fetches中的张量进行计算
This method runs one "step" of TensorFlow computation, by
running the necessary graph fragment to execute every `Operation`
and evaluate every `Tensor` in `fetches`, substituting the values in
`feed_dict` for the corresponding input values.
这一方法将在tensorflow中运行一次计算,通过将feed_dict中的数据馈入计算图中,
运行计算图定义的操作并最终得到fectch中tensor的评测结果
The `fetches` argument may be a single graph element, or an arbitrarily
nested list, tuple, namedtuple, dict, or OrderedDict containing graph
elements at its leaves. A graph element can be one of the following types:
fecches是从计算图中取出对应变量的参数,可以是单个图元素、任意的列表、元组、字典等等形式的图元素。
图元素包括操作、张量、稀疏张量、句柄、字符串等等。
* A `tf.Operation`.
The corresponding fetched value will be `None`.
* A `tf.Tensor`.
The corresponding fetched value will be a numpy ndarray containing the
value of that tensor.
* A `tf.SparseTensor`.
The corresponding fetched value will be a
`tf.compat.v1.SparseTensorValue`
containing the value of that sparse tensor.
* A `get_tensor_handle` op. The corresponding fetched value will be a
numpy ndarray containing the handle of that tensor.
* A `string` which is the name of a tensor or operation in the graph.
The value returned by `run()` has the same shape as the `fetches` argument,
where the leaves are replaced by the corresponding values returned by
TensorFlow.
run的返回值与fetches的形状一致
Example:
```python
a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
# 'fetches' can be a singleton
v = session.run(a)
# v is the numpy array [10, 20] # 这里就是单个元素作为fetch数值
# 'fetches' can be a list.
v = session.run([a, b]) # 这里作为list取回值
# v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
# 1-D array [1.0, 2.0]
# 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
MyData = collections.namedtuple('MyData', ['a', 'b'])
v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
# v is a dict with
# v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
# 'b' (the numpy array [1.0, 2.0])
# v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
# [10, 20].
```
feed_dict可以使