-
Notifications
You must be signed in to change notification settings - Fork 615
/
Copy pathcyclical_learning_rate.py
324 lines (271 loc) · 12.3 KB
/
cyclical_learning_rate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cyclical Learning Rate Schedule policies for TensorFlow."""
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike
from typeguard import typechecked
from typing import Union, Callable
@tf.keras.utils.register_keras_serializable(package="Addons")
class CyclicalLearningRate(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses cyclical schedule."""
@typechecked
def __init__(
self,
initial_learning_rate: Union[FloatTensorLike, Callable],
maximal_learning_rate: Union[FloatTensorLike, Callable],
step_size: FloatTensorLike,
scale_fn: Callable,
scale_mode: str = "cycle",
name: str = "CyclicalLearningRate",
):
"""Applies cyclical schedule to the learning rate.
See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186
```python
lr_schedule = tf.keras.optimizers.schedules.CyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_fn=lambda x: 1.,
scale_mode="cycle",
name="MyCyclicScheduler")
model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, epochs=5)
```
You can pass this schedule directly into a
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size denotes the number of training iterations it takes to get to maximal_learning_rate.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.
Returns:
Updated learning rate value.
"""
super().__init__()
self.initial_learning_rate = initial_learning_rate
self.maximal_learning_rate = maximal_learning_rate
self.step_size = step_size
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or "CyclicalLearningRate"):
initial_learning_rate = tf.convert_to_tensor(
self.initial_learning_rate, name="initial_learning_rate"
)
dtype = initial_learning_rate.dtype
maximal_learning_rate = tf.cast(self.maximal_learning_rate, dtype)
step_size = tf.cast(self.step_size, dtype)
step_as_dtype = tf.cast(step, dtype)
cycle = tf.floor(1 + step_as_dtype / (2 * step_size))
x = tf.abs(step_as_dtype / step_size - 2 * cycle + 1)
mode_step = cycle if self.scale_mode == "cycle" else step
return initial_learning_rate + (
maximal_learning_rate - initial_learning_rate
) * tf.maximum(tf.cast(0, dtype), (1 - x)) * self.scale_fn(mode_step)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"maximal_learning_rate": self.maximal_learning_rate,
"scale_fn": self.scale_fn,
"step_size": self.step_size,
"scale_mode": self.scale_mode,
}
@tf.keras.utils.register_keras_serializable(package="Addons")
class TriangularCyclicalLearningRate(CyclicalLearningRate):
@typechecked
def __init__(
self,
initial_learning_rate: Union[FloatTensorLike, Callable],
maximal_learning_rate: Union[FloatTensorLike, Callable],
step_size: FloatTensorLike,
scale_mode: str = "cycle",
name: str = "TriangularCyclicalLearningRate",
):
"""Applies triangular cyclical schedule to the learning rate.
See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186
```python
from tf.keras.optimizers import schedules
lr_schedule = schedules.TriangularCyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
name="MyCyclicScheduler")
model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, epochs=5)
```
You can pass this schedule directly into a
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size denotes the number of training iterations it takes to get to maximal_learning_rate
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.
Returns:
Updated learning rate value.
"""
super().__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: 1.0,
scale_mode=scale_mode,
name=name,
)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"maximal_learning_rate": self.maximal_learning_rate,
"step_size": self.step_size,
"scale_mode": self.scale_mode,
}
@tf.keras.utils.register_keras_serializable(package="Addons")
class Triangular2CyclicalLearningRate(CyclicalLearningRate):
@typechecked
def __init__(
self,
initial_learning_rate: Union[FloatTensorLike, Callable],
maximal_learning_rate: Union[FloatTensorLike, Callable],
step_size: FloatTensorLike,
scale_mode: str = "cycle",
name: str = "Triangular2CyclicalLearningRate",
):
"""Applies triangular2 cyclical schedule to the learning rate.
See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186
```python
from tf.keras.optimizers import schedules
lr_schedule = schedules.Triangular2CyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
name="MyCyclicScheduler")
model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, epochs=5)
```
You can pass this schedule directly into a
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size denotes the number of training iterations it takes to get to maximal_learning_rate
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.
Returns:
Updated learning rate value.
"""
super().__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: 1 / (2.0 ** (x - 1)),
scale_mode=scale_mode,
name=name,
)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"maximal_learning_rate": self.maximal_learning_rate,
"step_size": self.step_size,
"scale_mode": self.scale_mode,
}
@tf.keras.utils.register_keras_serializable(package="Addons")
class ExponentialCyclicalLearningRate(CyclicalLearningRate):
@typechecked
def __init__(
self,
initial_learning_rate: Union[FloatTensorLike, Callable],
maximal_learning_rate: Union[FloatTensorLike, Callable],
step_size: FloatTensorLike,
scale_mode: str = "iterations",
gamma: FloatTensorLike = 1.0,
name: str = "ExponentialCyclicalLearningRate",
):
"""Applies exponential cyclical schedule to the learning rate.
See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186
```python
from tf.keras.optimizers import schedules
lr_schedule = ExponentialCyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
gamma=0.96,
name="MyCyclicScheduler")
model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, epochs=5)
```
You can pass this schedule directly into a
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size denotes the number of training iterations it takes to get to maximal_learning_rate
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
gamma: A scalar `float32` or `float64` `Tensor` or a
Python number. Gamma value.
name: (Optional) Name for the operation.
Returns:
Updated learning rate value.
"""
self.gamma = gamma
super().__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: gamma**x,
scale_mode=scale_mode,
name=name,
)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"maximal_learning_rate": self.maximal_learning_rate,
"step_size": self.step_size,
"scale_mode": self.scale_mode,
"gamma": self.gamma,
}