-
Notifications
You must be signed in to change notification settings - Fork 2k
/
Copy pathexports_initializers.ts
210 lines (195 loc) · 6.73 KB
/
exports_initializers.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
// tslint:disable-next-line:max-line-length
import {Constant, ConstantArgs, GlorotNormal, GlorotUniform, HeNormal, HeUniform, Identity, IdentityArgs, Initializer, LeCunNormal, LeCunUniform, Ones, Orthogonal, OrthogonalArgs, RandomNormal, RandomNormalArgs, RandomUniform, RandomUniformArgs, SeedOnlyInitializerArgs, TruncatedNormal, TruncatedNormalArgs, VarianceScaling, VarianceScalingArgs, Zeros} from './initializers';
/**
* Initializer that generates tensors initialized to 0.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function zeros(): Zeros {
return new Zeros();
}
/**
* Initializer that generates tensors initialized to 1.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function ones(): Initializer {
return new Ones();
}
/**
* Initializer that generates values initialized to some constant.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function constant(args: ConstantArgs): Initializer {
return new Constant(args);
}
/**
* Initializer that generates random values initialized to a uniform
* distribution.
*
* Values will be distributed uniformly between the configured minval and
* maxval.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function randomUniform(args: RandomUniformArgs): Initializer {
return new RandomUniform(args);
}
/**
* Initializer that generates random values initialized to a normal
* distribution.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function randomNormal(args: RandomNormalArgs): Initializer {
return new RandomNormal(args);
}
/**
* Initializer that generates random values initialized to a truncated normal
* distribution.
*
* These values are similar to values from a `RandomNormal` except that values
* more than two standard deviations from the mean are discarded and re-drawn.
* This is the recommended initializer for neural network weights and filters.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function truncatedNormal(args: TruncatedNormalArgs): Initializer {
return new TruncatedNormal(args);
}
/**
* Initializer that generates the identity matrix.
* Only use for square 2D matrices.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function identity(args: IdentityArgs): Initializer {
return new Identity(args);
}
/**
* Initializer capable of adapting its scale to the shape of weights.
* With distribution=NORMAL, samples are drawn from a truncated normal
* distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
* - number of input units in the weight tensor, if mode = FAN_IN.
* - number of output units, if mode = FAN_OUT.
* - average of the numbers of input and output units, if mode = FAN_AVG.
* With distribution=UNIFORM,
* samples are drawn from a uniform distribution
* within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
*
* @doc {heading: 'Initializers',namespace: 'initializers'}
*/
export function varianceScaling(config: VarianceScalingArgs): Initializer {
return new VarianceScaling(config);
}
/**
* Glorot uniform initializer, also called Xavier uniform initializer.
* It draws samples from a uniform distribution within [-limit, limit]
* where `limit` is `sqrt(6 / (fan_in + fan_out))`
* where `fan_in` is the number of input units in the weight tensor
* and `fan_out` is the number of output units in the weight tensor
*
* Reference:
* Glorot & Bengio, AISTATS 2010
* http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function glorotUniform(args: SeedOnlyInitializerArgs): Initializer {
return new GlorotUniform(args);
}
/**
* Glorot normal initializer, also called Xavier normal initializer.
* It draws samples from a truncated normal distribution centered on 0
* with `stddev = sqrt(2 / (fan_in + fan_out))`
* where `fan_in` is the number of input units in the weight tensor
* and `fan_out` is the number of output units in the weight tensor.
*
* Reference:
* Glorot & Bengio, AISTATS 2010
* http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function glorotNormal(args: SeedOnlyInitializerArgs): Initializer {
return new GlorotNormal(args);
}
/**
* He normal initializer.
*
* It draws samples from a truncated normal distribution centered on 0
* with `stddev = sqrt(2 / fanIn)`
* where `fanIn` is the number of input units in the weight tensor.
*
* Reference:
* He et al., http://arxiv.org/abs/1502.01852
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function heNormal(args: SeedOnlyInitializerArgs): Initializer {
return new HeNormal(args);
}
/**
* He uniform initializer.
*
* It draws samples from a uniform distribution within [-limit, limit]
* where `limit` is `sqrt(6 / fan_in)`
* where `fanIn` is the number of input units in the weight tensor.
*
* Reference:
* He et al., http://arxiv.org/abs/1502.01852
*
* @doc {heading: 'Initializers',namespace: 'initializers'}
*/
export function heUniform(args: SeedOnlyInitializerArgs): Initializer {
return new HeUniform(args);
}
/**
* LeCun normal initializer.
*
* It draws samples from a truncated normal distribution centered on 0
* with `stddev = sqrt(1 / fanIn)`
* where `fanIn` is the number of input units in the weight tensor.
*
* References:
* [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
* [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function leCunNormal(args: SeedOnlyInitializerArgs): Initializer {
return new LeCunNormal(args);
}
/**
* LeCun uniform initializer.
*
* It draws samples from a uniform distribution in the interval
* `[-limit, limit]` with `limit = sqrt(3 / fanIn)`,
* where `fanIn` is the number of input units in the weight tensor.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function leCunUniform(args: SeedOnlyInitializerArgs): Initializer {
return new LeCunUniform(args);
}
/**
* Initializer that generates a random orthogonal matrix.
*
* Reference:
* [Saxe et al., http://arxiv.org/abs/1312.6120](http://arxiv.org/abs/1312.6120)
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
export function orthogonal(args: OrthogonalArgs): Initializer {
return new Orthogonal(args);
}