diff --git a/src/quax/_core.py b/src/quax/_core.py index 3562da5..9473166 100644 --- a/src/quax/_core.py +++ b/src/quax/_core.py @@ -102,6 +102,11 @@ def full_lower(self) -> Union[ArrayLike, "_QuaxTracer"]: else: return self + def to_concrete_value(self) -> ArrayLike | None: # pyright: ignore[reportIncompatibleMethodOverride] + if isinstance(self.value, _DenseArrayValue): + return core.to_concrete_value(self.value.array) + return None + def _default_process( primitive: jexc.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index f48c3c5..520efbb 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -117,3 +117,16 @@ def test_default_path(): got = quax.quaxify(lax.betainc)(jnp.array(1.0), x, y) assert jnp.array_equal(got, exp) + + +# See https://github.com/patrick-kidger/quax/issues/58 +def test_concrete_bool_conversion(): + """Test that quaxify doesn't break functions needing concrete boolean values.""" + xbool = jnp.array([True, False, True], dtype=bool) + x1 = jnp.array([1.0, 2.0, 3.0], dtype=float) + + compress = quax.quaxify(jnp.compress) + got = compress(xbool, x1) + exp = jnp.compress(xbool, x1) + + assert jnp.array_equal(got, exp) diff --git a/tests/unit/test_numpy/test_jax_array.py b/tests/unit/test_numpy/test_jax_array.py index 57b6124..e1e17b7 100644 --- a/tests/unit/test_numpy/test_jax_array.py +++ b/tests/unit/test_numpy/test_jax_array.py @@ -1,5 +1,7 @@ """Test with JAX inputs.""" +import equinox as eqx +import jax import jax.numpy as jnp import jax.tree as jtu import numpy as np @@ -49,7 +51,7 @@ ("argpartition", (x, 1), {}), ("argsort", (x,), {}), *( - pytest.param("argwhere", (x,), {}, marks=xfail_quax58), + ("argwhere", (x,), {}), ("argwhere", (x,), {"size": x.size}), # TODO: not need static size ), ("around", (x,), {}), @@ -70,9 +72,7 @@ ("average", (x,), {}), ("bartlett", (3,), {}), *( - pytest.param( - "bincount", (jnp.asarray([0, 1, 1, 2, 2, 2]),), {}, marks=xfail_quax58 - ), + ("bincount", (jnp.asarray([0, 1, 1, 2, 2, 2]),), {}), ("bincount", (jnp.asarray([0, 1, 1, 2, 2, 2]),), {"length": 3}), ), ("bitwise_and", (xbool, xbool), {}), @@ -93,14 +93,14 @@ ("can_cast", (x, int), {}), ("cbrt", (x,), {}), ("ceil", (x,), {}), - pytest.param("choose", (0, [x, x]), {}, marks=xfail_quax58), + ("choose", (0, [x, x]), {}), ("clip", (x, 1, 2), {}), ("column_stack", ([x, x],), {}), pytest.param("complex128", (1,), {}, marks=pytest.mark.xfail), ("complex64", (1,), {}), pytest.param("complex_", (1,), {}, marks=pytest.mark.xfail), pytest.param("complexfloating", (1,), {}, marks=pytest.mark.xfail), - pytest.param("compress", (xbool, x), {}, marks=xfail_quax58), + ("compress", (xbool, x), {}), ("concat", ([x, x],), {}), ("concatenate", ([x, x],), {}), ("conj", (x,), {}), @@ -150,7 +150,7 @@ ("fill_diagonal", (jnp.eye(3), 2), {"inplace": False}), ("fix", (x,), {}), *( - pytest.param("flatnonzero", (x,), {}, marks=xfail_quax58), + ("flatnonzero", (x,), {}), ("flatnonzero", (x,), {"size": x.size}), ), ("flip", (x,), {}), @@ -277,7 +277,7 @@ ("negative", (x,), {}), ("nextafter", (x, y), {}), *( - pytest.param("nonzero", (x,), {}, marks=xfail_quax58), + ("nonzero", (x,), {}), ("nonzero", (x,), {"size": x.size}), ), ("not_equal", (x, y), {}), @@ -310,12 +310,7 @@ ("rad2deg", (x,), {}), ("radians", (x,), {}), ("ravel", (x,), {}), - pytest.param( - "ravel_multi_index", - (jnp.array([[0, 1], [0, 1]]), (2, 2)), - {}, - marks=xfail_quax58, - ), + ("ravel_multi_index", (jnp.array([[0, 1], [0, 1]]), (2, 2)), {}), ("real", (x,), {}), ("reciprocal", (x,), {}), ("remainder", (x, y), {}), @@ -327,7 +322,7 @@ ("rint", (x,), {}), ("roll", (x, 4), {}), ("rollaxis", (x, -1), {}), - pytest.param("roots", (x[:, 0],), {}, marks=xfail_quax58), + ("roots", (x[:, 0],), {}), ("rot90", (x,), {}), ("round", (x,), {}), # pytest.param("round_", (x,), {}, marks=pytest.mark.deprecated), @@ -371,7 +366,7 @@ ("transpose", (x,), {}), ("tril", (jnp.eye(4),), {}), ("tril_indices_from", (jnp.eye(4),), {}), - pytest.param( + ( "trim_zeros", ( jnp.concatenate( @@ -379,7 +374,6 @@ ), ), {}, - marks=xfail_quax58, ), ("triu", (x,), {}), ("triu_indices_from", (x,), {}), @@ -387,27 +381,27 @@ ("trunc", (x,), {}), pytest.param("ufunc", (x,), {}, marks=mark_todo), *( - pytest.param("union1d", (x[:, 0], y[:, 0]), {}, marks=xfail_quax58), + ("union1d", (x[:, 0], y[:, 0]), {}), ("union1d", (x[:, 0], y[:, 0]), {"size": x[:, 0].size}), ), *( - pytest.param("unique", (x,), {}, marks=xfail_quax58), + ("unique", (x,), {}), ("unique", (x,), {"size": x.size}), ), *( - pytest.param("unique_all", (x,), {}, marks=xfail_quax58), + ("unique_all", (x,), {}), ("unique_all", (x,), {"size": x.size}), ), *( - pytest.param("unique_counts", (x,), {}, marks=xfail_quax58), + ("unique_counts", (x,), {}), ("unique_counts", (x,), {"size": x.size}), ), *( - pytest.param("unique_inverse", (x,), {}, marks=xfail_quax58), + ("unique_inverse", (x,), {}), ("unique_inverse", (x,), {"size": x.size}), ), *( - pytest.param("unique_values", (x,), {}, marks=xfail_quax58), + ("unique_values", (x,), {}), ("unique_values", (x,), {"size": x.size}), ), pytest.param( @@ -430,8 +424,9 @@ ], ) def test_numpy_functions(func_name, args, kw): - """Test lax vs qlax functions.""" + """Test numpy vs qnumpy functions.""" func = getattr(jnp, func_name) + # Jax exp = func(*args, **kw) exp = exp if isinstance(exp, tuple | list) else (exp,) @@ -442,6 +437,18 @@ def test_numpy_functions(func_name, args, kw): assert jtu.all(jtu.map(jnp.allclose, got, exp)) + # Check that where JIT'ed JAX is supported, Quaxed + JIT is too. + try: + _ = eqx.filter_jit(func)(*args, **kw) + except (jax.errors.ConcretizationTypeError, ValueError, TypeError): + pass + else: + # Quaxed + JIT + got_jit = eqx.filter_jit(quaxify(func))(*args, **kw) + got_jit = got_jit if isinstance(got_jit, tuple | list) else (got_jit,) + + assert jtu.all(jtu.map(jnp.allclose, got_jit, exp)) + ############################################################################### diff --git a/tests/unit/test_numpy/test_myarray.py b/tests/unit/test_numpy/test_myarray.py index 931d05e..bc67179 100644 --- a/tests/unit/test_numpy/test_myarray.py +++ b/tests/unit/test_numpy/test_myarray.py @@ -1,5 +1,7 @@ """Test with JAX inputs.""" +import equinox as eqx +import jax import jax.numpy as jnp import jax.random as jr import jax.tree as jtu @@ -16,6 +18,7 @@ ) mark_todo = pytest.mark.skip("TODO") mark_nomd = pytest.mark.xfail(reason="Can't be supported with MD on primitives") +mark_tracerleak = pytest.mark.xfail(reason="Tracers are leaking") x = MyArray(jnp.array([[1, 2], [3, 4]], dtype=float)) y = MyArray(jnp.array([[5, 6], [7, 8]], dtype=float)) @@ -111,7 +114,7 @@ ("broadcast_to", (x, (2, 2)), {}, True), ("cbrt", (x,), {}, True), ("ceil", (x,), {}, True), - pytest.param("choose", (0, [x, x]), {}, True, marks=xfail_quax58), + ("choose", (0, [x, x]), {}, True), ("clip", (x, 1, 2), {}, True), ("column_stack", ([x, x],), {}, True), pytest.param("complex128", (x,), {}, True, marks=pytest.mark.xfail), @@ -481,8 +484,9 @@ ) def test_numpy_functions(func_name, args, kw, expect_myarray): """Test lax vs qlax functions.""" - # Jax func = getattr(jnp, func_name) + + # Jax jax_args, jax_kw = jtu.map(unwrap, (args, kw), is_leaf=is_myarray) exp = func(*jax_args, **jax_kw) exp = exp if isinstance(exp, tuple | list) else (exp,) @@ -494,6 +498,18 @@ def test_numpy_functions(func_name, args, kw, expect_myarray): assert jtu.all(jtu.map(jnp.allclose, got, exp)) + # Check that where JIT'ed JAX is supported, Quaxed + JIT is too. + try: + _ = eqx.filter_jit(func)(*args, **kw) + except (jax.errors.ConcretizationTypeError, ValueError, TypeError): + pass + else: + # Quaxed + JIT + got_jit = eqx.filter_jit(quax.quaxify(func))(*args, **kw) + got_jit = got_jit if isinstance(got_jit, tuple | list) else (got_jit,) + + assert jtu.all(jtu.map(jnp.allclose, got_jit, exp)) + # ###############################################################################