-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathbayesian_resnet.py
128 lines (108 loc) · 4.16 KB
/
bayesian_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Builds a Bayesian ResNet18 Model."""
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import tf_keras
def bayesian_resnet(input_shape,
num_classes=10,
kernel_posterior_scale_mean=-9.0,
kernel_posterior_scale_stddev=0.1,
kernel_posterior_scale_constraint=0.2):
"""Constructs a ResNet18 model.
Args:
input_shape: A `tuple` indicating the Tensor shape.
num_classes: `int` representing the number of class labels.
kernel_posterior_scale_mean: Python `int` number for the kernel
posterior's scale (log variance) mean. The smaller the mean the closer
is the initialization to a deterministic network.
kernel_posterior_scale_stddev: Python `float` number for the initial kernel
posterior's scale stddev.
```
q(W|x) ~ N(mu, var),
log_var ~ N(kernel_posterior_scale_mean, kernel_posterior_scale_stddev)
````
kernel_posterior_scale_constraint: Python `float` number for the log value
to constrain the log variance throughout training.
i.e. log_var <= log(kernel_posterior_scale_constraint).
Returns:
tf_keras.Model.
"""
filters = [64, 128, 256, 512]
kernels = [3, 3, 3, 3]
strides = [1, 2, 2, 2]
def _untransformed_scale_constraint(t):
return tf.clip_by_value(t, -1000,
tf.math.log(kernel_posterior_scale_constraint))
kernel_posterior_fn = tfp.layers.default_mean_field_normal_fn(
untransformed_scale_initializer=tf.compat.v1.initializers.random_normal(
mean=kernel_posterior_scale_mean,
stddev=kernel_posterior_scale_stddev),
untransformed_scale_constraint=_untransformed_scale_constraint)
image = tf_keras.layers.Input(shape=input_shape, dtype='float32')
x = tfp.layers.Convolution2DFlipout(
64,
3,
strides=1,
padding='same',
kernel_posterior_fn=kernel_posterior_fn)(image)
for i in range(len(kernels)):
x = _resnet_block(
x,
filters[i],
kernels[i],
strides[i],
kernel_posterior_fn)
x = tf_keras.layers.BatchNormalization()(x)
x = tf_keras.layers.Activation('relu')(x)
x = tf_keras.layers.AveragePooling2D(4, 1)(x)
x = tf_keras.layers.Flatten()(x)
x = tfp.layers.DenseFlipout(
num_classes,
kernel_posterior_fn=kernel_posterior_fn)(x)
model = tf_keras.Model(inputs=image, outputs=x, name='resnet18')
return model
def _resnet_block(x, filters, kernel, stride, kernel_posterior_fn):
"""Network block for ResNet."""
x = tf_keras.layers.BatchNormalization()(x)
x = tf_keras.layers.Activation('relu')(x)
if stride != 1 or filters != x.shape[1]:
shortcut = _projection_shortcut(x, filters, stride, kernel_posterior_fn)
else:
shortcut = x
x = tfp.layers.Convolution2DFlipout(
filters,
kernel,
strides=stride,
padding='same',
kernel_posterior_fn=kernel_posterior_fn)(x)
x = tf_keras.layers.BatchNormalization()(x)
x = tf_keras.layers.Activation('relu')(x)
x = tfp.layers.Convolution2DFlipout(
filters,
kernel,
strides=1,
padding='same',
kernel_posterior_fn=kernel_posterior_fn)(x)
x = tf_keras.layers.add([x, shortcut])
return x
def _projection_shortcut(x, out_filters, stride, kernel_posterior_fn):
x = tfp.layers.Convolution2DFlipout(
out_filters,
1,
strides=stride,
padding='valid',
kernel_posterior_fn=kernel_posterior_fn)(x)
return x