-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathdistribution.py
2188 lines (1830 loc) · 85.1 KB
/
distribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base classes for probability distributions."""
import abc
import collections
import contextlib
import functools
import inspect
import logging
import types
import decorator
import six
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.distributions import kullback_leibler
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import batch_shape_lib
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import name_util
from tensorflow_probability.python.internal import nest_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import slicing
from tensorflow_probability.python.internal import tensorshape_util
# Symbol import needed to avoid BUILD-dependency cycle
from tensorflow_probability.python.math.generic import log1mexp
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util import tf_inspect # pylint: disable=g-direct-tensorflow-import
__all__ = [
'Distribution',
]
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = {
'batch_shape': '_batch_shape',
'batch_shape_tensor': '_batch_shape_tensor',
'cdf': '_cdf',
'covariance': '_covariance',
'cross_entropy': '_cross_entropy',
'entropy': '_entropy',
'event_shape': '_event_shape',
'event_shape_tensor': '_event_shape_tensor',
'experimental_default_event_space_bijector': (
'_default_event_space_bijector'),
'experimental_sample_and_log_prob': '_sample_and_log_prob',
'kl_divergence': '_kl_divergence',
'log_cdf': '_log_cdf',
'log_prob': '_log_prob',
'log_survival_function': '_log_survival_function',
'mean': '_mean',
'mode': '_mode',
'prob': '_prob',
'sample': '_sample_n',
'stddev': '_stddev',
'survival_function': '_survival_function',
'variance': '_variance',
}
_ALWAYS_COPY_PUBLIC_METHOD_WRAPPERS = ['kl_divergence', 'cross_entropy']
UNSET_VALUE = object()
JAX_MODE = False # Overwritten by rewrite script.
@six.add_metaclass(abc.ABCMeta)
class _BaseDistribution(tf.Module):
"""Abstract base class needed for resolving subclass hierarchy."""
pass
def _copy_fn(fn):
"""Create a deep copy of fn.
Args:
fn: a callable
Returns:
A `FunctionType`: a deep copy of fn.
Raises:
TypeError: if `fn` is not a callable.
"""
if not callable(fn):
raise TypeError('fn is not callable: {}'.format(fn))
# The blessed way to copy a function. copy.deepcopy fails to create a
# non-reference copy. Since:
# types.FunctionType == type(lambda: None),
# and the docstring for the function type states:
#
# function(code, globals[, name[, argdefs[, closure]]])
#
# Create a function object from a code object and a dictionary.
# ...
#
# Here we can use this to create a new function with the old function's
# code, globals, closure, etc.
return types.FunctionType(
code=fn.__code__, globals=fn.__globals__,
name=fn.__name__, argdefs=fn.__defaults__,
closure=fn.__closure__)
def _update_docstring(old_str, append_str):
"""Update old_str by inserting append_str just before the 'Args:' section."""
old_str = old_str or ''
old_str_lines = old_str.split('\n')
# Step 0: Prepend spaces to all lines of append_str. This is
# necessary for correct markdown generation.
append_str = '\n'.join(' %s' % line for line in append_str.split('\n'))
# Step 1: Find mention of 'Args':
has_args_ix = [
ix for ix, line in enumerate(old_str_lines)
if line.strip().lower() == 'args:']
if has_args_ix:
final_args_ix = has_args_ix[-1]
return ('\n'.join(old_str_lines[:final_args_ix])
+ '\n\n' + append_str + '\n\n'
+ '\n'.join(old_str_lines[final_args_ix:]))
else:
return old_str + '\n\n' + append_str
def _remove_dict_keys_with_value(dict_, val):
"""Removes `dict` keys which have have `self` as value."""
return {k: v for k, v in dict_.items() if v is not val}
def _set_sample_static_shape_for_tensor(x,
event_shape,
batch_shape,
sample_shape):
"""Helper to `_set_sample_static_shape`; sets shape info for a `Tensor`."""
sample_shape = tf.TensorShape(tf.get_static_value(sample_shape))
ndims = tensorshape_util.rank(x.shape)
sample_ndims = tensorshape_util.rank(sample_shape)
batch_ndims = tensorshape_util.rank(batch_shape)
event_ndims = tensorshape_util.rank(event_shape)
# Infer rank(x).
if (ndims is None and
sample_ndims is not None and
batch_ndims is not None and
event_ndims is not None):
ndims = sample_ndims + batch_ndims + event_ndims
tensorshape_util.set_shape(x, [None] * ndims)
# Infer sample shape.
if ndims is not None and sample_ndims is not None:
shape = tensorshape_util.concatenate(sample_shape,
[None] * (ndims - sample_ndims))
tensorshape_util.set_shape(x, shape)
# Infer event shape.
if ndims is not None and event_ndims is not None:
shape = tf.TensorShape(
[None]*(ndims - event_ndims)).concatenate(event_shape)
tensorshape_util.set_shape(x, shape)
# Infer batch shape.
if batch_ndims is not None:
if ndims is not None:
if sample_ndims is None and event_ndims is not None:
sample_ndims = ndims - batch_ndims - event_ndims
elif event_ndims is None and sample_ndims is not None:
event_ndims = ndims - batch_ndims - sample_ndims
if sample_ndims is not None and event_ndims is not None:
shape = tf.TensorShape([None]*sample_ndims).concatenate(
batch_shape).concatenate([None]*event_ndims)
tensorshape_util.set_shape(x, shape)
return x
class _DistributionMeta(abc.ABCMeta):
"""Helper metaclass for tfp.Distribution."""
def __new__(mcs, classname, baseclasses, attrs):
"""Control the creation of subclasses of the Distribution class.
The main purpose of this method is to properly propagate docstrings
from private Distribution methods, like `_log_prob`, into their
public wrappers as inherited by the Distribution base class
(e.g. `log_prob`).
Args:
classname: The name of the subclass being created.
baseclasses: A tuple of parent classes.
attrs: A dict mapping new attributes to their values.
Returns:
The class object.
Raises:
TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
the new class is derived via multiple inheritance and the first
parent class is not a subclass of `BaseDistribution`.
AttributeError: If `Distribution` does not implement e.g. `log_prob`.
ValueError: If a `Distribution` public method lacks a docstring.
"""
if not baseclasses: # Nothing to be done for Distribution
raise TypeError('Expected non-empty baseclass. Does Distribution '
'not subclass _BaseDistribution?')
which_base = [
base for base in baseclasses
if base == _BaseDistribution or issubclass(base, Distribution)]
base = which_base[0] if which_base else None
if base is None or base == _BaseDistribution:
# Nothing to be done for Distribution or unrelated subclass.
return super(_DistributionMeta, mcs).__new__(
mcs, classname, baseclasses, attrs)
if not issubclass(base, Distribution):
raise TypeError('First parent class declared for {} must be '
'Distribution, but saw "{}"'.format(
classname, base.__name__))
for attr, special_attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS.items():
if attr in attrs:
# The method is being overridden, do not update its docstring.
continue
class_attr_value = attrs.get(attr, None)
base_attr_value = getattr(base, attr, None)
if not base_attr_value:
raise AttributeError(
'Internal error: expected base class "{}" to '
'implement method "{}"'.format(base.__name__, attr))
class_special_attr_value = attrs.get(special_attr, None)
class_special_attr_docstring = (
None if class_special_attr_value is None else
tf_inspect.getdoc(class_special_attr_value))
if (class_special_attr_docstring or
attr in _ALWAYS_COPY_PUBLIC_METHOD_WRAPPERS):
class_attr_value = _copy_fn(base_attr_value)
attrs[attr] = class_attr_value
if not class_special_attr_docstring:
# No docstring to append.
continue
class_attr_docstring = tf_inspect.getdoc(base_attr_value)
if class_attr_docstring is None:
raise ValueError(
'Expected base class fn to contain a docstring: {}.{}'.format(
base.__name__, attr))
class_attr_value.__doc__ = _update_docstring(
class_attr_value.__doc__,
'Additional documentation from `{}`:\n\n{}'.format(
classname, class_special_attr_docstring))
# Now we'll intercept the default __init__ if it exists.
default_init = attrs.get('__init__', None)
if default_init is None:
# The class has no __init__ because its abstract. (And we won't add one.)
return super(_DistributionMeta, mcs).__new__(
mcs, classname, baseclasses, attrs)
# Warn when a subclass inherits `_parameter_properties` from its parent
# (this is unsafe, since the subclass will in general have different
# parameters). Exceptions are:
# - Subclasses that don't define their own `__init__` (handled above by
# the short-circuit when `default_init is None`).
# - Subclasses that define a passthrough `__init__(self, *args, **kwargs)`.
# pylint: disable=protected-access
init_argspec = tf_inspect.getfullargspec(default_init)
if ('_parameter_properties' not in attrs
# Passthrough exception: may only take `self` and at least one of
# `*args` and `**kwargs`.
and (len(init_argspec.args) > 1
or not (init_argspec.varargs or init_argspec.varkw))):
@functools.wraps(base._parameter_properties)
def wrapped_properties(*args, **kwargs): # pylint: disable=missing-docstring
"""Wrapper to warn if `parameter_properties` is inherited."""
properties = base._parameter_properties(*args, **kwargs)
# Warn *after* calling the base method, so that we don't bother warning
# if it just raised NotImplementedError anyway.
logging.warning("""
Distribution subclass %s inherits `_parameter_properties from its parent (%s)
while also redefining `__init__`. The inherited annotations cover the following
parameters: %s. It is likely that these do not match the subclass parameters.
This may lead to errors when computing batch shapes, slicing into batch
dimensions, calling `.copy()`, flattening the distribution as a CompositeTensor
(e.g., when it is passed or returned from a `tf.function`), and possibly other
cases. The recommended pattern for distribution subclasses is to define a new
`_parameter_properties` method with the subclass parameters, and to store the
corresponding parameter values as `self._parameters` in `__init__`, after
calling the superclass constructor:
```
class MySubclass(tfd.SomeDistribution):
def __init__(self, param_a, param_b):
parameters = dict(locals())
# ... do subclass initialization ...
super(MySubclass, self).__init__(**base_class_params)
# Ensure that the subclass (not base class) parameters are stored.
self._parameters = parameters
def _parameter_properties(self, dtype, num_classes=None):
return dict(
# Annotations may optionally specify properties, such as `event_ndims`,
# `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
# the `ParameterProperties` documentation for details.
param_a=tfp.util.ParameterProperties(),
param_b=tfp.util.ParameterProperties())
```
""", classname, base.__name__, str(properties.keys()))
return properties
attrs['_parameter_properties'] = wrapped_properties
# For a comparison of different methods for wrapping functions, see:
# https://hynek.me/articles/decorators/
@decorator.decorator
def wrapped_init(wrapped, self_, *args, **kwargs):
"""A 'top-level `__init__`' which is always called."""
# We can't use `wrapped` because it results in a self reference which
# confounds `tf.function`.
del wrapped
# Note: if we ever want to have things set in `self` before `__init__` is
# called, here is the place to do it.
self_._parameters = None
default_init(self_, *args, **kwargs)
# Note: if we ever want to override things set in `self` by subclass
# `__init__`, here is the place to do it.
if self_._parameters is None:
# We prefer subclasses will set `parameters = dict(locals())` because
# this has nearly zero overhead. However, failing to do this, we will
# resolve the input arguments dynamically and only when needed.
dummy_self = tuple()
self_._parameters = self_._no_dependency(lambda: ( # pylint: disable=g-long-lambda
_remove_dict_keys_with_value(
inspect.getcallargs(default_init, dummy_self, *args, **kwargs),
dummy_self)))
elif hasattr(self_._parameters, 'pop'):
self_._parameters = self_._no_dependency(
_remove_dict_keys_with_value(self_._parameters, self_))
# pylint: enable=protected-access
attrs['__init__'] = wrapped_init(default_init) # pylint: disable=no-value-for-parameter,assignment-from-no-return
return super(_DistributionMeta, mcs).__new__(
mcs, classname, baseclasses, attrs)
@six.add_metaclass(_DistributionMeta)
class Distribution(_BaseDistribution):
"""A generic probability distribution base class.
`Distribution` is a base class for constructing and organizing properties
(e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).
#### Subclassing
Subclasses are expected to implement a leading-underscore version of the
same-named function. The argument signature should be identical except for
the omission of `name='...'`. For example, to enable `log_prob(value,
name='log_prob')` a subclass should implement `_log_prob(value)`.
Subclasses can append to public-level docstrings by providing
docstrings for their method specializations. For example:
```python
@distribution_util.AppendDocstring('Some other details.')
def _log_prob(self, value):
...
```
would add the string "Some other details." to the `log_prob` function
docstring. This is implemented as a simple decorator to avoid python
linter complaining about missing Args/Returns/Raises sections in the
partial docstrings.
TFP methods generally assume that Distribution subclasses implement at least
the following methods:
- `_sample_n`.
- `_log_prob` or `_prob`.
- `_event_shape` and `_event_shape_tensor`.
- `_parameter_properties` OR `_batch_shape` and `_batch_shape_tensor`.
Batch shape methods can be automatically derived from `parameter_properties`
in most cases, so it's usually not necessary to implement them directly.
Exceptions include Distributions that accept non-Tensor parameters (for
example, a distribution parameterized by a callable), or that have nonstandard
batch semantics (for example, `BatchReshape`).
Some functionality may depend on implementing additional methods. It is common
for Distribution subclasses to implement:
- Relevant statistics, such as `_mean`, `_mode`, `_variance` and/or `_stddev`.
- At least one of `_log_cdf`, `_cdf`, `_survival_function`, or
`_log_survival_function`.
- `_quantile`.
- `_entropy`.
- `_default_event_space_bijector`.
- `_parameter_properties` (to support automatic batch shape derivation,
batch slicing and other features).
- `_sample_and_log_prob`,
- `_maximum_likelihood_parameters`.
Note that subclasses of existing Distributions that redefine `__init__` do
*not* automatically inherit
`_parameter_properties` annotations from their parent: the subclass must
explicitly implement its own `_parameter_properties` method to support the
features, such as batch slicing, that this enables.
#### Broadcasting, batching, and shapes
All distributions support batches of independent distributions of that type.
The batch shape is determined by broadcasting together the parameters.
The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and
`log_prob` reflect this broadcasting, as does the return value of `sample`.
`sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is
the shape of the `Tensor` returned from `sample(n)`, `n` is the number of
samples, `batch_shape` defines how many independent distributions there are,
and `event_shape` defines the shape of samples from each of those independent
distributions. Samples are independent along the `batch_shape` dimensions, but
not necessarily so along the `event_shape` dimensions (depending on the
particulars of the underlying distribution).
Using the `Uniform` distribution as an example:
```python
minval = 3.0
maxval = [[4.0, 6.0],
[10.0, 12.0]]
# Broadcasting:
# This instance represents 4 Uniform distributions. Each has a lower bound at
# 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
u = Uniform(minval, maxval)
# `event_shape` is `TensorShape([])`.
event_shape = u.event_shape
# `event_shape_t` is a `Tensor` which will evaluate to [].
event_shape_t = u.event_shape_tensor()
# Sampling returns a sample per distribution. `samples` has shape
# [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5,
# batch_shape=[2, 2], and event_shape=[].
samples = u.sample(5)
# The broadcasting holds across methods. Here we use `cdf` as an example. The
# same holds for `log_cdf` and the likelihood functions.
# `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the
# shape of the `Uniform` instance.
cum_prob_broadcast = u.cdf(4.0)
# `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting
# occurred.
cum_prob_per_dist = u.cdf([[4.0, 5.0],
[6.0, 7.0]])
# INVALID as the `value` argument is not broadcastable to the distribution's
# shape.
cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
```
#### Shapes
There are three important concepts associated with TensorFlow Distributions
shapes:
- Event shape describes the shape of a single draw from the distribution;
it may be dependent across dimensions. For scalar distributions, the event
shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is
`[5]`.
- Batch shape describes independent, not identically distributed draws, aka a
"collection" or "bunch" of distributions.
- Sample shape describes independent, identically distributed draws of batches
from the distribution family.
The event shape and the batch shape are properties of a Distribution object,
whereas the sample shape is associated with a specific call to `sample` or
`log_prob`.
For detailed usage examples of TensorFlow Distributions shapes, see
[this tutorial](
https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)
#### Parameter values leading to undefined statistics or distributions.
Some distributions do not have well-defined statistics for all initialization
parameter values. For example, the beta distribution is parameterized by
positive real numbers `concentration1` and `concentration0`, and does not have
well-defined mode if `concentration1 < 1` or `concentration0 < 1`.
The user is given the option of raising an exception or returning `NaN`.
```python
a = tf.exp(tf.matmul(logits, weights_a))
b = tf.exp(tf.matmul(logits, weights_b))
# Will raise exception if ANY batch member has a < 1 or b < 1.
dist = distributions.beta(a, b, allow_nan_stats=False)
mode = dist.mode()
# Will return NaN for batch members with either a < 1 or b < 1.
dist = distributions.beta(a, b, allow_nan_stats=True) # Default behavior
mode = dist.mode()
```
In all cases, an exception is raised if *invalid* parameters are passed, e.g.
```python
# Will raise an exception if any Op is run.
negative_a = -1.0 * a # beta distribution by definition has a > 0.
dist = distributions.beta(negative_a, b, allow_nan_stats=True)
dist.mean()
```
"""
def __init__(self,
dtype,
reparameterization_type,
validate_args,
allow_nan_stats,
parameters=None,
graph_parents=None,
name=None):
"""Constructs the `Distribution`.
**This is a private method for subclass use.**
Args:
dtype: The type of the event samples. `None` implies no type-enforcement.
reparameterization_type: Instance of `ReparameterizationType`.
If `tfd.FULLY_REPARAMETERIZED`, then samples from the distribution are
fully reparameterized, and straight-through gradients are supported.
If `tfd.NOT_REPARAMETERIZED`, then samples from the distribution are not
fully reparameterized, and straight-through gradients are either
partially unsupported or are not supported at all.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
parameters: Python `dict` of parameters used to instantiate this
`Distribution`.
graph_parents: Python `list` of graph prerequisites of this
`Distribution`.
name: Python `str` name prefixed to Ops created by this class. Default:
subclass name.
Raises:
ValueError: if any member of graph_parents is `None` or not a `Tensor`.
"""
if not name:
name = type(self).__name__
name = name_util.camel_to_lower_snake(name)
constructor_name_scope = name_util.get_name_scope_name(name)
# Extract the (locally unique) name from the scope.
name = (constructor_name_scope.split('/')[-2]
if '/' in constructor_name_scope
else name)
name = name_util.strip_invalid_chars(name)
super(Distribution, self).__init__(name=name)
self._constructor_name_scope = constructor_name_scope
self._name = name
graph_parents = [] if graph_parents is None else graph_parents
for i, t in enumerate(graph_parents):
if t is None or not tf.is_tensor(t):
raise ValueError('Graph parent item %d is not a Tensor; %s.' % (i, t))
self._dtype = self._no_dependency(dtype)
self._reparameterization_type = reparameterization_type
self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args
self._parameters = self._no_dependency(parameters)
self._parameters_sanitized = False
self._graph_parents = graph_parents
self._defer_all_assertions = (
auto_composite_tensor.is_deferred_assertion_context())
if not self._defer_all_assertions:
self._initial_parameter_control_dependencies = tuple(
d for d in self._parameter_control_dependencies(is_init=True)
if d is not None)
else:
self._initial_parameter_control_dependencies = ()
if self._initial_parameter_control_dependencies:
self._initial_parameter_control_dependencies = (
tf.group(*self._initial_parameter_control_dependencies),)
@property
def _composite_tensor_params(self):
"""A tuple describing which parameters are expected to be tensors.
CompositeTensor requires us to partition dynamic (tensor) parts from static
(metadata) parts like 'validate_args'. This collects the keys of parameters
which are expected to be tensors.
"""
return (self._composite_tensor_nonshape_params +
self._composite_tensor_shape_params)
@property
def _composite_tensor_nonshape_params(self):
"""A tuple describing which parameters are non-shape-related tensors.
Flattening in JAX involves many of the same considerations with regards to
identifying tensor arguments for the purposes of CompositeTensor, except
that shape-related items will be considered metadata. This property
identifies the keys of parameters that are expected to be tensors, except
those that are shape-related.
"""
return tuple(k for k, v in self.parameter_properties().items()
if not v.specifies_shape)
@property
def _composite_tensor_shape_params(self):
"""A tuple describing which parameters are shape-related tensors.
Flattening in JAX involves many of the same considerations with regards to
identifying tensor arguments for the purposes of CompositeTensor, except
that shape-related items will be considered metadata. This property
identifies the keys of parameters that are expected to be shape-related
tensors, so that they can be collected appropriately in CompositeTensor but
not in JAX applications.
"""
return tuple(k for k, v in self.parameter_properties().items()
if v.specifies_shape)
@classmethod
def _parameter_properties(cls, dtype, num_classes=None):
raise NotImplementedError(
'_parameter_properties` is not implemented: {}.'.format(cls.__name__))
@classmethod
def parameter_properties(cls, dtype=tf.float32, num_classes=None):
"""Returns a dict mapping constructor arg names to property annotations.
This dict should include an entry for each of the distribution's
`Tensor`-valued constructor arguments.
Distribution subclasses are not required to implement
`_parameter_properties`, so this method may raise `NotImplementedError`.
Providing a `_parameter_properties` implementation enables several advanced
features, including:
- Distribution batch slicing (`sliced_distribution = distribution[i:j]`).
- Automatic inference of `_batch_shape` and
`_batch_shape_tensor`, which must otherwise be computed explicitly.
- Automatic instantiation of the distribution within TFP's internal
property tests.
- Automatic construction of 'trainable' instances of the distribution
using appropriate bijectors to avoid violating parameter constraints.
This enables the distribution family to be used easily as a
surrogate posterior in variational inference.
In the future, parameter property annotations may enable additional
functionality; for example, returning Distribution instances from
`tf.vectorized_map`.
Args:
dtype: Optional float `dtype` to assume for continuous-valued parameters.
Some constraining bijectors require advance knowledge of the dtype
because certain constants (e.g., `tfb.Softplus.low`) must be
instantiated with the same dtype as the values to be transformed.
num_classes: Optional `int` `Tensor` number of classes to assume when
inferring the shape of parameters for categorical-like distributions.
Otherwise ignored.
Returns:
parameter_properties: A
`str -> `tfp.python.internal.parameter_properties.ParameterProperties`
dict mapping constructor argument names to `ParameterProperties`
instances.
Raises:
NotImplementedError: if the distribution class does not implement
`_parameter_properties`.
"""
with tf.name_scope('parameter_properties'):
return cls._parameter_properties(dtype, num_classes=num_classes)
@classmethod
@deprecation.deprecated('2021-03-01',
'The `param_shapes` method of `tfd.Distribution` is '
'deprecated; use `parameter_properties` instead.')
def param_shapes(cls, sample_shape, name='DistributionParamShapes'):
"""Shapes of parameters given the desired shape of a call to `sample()`.
This is a class method that describes what key/value arguments are required
to instantiate the given `Distribution` so that a particular shape is
returned for that instance's call to `sample()`.
Subclasses should override class method `_param_shapes`.
Args:
sample_shape: `Tensor` or python list/tuple. Desired shape of a call to
`sample()`.
name: name to prepend ops with.
Returns:
`dict` of parameter name to `Tensor` shapes.
"""
with tf.name_scope(name):
param_shapes = {}
for (param_name, param) in cls.parameter_properties().items():
param_shapes[param_name] = tf.convert_to_tensor(
param.shape_fn(sample_shape), dtype=tf.int32)
return param_shapes
@classmethod
@deprecation.deprecated(
'2021-03-01', 'The `param_static_shapes` method of `tfd.Distribution` is '
'deprecated; use `parameter_properties` instead.')
def param_static_shapes(cls, sample_shape):
"""param_shapes with static (i.e. `TensorShape`) shapes.
This is a class method that describes what key/value arguments are required
to instantiate the given `Distribution` so that a particular shape is
returned for that instance's call to `sample()`. Assumes that the sample's
shape is known statically.
Subclasses should override class method `_param_shapes` to return
constant-valued tensors when constant values are fed.
Args:
sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
to `sample()`.
Returns:
`dict` of parameter name to `TensorShape`.
Raises:
ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
"""
if isinstance(sample_shape, tf.TensorShape):
if not tensorshape_util.is_fully_defined(sample_shape):
raise ValueError('TensorShape sample_shape must be fully defined')
sample_shape = tensorshape_util.as_list(sample_shape)
params = cls.param_shapes(sample_shape)
static_params = {}
for name, shape in params.items():
static_shape = tf.get_static_value(shape)
if static_shape is None:
raise ValueError(
'sample_shape must be a fully-defined TensorShape or list/tuple')
static_params[name] = tf.TensorShape(static_shape)
return static_params
@property
def name(self):
"""Name prepended to all ops created by this `Distribution`."""
return self._name if hasattr(self, '_name') else None
@property
def dtype(self):
"""The `DType` of `Tensor`s handled by this `Distribution`."""
return self._dtype if hasattr(self, '_dtype') else None
@property
def parameters(self):
"""Dictionary of parameters used to instantiate this `Distribution`."""
# Remove 'self', '__class__', or other special variables. These can appear
# if the subclass used: `parameters = dict(locals())`.
if (not hasattr(self, '_parameters_sanitized') or
not self._parameters_sanitized):
p = self._parameters() if callable(self._parameters) else self._parameters
self._parameters = self._no_dependency({
k: v for k, v in p.items()
if not k.startswith('__') and v is not self})
self._parameters_sanitized = True
# In some situations, the Distribution metaclass logic defers the evaluation
# of parameters, but at this point we actually want to evaluate the
# parameters.
return dict(
self._parameters() if callable(self._parameters) else self._parameters)
def _params_event_ndims(self):
"""Returns a dict mapping constructor argument names to per-event rank.
The ranks are pulled from `cls.parameter_properties()`; this is a
convenience wrapper.
Returns:
params_event_ndims: Per-event parameter ranks, a `str->int dict`.
"""
try:
properties = type(self).parameter_properties()
except NotImplementedError:
raise NotImplementedError(
'{} does not support batch slicing; must implement '
'_parameter_properties.'.format(type(self)))
params_event_ndims = {}
from tensorflow_probability.python.internal import parameter_properties # pylint: disable=g-import-not-at-top
for (k, param) in properties.items():
ndims = param.instance_event_ndims(self)
if param.is_tensor and (
ndims is not parameter_properties.NO_EVENT_NDIMS and
ndims is not None):
params_event_ndims[k] = ndims
return params_event_ndims
def __getitem__(self, slices):
"""Slices the batch axes of this distribution, returning a new instance.
```python
b = tfd.Bernoulli(logits=tf.zeros([3, 5, 7, 9]))
b.batch_shape # => [3, 5, 7, 9]
b2 = b[:, tf.newaxis, ..., -2:, 1::2]
b2.batch_shape # => [3, 1, 5, 2, 4]
x = tf.random.normal([5, 3, 2, 2])
cov = tf.matmul(x, x, transpose_b=True)
chol = tf.linalg.cholesky(cov)
loc = tf.random.normal([4, 1, 3, 1])
mvn = tfd.MultivariateNormalTriL(loc, chol)
mvn.batch_shape # => [4, 5, 3]
mvn.event_shape # => [2]
mvn2 = mvn[:, 3:, ..., ::-1, tf.newaxis]
mvn2.batch_shape # => [4, 2, 3, 1]
mvn2.event_shape # => [2]
```
Args:
slices: slices from the [] operator
Returns:
dist: A new `tfd.Distribution` instance with sliced parameters.
"""
return slicing.batch_slice(self, {}, slices)
def __iter__(self):
raise TypeError('{!r} object is not iterable'.format(type(self).__name__))
@property
def reparameterization_type(self):
"""Describes how samples from the distribution are reparameterized.
Currently this is one of the static instances
`tfd.FULLY_REPARAMETERIZED` or `tfd.NOT_REPARAMETERIZED`.
Returns:
An instance of `ReparameterizationType`.
"""
return self._reparameterization_type
@property
def allow_nan_stats(self):
"""Python `bool` describing behavior when a stat is undefined.
Stats return +/- infinity when it makes sense. E.g., the variance of a
Cauchy distribution is infinity. However, sometimes the statistic is
undefined, e.g., if a distribution's pdf does not achieve a maximum within
the support of the distribution, the mode is undefined. If the mean is
undefined, then by definition the variance is undefined. E.g. the mean for
Student's T for df = 1 is undefined (no clear way to say it is either + or -
infinity), so the variance = E[(X - mean)**2] is also undefined.
Returns:
allow_nan_stats: Python `bool`.
"""
return self._allow_nan_stats
@property
def validate_args(self):
"""Python `bool` indicating possibly expensive checks are enabled."""
return self._validate_args
@property
def experimental_shard_axis_names(self):
"""The list or structure of lists of active shard axis names."""
return []
def copy(self, **override_parameters_kwargs):
"""Creates a deep copy of the distribution.
Note: the copy distribution may continue to depend on the original
initialization arguments.
Args:
**override_parameters_kwargs: String/value dictionary of initialization
arguments to override with new values.
Returns:
distribution: A new instance of `type(self)` initialized from the union
of self.parameters and override_parameters_kwargs, i.e.,
`dict(self.parameters, **override_parameters_kwargs)`.
"""
try:
# We want track provenance from origin variables, so we use batch_slice
# if this distribution supports slicing. See the comment on
# PROVENANCE_ATTR in batch_slicing.py
return slicing.batch_slice(self, override_parameters_kwargs, Ellipsis)
except NotImplementedError:
pass
parameters = dict(self.parameters, **override_parameters_kwargs)
d = type(self)(**parameters)
# pylint: disable=protected-access
d._parameters = self._no_dependency(parameters)
d._parameters_sanitized = True
# pylint: enable=protected-access
return d
def _broadcast_parameters_with_batch_shape(self, batch_shape):
"""Broadcasts each parameter's batch shape with the given `batch_shape`.
This is semantically equivalent to wrapping with the `BatchBroadcast`
distribution, but returns a distribution of the same type as the original
in which all parameter Tensors are reified at the the broadcast batch shape.
It can be understood as a pseudo-inverse operation to batch slicing:
```python
dist = tfd.Normal(0., 1.)
# ==> `dist.batch_shape == []`
broadcast_dist = dist._broadcast_parameters_with_batch_shape([3])
# ==> `broadcast_dist.batch_shape == [3]`
# `broadcast_dist.loc.shape == [3]`
# `broadcast_dist.scale.shape == [3]`
sliced_dist = broadcast_dist[0]
# ==> `sliced_dist.batch_shape == []`.
```
Args:
batch_shape: Integer `Tensor` batch shape.
Returns:
broadcast_dist: copy of this distribution in which each parameter's
batch shape is determined by broadcasting its current batch shape with
the given `batch_shape`.
"""
return self.copy(
**batch_shape_lib.broadcast_parameters_with_batch_shape(
self, batch_shape))
def _batch_shape_tensor(self, **parameter_kwargs):
"""Infers batch shape from parameters.
The overall batch shape is inferred by broadcasting the batch shapes of
all parameters,
```python
parameter_batch_shapes = []
for name, properties in self.parameter_properties.items():
parameter = self.parameters[name]
parameter_batch_shapes.append(
base_shape(parameter)[:-properties.instance_event_ndims(parameter)])
```
where a parameter's `base_shape` is its batch shape if it
defines one (e.g., if it is a Distribution, LinearOperator, etc.), and its
Tensor shape otherwise. Parameters with structured batch shape
(in particular, non-autobatched JointDistributions) are not currently
supported.
Args:
**parameter_kwargs: Optional keyword arguments overriding the parameter
values in `self.parameters`. Typically this is used to avoid multiple
Tensor conversions of the same value.
Returns:
batch_shape_tensor: `Tensor` broadcast batch shape of all parameters.
"""
try:
return batch_shape_lib.inferred_batch_shape_tensor(
self, **parameter_kwargs)
except NotImplementedError:
raise NotImplementedError('Cannot compute batch shape of distribution '
'{}: you must implement at least one of '
'`_batch_shape_tensor` or '
'`_parameter_properties`.'.format(self))
def batch_shape_tensor(self, name='batch_shape_tensor'):
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
The batch dimensions are indexes into independent, non-identical
parameterizations of this distribution.
Args:
name: name to give to the op