Skip to content

Fix _QuaxTracer concrete bool conversion for plain JAX arrays#110

Open
Copilot wants to merge 4 commits intomainfrom
copilot/fix-quaxify-call-issue
Open

Fix _QuaxTracer concrete bool conversion for plain JAX arrays#110
Copilot wants to merge 4 commits intomainfrom
copilot/fix-quaxify-call-issue

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 9, 2026

_Quaxify.__call__ wraps all dynamic JAX arrays as _QuaxTracer objects inside a _QuaxTrace context. When the wrapped function internally requires Python-level boolean evaluation of a concrete array (e.g. jnp.compress doing if reductions.any(extra):), it fails with TracerBoolConversionError because _QuaxTracer never overrode to_concrete_value(), so is_concrete() always returned False.

# Previously raised TracerBoolConversionError
compress = quax.quaxify(jnp.compress)
compress(jnp.array([True, False, True]), jnp.array([1., 2., 3.]))
# Now works correctly → Array([1., 3.], dtype=float32)

Changes

  • src/quax/_core.py — Override to_concrete_value() on _QuaxTracer: when the wrapped value is a _DenseArrayValue, delegate to core.to_concrete_value(self.value.array). This allows __bool__ (and __array__) to resolve concretely for plain JAX arrays passing through quaxify, while correctly returning None under JIT.

  • tests/unit/test_core.py — Add test_concrete_bool_conversion reproducing the issue MWE.

  • tests/unit/test_numpy/test_jax_array.py — Remove xfail_quax58 from the 15 functions that now pass without a size argument: argwhere, bincount, choose, compress, flatnonzero, nonzero, ravel_multi_index, roots, trim_zeros, union1d, unique* (5 variants). Keep xfail_quax58 on extract, intersect1d, setdiff1d, setxor1d (still fail with TracerArrayConversionError).

  • tests/unit/test_numpy/test_myarray.py — Remove xfail_quax58 from choose (only case that passes; all others still fail because the condition is a custom Value subclass, not a _DenseArrayValue).

Copilot AI linked an issue Apr 9, 2026 that may be closed by this pull request
Copilot AI changed the title [WIP] Fix problem in _Quaxify.__call__ method Fix _QuaxTracer concrete bool conversion for plain JAX arrays Apr 9, 2026
Copilot AI requested a review from nstarman April 9, 2026 13:25
@nstarman nstarman marked this pull request as ready for review April 9, 2026 15:28
Copilot AI review requested due to automatic review settings April 9, 2026 15:28
@nstarman
Copy link
Copy Markdown
Owner

nstarman commented Apr 9, 2026

Excellent. A simple solution. Most tests are passing. Not all, so let's not close #58 quite yet.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a TracerBoolConversionError when quaxify-wrapped functions attempt Python boolean evaluation on concrete (eager) JAX arrays by making _QuaxTracer able to produce a concrete value for _DenseArrayValue-backed arrays.

Changes:

  • Override _QuaxTracer.to_concrete_value() to delegate concretization for _DenseArrayValue payloads.
  • Add a regression test covering concrete boolean conversion via jnp.compress.
  • Remove xfail_quax58 from multiple NumPy-on-JAX tests that now pass; keep xfails for cases still failing with array conversion errors.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
src/quax/_core.py Adds _QuaxTracer.to_concrete_value() override to enable concretization for plain JAX arrays.
tests/unit/test_core.py Adds regression test reproducing the concrete-bool conversion failure (issue #58).
tests/unit/test_numpy/test_jax_array.py Un-xfails a set of jnp.* functions now passing without static size.
tests/unit/test_numpy/test_myarray.py Un-xfails choose for MyArray where it now passes.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/quax/_core.py Outdated
Comment thread src/quax/_core.py
nstarman added 2 commits April 9, 2026 11:42
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
@nstarman nstarman force-pushed the copilot/fix-quaxify-call-issue branch from 90c9a35 to 60ff43d Compare April 9, 2026 15:47
@nstarman nstarman added this to the v0.3.x milestone Apr 9, 2026
@nstarman
Copy link
Copy Markdown
Owner

nstarman commented Apr 9, 2026

@meeseeksdev help

@nstarman
Copy link
Copy Markdown
Owner

nstarman commented Apr 9, 2026

@meeseeksdev hello

@lumberbot-app
Copy link
Copy Markdown

lumberbot-app bot commented Apr 9, 2026

Helloooo @nstarman, I'm Mr. Meeseeks! Look at me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants