-
Notifications
You must be signed in to change notification settings - Fork 615
/
Copy pathaverage_model_checkpoint.py
96 lines (84 loc) · 3.5 KB
/
average_model_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import tensorflow as tf
from typeguard import typechecked
from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper
class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
r"""The callback that saves average model weights.
The callback that should be used with optimizers that extend
`tfa.optimizers.AveragedOptimizerWrapper`, i.e.,
`tfa.optimizers.MovingAverage` and
`tfa.optimizers.StochasticAverage` optimizers.
It saves and, optionally, assigns the averaged weights.
Args:
update_weights: If `True`, assign the moving average weights
to the model, and save them. If False, keep the old
non-averaged weights, but the saved model uses the
average weights.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
"""
@typechecked
def __init__(
self,
update_weights: bool,
filepath: str,
monitor: str = "val_loss",
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = "auto",
save_freq: str = "epoch",
**kwargs,
):
self.update_weights = update_weights
super().__init__(
filepath,
monitor,
verbose,
save_best_only,
save_weights_only,
mode,
save_freq,
**kwargs,
)
def _get_optimizer(self):
optimizer = self.model.optimizer
if type(optimizer).__name__ in ["LossScaleOptimizer", "LossScaleOptimizerV1"]:
optimizer = optimizer.inner_optimizer
return optimizer
def set_model(self, model):
super().set_model(model)
optimizer = self._get_optimizer()
if not isinstance(optimizer, AveragedOptimizerWrapper):
raise TypeError(
"AverageModelCheckpoint is only used when training"
"with MovingAverage or StochasticAverage"
)
def _save_model(self, *args, **kwargs):
optimizer = self._get_optimizer()
assert isinstance(optimizer, AveragedOptimizerWrapper)
if self.update_weights:
optimizer.assign_average_vars(self.model.trainable_weights)
return super()._save_model(*args, **kwargs)
else:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights = self.model.get_weights()
optimizer.assign_average_vars(self.model.trainable_weights)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result = super()._save_model(*args, **kwargs)
self.model.set_weights(non_avg_weights)
return result