-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathtest_mcmc_external.py
88 lines (74 loc) · 2.79 KB
/
test_mcmc_external.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import numpy.testing as npt
import pytest
from pymc import Data, Model, Normal, sample
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
def test_external_nuts_sampler(recwarn, nuts_sampler):
if nuts_sampler != "pymc":
pytest.importorskip(nuts_sampler)
with Model():
x = Normal("x", 100, 5)
y = Data("y", [1, 2, 3, 4])
Data("z", [100, 190, 310, 405])
Normal("L", mu=x, sigma=0.1, observed=y)
kwargs = {
"nuts_sampler": nuts_sampler,
"random_seed": 123,
"chains": 2,
"tune": 500,
"draws": 500,
"progressbar": False,
"initvals": {"x": 0.0},
}
idata1 = sample(**kwargs)
idata2 = sample(**kwargs)
reference_kwargs = kwargs.copy()
reference_kwargs["nuts_sampler"] = "pymc"
idata_reference = sample(**reference_kwargs)
warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
}
expected = set()
if nuts_sampler == "nutpie":
expected.add(
(
UserWarning,
"`initvals` are currently not passed to nutpie sampler. "
"Use `init_mean` kwarg following nutpie specification instead.",
)
)
assert warns == expected
assert "y" in idata1.constant_data
assert "z" in idata1.constant_data
assert "L" in idata1.observed_data
assert idata1.posterior.chain.size == 2
assert idata1.posterior.draw.size == 500
assert idata1.posterior.tuning_steps == 500
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)
assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys()
def test_step_args():
with Model() as model:
a = Normal("a")
idata = sample(
nuts_sampler="numpyro",
target_accept=0.5,
nuts={"max_treedepth": 10},
random_seed=1411,
progressbar=False,
)
npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)