Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12", "3.13"]
python-version: ["3.12", "3.13", "3.14"]
runs-on: [ubuntu-latest, macos-latest, windows-latest]

steps:
Expand All @@ -50,7 +50,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
python-version: ["3.12"]
runs-on: [ubuntu-latest]

steps:
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.11
3.12
22 changes: 15 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "quaxed"
dynamic = ["version"]
description = "Pre-quaxed libraries for multiple dispatch over abstract array types in JAX"
readme = "README.md"
requires-python = ">=3.11"
requires-python = ">=3.12"
authors = [
{ name = "Nathaniel Starkman", email = "nstarman@users.noreply.github.com" },
]
Expand All @@ -16,17 +16,17 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Scientific/Engineering",
"Typing :: Typed",
]
dependencies = [
"equinox>=0.13.2",
"jax>=0.5.3",
"jaxtyping>=0.3.1",
"optype>=0.8.0",
"optype>=0.14.0",
"plum-dispatch>=2.5.2",
"quax>=0.2.1",
]
Expand Down Expand Up @@ -76,7 +76,7 @@ nox = [
test = [
"optional-dependencies>=0.4.0",
"pytest>=8.3",
"pytest-cov>= 6.2.1",
"pytest-cov>=6.2.1",
"pytest-env>=1.1.5",
"pytest-github-actions-annotate-failures>=0.3.0", # only applies to GHA
"sybil[pytest]>=9.2.0",
Expand Down Expand Up @@ -110,7 +110,7 @@ port.exclude_lines = [

[tool.mypy]
files = ["src"]
python_version = "3.11"
python_version = "3.12"
warn_unused_configs = true
strict = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
Expand All @@ -130,7 +130,7 @@ warn_return_any = false


[tool.pylint]
py-version = "3.11"
py-version = "3.12"
ignore-paths = [".*/_version.py"]
reports.output-format = "colorized"
similarities.ignore-imports = "yes"
Expand Down Expand Up @@ -208,14 +208,22 @@ convention = "numpy"
[tool.uv]
constraint-dependencies = [
"appnope>=0.1.2",
"backcall>=0.2.0",
"bleach>6.0",
"cffi>=1.14",
"decorator>=5.1.1",
"future>=1.0.0",
"iniconfig>=2.0.0",
"matplotlib>=3.7.1",
"matplotlib-inline>=0.1.6",
"nest-asyncio>=1.5.0",
"opt-einsum>=3.2.1",
"pickleshare>=0.7.5",
"ply>3.11",
"psutil>=5.9.0",
"pycparser>=2.20",
"pyparsing>=3.0.0",
"pyzmq>=25.0",
"pyzmq>=26.0",
"scipy>=1.11.2",
"wcwidth>=0.2.0"
]
4 changes: 2 additions & 2 deletions src/quaxed/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
)

from collections.abc import Callable, Hashable
from typing import Any, TypeAlias
from typing import Any

import jax
from quax import quaxify

AxisName: TypeAlias = Hashable
type AxisName = Hashable


# =============================================================================
Expand Down
6 changes: 1 addition & 5 deletions src/quaxed/_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""Utility functions for quaxed."""

from typing import TypeVar

import quax

T = TypeVar("T")


def quaxify(func: T, *, filter_spec: bool | tuple[bool, ...] = True) -> T:
def quaxify[T](func: T, *, filter_spec: bool | tuple[bool, ...] = True) -> T:
"""Quaxify, but makes mypy happy."""
return quax.quaxify(func, filter_spec=filter_spec)
6 changes: 2 additions & 4 deletions src/quaxed/numpy/_higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import functools
import warnings
from collections.abc import Callable, Collection
from typing import Any, TypeVar
from typing import Any

import equinox as eqx
import jax
Expand All @@ -19,10 +19,8 @@

from ._core import asarray, expand_dims as _expand_dims, squeeze

T = TypeVar("T")


def expand_dims(a: T, axis: int | tuple[int, ...]) -> T:
def expand_dims[T](a: T, axis: int | tuple[int, ...]) -> T:
dynamic, static = eqx.partition(a, eqx.is_array_like)
expanded_dynamic = jax.tree.map(lambda x: _expand_dims(x, axis), dynamic)
return eqx.combine(expanded_dynamic, static)
Expand Down
11 changes: 6 additions & 5 deletions tests/myarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Sequence
from dataclasses import replace
from typing import Any, Self
from typing import Any, Final, Self

import equinox as eqx
import jax
Expand All @@ -18,6 +18,7 @@
from quaxed._types import DType

JAX_VERSION = packaging.version.parse(jax.__version__)
JAX_VERSION_LT_8: Final = packaging.version.Version("0.8.0") > JAX_VERSION


class MyArray(ArrayValue):
Expand Down Expand Up @@ -1163,8 +1164,8 @@ def reduce_prod_p(x: MyArray, /, **kw) -> MyArray:


@register(lax.reduce_sum_p)
def reduce_sum_p(x: MyArray, *, axes: tuple[int, ...]) -> MyArray:
return replace(x, array=lax.reduce_sum_p.bind(x.array, axes=axes))
def reduce_sum_p(x: MyArray, **kw) -> MyArray:
return replace(x, array=lax.reduce_sum_p.bind(x.array, **kw))


# ==============================================================================
Expand Down Expand Up @@ -1709,8 +1710,8 @@ def tanh_p(x: MyArray, /, **kw: Any) -> MyArray:


@register(lax.top_k_p)
def top_k_p(operand: MyArray, k: int = 0) -> MyArray:
return [MyArray(x) for x in lax.top_k(operand.array, k)]
def top_k_p(operand: MyArray, /, k: int = 0, **kw: Any) -> list[MyArray]:
return [MyArray(x) for x in lax.top_k(operand.array, k, **kw)]


# ==============================================================================
Expand Down
4 changes: 1 addition & 3 deletions tests/test_lax/test_myarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Test with JAX inputs."""

from typing import TypeAlias

import jax.numpy as jnp
import jax.tree as jtu
import pytest
Expand All @@ -13,7 +11,7 @@
from ..conftest import OptDeps
from ..myarray import MyArray

AnyTuple: TypeAlias = tuple[object, ...]
type AnyTuple = tuple[object, ...]

mark_todo = pytest.mark.skip(reason="TODO")
mark_deprecated_jax7 = (
Expand Down
Loading
Loading