-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Feature/print value helper #8229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v6
Are you sure you want to change the base?
Changes from 3 commits
262927f
d7ca51b
ed50dd4
f2e9a34
78fdb03
5cfaa44
285114e
6a03ab9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from pymc.printing import print_value | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # pymc/printing.py | ||
|
|
||
| """Helper utilities for debugging PyMC models.""" | ||
|
|
||
| from pytensor.printing import Print | ||
|
|
||
| __all__ = ["print_value"] | ||
|
Comment on lines
+1
to
+7
|
||
|
|
||
|
|
||
| def print_value(var, name=None): | ||
| """Print the value of a variable each time it is computed during sampling. | ||
|
|
||
| This wraps the variable in a ``pytensor.printing.Print`` op, which is a | ||
| pass-through that prints the variable's value as a side effect whenever it | ||
| is evaluated. | ||
|
|
||
| .. warning:: | ||
| This may significantly affect sampling performance. Use only for | ||
| debugging purposes. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| var : TensorVariable | ||
| The PyMC variable to debug-print. | ||
| name : str, optional | ||
| Label shown in the printed output. Defaults to ``var.name``. | ||
|
|
||
| Returns | ||
| ------- | ||
| TensorVariable | ||
| The same variable, wrapped in a Print op (value is unchanged). | ||
|
|
||
| Examples | ||
| -------- | ||
| .. code-block:: python | ||
|
|
||
| import pymc as pm | ||
|
|
||
| with pm.Model(): | ||
| mu = pm.Normal("mu", mu=0, sigma=1) | ||
| mu = pm.print_value(mu, name="mu") # prints mu during sampling | ||
| obs = pm.Normal("obs", mu=mu, sigma=1, observed=[1, 2, 3]) | ||
| idata = pm.sample() | ||
| """ | ||
| if name is None: | ||
| name = getattr(var, "name", "value") | ||
| return Print(name)(var) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # tests/test_printing.py | ||
krishnavallabha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| import pymc as pm | ||
| import pytensor.tensor as pt | ||
| from pytensor.printing import Print | ||
krishnavallabha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from pymc.printing import print_value | ||
|
|
||
|
|
||
| def test_print_value_is_passthrough(): | ||
| """print_value should not change the variable's value.""" | ||
| x = pt.vector("x") | ||
| x_printed = print_value(x, name="test_x") | ||
|
|
||
| import pytensor | ||
| import numpy as np | ||
|
|
||
| f = pytensor.function([x], x_printed) | ||
| result = f([1.0, 2.0, 3.0]) | ||
| np.testing.assert_array_equal(result, [1.0, 2.0, 3.0]) | ||
|
|
||
|
|
||
| def test_print_value_default_name(): | ||
| """print_value should use var.name if no name is given.""" | ||
| x = pt.vector("my_var") | ||
| x_printed = print_value(x) | ||
| # The Print op's message should match the variable name | ||
| assert x_printed.owner.op.message == "my_var" | ||
|
|
||
|
|
||
| def test_print_value_custom_name(): | ||
| """print_value should use the custom name when provided.""" | ||
| x = pt.vector("x") | ||
| x_printed = print_value(x, name="custom_label") | ||
| assert x_printed.owner.op.message == "custom_label" | ||
|
|
||
|
|
||
| def test_print_value_accessible_from_pm(): | ||
| """print_value should be accessible as pm.print_value.""" | ||
| assert hasattr(pm, "print_value") | ||
Uh oh!
There was an error while loading. Please reload this page.