-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathspecial.py
2659 lines (2157 loc) · 94.6 KB
/
special.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implements special functions in TensorFlow."""
import functools
# Dependency imports
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.math import generic
__all__ = [
'atan_difference',
'betainc',
'betaincinv',
'dawsn',
'erfcinv',
'erfcx',
'igammainv',
'igammacinv',
'round_exponential_bump_function',
'lambertw',
'lambertw_winitzki_approx',
'logerfc',
'logerfcx',
'log_gamma_correction',
'log_gamma_difference',
'lbeta',
'owens_t',
]
NUMPY_MODE = False
def atan_difference(x, y, name=None):
"""Difference of arctan(x) and arctan(y).
Computes arctan(x) - arctan(y) avoiding catastrophic cancellation. This is
by resorting to the identity:
```none
arctan(x) - arctan(y) = arctan((x - y) / (1 + x * y)) +
pi * sign(x) * 1_{x * y < -1)
```
where `1_A` is the indicator function on the set `A`.
For a derivation of this fact, see [1].
#### References
[1] De Stefano, Sum of Arctangents
https://sites.google.com/site/micdestefano/mathematics/trigonometry/sum-of-arctangents
Args:
x: Floating-point Tensor. Should be broadcastable with `y`.
y: Floating-point Tensor. Should be broadcastable with `x`.
name: Optional Python `str` naming the operation.
Returns:
z: Tensor of same shape and dtype as `x` and `y`.
"""
with tf.name_scope(name or 'atan_difference'):
dtype = dtype_util.common_dtype([x, y], tf.float32)
x = tf.convert_to_tensor(x, dtype=dtype)
y = tf.convert_to_tensor(y, dtype=dtype)
difference = tf.math.atan((x - y) / (1 + x * y))
difference = difference + tf.where(
x * y < - 1., np.pi * tf.math.sign(x), 0.)
difference = tf.where(
tf.math.equal(x * y, -1.), np.pi * tf.math.sign(x) / 2., difference)
return difference
# 16-bit (half precision) floating-point dtypes available on current backend.
_f16bit_dtypes = [tf.float16] if NUMPY_MODE else [tf.bfloat16, tf.float16]
def _betainc_naive(a, b, x):
"""Returns the regularized incomplete beta function element-wise."""
dtype_orig = dtype_util.common_dtype([a, b, x], tf.float32)
# We promote bfloat16 and float16 to float32 to make this function consistent
# with the XLA implementation of betainc.
should_promote_dtype = (dtype_orig in _f16bit_dtypes)
dtype = tf.float32 if should_promote_dtype else dtype_orig
a, b, x = [tf.convert_to_tensor(z, dtype=dtype_orig) for z in [a, b, x]]
if should_promote_dtype:
a, b, x = [tf.cast(z, dtype) for z in [a, b, x]]
broadcast_shape = functools.reduce(
ps.broadcast_shape, [ps.shape(a), ps.shape(b), ps.shape(x)])
a, b, x = [tf.broadcast_to(z, broadcast_shape) for z in [a, b, x]]
result = tf.math.betainc(a, b, x)
# If we promoted the dtype, then we have to convert the result back to the
# original dtype.
if should_promote_dtype:
result = tf.cast(result, dtype_orig)
return result
def _betainc_even_partial_numerator(iteration, a, b, x, dtype):
"""Even partial numerator used in the continued fraction for betainc."""
# This function computes the partial numerator d_{2m} that is specified
# here: https://dlmf.nist.gov/8.17.E23
one = tf.constant(1., dtype=dtype)
two = tf.constant(2., dtype=dtype)
m = iteration
a_plus_2m = a + two * m
a_plus_2m_minus_one = a_plus_2m - one
denominator = a_plus_2m * a_plus_2m_minus_one
db = m * x / denominator
value = db * (b - m)
da = -value * (a_plus_2m + a_plus_2m_minus_one) / denominator
gradient = tf.concat([da, db], axis=-1)
return value, gradient
def _betainc_odd_partial_numerator(iteration, a, b, x, dtype):
"""Odd partial numerator used in the continued fraction for betainc."""
# This function computes the partial numerator d_{2m + 1} that is specified
# here: https://dlmf.nist.gov/8.17.E23
one = tf.constant(1., dtype=dtype)
two = tf.constant(2., dtype=dtype)
m = iteration
a_plus_m = a + m
a_plus_2m = a_plus_m + m
a_plus_2m_plus_one = a_plus_2m + one
a_plus_b_plus_m = a_plus_m + b
denominator = a_plus_2m * a_plus_2m_plus_one
db = -a_plus_m * x / denominator
value = db * a_plus_b_plus_m
da = -value * ((a_plus_2m + a_plus_2m_plus_one) / denominator) - x * (
two * a_plus_m + b) / denominator
gradient = tf.concat([da, db], axis=-1)
return value, gradient
def _betainc_modified_lentz_method(a, b, x, dtype, use_continued_fraction):
"""Returns the continued fraction for betainc by modified Lentz's method."""
# This function implements the method described in the appendix of [1] for
# evaluating continued fractions.
# [1] Thompson, Ian J., and A. Ross Barnett.
# Coulomb and Bessel functions of complex arguments and order.
# Journal of Computational Physics 64.2 (1986): 490-509.
# https://www.fresco.org.uk/papers/Thompson-JCP64p490.pdf
numpy_dtype = dtype_util.as_numpy_dtype(dtype)
one = tf.constant(1., dtype=dtype)
eps = tf.constant(np.finfo(numpy_dtype).eps, dtype=dtype)
tiny = tf.constant(np.finfo(numpy_dtype).tiny, dtype=dtype)
# max_iterations and tolerance were taken from Cephes.
if numpy_dtype == np.float32:
max_iterations = 100
tolerance = eps
else:
max_iterations = 300
tolerance = tf.constant(3., dtype=dtype) * eps
small = tf.sqrt(tiny)
def continued_fraction_step(
iteration,
values,
gradients,
partial_numerator_fn):
ratio_numerators, ratio_denominators, convergent = values
dratio_numerators, dratio_denominators, dconvergent = gradients
partial_numerator, dpartial_numerator = partial_numerator_fn(
iteration, a, b, x, dtype)
# new_ratio_numerators = C_n = A_n / A_{n - 1}
new_ratio_numerators = one + partial_numerator / ratio_numerators
new_ratio_numerators = tf.where(
tf.abs(new_ratio_numerators) < small, small, new_ratio_numerators)
# new_ratio_denominators = D_n = B_{n - 1} / B_n
new_ratio_denominators = one + partial_numerator * ratio_denominators
new_ratio_denominators = tf.where(
tf.abs(new_ratio_denominators) < small, small, new_ratio_denominators)
new_ratio_denominators = tf.math.reciprocal(new_ratio_denominators)
# new_convergent = h_n = A_n / B_n = h_{n - 1} * C_n * D_n
delta = new_ratio_numerators * new_ratio_denominators
new_convergent = convergent * delta
new_dratio_numerators = (dpartial_numerator * ratio_numerators -
partial_numerator * dratio_numerators)
new_dratio_numerators = new_dratio_numerators / tf.math.square(
ratio_numerators)
new_dratio_denominators = (dpartial_numerator * ratio_denominators +
partial_numerator * dratio_denominators)
new_dratio_denominators = -new_dratio_denominators * tf.math.square(
new_ratio_denominators)
new_dconvergent = dconvergent * delta + (
convergent * new_dratio_numerators * new_ratio_denominators)
new_dconvergent = new_dconvergent + (
convergent * new_dratio_denominators * new_ratio_numerators)
new_values = (new_ratio_numerators, new_ratio_denominators, new_convergent)
new_gradients = (
new_dratio_numerators, new_dratio_denominators, new_dconvergent)
return new_values, new_gradients, delta
def continued_fraction_evaluation(should_stop, iteration, values, gradients):
# We run two steps of modified Lentz's method per iteration.
# First step of the iteration: the even one.
new_values, new_gradients, _ = continued_fraction_step(
iteration, values, gradients, _betainc_even_partial_numerator)
# Second step of the iteration: the odd one.
new_values, new_gradients, delta = continued_fraction_step(
iteration, new_values, new_gradients, _betainc_odd_partial_numerator)
should_stop = should_stop | (tf.math.abs(delta - one) < tolerance)
return should_stop, iteration + one, new_values, new_gradients
# Assume all input Tensors have the same shape. The extra dimension is
# needed to compute the gradients with respect to a and b.
a, b, x, use_continued_fraction = [
z[..., tf.newaxis] for z in [a, b, x, use_continued_fraction]]
apb = a + b
ap1 = a + one
# Initialization and first step of modified Lentz's method.
initial_ratio_numerators = tf.ones_like(x)
initial_ratio_denominators = one - apb * x / ap1
initial_ratio_denominators = tf.where(
tf.abs(initial_ratio_denominators) < small,
small,
initial_ratio_denominators)
initial_ratio_denominators = tf.math.reciprocal(initial_ratio_denominators)
initial_convergent = initial_ratio_denominators
initial_values = (
initial_ratio_numerators, initial_ratio_denominators, initial_convergent)
initial_dratio_denominators = (tf.concat([one - b, ap1], axis=-1) * x /
tf.math.square(x * apb - ap1))
initial_dratio_numerators = tf.zeros_like(initial_dratio_denominators)
initial_dconvergent = initial_dratio_denominators
initial_gradients = (
initial_dratio_numerators,
initial_dratio_denominators,
initial_dconvergent)
(_, _, values, gradients) = tf.while_loop(
cond=lambda stop, *_: tf.reduce_any(~stop),
body=continued_fraction_evaluation,
loop_vars=(
~use_continued_fraction,
tf.constant(1., dtype=dtype),
initial_values,
initial_gradients),
maximum_iterations=max_iterations)
# Remove the previously added extra dimension: it is no longer needed.
convergent = tf.squeeze(values[-1], axis=-1)
convergent_grad_a, convergent_grad_b = tf.unstack(gradients[-1], axis=-1)
return convergent, convergent_grad_a, convergent_grad_b
def _betainc_der_continued_fraction(a, b, x, dtype, use_continued_fraction):
"""Returns the partial derivatives of betainc with respect to a and b."""
# This function evaluates betainc(a, b, x) by its continued fraction
# expansion given here: https://dlmf.nist.gov/8.17.E22
# We apply this function when the input (a, b, x) does not belong to the
# proper region of computation of `_betainc_der_power_series`.
one = tf.constant(1., dtype=dtype)
two = tf.constant(2., dtype=dtype)
# This continued fraction expansion of betainc converges rapidly
# for x < (a - 1) / (a + b - 2). For x >= (a - 1) / (a + b - 2),
# we can obtain an equivalent computation by using the symmetry
# relation given here: https://dlmf.nist.gov/8.17.E4
# betainc(a, b, x) = 1 - betainc(b, a, 1 - x)
use_symmetry_relation = (x >= (a - one) / (a + b - two))
a_orig = a
a = tf.where(use_symmetry_relation, b, a)
b = tf.where(use_symmetry_relation, a_orig, b)
x = tf.where(use_symmetry_relation, one - x, x)
cf, cf_grad_a, cf_grad_b = _betainc_modified_lentz_method(
a, b, x, dtype, use_continued_fraction)
normalization = tf.math.exp(
tf.math.xlogy(a, x) + tf.math.xlog1py(b, -x) -
tf.math.log(a) - lbeta(a, b))
digamma_apb = tf.math.digamma(a + b)
grad_a = normalization * (
cf_grad_a + cf * (
tf.math.log(x) - tf.math.reciprocal(a) +
digamma_apb - tf.math.digamma(a)))
grad_b = normalization * (
cf_grad_b + cf * (
tf.math.log1p(-x) + digamma_apb -
tf.math.digamma(b)))
# If we are taking advantage of the symmetry relation, then we have to
# adjust grad_a and grad_b.
grad_a_orig = grad_a
grad_a = tf.where(use_symmetry_relation, -grad_b, grad_a)
grad_b = tf.where(use_symmetry_relation, -grad_a_orig, grad_b)
return grad_a, grad_b
def _betainc_der_power_series(a, b, x, dtype, use_power_series):
"""Returns the partial derivatives of betainc with respect to a and b."""
# This function evaluates betainc(a, b, x) by its series representation:
# x ** a * 2F1(a, 1 - b; a + 1; x) / (a * B(a, b)) ,
# where 2F1 is the Gaussian hypergeometric function.
# We apply this function when the input (a, b, x) satisfies at least one
# of the following conditions:
# C1: (x < a / (a + b)) & (b * x <= 1) & (x <= 0.95)
# C2: (x >= a / (a + b)) & (a * (1 - x) <= 1) & (x >= 0.05)
numpy_dtype = dtype_util.as_numpy_dtype(dtype)
eps = tf.constant(np.finfo(numpy_dtype).eps, dtype=dtype)
half = tf.constant(0.5, dtype=dtype)
one = tf.constant(1., dtype=dtype)
# Avoid returning NaN or infinity when the input does not satisfy either
# C1 or C2.
safe_a = tf.where(use_power_series, a, half)
safe_b = tf.where(use_power_series, b, half)
safe_x = tf.where(use_power_series, x, half)
# When x >= a / (a + b), we must apply the symmetry relation given here:
# https://dlmf.nist.gov/8.17.E4
# betainc(a, b, x) = 1 - betainc(b, a, 1 - x)
use_symmetry_relation = (safe_x >= safe_a / (safe_a + safe_b))
safe_a_orig = safe_a
safe_a = tf.where(use_symmetry_relation, safe_b, safe_a)
safe_b = tf.where(use_symmetry_relation, safe_a_orig, safe_b)
safe_x = tf.where(use_symmetry_relation, one - safe_x, safe_x)
# max_iterations was set by experimentation and tolerance was taken from
# Cephes.
max_iterations = 300 if numpy_dtype == np.float32 else 600
tolerance = eps / safe_a
# Evaluate the series that defines the following expression:
# 2F1(a, 1 - b; a + 1; x) / a
def power_series_evaluation(should_stop, values, gradients):
n, product, series_sum = values
product_grad_b, da, db = gradients
x_div_n = safe_x / n
factor = (n - safe_b) * x_div_n
apn = safe_a + n
new_product = product * factor
term = new_product / apn
new_product_grad_b = factor * product_grad_b - product * x_div_n
new_da = da - new_product / tf.math.square(apn)
new_db = db + new_product_grad_b / apn
values = n + one, new_product, series_sum + term
gradients = new_product_grad_b, new_da, new_db
return should_stop | (tf.math.abs(term) <= tolerance), values, gradients
initial_n = one
initial_product = tf.ones_like(safe_a)
initial_series_sum = one / safe_a
initial_values = (initial_n, initial_product, initial_series_sum)
initial_product_grad_b = tf.zeros_like(safe_b)
initial_da = -tf.math.reciprocal(tf.math.square(safe_a))
initial_db = initial_product_grad_b
initial_gradients = (initial_product_grad_b, initial_da, initial_db)
(_, values, gradients) = tf.while_loop(
cond=lambda stop, *_: tf.reduce_any(~stop),
body=power_series_evaluation,
loop_vars=(
~use_power_series,
initial_values,
initial_gradients),
maximum_iterations=max_iterations)
_, _, series_sum = values
_, series_grad_a, series_grad_b = gradients
normalization = tf.math.exp(
tf.math.xlogy(safe_a, safe_x) - lbeta(safe_a, safe_b))
digamma_apb = tf.math.digamma(safe_a + safe_b)
grad_a = normalization * (series_grad_a + series_sum * (
digamma_apb - tf.math.digamma(safe_a) + tf.math.log(safe_x)))
grad_b = normalization * (series_grad_b + series_sum * (
digamma_apb - tf.math.digamma(safe_b)))
# If we are taking advantage of the symmetry relation, then we have to
# adjust grad_a and grad_b.
grad_a_orig = grad_a
grad_a = tf.where(use_symmetry_relation, -grad_b, grad_a)
grad_b = tf.where(use_symmetry_relation, -grad_a_orig, grad_b)
return grad_a, grad_b
def _betainc_partials(a, b, x):
"""Returns the partial derivatives of `betainc(a, b, x)`."""
dtype_orig = dtype_util.common_dtype([a, b, x], tf.float32)
# We promote bfloat16 and float16 to float32 to make this function consistent
# with betainc.
should_promote_dtype = (dtype_orig in _f16bit_dtypes)
dtype = tf.float32 if should_promote_dtype else dtype_orig
numpy_dtype = dtype_util.as_numpy_dtype(dtype)
zero = tf.constant(0., dtype=dtype)
one = tf.constant(1., dtype=dtype)
a, b, x = [tf.convert_to_tensor(z, dtype=dtype_orig) for z in [a, b, x]]
if should_promote_dtype:
a, b, x = [tf.cast(z, dtype) for z in [a, b, x]]
broadcast_shape = functools.reduce(
ps.broadcast_shape, [ps.shape(a), ps.shape(b), ps.shape(x)])
a, b, x = [tf.broadcast_to(z, broadcast_shape) for z in [a, b, x]]
# The partial derivative of betainc with respect to x can be obtained
# directly by using the expression given here:
# http://functions.wolfram.com/06.21.20.0001.01
grad_x = tf.math.exp(
tf.math.xlogy(a - one, x) + tf.math.xlog1py(b - one, -x) - lbeta(a, b))
# The partial derivatives of betainc with respect to a and b are computed
# by using forward mode.
use_power_series = (
((x < a / (a + b)) & (b * x <= one) & (x <= 0.95)) | (
(x >= a / (a + b)) & (a * (one - x) <= one) & (x >= 0.05)))
ps_grad_a, ps_grad_b = _betainc_der_power_series(
a, b, x, dtype, use_power_series)
cf_grad_a, cf_grad_b = _betainc_der_continued_fraction(
a, b, x, dtype, ~use_power_series)
grad_a = tf.where(use_power_series, ps_grad_a, cf_grad_a)
grad_b = tf.where(use_power_series, ps_grad_b, cf_grad_b)
# According to the code accompanying [1], grad_a = grad_b = 0 when x is
# equal to 0 or 1.
# [1] R. Boik, J. Robinson-Cox,
# Derivatives of the Incomplete Beta Function
# https://www.jstatsoft.org/article/view/v003i01/beta.der.pdf
grads_a_and_b_should_be_zero = tf.math.equal(x, zero) | tf.math.equal(x, one)
grad_a, grad_b = [
tf.where(grads_a_and_b_should_be_zero, zero, grad)
for grad in [grad_a, grad_b]]
# Determine if the inputs are out of range (should return NaN output).
result_is_nan = (a <= zero) | (b <= zero) | (x < zero) | (x > one)
grad_a, grad_b, grad_x = [
tf.where(result_is_nan, numpy_dtype(np.nan), grad)
for grad in [grad_a, grad_b, grad_x]]
# If we promoted the dtype, then we have to convert the gradients back to the
# original dtype.
if should_promote_dtype:
grad_a, grad_b, grad_x = [
tf.cast(grad, dtype_orig) for grad in [grad_a, grad_b, grad_x]]
return grad_a, grad_b, grad_x
def _betainc_fwd(a, b, x):
"""Computes output, aux (collaborates with _betainc_bwd)."""
output = _betainc_naive(a, b, x)
return output, (a, b, x)
def _betainc_bwd(aux, g):
"""Reverse mode impl for betainc."""
a, b, x = aux
pa, pb, px = _betainc_partials(a, b, x)
return generic.fix_gradient_for_broadcasting(
[a, b, x], [pa * g, pb * g, px * g])
def _betainc_jvp(primals, tangents):
"""Computes JVP for betainc (supports JAX custom derivative)."""
a, b, x = primals
da, db, dx = tangents
y = _betainc_custom_gradient(a, b, x)
pa, pb, px = _betainc_partials(a, b, x)
return (y, pa * da + pb * db + px * dx)
@tfp_custom_gradient.custom_gradient(
vjp_fwd=_betainc_fwd,
vjp_bwd=_betainc_bwd,
jvp_fn=_betainc_jvp)
def _betainc_custom_gradient(a, b, x):
"""Computes `betainc(a, b, x)` with correct custom gradient."""
return _betainc_naive(a, b, x)
def betainc(a, b, x, name=None):
"""Computes the regularized incomplete beta function element-wise.
Args:
a: Floating-point Tensor. Must be broadcastable with `b` and `x`.
b: Floating-point Tensor. Must be broadcastable with `a` and `x`.
x: Floating-point Tensor. Must be broadcastable with `a` and `b`.
name: A name for the operation (optional).
Returns:
betainc: Floating-point Tensor, the regularized incomplete beta
function computed element-wise.
"""
with tf.name_scope(name or 'betainc'):
dtype = dtype_util.common_dtype([a, b, x], tf.float32)
a = tf.convert_to_tensor(a, dtype=dtype)
b = tf.convert_to_tensor(b, dtype=dtype)
x = tf.convert_to_tensor(x, dtype=dtype)
return _betainc_custom_gradient(a, b, x)
# The implementation of the inverse of the regularized incomplete beta function
# is based on ideas and equations available in the following references:
# [1] Milton Abramowitz and Irene A. Stegun
# Handbook of Mathematical Functions with Formulas, Graphs, and
# Mathematical Tables
# US Government Printing Office, 1964 (reprinted 1972)
# https://archive.org/details/AandS-mono600
# [2] William Press, Saul Teukolsky, William Vetterling and Brian Flannery
# Numerical Recipes: The Art of Scientific Computing
# Cambridge University Press, 2007 (Third Edition)
# http://numerical.recipes/book/book.html
# [3] John Maddock, Paul A. Bristow, et al.
# The Incomplete Beta Function Inverses
# https://www.boost.org/doc/libs/1_79_0/libs/math/doc/html/special.html
# [4] Stephen L. Moshier
# Cephes Mathematical Library
# https://netlib.org/cephes/
def _betaincinv_initial_approx(a, b, y, dtype):
"""Computes an initial approximation for `betaincinv(a, b, y)`."""
numpy_dtype = dtype_util.as_numpy_dtype(dtype)
tiny = np.finfo(numpy_dtype).tiny
eps = np.finfo(numpy_dtype).eps
one = numpy_dtype(1.)
two = numpy_dtype(2.)
three = numpy_dtype(3.)
five = numpy_dtype(5.)
six = numpy_dtype(6.)
max_log = numpy_dtype((np.finfo(numpy_dtype).maxexp - 1.) * np.log(2.))
# When min(a, b) >= 1, we use the approximation proposed by [1].
# Equation 26.5.22 [1, page 945].
yp = -tf.math.ndtri(y)
inv_2a_minus_one = tf.math.reciprocal(two * a - one)
inv_2b_minus_one = tf.math.reciprocal(two * b - one)
lmb = (tf.math.square(yp) - three) / six
h = two * tf.math.reciprocal(inv_2a_minus_one + inv_2b_minus_one)
w = (yp * tf.math.sqrt(h + lmb) / h -
(inv_2b_minus_one - inv_2a_minus_one) *
(lmb + five / six - two / (three * h)))
result_for_large_a_and_b = a / (a + b * tf.math.exp(two * w))
# When min(a, b) < 1 and max(a, b) >= 1, we use the approximation proposed by
# [2]. This approximation depends on the following approximation for betainc:
# betainc(a, b, x) ~=
# x ** a / (integral_approx * a) , when x <= mean ,
# (1 - x) ** b / (integral_approx * b) , when x > mean ,
# where:
# integral_approx = (mean ** a) / a + (mean_complement ** b) / b ,
# mean = a / (a + b) ,
# mean_complement = 1 - mean = b / (a + b) .
# We invert betainc(a, b, x) with respect to x in the proper regime.
# Equation 6.4.7 [2, page 271].
a_plus_b = a + b
mean = a / a_plus_b
mean_complement = b / a_plus_b
integral_approx_part_a = tf.math.exp(tf.math.xlogy(a, mean) - tf.math.log(a))
integral_approx_part_b = tf.math.exp(tf.math.xlogy(b, mean_complement) -
tf.math.log(b))
integral_approx = integral_approx_part_a + integral_approx_part_b
# Solve Equation 6.4.8 [2, page 271] for x in the respective regimes.
inv_a = tf.math.reciprocal(a)
inv_b = tf.math.reciprocal(b)
result_for_small_a_or_b = tf.where(
y <= (integral_approx_part_a / integral_approx),
tf.math.exp(tf.math.xlogy(inv_a, y) + tf.math.xlogy(inv_a, a) +
tf.math.xlogy(inv_a, integral_approx)),
-tf.math.expm1(tf.math.xlog1py(inv_b, -y) + tf.math.xlogy(inv_b, b) +
tf.math.xlogy(inv_b, integral_approx)))
# And when max(a, b) < 1, we use the approximation proposed by [3] for the
# same domain:
# betaincinv(a, b, y) ~= xg / (1 + xg) ,
# where:
# xg = (a * y * Beta(a, b)) ** (1 / a) .
log_xg = tf.math.xlogy(inv_a, a) + tf.math.xlogy(inv_a, y) + (
inv_a * lbeta(a, b))
xg = tf.math.exp(tf.math.minimum(log_xg, max_log))
result_for_small_a_and_b = xg / (one + xg)
# Return the appropriate result for parameters a and b.
result = tf.where(
tf.math.minimum(a, b) >= one,
result_for_large_a_and_b,
tf.where(
tf.math.maximum(a, b) < one,
result_for_small_a_and_b,
result_for_small_a_or_b))
return tf.clip_by_value(result, tiny, one - eps)
def _betaincinv_computation(a, b, y):
"""Returns the inverse of `betainc(a, b, x)` with respect to `x`."""
dtype_orig = dtype_util.common_dtype([a, b, y], tf.float32)
# We promote bfloat16 and float16 to float32 to make this function consistent
# with betainc.
should_promote_dtype = (dtype_orig in _f16bit_dtypes)
dtype = tf.float32 if should_promote_dtype else dtype_orig
numpy_dtype = dtype_util.as_numpy_dtype(dtype)
zero = numpy_dtype(0.)
tiny = np.finfo(numpy_dtype).tiny
eps = np.finfo(numpy_dtype).eps
half = numpy_dtype(0.5)
one = numpy_dtype(1.)
two = numpy_dtype(2.)
halley_correction_min = numpy_dtype(0.5)
halley_correction_max = numpy_dtype(1.5)
a, b, y = [tf.convert_to_tensor(z, dtype=dtype_orig) for z in [a, b, y]]
if should_promote_dtype:
a, b, y = [tf.cast(z, dtype) for z in [a, b, y]]
broadcast_shape = functools.reduce(
ps.broadcast_shape, [ps.shape(a), ps.shape(b), ps.shape(y)])
a, b, y = [tf.broadcast_to(z, broadcast_shape) for z in [a, b, y]]
# When tfp_math.betainc(a, b, 0.5) < y, we apply the symmetry relation given
# here: https://dlmf.nist.gov/8.17.E4
# betainc(a, b, x) = 1 - betainc(b, a, 1 - x) .
# If dtype is float32, we have additional conditions to apply this relation:
# (a < 1) & (b < 1) & (tfp_math.betainc(a, b, a / (a + b)) < y) .
error_at_half = betainc(a, b, half) - y
if numpy_dtype == np.float32:
a_and_b_are_small = (a < one) & (b < one)
error_at_mean = betainc(a, b, a / (a + b)) - y
use_symmetry_relation = (error_at_half < zero) & a_and_b_are_small & (
error_at_mean < zero)
else:
use_symmetry_relation = (error_at_half < zero)
a_orig, y_orig = (a, y)
a = tf.where(use_symmetry_relation, b, a)
b = tf.where(use_symmetry_relation, a_orig, b)
y = tf.where(use_symmetry_relation, one - y, y)
a_minus_1 = a - one
b_minus_1 = b - one
lbeta_a_and_b = lbeta(a, b)
two_tiny = two * tiny
# max_iterations was taken from [4] and tolerance was set by experimentation.
if numpy_dtype == np.float32:
max_iterations = 10
tolerance = numpy_dtype(8.) * eps
else:
max_iterations = 8
tolerance = numpy_dtype(4096.) * eps
def root_finding_iteration(should_stop, low, high, candidate):
error = betainc(a, b, candidate) - y
error_over_der = error / tf.math.exp(
tf.math.xlogy(a_minus_1, candidate) +
tf.math.xlog1py(b_minus_1, -candidate) -
lbeta_a_and_b)
second_der_over_der = a_minus_1 / candidate - b_minus_1 / (one - candidate)
# Following [2, section 9.4.2, page 463], we limit the influence of the
# Halley's correction to the Newton's method, since this correction can
# reduce the Newton's region of convergence. We set minimum and maximum
# values for this correction by experimentation.
halley_correction = tf.clip_by_value(
one - half * error_over_der * second_der_over_der,
halley_correction_min,
halley_correction_max)
halley_delta = error_over_der / halley_correction
halley_candidate = tf.where(
should_stop, candidate, candidate - halley_delta)
# Fall back to bisection if the current step would take the new candidate
# out of bounds.
new_candidate = tf.where(
halley_candidate <= low,
half * (candidate + low),
tf.where(
halley_candidate >= high,
half * (candidate + high),
halley_candidate))
new_delta = candidate - new_candidate
new_delta_is_negative = (new_delta < zero)
new_low = tf.where(new_delta_is_negative, candidate, low)
new_high = tf.where(new_delta_is_negative, high, candidate)
adjusted_tolerance = tf.math.maximum(tolerance * new_candidate, two_tiny)
should_stop = (should_stop | (tf.math.abs(new_delta) < adjusted_tolerance) |
tf.math.equal(new_low, new_high))
return should_stop, new_low, new_high, new_candidate
initial_candidate = _betaincinv_initial_approx(a, b, y, dtype)
# Bracket the solution with the interval (low, high).
initial_low = tf.zeros_like(y)
if numpy_dtype == np.float32:
initial_high = tf.ones_like(y) * tf.where(
a_and_b_are_small & (error_at_mean < zero), half, one)
else:
initial_high = tf.ones_like(y) * half
(_, _, _, result) = tf.while_loop(
cond=lambda stop, *_: tf.reduce_any(~stop),
body=root_finding_iteration,
loop_vars=(
tf.equal(y, initial_low) | tf.equal(y, initial_high),
initial_low,
initial_high,
initial_candidate),
maximum_iterations=max_iterations)
# If we are taking advantage of the symmetry relation, we have to adjust the
# input y and the solution.
y = y_orig
result = tf.where(
use_symmetry_relation, one - tf.math.maximum(result, eps), result)
# Handle trivial cases.
result = tf.where(tf.equal(y, zero) | tf.equal(y, one), y, result)
# Determine if the inputs are out of range (should return NaN output).
result_is_nan = (a <= zero) | (b <= zero) | (y < zero) | (y > one)
result = tf.where(result_is_nan, numpy_dtype(np.nan), result)
# If we promoted the dtype, then we have to convert the result back to the
# original dtype.
if should_promote_dtype:
result = tf.cast(result, dtype_orig)
return result
def _betaincinv_partials(a, b, y, return_value=False):
"""Returns the partial derivatives of `betaincinv(a, b, y)`."""
dtype_orig = dtype_util.common_dtype([a, b, y], tf.float32)
# We promote bfloat16 and float16 to float32 to make this function consistent
# with betaincinv.
should_promote_dtype = (dtype_orig in _f16bit_dtypes)
dtype = tf.float32 if should_promote_dtype else dtype_orig
a, b, y = [tf.convert_to_tensor(z, dtype=dtype_orig) for z in [a, b, y]]
if should_promote_dtype:
a, b, y = [tf.cast(z, dtype) for z in [a, b, y]]
# We use the fact that betainc and betaincinv are inverses of each other to
# compute the gradients.
x = _betaincinv_custom_gradient(a, b, y)
betainc_partial_a, betainc_partial_b, betainc_partial_x = _betainc_partials(
a, b, x)
partial_a = -betainc_partial_a / betainc_partial_x
partial_b = -betainc_partial_b / betainc_partial_x
partial_y = tf.math.reciprocal(betainc_partial_x)
if return_value:
results = (partial_a, partial_b, partial_y, x)
else:
results = (partial_a, partial_b, partial_y)
# If we promoted the dtype, then we have to convert the results back to the
# original dtype.
if should_promote_dtype:
results = [tf.cast(z, dtype_orig) for z in results]
return results
def _betaincinv_fwd(a, b, y):
"""Computes output, aux (collaborates with _betaincinv_bwd)."""
output = _betaincinv_computation(a, b, y)
return output, (a, b, y)
def _betaincinv_bwd(aux, g):
"""Reverse mode impl for betaincinv."""
a, b, y = aux
# pylint: disable=unbalanced-tuple-unpacking
pa, pb, py = _betaincinv_partials(a, b, y)
return generic.fix_gradient_for_broadcasting(
[a, b, y], [pa * g, pb * g, py * g])
def _betaincinv_jvp(primals, tangents):
"""Computes JVP for betaincinv (supports JAX custom derivative)."""
a, b, y = primals
da, db, dy = tangents
pa, pb, py, x = _betaincinv_partials(a, b, y, return_value=True)
return (x, pa * da + pb * db + py * dy)
@tfp_custom_gradient.custom_gradient(
vjp_fwd=_betaincinv_fwd,
vjp_bwd=_betaincinv_bwd,
jvp_fn=_betaincinv_jvp)
def _betaincinv_custom_gradient(a, b, y):
"""Computes `betaincinv(a, b, y)` with correct custom gradient."""
return _betaincinv_computation(a, b, y)
def betaincinv(a, b, y, name=None):
"""Computes the inverse of `tfp.math.betainc` with respect to `x`.
This function returns a value `x` such that `y = tfp.math.betainc(a, b, x)`.
Args:
a: Floating-point Tensor. Must be broadcastable with `b` and `y`.
b: Floating-point Tensor. Must be broadcastable with `a` and `y`.
y: Floating-point Tensor. Must be broadcastable with `a` and `b`.
name: A name for the operation (optional).
Returns:
betaincinv: Floating-point Tensor, inverse of the regularized incomplete
beta function computed element-wise.
"""
with tf.name_scope(name or 'betaincinv'):
dtype = dtype_util.common_dtype([a, b, y], tf.float32)
a = tf.convert_to_tensor(a, dtype=dtype)
b = tf.convert_to_tensor(b, dtype=dtype)
y = tf.convert_to_tensor(y, dtype=dtype)
return _betaincinv_custom_gradient(a, b, y)
def _dawsn_naive(x):
"""Returns the Dawson Integral computed at x elementwise."""
dtype = dtype_util.common_dtype([x], tf.float32)
numpy_dtype = dtype_util.as_numpy_dtype(dtype)
x = tf.convert_to_tensor(x, dtype=dtype)
n1 = [
1.13681498971755972054E-11,
8.49262267667473811108E-10,
1.94434204175553054283E-8,
9.53151741254484363489E-7,
3.07828309874913200438E-6,
3.52513368520288738649E-4,
-8.50149846724410912031E-4,
4.22618223005546594270E-2,
-9.17480371773452345351E-2,
9.99999999999999994612E-1]
d1 = [
2.40372073066762605484E-11,
1.48864681368493396752E-9,
5.21265281010541664570E-8,
1.27258478273186970203E-6,
2.32490249820789513991E-5,
3.25524741826057911661E-4,
3.48805814657162590916E-3,
2.79448531198828973716E-2,
1.58874241960120565368E-1,
5.74918629489320327824E-1,
1.00000000000000000539E0]
n2 = [
5.08955156417900903354E-1,
-2.44754418142697847934E-1,
9.41512335303534411857E-2,
-2.18711255142039025206E-2,
3.66207612329569181322E-3,
-4.23209114460388756528E-4,
3.59641304793896631888E-5,
-2.14640351719968974225E-6,
9.10010780076391431042E-8,
-2.40274520828250956942E-9,
3.59233385440928410398E-11]
d2 = [
1.00000000000000000000E0,
-6.31839869873368190192E-1,
2.36706788228248691528E-1,
-5.31806367003223277662E-2,
8.48041718586295374409E-3,
-9.47996768486665330168E-4,
7.81025592944552338085E-5,
-4.55875153252442634831E-6,
1.89100358111421846170E-7,
-4.91324691331920606875E-9,
7.18466403235734541950E-11]
n3 = [
-5.90592860534773254987E-1,
6.29235242724368800674E-1,
-1.72858975380388136411E-1,
1.64837047825189632310E-2,
-4.86827613020462700845E-4]
d3 = [
1.00000000000000000000E0,
-2.69820057197544900361E0,
1.73270799045947845857E0,
-3.93708582281939493482E-1,
3.44278924041233391079E-2,
-9.73655226040941223894E-4]
n1, d1, n2, d2, n3, d3 = [
[numpy_dtype(c) for c in lst] for lst in (n1, d1, n2, d2, n3, d3)]
abs_x = tf.math.abs(x)
result_small = abs_x * tf.math.polyval(
n1, tf.math.square(x)) / tf.math.polyval(d1, tf.math.square(x))
result_small = tf.math.sign(x) * result_small
inv_xsq = tf.math.reciprocal(tf.math.square(x))
result_medium = tf.math.reciprocal(abs_x) + inv_xsq * (
tf.math.polyval(n2, inv_xsq) / (abs_x * tf.math.polyval(d2, inv_xsq)))
result_medium = 0.5 * tf.math.sign(x) * result_medium
result_very_large = 0.5 * tf.math.sign(x) * tf.math.reciprocal(abs_x)
result_large = tf.math.reciprocal(abs_x) + inv_xsq * (
tf.math.polyval(n3, inv_xsq) / (abs_x * tf.math.polyval(d3, inv_xsq)))
result_large = 0.5 * tf.math.sign(x) * result_large
return tf.where(
abs_x < 3.25,
result_small,
tf.where(
abs_x < 6.25,
result_medium,
tf.where(
abs_x > 1e9,
result_very_large,
result_large)))
def _dawsn_fwd(x):
"""Compute output, aux (collaborates with _dawsn_bwd)."""
output = _dawsn_naive(x)
return output, (x,)
def _dawsn_bwd(aux, g):
"""Reverse mode impl for dawsn."""
x, = aux
y = _dawsn_custom_gradient(x)
return g * (1. - 2 * x * y)
def _dawsn_jvp(primals, tangents):
"""Computes JVP for dawsn (supports JAX custom derivative)."""
x, = primals
dx, = tangents
y = _dawsn_custom_gradient(x)
return y, dx * (1. - 2 * x * y)
@tfp_custom_gradient.custom_gradient(
vjp_fwd=_dawsn_fwd,
vjp_bwd=_dawsn_bwd,
jvp_fn=_dawsn_jvp)