Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def full_lower(self) -> Union[ArrayLike, "_QuaxTracer"]:
else:
return self

def to_concrete_value(self) -> ArrayLike | None: # pyright: ignore[reportIncompatibleMethodOverride]
Comment thread
nstarman marked this conversation as resolved.
if isinstance(self.value, _DenseArrayValue):
return core.to_concrete_value(self.value.array)
return None


def _default_process(
primitive: jexc.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,16 @@ def test_default_path():
got = quax.quaxify(lax.betainc)(jnp.array(1.0), x, y)

assert jnp.array_equal(got, exp)


# See https://github.com/patrick-kidger/quax/issues/58
def test_concrete_bool_conversion():
"""Test that quaxify doesn't break functions needing concrete boolean values."""
xbool = jnp.array([True, False, True], dtype=bool)
x1 = jnp.array([1.0, 2.0, 3.0], dtype=float)

compress = quax.quaxify(jnp.compress)
got = compress(xbool, x1)
exp = jnp.compress(xbool, x1)

assert jnp.array_equal(got, exp)
55 changes: 31 additions & 24 deletions tests/unit/test_numpy/test_jax_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test with JAX inputs."""

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree as jtu
import numpy as np
Expand Down Expand Up @@ -49,7 +51,7 @@
("argpartition", (x, 1), {}),
("argsort", (x,), {}),
*(
pytest.param("argwhere", (x,), {}, marks=xfail_quax58),
("argwhere", (x,), {}),
("argwhere", (x,), {"size": x.size}), # TODO: not need static size
),
("around", (x,), {}),
Expand All @@ -70,9 +72,7 @@
("average", (x,), {}),
("bartlett", (3,), {}),
*(
pytest.param(
"bincount", (jnp.asarray([0, 1, 1, 2, 2, 2]),), {}, marks=xfail_quax58
),
("bincount", (jnp.asarray([0, 1, 1, 2, 2, 2]),), {}),
("bincount", (jnp.asarray([0, 1, 1, 2, 2, 2]),), {"length": 3}),
),
("bitwise_and", (xbool, xbool), {}),
Expand All @@ -93,14 +93,14 @@
("can_cast", (x, int), {}),
("cbrt", (x,), {}),
("ceil", (x,), {}),
pytest.param("choose", (0, [x, x]), {}, marks=xfail_quax58),
("choose", (0, [x, x]), {}),
("clip", (x, 1, 2), {}),
("column_stack", ([x, x],), {}),
pytest.param("complex128", (1,), {}, marks=pytest.mark.xfail),
("complex64", (1,), {}),
pytest.param("complex_", (1,), {}, marks=pytest.mark.xfail),
pytest.param("complexfloating", (1,), {}, marks=pytest.mark.xfail),
pytest.param("compress", (xbool, x), {}, marks=xfail_quax58),
("compress", (xbool, x), {}),
("concat", ([x, x],), {}),
("concatenate", ([x, x],), {}),
("conj", (x,), {}),
Expand Down Expand Up @@ -150,7 +150,7 @@
("fill_diagonal", (jnp.eye(3), 2), {"inplace": False}),
("fix", (x,), {}),
*(
pytest.param("flatnonzero", (x,), {}, marks=xfail_quax58),
("flatnonzero", (x,), {}),
("flatnonzero", (x,), {"size": x.size}),
),
("flip", (x,), {}),
Expand Down Expand Up @@ -277,7 +277,7 @@
("negative", (x,), {}),
("nextafter", (x, y), {}),
*(
pytest.param("nonzero", (x,), {}, marks=xfail_quax58),
("nonzero", (x,), {}),
("nonzero", (x,), {"size": x.size}),
),
("not_equal", (x, y), {}),
Expand Down Expand Up @@ -310,12 +310,7 @@
("rad2deg", (x,), {}),
("radians", (x,), {}),
("ravel", (x,), {}),
pytest.param(
"ravel_multi_index",
(jnp.array([[0, 1], [0, 1]]), (2, 2)),
{},
marks=xfail_quax58,
),
("ravel_multi_index", (jnp.array([[0, 1], [0, 1]]), (2, 2)), {}),
("real", (x,), {}),
("reciprocal", (x,), {}),
("remainder", (x, y), {}),
Expand All @@ -327,7 +322,7 @@
("rint", (x,), {}),
("roll", (x, 4), {}),
("rollaxis", (x, -1), {}),
pytest.param("roots", (x[:, 0],), {}, marks=xfail_quax58),
("roots", (x[:, 0],), {}),
("rot90", (x,), {}),
("round", (x,), {}),
# pytest.param("round_", (x,), {}, marks=pytest.mark.deprecated),
Expand Down Expand Up @@ -371,43 +366,42 @@
("transpose", (x,), {}),
("tril", (jnp.eye(4),), {}),
("tril_indices_from", (jnp.eye(4),), {}),
pytest.param(
(
"trim_zeros",
(
jnp.concatenate(
(np.array([0.0, 0, 0]), x[:, 0], jnp.array([0.0, 0, 0]))
),
),
{},
marks=xfail_quax58,
),
("triu", (x,), {}),
("triu_indices_from", (x,), {}),
("true_divide", (x, y), {}),
("trunc", (x,), {}),
pytest.param("ufunc", (x,), {}, marks=mark_todo),
*(
pytest.param("union1d", (x[:, 0], y[:, 0]), {}, marks=xfail_quax58),
("union1d", (x[:, 0], y[:, 0]), {}),
("union1d", (x[:, 0], y[:, 0]), {"size": x[:, 0].size}),
),
*(
pytest.param("unique", (x,), {}, marks=xfail_quax58),
("unique", (x,), {}),
("unique", (x,), {"size": x.size}),
),
*(
pytest.param("unique_all", (x,), {}, marks=xfail_quax58),
("unique_all", (x,), {}),
("unique_all", (x,), {"size": x.size}),
),
*(
pytest.param("unique_counts", (x,), {}, marks=xfail_quax58),
("unique_counts", (x,), {}),
("unique_counts", (x,), {"size": x.size}),
),
*(
pytest.param("unique_inverse", (x,), {}, marks=xfail_quax58),
("unique_inverse", (x,), {}),
("unique_inverse", (x,), {"size": x.size}),
),
*(
pytest.param("unique_values", (x,), {}, marks=xfail_quax58),
("unique_values", (x,), {}),
("unique_values", (x,), {"size": x.size}),
),
pytest.param(
Expand All @@ -430,8 +424,9 @@
],
)
def test_numpy_functions(func_name, args, kw):
"""Test lax vs qlax functions."""
"""Test numpy vs qnumpy functions."""
func = getattr(jnp, func_name)

# Jax
exp = func(*args, **kw)
exp = exp if isinstance(exp, tuple | list) else (exp,)
Expand All @@ -442,6 +437,18 @@ def test_numpy_functions(func_name, args, kw):

assert jtu.all(jtu.map(jnp.allclose, got, exp))

# Check that where JIT'ed JAX is supported, Quaxed + JIT is too.
try:
_ = eqx.filter_jit(func)(*args, **kw)
except (jax.errors.ConcretizationTypeError, ValueError, TypeError):
pass
else:
# Quaxed + JIT
got_jit = eqx.filter_jit(quaxify(func))(*args, **kw)
got_jit = got_jit if isinstance(got_jit, tuple | list) else (got_jit,)

assert jtu.all(jtu.map(jnp.allclose, got_jit, exp))


###############################################################################

Expand Down
20 changes: 18 additions & 2 deletions tests/unit/test_numpy/test_myarray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test with JAX inputs."""

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree as jtu
Expand All @@ -16,6 +18,7 @@
)
mark_todo = pytest.mark.skip("TODO")
mark_nomd = pytest.mark.xfail(reason="Can't be supported with MD on primitives")
mark_tracerleak = pytest.mark.xfail(reason="Tracers are leaking")

