-
Notifications
You must be signed in to change notification settings - Fork 45.6k
/
Copy pathloop_fns.py
207 lines (168 loc) · 7.84 KB
/
loop_fns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Copyright 2024 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for creating loop functions."""
from absl import logging
from orbit.utils import tpu_summaries
import tensorflow as tf, tf_keras
def create_loop_fn(step_fn):
"""Creates a loop function driven by a Python `while` loop.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. There are no constraints on the return value of the
function (except that it must be compatible with any `reduce_fn` provided
to the returned `loop_fn`).
Returns:
A loop function taking required `iterator` and `num_steps` parameters, as
well as optional `state` and `reduce_fn` parameters for accumulating state
over multiple iterations of the loop. See the `loop_fn` definition below for
additional details.
"""
def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Additionally, state may be accumulated across iterations of the loop.
Conceptually, state accumulation is handled roughly as follows:
for _ in range(num_steps):
step_outputs = step_fn(iterator)
state = reduce_fn(state, step_outputs)
return state
However, the implementation is slightly more complicated in order to support
looping until the iterator is exhausted (when `num_steps == -1`) and to
properly catch exceptions when running under async remote eager (as is the
case in TPU training setups involving separate coordinator/worker machines).
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps == -1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`, or `None` if `state` and
`reduce_fn` are not provided.
"""
step = 0
try:
# To make sure the OutOfRangeError exception can be handled well under
# async remote eager, we need to wrap the loop body in `async_scope`.
with tf.experimental.async_scope():
while num_steps == -1 or step < num_steps:
outputs = step_fn(iterator)
if reduce_fn is not None:
state = reduce_fn(state, outputs)
step += 1
return state
except (StopIteration, tf.errors.OutOfRangeError):
logging.info("The dataset iterator is exhausted after %d steps.", step)
tf.experimental.async_clear_error()
return state
return loop_fn
def create_tf_while_loop_fn(step_fn):
"""Creates a loop function compatible with TF's AutoGraph loop conversion.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator` and `num_steps` parameters. If
called inside a `tf.function`, the loop will be converted by AutoGraph into
a `tf.while_loop` construct. See the `loop_fn` definition below for
additional details.
"""
def loop_fn(iterator, num_steps):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
step_fn(iterator)
return loop_fn
def create_tf_while_loop_fn_with_state(step_fn):
"""Creates a TF while loop function with state.
This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
to be accumulated over multiple iterations of the loop. Note that the
structure of the `state` cannot be changed across iterations.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator`, `num_steps`, `state` and
`reduce_fn` parameters. If called inside a `tf.function`, the loop will be
converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
definition below for additional details.
"""
def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
state: An initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
def _get_relaxed_tensor_shape(t):
"""Returns a `TensorShape` with all `None` dimensions."""
if not tf.is_tensor(t):
return None
shape = t.shape
if shape.rank is not None and shape.rank > 0:
return tf.TensorShape([None] * shape.rank)
return shape
def _get_relaxed_shape_structure(s):
"""Returns the relaxed shape of the input nested structure `s`."""
return tf.nest.pack_sequence_as(
state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)])
for _ in tf.range(num_steps):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(state, _get_relaxed_shape_structure(state))])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
return state
return loop_fn_with_state
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
"""Implements a two-program approach for optimizing summaries on TPU.
This version works with the result of `create_tf_while_loop_fn`.
"""
def __call__(self, iterator, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(iterator, tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(iterator, num_steps)
return output