-
Notifications
You must be signed in to change notification settings - Fork 19.6k
/
Copy pathconstant_initializers.py
153 lines (117 loc) · 4.77 KB
/
constant_initializers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend import standardize_dtype
from keras.src.initializers.initializer import Initializer
from keras.src.saving import serialization_lib
@keras_export(["keras.initializers.Constant", "keras.initializers.constant"])
class Constant(Initializer):
"""Initializer that generates tensors with constant values.
Only scalar values are allowed.
The constant value provided must be convertible to the dtype requested
when calling the initializer.
Examples:
>>> # Standalone usage:
>>> initializer = Constant(10.)
>>> values = initializer(shape=(2, 2))
>>> # Usage in a Keras layer:
>>> initializer = Constant(10.)
>>> layer = Dense(3, kernel_initializer=initializer)
Args:
value: A Python scalar.
"""
def __init__(self, value=0.0):
self.value = value
def __call__(self, shape, dtype=None):
dtype = standardize_dtype(dtype)
return ops.cast(self.value, dtype=dtype) * ops.ones(
shape=shape, dtype=dtype
)
def get_config(self):
return {"value": serialization_lib.serialize_keras_object(self.value)}
@classmethod
def from_config(cls, config):
value = serialization_lib.deserialize_keras_object(config["value"])
return cls(value)
@keras_export(["keras.initializers.Zeros", "keras.initializers.zeros"])
class Zeros(Initializer):
"""Initializer that generates tensors initialized to 0.
Examples:
>>> # Standalone usage:
>>> initializer = Zeros()
>>> values = initializer(shape=(2, 2))
>>> # Usage in a Keras layer:
>>> initializer = Zeros()
>>> layer = Dense(units=3, kernel_initializer=initializer)
"""
def __call__(self, shape, dtype=None):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
are supported. If not specified, `keras.backend.floatx()`
is used, which default to `float32` unless you configured it
otherwise (via `keras.backend.set_floatx(float_dtype)`).
"""
dtype = standardize_dtype(dtype)
return ops.zeros(shape, dtype=dtype)
@keras_export(["keras.initializers.Ones", "keras.initializers.ones"])
class Ones(Initializer):
"""Initializer that generates tensors initialized to 1.
Also available via the shortcut function `ones`.
Examples:
>>> # Standalone usage:
>>> initializer = Ones()
>>> values = initializer(shape=(2, 2))
>>> # Usage in a Keras layer:
>>> initializer = Ones()
>>> layer = Dense(3, kernel_initializer=initializer)
"""
def __call__(self, shape, dtype=None):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
are supported. If not specified, `keras.backend.floatx()`
is used, which default to `float32` unless you configured it
otherwise (via `keras.backend.set_floatx(float_dtype)`).
"""
dtype = standardize_dtype(dtype)
return ops.ones(shape, dtype=dtype)
@keras_export(
[
"keras.initializers.IdentityInitializer",
"keras.initializers.Identity",
"keras.initializers.identity",
]
)
class Identity(Initializer):
"""Initializer that generates the identity matrix.
Only usable for generating 2D matrices.
Examples:
>>> # Standalone usage:
>>> initializer = Identity()
>>> values = initializer(shape=(2, 2))
>>> # Usage in a Keras layer:
>>> initializer = Identity()
>>> layer = Dense(3, kernel_initializer=initializer)
Args:
gain: Multiplicative factor to apply to the identity matrix.
"""
def __init__(self, gain=1.0):
self.gain = gain
def __call__(self, shape, dtype=None):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
are supported. If not specified, `keras.backend.floatx()`
is used, which default to `float32` unless you configured it
otherwise (via `keras.backend.set_floatx(float_dtype)`).
"""
if len(shape) != 2:
raise ValueError(
"Identity matrix initializer can only be used for 2D matrices. "
f"Received: shape={shape} of rank {len(shape)}."
)
dtype = standardize_dtype(dtype)
return self.gain * ops.eye(*shape, dtype=dtype)