x = MyArray(jnp.array([[1, 2], [3, 4]], dtype=float))
y = MyArray(jnp.array([[5, 6], [7, 8]], dtype=float))
Expand Down Expand Up @@ -111,7 +114,7 @@
("broadcast_to", (x, (2, 2)), {}, True),
("cbrt", (x,), {}, True),
("ceil", (x,), {}, True),
pytest.param("choose", (0, [x, x]), {}, True, marks=xfail_quax58),
("choose", (0, [x, x]), {}, True),
("clip", (x, 1, 2), {}, True),
("column_stack", ([x, x],), {}, True),
pytest.param("complex128", (x,), {}, True, marks=pytest.mark.xfail),
Expand Down Expand Up @@ -481,8 +484,9 @@
)
def test_numpy_functions(func_name, args, kw, expect_myarray):
"""Test lax vs qlax functions."""
# Jax
func = getattr(jnp, func_name)

# Jax
jax_args, jax_kw = jtu.map(unwrap, (args, kw), is_leaf=is_myarray)
exp = func(*jax_args, **jax_kw)
exp = exp if isinstance(exp, tuple | list) else (exp,)
Expand All @@ -494,6 +498,18 @@ def test_numpy_functions(func_name, args, kw, expect_myarray):

assert jtu.all(jtu.map(jnp.allclose, got, exp))

# Check that where JIT'ed JAX is supported, Quaxed + JIT is too.
try:
_ = eqx.filter_jit(func)(*args, **kw)
except (jax.errors.ConcretizationTypeError, ValueError, TypeError):
pass
else:
# Quaxed + JIT
got_jit = eqx.filter_jit(quax.quaxify(func))(*args, **kw)
got_jit = got_jit if isinstance(got_jit, tuple | list) else (got_jit,)

assert jtu.all(jtu.map(jnp.allclose, got_jit, exp))


# ###############################################################################

Expand Down
Loading