Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
import pytensor.gradient as tg

from arviz_base import dict_to_dataset, make_attrs
from pytensor.compile.mode import get_mode
from pytensor.graph.basic import Variable
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.numba.linker import NumbaLinker
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol
Expand Down Expand Up @@ -84,6 +87,14 @@
except ImportError:
MemoryStore = type("MemoryStore", (), {})

try:
import nutpie

NUTPIE_INSTALLED = True
except ImportError:
NUTPIE_INSTALLED = False
Comment on lines +90 to +95
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
import nutpie
NUTPIE_INSTALLED = True
except ImportError:
NUTPIE_INSTALLED = False
import importlib
NUTPIE_INSTALLED = bool(importlib.util.find_spec("nutpie"))



sys.setrecursionlimit(10000)

__all__ = [
Expand Down Expand Up @@ -337,15 +348,28 @@ def _sample_external_nuts(
idata_kwargs: dict | None,
compute_convergence_checks: bool,
nuts_sampler_kwargs: dict | None,
compile_kwargs: dict | None = None,
**kwargs,
):
if nuts_sampler_kwargs is None:
nuts_sampler_kwargs = {}
else:
if "backend" in nuts_sampler_kwargs:
warnings.warn(
"backend should be passed explicitly to pm.sample, not inserted in nuts_sampler_kwargs",
FutureWarning,
)
compile_kwargs["mode"] = get_mode(nuts_sampler_kwargs.pop("backend"))

if "gradient_backend" in nuts_sampler_kwargs:
warnings.warn(
"gradient_backend should be passed to pm.sample as `compile_kwargs`, not inserted in nuts_sampler_kwargs",
FutureWarning,
)
compile_kwargs["gradient_backend"] = nuts_sampler_kwargs.pop("gradient_backend")

if sampler == "nutpie":
try:
import nutpie
except ImportError as err:
if not NUTPIE_INSTALLED:
raise ImportError(
"nutpie not found. Install it with conda install -c conda-forge nutpie"
) from err
Expand All @@ -363,11 +387,9 @@ def _sample_external_nuts(
UserWarning,
)

compile_kwargs = {}
nuts_sampler_kwargs = nuts_sampler_kwargs.copy()
for kwarg in ("backend", "gradient_backend"):
if kwarg in nuts_sampler_kwargs:
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
mode = compile_kwargs.pop("mode", None)
default_backend = "jax" if isinstance(get_mode(mode).linker, JAXLinker) else "numba"
compile_kwargs.setdefault("backend", default_backend)
compiled_model = nutpie.compile_pymc_model(
model,
var_names=var_names,
Expand Down Expand Up @@ -407,7 +429,7 @@ def _sample_external_nuts(
"sampling_time": t_sample,
"tuning_steps": tune,
},
library=nutpie,
inference_library=nutpie,
)
for k, v in attrs.items():
idata.posterior.attrs[k] = v
Expand Down Expand Up @@ -458,7 +480,7 @@ def sample(
quiet: bool = False,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -491,7 +513,7 @@ def sample(
quiet: bool = False,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -524,7 +546,7 @@ def sample(
quiet: bool = False,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand All @@ -541,6 +563,7 @@ def sample(
blas_cores: int | None | Literal["auto"] = "auto",
model: Model | None = None,
compile_kwargs: dict | None = None,
backend: str | None = None,
**kwargs,
) -> DataTree | MultiTrace | ZarrTrace:
r"""Draw samples from the posterior using the given step methods.
Expand Down Expand Up @@ -593,10 +616,11 @@ def sample(
method will be used, if appropriate to the model.
var_names : list of str, optional
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
nuts_sampler : str
nuts_sampler : str, optional
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
This requires the chosen sampler to be installed.
All samplers, except "pymc", require the full model to be continuous.
If ``None`` (default), "nutpie" is used if installed and can be compiled to the desired backend.
blas_cores: int or "auto" or None, default = "auto"
The total number of threads blas and openmp functions should use during sampling.
Setting it to "auto" will ensure that the total number of active blas threads is the
Expand Down Expand Up @@ -658,10 +682,12 @@ def sample(
See multiprocessing documentation for details.
model : Model (optional if in ``with`` context)
Model to sample from. The model needs to have free random variables.
backend: str, optional.
Which computational backend to use. Recommended to be one of "numba", "cvm", "jax", and "py".
May require installing extra dependencies.
compile_kwargs: dict, optional
Dictionary with keyword argument to pass to the functions compiled by the step methods.


Returns
-------
trace : pymc.backends.base.MultiTrace | pymc.backends.zarr.ZarrTrace | arviz.InferenceData
Expand Down Expand Up @@ -822,19 +848,35 @@ def sample(
)
)

if nuts_sampler != "pymc":
if not exclusive_nuts:
raise ValueError(
"Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
)
if compile_kwargs is None:
compile_kwargs = {}

if backend is not None:
if "mode" in compile_kwargs:
raise ValueError("Can only define one of backend or compile_kwargs['mode']")
compile_kwargs["mode"] = get_mode(backend)

if nuts_sampler is None:
# Nutpie must take all the variables and can only compile to Numba or JAX
can_use_nutpie = (
exclusive_nuts
and NUTPIE_INSTALLED
and isinstance(get_mode(compile_kwargs.get("mode")).linker, NumbaLinker | JAXLinker)
)
nuts_sampler = "nutpie" if can_use_nutpie else "pymc"
elif nuts_sampler != "pymc" and not exclusive_nuts:
raise ValueError(
"Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
)

if nuts_sampler != "pymc":
with joined_blas_limiter():
return _sample_external_nuts(
sampler=nuts_sampler,
draws=draws,
tune=tune,
chains=chains,
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
target_accept=nuts_sampler_kwargs.get("target_accept", 0.8),
random_seed=random_seed,
initvals=initvals,
model=model,
Expand All @@ -844,6 +886,7 @@ def sample(
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
nuts_sampler_kwargs=nuts_sampler_kwargs,
compile_kwargs=compile_kwargs,
**kwargs,
)

Expand Down
124 changes: 124 additions & 0 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,3 +992,127 @@ def test_quiet_false_shows_logs(self, caplog):

pymc_logs = [r for r in caplog.records if r.name.startswith("pymc")]
assert len(pymc_logs) > 0


class TestNutpieSelection:
@pytest.fixture
def continuous_model(self):
with pm.Model() as model:
pm.Normal("x")
return model

def test_auto_selection_numba(self, continuous_model):
with (
mock.patch("pymc.sampling.mcmc.get_mode") as mock_get_mode,
mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_sample_external,
):
# Create a mock linker and mode
MockNumbaLinker = type("MockNumbaLinker", (), {})
with mock.patch("pymc.sampling.mcmc.NumbaLinker", MockNumbaLinker):
mock_mode = mock.Mock()
mock_mode.linker = MockNumbaLinker()
mock_get_mode.return_value = mock_mode

pm.sample(
model=continuous_model,
compile_kwargs={"mode": "NUMBA"},
tune=10,
draws=10,
chains=1,
progressbar=False,
)

mock_sample_external.assert_called_once()
assert mock_sample_external.call_args[1]["sampler"] == "nutpie"
assert mock_sample_external.call_args[1].get("compile_kwargs") == {"mode": "NUMBA"}

def test_auto_selection_jax(self, continuous_model):
with (
mock.patch("pymc.sampling.mcmc.get_mode") as mock_get_mode,
mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_sample_external,
):
MockJAXLinker = type("MockJAXLinker", (), {})
with mock.patch("pymc.sampling.mcmc.JAXLinker", MockJAXLinker):
mock_mode = mock.Mock()
mock_mode.linker = MockJAXLinker()
mock_get_mode.return_value = mock_mode

pm.sample(
model=continuous_model,
compile_kwargs={"mode": "JAX"},
tune=10,
draws=10,
chains=1,
progressbar=False,
)

mock_sample_external.assert_called_once()
assert mock_sample_external.call_args[1]["sampler"] == "nutpie"
# Backend should be propagated correctly in _sample_external_nuts, but here we check kwargs passed TO it
assert mock_sample_external.call_args[1].get("compile_kwargs") == {"mode": "JAX"}

def test_fallback_cvm(self, continuous_model):
with (
mock.patch("pymc.sampling.mcmc.get_mode") as mock_get_mode,
mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_sample_external,
mock.patch("pymc.sampling.mcmc._iter_sample"),
mock.patch("pymc.sampling.mcmc._mp_sample"),
):
# Use real NumbaLinker/JAXLinker classes if possible, or mocks that won't match CVM

mock_mode = mock.Mock()
mock_mode.linker = mock.Mock() # Generic mock, not JAX or Numba
mock_get_mode.return_value = mock_mode

pm.sample(
model=continuous_model,
compile_kwargs={"mode": "FAST_RUN"},
tune=10,
draws=10,
chains=1,
progressbar=False,
)

mock_sample_external.assert_not_called()

def test_explicit_selection(self, continuous_model):
with mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_sample_external:
pm.sample(
model=continuous_model,
nuts_sampler="nutpie",
tune=10,
draws=10,
chains=1,
progressbar=False,
)
mock_sample_external.assert_called_once()
assert mock_sample_external.call_args[1]["sampler"] == "nutpie"

def test_backend_propagation_internal(self, continuous_model):
with mock.patch.dict("sys.modules", {"nutpie": mock.Mock()}):
import nutpie

nutpie.compile_pymc_model = mock.Mock()
nutpie.sample = mock.Mock(return_value=mock.Mock())

with mock.patch("pymc.sampling.mcmc.get_mode") as mock_get_mode:
MockNumbaLinker = type("MockNumbaLinker", (), {})
with mock.patch("pymc.sampling.mcmc.NumbaLinker", MockNumbaLinker):
mock_mode = mock.Mock()
mock_mode.linker = MockNumbaLinker()
mock_get_mode.return_value = mock_mode

# We can call pm.sample with nuts_sampler="nutpie" and compile_kwargs
pm.sample(
model=continuous_model,
nuts_sampler="nutpie",
compile_kwargs={"mode": "NUMBA"},
tune=10,
draws=10,
chains=1,
progressbar=False,
)

nutpie.compile_pymc_model.assert_called()
_, kwargs = nutpie.compile_pymc_model.call_args
assert kwargs.get("backend") == "numba"
Loading