page_type: reference
<style> table img { max-width: 100%; } </style>- Class
tf.GradientTape
- Class
tf.contrib.eager.GradientTape
Defined in tensorflow/python/eager/backprop.py
.
Record operations for automatic differentiation.
Operations are recorded if they are executed within this context manager and at least one of their inputs is being "watched".
Trainable variables (created by tf.contrib.eager.Variable
or
tf.get_variable
, trainable=True is default in both cases) are automatically
watched. Tensors can be manually watched by invoking the watch
method on
this context manager.
For example, consider the function y = x * x
. The gradient at x = 3.0
can
be computed as:
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) # Will compute to 6.0
GradientTapes can be nested to compute higher-order derivatives. For example,
x = tf.constant(3.0)
with tf.GradientTape() as g:
with tf.GradientTape() as gg:
gg.watch(x)
y = x * x
dy_dx = gg.gradient(y, x) # Will compute to 6.0
d2y_dx2 = g.gradient(dy_dx, x) # Will compute to 2.0
By default, the resources held by a GradientTape are released as soon as GradientTape.gradient() method is called. To compute multiple gradients over the same computation, create a persistent gradient tape. This allows multiple calls to the gradient() method as resources are released when the tape object is garbage collected. For example:
x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
g.watch(x)
y = x * x
z = y * y
dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3)
dy_dx = g.gradient(y, x) # 6.0
del g # Drop the reference to the tape
Note that only tensors with real or complex dtypes are differentiable.
__init__(persistent=False)
Creates a new GradientTape.
persistent
: Boolean controlling whether a persistent gradient tape is created. False by default, which means at most one call can be made to the gradient() method on this object.
__enter__()
Enters a context inside which operations are recorded on this tape.
__exit__(
typ,
value,
traceback
)
Exits the recording context, no further operations are traced.
gradient(
target,
sources,
output_gradients=None
)
Computes the gradient using operations recorded in context of this tape.
target
: Tensor (or list of tensors) to be differentiated.sources
: a list or nested structure of Tensors or Variables.target
will be differentiated against elements insources
.output_gradients
: a list of gradients, one for each element of target. Defaults to None.
a list or nested structure of Tensors (or IndexedSlices, or None),
one for each element in sources
. Returned structure is the same as
the structure of sources
.
RuntimeError
: if called inside the context of the tape, or if called more than once on a non-persistent tape.
reset()
Clears all information stored in this tape.
Equivalent to exiting and reentering the tape context manager with a new tape. For example, the two following code blocks are equivalent:
with tf.GradientTape() as t:
loss = loss_fn()
with tf.GradientTape() as t:
loss += other_loss_fn()
t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
# The following is equivalent to the above
with tf.GradientTape() as t:
loss = loss_fn()
t.reset()
loss += other_loss_fn()
t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
This is useful if you don't want to exit the context manager for the tape, or can't because the desired reset point is inside a control flow construct:
with tf.GradientTape() as t:
loss = ...
if loss > k:
t.reset()
stop_recording()
Temporarily stops recording operations on this tape.
Operations executed while this context manager is active will not be recorded on the tape. This is useful for reducing the memory used by tracing all computations.
For example:
with tf.GradientTape(persistent=True) as t:
loss = compute_loss(model)
with t.stop_recording():
# The gradient computation below is not traced, saving memory.
grads = t.gradient(loss, model.variables)
None
RuntimeError
: if the tape is not currently recording.
watch(tensor)
Ensures that tensor
is being traced by this tape.
tensor
: a Tensor or list of Tensors.
watched_variables()
Returns variables watched by this tape in order of construction.