-
Notifications
You must be signed in to change notification settings - Fork 9
State-Centered Temporal Processes #828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 81 commits
Commits
Show all changes
94 commits
Select commit
Hold shift + click to select a range
680bb1e
Merge branch 'main' of https://github.com/CDCgov/PyRenew
cdc-mitzimorris 2cb876b
update
cdc-mitzimorris 60db8df
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 32a5314
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris d6213f2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 96f27c9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 1cb6fa2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris f62e1e4
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 0c6785d
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 1ee62b9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 0629461
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris efeadee
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 371ba98
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 0304bed
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris ffeea65
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 50e7261
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris dae6af8
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 5cb3097
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 1d80ccc
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris e73b401
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris b1473b5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 0b929b5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 3ee00a7
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 307982a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris b862bc6
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 2c665a5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 60d6458
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris ec8c464
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris c018bf7
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris d0207dd
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris f3c706a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 684c6c5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris ca2454f
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 0f38afc
merge
cdc-mitzimorris d8e7a57
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 7e9b5fe
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris e1d8014
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 83ddbf0
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 69ea4ea
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 555e87b
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris fa5a7cb
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 69cdab0
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris c28a89f
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris fd091ca
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 8cee471
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris b2a1e1a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 2006afd
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris a31ec85
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris 0db35a9
implementation and unit tests for centered versions of temporal proce…
cdc-mitzimorris a92d58b
checkpointing - test cleanup
cdc-mitzimorris 9911f4a
checkpointing
cdc-mitzimorris 7795672
fix unit test
cdc-mitzimorris b7030a3
benchmark test suite
cdc-mitzimorris c0d4684
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fe81470
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris ee0c276
lint fix
cdc-mitzimorris 1e99920
more unit tests
cdc-mitzimorris c8a8764
checkpointing
cdc-mitzimorris 0c852ab
refactoring benchmarks
cdc-mitzimorris c67fe92
Day-of-week effects applied on observation time axis (i.e., not befor…
cdc-mitzimorris 197d9da
simplify benchmarks
cdc-mitzimorris 3fe2f2b
remove dependency on R forecasttools package by substituting local po…
cdc-mitzimorris fec5f4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f910b2b
remove dependency on R forecasttools package by substituting local po…
cdc-mitzimorris a3e34ab
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris 7d9619a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 360e8f0
more informative benchmark outputs
cdc-mitzimorris ab623e6
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris 07f5bb6
checkpointing
cdc-mitzimorris 72a4f19
fixing real data loading
cdc-mitzimorris 670ee27
fix typo
cdc-mitzimorris b62cddf
fix typo
cdc-mitzimorris 1f0b68f
deptry fix
cdc-mitzimorris 7a9031e
changes per bot review
cdc-mitzimorris ed614e6
cleanup
cdc-mitzimorris 70e116c
cleanup
cdc-mitzimorris a16ce91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 26616bf
tweak benchmarks report
cdc-mitzimorris 543135b
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris b1367ac
simplify PR; remove all benchmarks code and tests
cdc-mitzimorris 29eff16
more benchmarks cleanup
cdc-mitzimorris 95ff442
changes per code review
cdc-mitzimorris 7be88ce
changes per code review
cdc-mitzimorris 9297962
Update pyrenew/latent/state_centered_distributions.py
cdc-mitzimorris 90431bd
Update pyrenew/latent/state_centered_distributions.py
cdc-mitzimorris 83f77f1
Update pyrenew/latent/state_centered_distributions.py
cdc-mitzimorris 3e88378
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris ef3923a
latent __init__.py exposes state centered distributions
cdc-mitzimorris 047b62c
changes per code review
cdc-mitzimorris 897235f
Update test/test_temporal_processes.py
cdc-mitzimorris 0d4f45b
Update test/test_temporal_processes.py
cdc-mitzimorris 0f34c29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c648c59
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris 8af0474
changes to unit tests per code review
cdc-mitzimorris File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
cdc-mitzimorris marked this conversation as resolved.
dylanhmorris marked this conversation as resolved.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,387 @@ | ||||||
| """NumPyro distributions for state-centered temporal-process priors.""" | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import jax | ||||||
| import jax.numpy as jnp | ||||||
| from jax import lax, random | ||||||
| from jax.typing import ArrayLike | ||||||
| from numpyro.distributions import constraints | ||||||
| from numpyro.distributions.continuous import Normal | ||||||
| from numpyro.distributions.distribution import Distribution | ||||||
| from numpyro.distributions.util import validate_sample | ||||||
| from numpyro.util import is_prng_key | ||||||
|
|
||||||
|
|
||||||
| class StateRandomWalk(Distribution): | ||||||
| r""" | ||||||
| State-centered random-walk prior on a post-initial state path. | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| Given a deterministic initial state $x_0$ = ``initial_loc``: | ||||||
|
|
||||||
| $$ | ||||||
| x_t \sim \mathrm{Normal}(x_{t-1}, \sigma), \quad t = 1, \dots, T | ||||||
| $$ | ||||||
|
|
||||||
| The sampled value is the post-initial path | ||||||
| $[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``. | ||||||
| """ | ||||||
|
dylanhmorris marked this conversation as resolved.
|
||||||
|
|
||||||
| arg_constraints = { | ||||||
| "scale": constraints.positive, | ||||||
| "initial_loc": constraints.real, | ||||||
| } | ||||||
| support = constraints.real_vector | ||||||
| reparametrized_params = ["scale", "initial_loc"] | ||||||
| pytree_aux_fields = ("num_steps",) | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| scale: ArrayLike, | ||||||
| initial_loc: ArrayLike = 0.0, | ||||||
| num_steps: int = 1, | ||||||
| *, | ||||||
| validate_args: bool | None = None, | ||||||
| ) -> None: | ||||||
| """Construct a state-centered random-walk distribution.""" | ||||||
| if not isinstance(num_steps, int) or num_steps <= 0: | ||||||
| raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") | ||||||
| self.scale = scale | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
Outdated
|
||||||
| self.initial_loc = initial_loc | ||||||
| self.num_steps = num_steps | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| batch_shape = lax.broadcast_shapes( | ||||||
| jnp.shape(scale), | ||||||
| jnp.shape(initial_loc), | ||||||
| ) | ||||||
| super().__init__(batch_shape, (num_steps,), validate_args=validate_args) | ||||||
|
|
||||||
| def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: | ||||||
| """ | ||||||
| Forward-sample a post-initial random-walk state path. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| ArrayLike | ||||||
| Array of shape ``sample_shape + batch_shape + (num_steps,)``. | ||||||
| """ | ||||||
| assert is_prng_key(key) | ||||||
|
|
||||||
| per_step_shape = sample_shape + self.batch_shape | ||||||
| scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) | ||||||
| initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) | ||||||
| noise = random.normal(key, shape=per_step_shape + (self.num_steps,)) | ||||||
| increments = scale[..., jnp.newaxis] * noise | ||||||
| return initial_loc[..., jnp.newaxis] + jnp.cumsum(increments, axis=-1) | ||||||
|
|
||||||
| @validate_sample | ||||||
| def log_prob(self, value: ArrayLike) -> ArrayLike: | ||||||
| """ | ||||||
| Compute the log-density of an observed post-initial state path. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| value | ||||||
| Post-initial path of shape | ||||||
| ``sample_shape + batch_shape + (num_steps,)``. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| ArrayLike | ||||||
| Log-density of shape ``sample_shape + batch_shape``. | ||||||
| """ | ||||||
| scale = jnp.asarray(self.scale) | ||||||
| initial_loc = jnp.asarray(self.initial_loc) | ||||||
| init_with_event = jnp.expand_dims(initial_loc, -1) | ||||||
| init_bcast = jnp.broadcast_to(init_with_event, value.shape[:-1] + (1,)) | ||||||
| v = jnp.concatenate([init_bcast, value], axis=-1) | ||||||
| step_probs = Normal(v[..., :-1], jnp.expand_dims(scale, -1)).log_prob( | ||||||
| v[..., 1:] | ||||||
| ) | ||||||
| return jnp.sum(step_probs, axis=-1) | ||||||
|
|
||||||
|
|
||||||
| class StateAR1(Distribution): | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
|
||||||
| r""" | ||||||
| State-centered AR(1) prior on a length-``num_steps`` state path. | ||||||
|
|
||||||
| Generative form: | ||||||
|
|
||||||
| $$ | ||||||
| x_0 \sim \mathrm{Normal}(\mu_0, \sigma_{\text{stat}}) | ||||||
| $$ | ||||||
| $$ | ||||||
| x_t \sim \mathrm{Normal}(\phi \, x_{t-1}, \sigma), \quad t = 1, \dots, T-1 | ||||||
| $$ | ||||||
|
|
||||||
| where $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$ is the | ||||||
| stationary standard deviation, $\mu_0$ is ``initial_loc``, $\phi$ is | ||||||
| ``autoreg``, and $\sigma$ is ``scale``. | ||||||
|
|
||||||
| The sampled value is the full path $[x_0, x_1, \ldots, x_{T-1}]$. | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| autoreg | ||||||
| AR(1) coefficient $\phi$. For stationarity, $|\phi| < 1$; this is | ||||||
| not enforced. | ||||||
| scale | ||||||
| Innovation standard deviation $\sigma$. Must be positive. | ||||||
| initial_loc | ||||||
| Prior mean $\mu_0$ of the initial state $x_0$. Defaults to ``0.0``. | ||||||
| num_steps | ||||||
| Length of the state path. Must be a positive integer. | ||||||
| validate_args | ||||||
| Forwarded to the base [`numpyro.distributions.Distribution`][]. | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
Outdated
|
||||||
| """ | ||||||
|
|
||||||
| arg_constraints = { | ||||||
| "autoreg": constraints.real, | ||||||
| "scale": constraints.positive, | ||||||
| "initial_loc": constraints.real, | ||||||
| } | ||||||
| support = constraints.real_vector | ||||||
| reparametrized_params = ["autoreg", "scale", "initial_loc"] | ||||||
| pytree_aux_fields = ("num_steps",) | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| autoreg: ArrayLike, | ||||||
| scale: ArrayLike, | ||||||
| initial_loc: ArrayLike = 0.0, | ||||||
| num_steps: int = 1, | ||||||
| *, | ||||||
| validate_args: bool | None = None, | ||||||
| ) -> None: | ||||||
| """ | ||||||
| Construct a state-centered AR(1) distribution. | ||||||
|
|
||||||
| Raises | ||||||
| ------ | ||||||
| ValueError | ||||||
| If ``num_steps`` is not a positive integer. | ||||||
| """ | ||||||
| if not isinstance(num_steps, int) or num_steps <= 0: | ||||||
| raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") | ||||||
| self.autoreg = autoreg | ||||||
| self.scale = scale | ||||||
| self.initial_loc = initial_loc | ||||||
| self.num_steps = num_steps | ||||||
|
|
||||||
| batch_shape = lax.broadcast_shapes( | ||||||
| jnp.shape(autoreg), | ||||||
| jnp.shape(scale), | ||||||
| jnp.shape(initial_loc), | ||||||
| ) | ||||||
| event_shape = (num_steps,) | ||||||
| super().__init__(batch_shape, event_shape, validate_args=validate_args) | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: | ||||||
| """ | ||||||
| Forward-sample a state path. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| ArrayLike | ||||||
| Array of shape ``sample_shape + batch_shape + (num_steps,)``. | ||||||
| """ | ||||||
| assert is_prng_key(key) | ||||||
|
|
||||||
| per_step_shape = sample_shape + self.batch_shape | ||||||
| autoreg = jnp.broadcast_to(jnp.asarray(self.autoreg), per_step_shape) | ||||||
| scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) | ||||||
| initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
|
||||||
| stationary_sd = scale / jnp.sqrt(1 - autoreg**2) | ||||||
|
|
||||||
| noise = random.normal(key, shape=(self.num_steps,) + per_step_shape) | ||||||
| z0 = noise[0] | ||||||
| x0 = initial_loc + stationary_sd * z0 | ||||||
|
|
||||||
| if self.num_steps == 1: | ||||||
| return x0[..., jnp.newaxis] | ||||||
|
|
||||||
| def step( | ||||||
| prev: ArrayLike, z_t: ArrayLike | ||||||
| ) -> tuple[ArrayLike, ArrayLike]: # numpydoc ignore=GL08 | ||||||
| new = autoreg * prev + scale * z_t | ||||||
| return new, new | ||||||
|
|
||||||
| _, xs = lax.scan(step, x0, noise[1:]) | ||||||
| path_time_first = jnp.concatenate([x0[jnp.newaxis], xs], axis=0) | ||||||
| return jnp.moveaxis(path_time_first, 0, -1) | ||||||
|
|
||||||
| @validate_sample | ||||||
| def log_prob(self, value: ArrayLike) -> ArrayLike: | ||||||
| """ | ||||||
| Compute the log-density of an observed state path. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| value | ||||||
| State path of shape ``sample_shape + batch_shape + (num_steps,)``. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| ArrayLike | ||||||
| Log-density of shape ``sample_shape + batch_shape``. | ||||||
| """ | ||||||
| scale = jnp.asarray(self.scale) | ||||||
| autoreg = jnp.asarray(self.autoreg) | ||||||
| stationary_sd = scale / jnp.sqrt(1 - autoreg**2) | ||||||
|
|
||||||
| init_prob = Normal(self.initial_loc, stationary_sd).log_prob(value[..., 0]) | ||||||
|
|
||||||
| scale_t = jnp.expand_dims(scale, -1) | ||||||
| autoreg_t = jnp.expand_dims(autoreg, -1) | ||||||
| step_locs = autoreg_t * value[..., :-1] | ||||||
| step_probs = Normal(step_locs, scale_t).log_prob(value[..., 1:]) | ||||||
| return init_prob + jnp.sum(step_probs, axis=-1) | ||||||
|
|
||||||
|
|
||||||
| class StateDifferencedAR1(Distribution): | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
|
||||||
| r""" | ||||||
| State-centered differenced AR(1) prior on a length-``num_steps`` post-initial path. | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| Generative form, given a deterministic initial state $x_0$ = ``initial_loc``: | ||||||
|
|
||||||
| $$ | ||||||
| x_1 \sim \mathrm{Normal}(x_0, \sigma_{\text{stat}}) | ||||||
| $$ | ||||||
| $$ | ||||||
| x_t \sim \mathrm{Normal}(x_{t-1} + \phi \, (x_{t-1} - x_{t-2}), \sigma), | ||||||
| \quad t \geq 2 | ||||||
| $$ | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| where $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$, $\phi$ is | ||||||
|
cdc-mitzimorris marked this conversation as resolved.
Outdated
|
||||||
| ``autoreg``, and $\sigma$ is ``scale``. | ||||||
|
|
||||||
| The sampled value is the post-initial path | ||||||
| $[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``. | ||||||
| The initial state $x_0$ is not part of the sample; it is supplied as | ||||||
| ``initial_loc`` and used to score the first transition. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| autoreg | ||||||
| AR(1) coefficient $\phi$ on first differences. For stationarity, | ||||||
| $|\phi| < 1$; this is not enforced. | ||||||
| scale | ||||||
| Innovation standard deviation $\sigma$. Must be positive. | ||||||
| initial_loc | ||||||
| Deterministic initial state $x_0$. Used to score the first | ||||||
| transition; not itself sampled. | ||||||
| num_steps | ||||||
| Length of the post-initial path. Must be a positive integer. | ||||||
| validate_args | ||||||
| Forwarded to the base [`numpyro.distributions.Distribution`][]. | ||||||
| """ | ||||||
|
|
||||||
| arg_constraints = { | ||||||
| "autoreg": constraints.real, | ||||||
| "scale": constraints.positive, | ||||||
| "initial_loc": constraints.real, | ||||||
| } | ||||||
| support = constraints.real_vector | ||||||
| reparametrized_params = ["autoreg", "scale", "initial_loc"] | ||||||
| pytree_aux_fields = ("num_steps",) | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| autoreg: ArrayLike, | ||||||
| scale: ArrayLike, | ||||||
| initial_loc: ArrayLike = 0.0, | ||||||
| num_steps: int = 1, | ||||||
| *, | ||||||
| validate_args: bool | None = None, | ||||||
| ) -> None: | ||||||
| """ | ||||||
| Construct a state-centered differenced AR(1) distribution. | ||||||
|
|
||||||
| Raises | ||||||
| ------ | ||||||
| ValueError | ||||||
| If ``num_steps`` is not a positive integer. | ||||||
| """ | ||||||
| if not isinstance(num_steps, int) or num_steps <= 0: | ||||||
| raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") | ||||||
| self.autoreg = autoreg | ||||||
| self.scale = scale | ||||||
| self.initial_loc = initial_loc | ||||||
| self.num_steps = num_steps | ||||||
|
|
||||||
| batch_shape = lax.broadcast_shapes( | ||||||
| jnp.shape(autoreg), | ||||||
| jnp.shape(scale), | ||||||
| jnp.shape(initial_loc), | ||||||
| ) | ||||||
| event_shape = (num_steps,) | ||||||
| super().__init__(batch_shape, event_shape, validate_args=validate_args) | ||||||
|
|
||||||
| def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: | ||||||
| """ | ||||||
| Forward-sample a post-initial path. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| ArrayLike | ||||||
| Array of shape ``sample_shape + batch_shape + (num_steps,)``. | ||||||
| """ | ||||||
| assert is_prng_key(key) | ||||||
|
|
||||||
| per_step_shape = sample_shape + self.batch_shape | ||||||
| autoreg = jnp.broadcast_to(jnp.asarray(self.autoreg), per_step_shape) | ||||||
| scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) | ||||||
| initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) | ||||||
| stationary_sd = scale / jnp.sqrt(1 - autoreg**2) | ||||||
|
|
||||||
| noise = random.normal(key, shape=(self.num_steps,) + per_step_shape) | ||||||
| z1 = noise[0] | ||||||
| x1 = initial_loc + stationary_sd * z1 | ||||||
|
|
||||||
| if self.num_steps == 1: | ||||||
| return x1[..., jnp.newaxis] | ||||||
|
|
||||||
| def step( | ||||||
| carry: tuple[ArrayLike, ArrayLike], z_t: ArrayLike | ||||||
| ) -> tuple[tuple[ArrayLike, ArrayLike], ArrayLike]: # numpydoc ignore=GL08 | ||||||
| prev_2, prev_1 = carry | ||||||
| new = prev_1 + autoreg * (prev_1 - prev_2) + scale * z_t | ||||||
| return (prev_1, new), new | ||||||
|
|
||||||
| _, xs = lax.scan(step, (initial_loc, x1), noise[1:]) | ||||||
| path_time_first = jnp.concatenate([x1[jnp.newaxis], xs], axis=0) | ||||||
| return jnp.moveaxis(path_time_first, 0, -1) | ||||||
|
|
||||||
| @validate_sample | ||||||
| def log_prob(self, value: ArrayLike) -> ArrayLike: | ||||||
| """ | ||||||
| Compute the log-density of an observed post-initial path. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| value | ||||||
| Post-initial path of shape | ||||||
| ``sample_shape + batch_shape + (num_steps,)``. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| ArrayLike | ||||||
| Log-density of shape ``sample_shape + batch_shape``. | ||||||
| """ | ||||||
| scale = jnp.asarray(self.scale) | ||||||
| autoreg = jnp.asarray(self.autoreg) | ||||||
| initial_loc = jnp.asarray(self.initial_loc) | ||||||
| stationary_sd = scale / jnp.sqrt(1 - autoreg**2) | ||||||
|
|
||||||
| init_prob = Normal(initial_loc, stationary_sd).log_prob(value[..., 0]) | ||||||
|
|
||||||
| init_with_event = jnp.expand_dims(initial_loc, -1) | ||||||
| init_bcast = jnp.broadcast_to(init_with_event, value.shape[:-1] + (1,)) | ||||||
| v = jnp.concatenate([init_bcast, value], axis=-1) | ||||||
|
|
||||||
| prev_delta = v[..., 1:-1] - v[..., :-2] | ||||||
| scale_t = jnp.expand_dims(scale, -1) | ||||||
| autoreg_t = jnp.expand_dims(autoreg, -1) | ||||||
| means = v[..., 1:-1] + autoreg_t * prev_delta | ||||||
| step_probs = Normal(means, scale_t).log_prob(v[..., 2:]) | ||||||
| return init_prob + jnp.sum(step_probs, axis=-1) | ||||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.