From b5a14729dee2ec0d6577782781d6bc9eac0dc945 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 23 Nov 2025 18:49:19 -0500 Subject: [PATCH 1/6] test: py3.14 Signed-off-by: nstarman --- .github/workflows/ci.yml | 11 +++++++---- .pre-commit-config.yaml | 3 --- pyproject.toml | 3 ++- uv.lock | 15 ++++++++++----- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36c78642..d18a0ba5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,15 +25,18 @@ jobs: - name: "3.12" python-version: "3.12" extra-install: "" - - name: "3.12-pre-beartype" - python-version: "3.12" - extra-install: "uv pip install --upgrade --pre beartype" - name: "3.13" python-version: "3.13" extra-install: "" - name: "3.13-pre-beartype" python-version: "3.13" - extra-install: "uv pip install --upgrade --pre beartype" + extra-install: "pip install --upgrade --pre beartype" + - name: "3.14" + python-version: "3.14" + extra-install: "" + - name: "3.14-pre-beartype" + python-version: "3.14" + extra-install: "pip install --upgrade --pre beartype" name: Test ${{ matrix.value.name }} steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 022ad491..8dc74792 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,9 +2,6 @@ ci: autoupdate_commit_msg: "chore: update pre-commit hooks" autofix_commit_msg: "style: pre-commit fixes" -default_language_version: - python: "3.10" - repos: - repo: meta hooks: diff --git a/pyproject.toml b/pyproject.toml index 5605da2e..53277c56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,8 @@ dynamic = ["version"] requires-python = ">=3.10" dependencies = [ - "beartype>=0.16.2", + "beartype>=0.22.2; python_version>='3.14'", + "beartype>=0.16.2; python_version<'3.14'", "typing-extensions>=4.9.0", "rich>=10.0" ] diff --git a/uv.lock b/uv.lock index 18b2d264..61554470 100644 --- a/uv.lock +++ b/uv.lock @@ -2,7 +2,8 @@ version = 1 revision = 3 requires-python = ">=3.10" resolution-markers = [ - "python_full_version >= '3.12'", + "python_full_version >= '3.14'", + "python_full_version >= '3.12' and python_full_version < '3.14'", "python_full_version == '3.11.*'", "python_full_version < '3.11'", ] @@ -486,7 +487,8 @@ name = "cibuildwheel" version = "3.3.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", + "python_full_version >= '3.14'", + "python_full_version >= '3.12' and python_full_version < '3.14'", "python_full_version == '3.11.*'", ] dependencies = [ @@ -1128,7 +1130,8 @@ name = "ipython" version = "9.8.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", + "python_full_version >= '3.14'", + "python_full_version >= '3.12' and python_full_version < '3.14'", "python_full_version == '3.11.*'", ] dependencies = [ @@ -1846,7 +1849,8 @@ name = "numpy" version = "2.3.5" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", + "python_full_version >= '3.14'", + "python_full_version >= '3.12' and python_full_version < '3.14'", "python_full_version == '3.11.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/76/65/21b3bc86aac7b8f2862db1e808f1ea22b028e30a225a34a5ede9bf8678f2/numpy-2.3.5.tar.gz", hash = "sha256:784db1dcdab56bf0517743e746dfb0f885fc68d948aba86eeec2cba234bdf1c0", size = 20584950, upload-time = "2025-11-16T22:52:42.067Z" } @@ -2065,7 +2069,8 @@ wheels = [ [package.metadata] requires-dist = [ - { name = "beartype", specifier = ">=0.16.2" }, + { name = "beartype", marker = "python_full_version < '3.14'", specifier = ">=0.16.2" }, + { name = "beartype", marker = "python_full_version >= '3.14'", specifier = ">=0.22.2" }, { name = "rich", specifier = ">=10.0" }, { name = "typing-extensions", specifier = ">=4.9.0" }, ] From 0b8b5f7b28fbe22e909b940136f1344d1eede5c8 Mon Sep 17 00:00:00 2001 From: nstarman Date: Thu, 12 Feb 2026 20:28:18 -0500 Subject: [PATCH 2/6] feat: deprecate union aliasing Signed-off-by: nstarman --- docs/comparison.md | 23 ++ docs/union_aliases.md | 48 ++- src/plum/_alias.py | 310 +++++++++++------- .../{test_alias.py => test_alias_upto313.py} | 16 +- tests/test_util.py | 5 +- 5 files changed, 269 insertions(+), 133 deletions(-) rename tests/{test_alias.py => test_alias_upto313.py} (89%) diff --git a/docs/comparison.md b/docs/comparison.md index a642a4a0..34bbe565 100644 --- a/docs/comparison.md +++ b/docs/comparison.md @@ -85,6 +85,27 @@ def f(x: int, y: Number): return "second" ``` +% invisible-code-block: python +% +% import sys + +% skip: start if(sys.version_info < (3, 14), reason="Union repr changed in Python 3.14+") + +```python +>>> try: f(1, 1) +... except Exception as e: print(f"{type(e).__name__}: {e}") +AmbiguousLookupError: `f(1, 1)` is ambiguous. +Candidates: + f(x: int | numbers.Number, y: int) + @ ... + f(x: int, y: numbers.Number) + @ ... +``` + +% skip: end + +% skip: start if(sys.version_info >= (3, 14), reason="Union repr changed in Python 3.14+") + ```python >>> try: f(1, 1) ... except Exception as e: print(f"{type(e).__name__}: {e}") @@ -96,6 +117,8 @@ Candidates: @ ... ``` +% skip: end + Just to sanity check that things are indeed working correctly: ```python diff --git a/docs/union_aliases.md b/docs/union_aliases.md index f02c08c5..c81e25d9 100644 --- a/docs/union_aliases.md +++ b/docs/union_aliases.md @@ -41,6 +41,22 @@ by aliasing a union, you change the way it is displayed. Union aliases must be activated explicitly, because the feature monkeypatches `Union.__str__` and `Union.__repr__`. +% invisible-code-block: python +% +% import sys + +% skip: start if(sys.version_info < (3, 14), reason="Union repr changed in Python 3.14+") + +```python +>>> from plum import set_union_alias +>>> set_union_alias(Scalar, alias="Scalar") +numpy.bool | numpy.float16 | ... +``` + +% skip: end + +% skip: start if(sys.version_info >= (3, 14), reason="Union repr changed in Python 3.14+") + ```python >>> from plum import activate_union_aliases, set_union_alias @@ -50,6 +66,8 @@ monkeypatches `Union.__str__` and `Union.__repr__`. typing.Union[Scalar] ``` +% skip: end + After this, `help(add)` now prints the following: % skip: next "Example" @@ -68,6 +86,30 @@ For example, printing just `Scalar` would omit the type parameter(s). Let's see with a few more examples how this works: +% invisible-code-block: python +% +% import sys + +% skip: start if(sys.version_info < (3, 14), reason="Union repr changed in Python 3.14+") + +```python +>>> Scalar +numpy.bool | numpy.float16 | ... + +>>> Union[tuple(scalar_types)] +numpy.bool | numpy.float16 | ... + +>>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple +numpy.bool | numpy.float16 | ... | tuple + +>>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list +numpy.bool | numpy.float16 | ... | tuple | list +``` + +% skip: end + +% skip: start if(sys.version_info >= (3, 14), reason="Union repr changed in Python 3.14+") + ```python >>> Scalar typing.Union[Scalar] @@ -76,12 +118,14 @@ typing.Union[Scalar] typing.Union[Scalar] >>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple -typing.Union[Scalar, tuple] + typing.Union[Scalar, tuple] >>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list -typing.Union[Scalar, tuple, list] + typing.Union[Scalar, tuple, list] ``` +% skip: end + If we don't include all of `scalar_types`, we won't see `Scalar`, as desired: % invisible-code-block: python diff --git a/src/plum/_alias.py b/src/plum/_alias.py index 54eefd46..60042f30 100644 --- a/src/plum/_alias.py +++ b/src/plum/_alias.py @@ -32,135 +32,191 @@ "set_union_alias", ) +import sys from functools import wraps -from typing import Any, TypeVar, Union, _type_repr, get_args +from typing import TypeVar, Union, _type_repr, get_args +from typing_extensions import deprecated UnionT = TypeVar("UnionT") -_union_type = type(Union[int, float]) # noqa: UP007 -_original_repr = _union_type.__repr__ -_original_str = _union_type.__str__ - - -@wraps(_original_repr) -def _new_repr(self: object) -> str: - """Print a `typing.Union`, replacing all aliased unions by their aliased names. - - Returns: - str: Representation of a `typing.Union` taking into account union aliases. - """ - args_tuple = get_args(self) - args_set = set(args_tuple) - - # Find all aliased unions contained in this union. - found_unions: list[set[Any]] = [] - found_positions: list[int] = [] - found_aliases: list[str] = [] - for union, alias in reversed(_ALIASED_UNIONS): - union_set = set(union) - if union_set <= args_set: - for i, arg in enumerate(args_tuple): - if arg in union_set: - found_unions.append(union_set) - found_positions.append(i) - found_aliases.append(alias) +_ALIASED_UNIONS: list = [] + +if sys.version_info < (3, 14): + _union_type = type(Union[int, float]) # noqa: UP007 + _original_repr = _union_type.__repr__ + _original_str = _union_type.__str__ + + @wraps(_original_repr) + def _new_repr(self: object) -> str: + """Print a `typing.Union`, replacing all aliased unions by their aliased names. + + Returns: + str: Representation of a `typing.Union` taking into account union aliases. + """ + args = get_args(self) + args_set = set(args) + + # Find all aliased unions contained in this union. + found_unions = [] + found_positions = [] + found_aliases = [] + for union, alias in reversed(_ALIASED_UNIONS): + union_set = set(union) + if union_set <= args_set: + found = False + for i, arg in enumerate(args): + if arg in union_set: + found_unions.append(union_set) + found_positions.append(i) + found_aliases.append(alias) + found = True + break + if not found: # pragma: no cover + # This branch should never be reached. + raise AssertionError( + "Could not identify union. This should never happen." + ) + + # Delete any unions that are contained in strictly bigger unions. We check for + # strictly inequality because any union includes itself. + for i in range(len(found_unions) - 1, -1, -1): + for union in found_unions: + if found_unions[i] < union: + del found_unions[i] + del found_positions[i] + del found_aliases[i] break - else: # pragma: no cover - # This should never be reached! If `union_set <= args_set`, we - # should find at least one argument. - msg = f"Unexpectedly failed to find argument for union `{union}`." - raise AssertionError(msg) - - # Delete any unions that are contained in strictly bigger unions. We check - # for strictly inequality because any union includes itself. - for i in range(len(found_unions) - 1, -1, -1): - for union_candidate in found_unions: - if found_unions[i] < union_candidate: - del found_unions[i] - del found_positions[i] - del found_aliases[i] - break - - # Create a set with all arguments of all found unions. - found_args: set[Any] = set() - for union_set in found_unions: - found_args |= union_set - - # Insert the aliases right before the first found argument. When we insert - # an element, the positions of following insertions need to be appropriately - # incremented. - args: list[Any] = list(args_tuple) - # Sort by insertion position to ensure that all following insertions are at - # higher indices. This makes the bookkeeping simple. - for delta, (i, alias) in enumerate( - sorted(zip(found_positions, found_aliases, strict=True), key=lambda x: x[0]) - ): - args.insert(i + delta, alias) - - # Filter all elements of unions that are aliased. - new_args: list[Any] = [] - for arg in args: - if arg not in found_args: - new_args.append(arg) - args = new_args - - # Generate a string representation. - args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args] - # Like `typing` does, print `Optional` whenever possible. - if len(args) == 2: - if args[0] is type(None): # noqa: E721 - return f"typing.Optional[{args_repr[1]}]" - elif args[1] is type(None): # noqa: E721 - return f"typing.Optional[{args_repr[0]}]" - # We would like to just print `args_repr[0]` whenever `len(args) == 1`, but - # this might break code that parses how unions print. - return "typing.Union[" + ", ".join(args_repr) + "]" - - -@wraps(_original_str) -def _new_str(self: object) -> str: - """Does the same as :func:`_new_repr`. - - Returns: - str: Representation of the `typing.Union` taking into account union aliases. - """ - return _new_repr(self) - - -def activate_union_aliases() -> None: - """When printing `typing.Union`s, replace all aliased unions by the aliased names. - This monkey patches `__repr__` and `__str__` for `typing.Union`.""" - _union_type.__repr__ = _new_repr # type: ignore[method-assign] - _union_type.__str__ = _new_str # type: ignore[method-assign] - - -def deactivate_union_aliases() -> None: - """Undo what :func:`.alias.activate` did. This restores the original `__repr__` - and `__str__` for `typing.Union`.""" - _union_type.__repr__ = _original_repr # type: ignore[method-assign] - _union_type.__str__ = _original_str # type: ignore[method-assign] - - -_ALIASED_UNIONS: list[tuple[tuple[Any, ...], str]] = [] - - -def set_union_alias(union: UnionT, alias: str) -> UnionT: - """Change how a `typing.Union` is printed. This does not modify `union`. - - Args: - union (type or type hint): A union. - alias (str): How to print `union`. - - Returns: - type or type hint: `union`. - """ - args = get_args(union) if isinstance(union, _union_type) else (union,) - for existing_union, existing_alias in _ALIASED_UNIONS: - if set(existing_union) == set(args) and alias != existing_alias: - if isinstance(union, _union_type): - union_str = _original_str(union) - else: - union_str = repr(union) - raise RuntimeError(f"`{union_str}` already has alias `{existing_alias}`.") - _ALIASED_UNIONS.append((args, alias)) - return union + + # Create a set with all arguments of all found unions. + found_args = set() + for union in found_unions: + found_args |= union + + # Insert the aliases right before the first found argument. When we insert an + # element, the positions of following insertions need to be appropriately + # incremented. + args = list(args) + # Sort by insertion position to ensure that all following insertions are + # at higher indices. This makes the bookkeeping simple. + for delta, (i, alias) in enumerate( + sorted( + zip(found_positions, found_aliases, strict=False), key=lambda x: x[0] + ) + ): + args.insert(i + delta, alias) + + # Filter all elements of unions that are aliased. + new_args = () + for arg in args: + if arg not in found_args: + new_args += (arg,) + args = new_args + + # Generate a string representation. + args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args] + # Like `typing` does, print `Optional` whenever possible. + if len(args) == 2: + if args[0] is type(None): # noqa: E721 + return f"typing.Optional[{args_repr[1]}]" + elif args[1] is type(None): # noqa: E721 + return f"typing.Optional[{args_repr[0]}]" + # We would like to just print `args_repr[0]` whenever `len(args) == 1`, but + # this might break code that parses how unions print. + return "typing.Union[" + ", ".join(args_repr) + "]" + + @wraps(_original_str) + def _new_str(self: object) -> str: + """Does the same as :func:`_new_repr`. + + Returns: + str: Representation of the `typing.Union` taking into account union aliases. + """ + return _new_repr(self) + + @deprecated( + "`activate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 + stacklevel=2, + ) + def activate_union_aliases() -> None: + """When printing `typing.Union`s, replace aliased unions by the aliased names. + This monkey patches `__repr__` and `__str__` for `typing.Union`.""" + _union_type.__repr__ = _new_repr + _union_type.__str__ = _new_str + + @deprecated( + "`deactivate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 + stacklevel=2, + ) + def deactivate_union_aliases() -> None: + """Undo what :func:`.alias.activate` did. This restores the original `__repr__` + and `__str__` for `typing.Union`.""" + _union_type.__repr__ = _original_repr + _union_type.__str__ = _original_str + + @deprecated( + "`set_union_alias` is deprecated and will be removed in a future version.", # noqa: E501 + stacklevel=2, + ) + def set_union_alias(union: UnionT, alias: str) -> UnionT: + """Change how a `typing.Union` is printed. This does not modify `union`. + + Args: + union (type or type hint): A union. + alias (str): How to print `union`. + + Returns: + type or type hint: `union`. + """ + args = get_args(union) if isinstance(union, _union_type) else (union,) + for existing_union, existing_alias in _ALIASED_UNIONS: + if set(existing_union) == set(args) and alias != existing_alias: + if isinstance(union, _union_type): + union_str = _original_str(union) + else: + union_str = repr(union) + raise RuntimeError( + f"`{union_str}` already has alias `{existing_alias}`." + ) + _ALIASED_UNIONS.append((args, alias)) + return union + + +else: + + @deprecated( + "`activate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 + category=RuntimeWarning, + stacklevel=2, + ) + def activate_union_aliases() -> None: + """When printing `typing.Union`s, replace aliased unions by the aliased names. + This monkey patches `__repr__` and `__str__` for `typing.Union`.""" + + @deprecated( + "`deactivate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 + category=RuntimeWarning, + stacklevel=2, + ) + def deactivate_union_aliases() -> None: + """Undo what :func:`.alias.activate` did. This restores the original `__repr__` + and `__str__` for `typing.Union`.""" + if sys.version_info < (3, 14): + _union_type.__repr__ = _original_repr + _union_type.__str__ = _original_str + + @deprecated( + "`set_union_alias` is deprecated and will be removed in a future version.", # noqa: E501 + category=RuntimeWarning, + stacklevel=2, + ) + def set_union_alias(union: UnionT, alias: str) -> UnionT: + """Change how a `typing.Union` is printed. This does not modify `union`. + + Args: + union (type or type hint): A union. + alias (str): How to print `union`. + + Returns: + type or type hint: `union`. + """ + return union diff --git a/tests/test_alias.py b/tests/test_alias_upto313.py similarity index 89% rename from tests/test_alias.py rename to tests/test_alias_upto313.py index bcadf6f5..4a7aa7c5 100644 --- a/tests/test_alias.py +++ b/tests/test_alias_upto313.py @@ -1,18 +1,28 @@ +import sys from typing import Union import pytest -from plum import activate_union_aliases, deactivate_union_aliases, set_union_alias +import plum +from plum import set_union_alias from plum._alias import _ALIASED_UNIONS +# These tests are for Python <= 3.13 only. +pytestmark = [ + pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="Union aliasing tests for Python <= 3.13", + ), +] + @pytest.fixture() def union_aliases(): """Activate union aliases during the test and remove all aliases after the test finishes.""" - activate_union_aliases() + plum.activate_union_aliases() yield - deactivate_union_aliases() + plum.deactivate_union_aliases() _ALIASED_UNIONS.clear() diff --git a/tests/test_util.py b/tests/test_util.py index 6bb824a9..322dcd98 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -21,7 +21,10 @@ class A: assert repr_short(int) == "int" assert repr_short(A) == "tests.test_util.test_repr_short..A" - assert repr_short(Union[int, float]) == "typing.Union[int, float]" # noqa: UP007 + if sys.version_info >= (3, 14): + assert repr_short(Union[int, float]) == "int | float" # noqa: UP007 + else: + assert repr_short(Union[int, float]) == "typing.Union[int, float]" # noqa: UP007 assert repr_short(int | float) == "int | float" From 0444c5b4e3c045dcf57d9166aedb4061c4e9cdad Mon Sep 17 00:00:00 2001 From: nstarman Date: Thu, 12 Feb 2026 21:18:07 -0500 Subject: [PATCH 3/6] refactor: some simplifications Signed-off-by: nstarman --- docs/union_aliases.md | 4 ++- src/plum/_signature.py | 64 ++++++++++++++++++------------------- src/plum/_type.py | 23 ++++++------- tests/test_alias_314plus.py | 64 +++++++++++++++++++++++++++++++++++++ tests/test_alias_upto313.py | 1 - 5 files changed, 110 insertions(+), 46 deletions(-) create mode 100644 tests/test_alias_314plus.py diff --git a/docs/union_aliases.md b/docs/union_aliases.md index c81e25d9..f3de8bc6 100644 --- a/docs/union_aliases.md +++ b/docs/union_aliases.md @@ -142,9 +142,11 @@ typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, You can deactivate union aliases with `deactivate_union_aliases`: ```python +>>> import warnings >>> from plum import deactivate_union_aliases ->>> deactivate_union_aliases() +>>> with warnings.catch_warnings(action="ignore"): +... deactivate_union_aliases() % skip: next "Result depends on NumPy version." >>> Scalar diff --git a/src/plum/_signature.py b/src/plum/_signature.py index 3e3e6b9c..7d9c54f4 100644 --- a/src/plum/_signature.py +++ b/src/plum/_signature.py @@ -12,12 +12,18 @@ from rich.console import Console, ConsoleOptions from rich.segment import Segment -import beartype.door +from beartype.door import TypeHint as TypeHintWrapper from beartype.peps import resolve_pep563 as beartype_resolve_pep563 from ._bear import is_bearable from ._type import is_faithful, resolve_type_hint -from ._util import Comparable, Missing, TypeHint, _MissingType, wrap_lambda +from ._util import ( + Comparable, + Missing, + TypeHint, + _MissingType, + wrap_lambda, +) from .repr import repr_short, rich_repr @@ -116,30 +122,21 @@ def __rich_console__( yield Segment("precedence=" + repr(self.precedence)) yield Segment(")") - def __eq__(self, other: Any) -> bool: - if isinstance(other, Signature): - if self.varargs is Missing: - self_varargs = Missing - else: - self_varargs = beartype.door.TypeHint(self.varargs) + def __eq__(self, other: Any, /) -> bool: + if not isinstance(other, Signature): + return False - if other.varargs is Missing: - other_varargs = Missing - else: - other_varargs = beartype.door.TypeHint(other.varargs) - - # We don't need to check faithfulness, because that is automatically derived - # from the arguments. - return ( - tuple(beartype.door.TypeHint(t) for t in self.types), - self_varargs, - self.precedence, - ) == ( - tuple(beartype.door.TypeHint(t) for t in other.types), - other_varargs, - other.precedence, - ) - return False + # We don't need to check faithfulness, because that is automatically + # derived from the arguments. + return ( + tuple(TypeHintWrapper(t) for t in self.types), + Missing if self.varargs is Missing else TypeHintWrapper(self.varargs), + self.precedence, + ) == ( + tuple(TypeHintWrapper(t) for t in other.types), + Missing if other.varargs is Missing else TypeHintWrapper(other.varargs), + other.precedence, + ) def __hash__(self) -> int: return hash((Signature, *self.types, self.varargs)) @@ -171,19 +168,20 @@ def __le__(self, other: object, /) -> bool: ): return False - # Expand the types and compare. We implement the subset relationship, but, very - # importantly, deviate from the subset relationship in exactly one place. + # Expand the types and compare. We implement the subset relationship, + # but, very importantly, deviate from the subset relationship in exactly + # one place. self_types = self.expand_varargs(len(other.types)) other_types = other.expand_varargs(len(self.types)) if all( [ - beartype.door.TypeHint(x) == beartype.door.TypeHint(y) + TypeHintWrapper(x) == TypeHintWrapper(y) for x, y in zip(self_types, other_types, strict=True) ] ): if self.has_varargs and other.has_varargs: - self_varargs = beartype.door.TypeHint(self.varargs) - other_varargs = beartype.door.TypeHint(other.varargs) + self_varargs = TypeHintWrapper(self.varargs) + other_varargs = TypeHintWrapper(other.varargs) return bool(self_varargs <= other_varargs) # Having variable arguments makes you slightly larger. @@ -197,7 +195,7 @@ def __le__(self, other: object, /) -> bool: elif all( [ - beartype.door.TypeHint(x) <= beartype.door.TypeHint(y) + TypeHintWrapper(x) <= TypeHintWrapper(y) for x, y in zip(self_types, other_types, strict=True) ] ): @@ -213,8 +211,8 @@ def __le__(self, other: object, /) -> bool: # is `1.0`, then reasonably the variable arguments should be # ignored and `(int, *A)` should be considered more specific # than `(Number, *B)`. - self_varargs = beartype.door.TypeHint(self.varargs) - other_varargs = beartype.door.TypeHint(other.varargs) + self_varargs = TypeHintWrapper(self.varargs) + other_varargs = TypeHintWrapper(other.varargs) return bool(self_varargs <= other_varargs) elif self.has_varargs: diff --git a/src/plum/_type.py b/src/plum/_type.py index 5044a88c..8a733816 100644 --- a/src/plum/_type.py +++ b/src/plum/_type.py @@ -249,18 +249,19 @@ def resolve_type_hint(x: object, /) -> object: elif isinstance(x, list): return [resolve_type_hint(arg) for arg in x] elif isinstance(x, type): - if isinstance(x, ResolvableType): - if isinstance(x, ModuleType) and not x.retrieve(): - # If the type could not be retrieved, then just return the - # wrapper. Namely, `x.resolve()` will then return `x`, which means - # that the below call will result in an infinite recursion. - return x - return resolve_type_hint(x.resolve()) - else: + if not isinstance(x, ResolvableType): + return x + elif isinstance(x, ModuleType) and not x.retrieve(): + # If the type could not be retrieved, then just return the + # wrapper. Namely, `x.resolve()` will then return `x`, which + # means that the below call will result in an infinite + # recursion. return x - # For example, `Is[lambda x: x > 0]` is an example of a `BeartypeValidator`. We - # shouldn't resolve those. + return resolve_type_hint(x.resolve()) + + # For example, `Is[lambda x: x > 0]` is an example of a `BeartypeValidator`. + # We shouldn't resolve those. elif isinstance(x, BeartypeValidator): return x @@ -298,7 +299,7 @@ class UnfaithfulType: return _is_faithful(resolve_type_hint(x)) -UNION_TYPES = frozenset({typing.Union, UnionType, typing.Optional}) +UNION_TYPES = (typing.Union, UnionType, typing.Optional) class _SupportsDunderFaithful(typing.Protocol): diff --git a/tests/test_alias_314plus.py b/tests/test_alias_314plus.py new file mode 100644 index 00000000..3a1e4a07 --- /dev/null +++ b/tests/test_alias_314plus.py @@ -0,0 +1,64 @@ +import sys +from typing import Union + +import pytest + +from plum import set_union_alias + +# These tests are for Python >= 3.14 only. +pytestmark = [ + pytest.mark.skipif( + sys.version_info < (3, 14), + reason="Union aliasing tests for Python >= 3.14", + ), +] + + +@pytest.mark.parametrize("display", [str, repr]) +def test_union_alias(display): + # Check that printing is normal before registering any aliases. + assert display(Union[int, str]) == "int | str" # noqa: UP007 + + # Register a simple alias and check that it prints correctly. + IntStr = set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + assert display(IntStr) == "int | str" # noqa: UP007 + assert display(Union[int, str]) == "int | str" # noqa: UP007 + + # Register a bigger alias. + set_union_alias(Union[int, str, float], alias="IntStrFloat") # noqa: UP007 + assert display(Union[int, str, float]) == "int | str | float" # noqa: UP007 + + +@pytest.mark.parametrize("display", [str, repr]) +def test_uniontype_alias(display): + # Check that printing is normal before registering any aliases. + assert display(int | str) == "int | str" + + # Register a simple alias and check that it prints correctly. + IntStr = set_union_alias(int | str, alias="IntStr") # noqa: UP007 + assert display(IntStr) == "int | str" # noqa: UP007 + assert display(int | str) == "int | str" # noqa: UP007 + + # Register a bigger alias. + set_union_alias(int | str | float, alias="IntStrFloat") # noqa: UP007 + assert display(int | str | float) == "int | str | float" # noqa: UP007 + + +def test_optional(): + assert repr(Union[int, None]) == "int | None" # noqa: UP007 + assert repr(Union[None, int]) == "None | int" # noqa: UP007 + assert repr(int | None) == "int | None" + + +def test_double_registration(): + # We can register with the same alias, but not with a different alias. + + set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + set_union_alias(Union[int, str], alias="OtherIntStr") # noqa: UP007 + + set_union_alias(int | str, alias="IntStr") # This is OK. + set_union_alias(int | str, alias="OtherIntStr") + + set_union_alias(int, alias="MyInt") + set_union_alias(int, alias="MyInt") # This is OK. + set_union_alias(int, alias="MyOtherInt") diff --git a/tests/test_alias_upto313.py b/tests/test_alias_upto313.py index 4a7aa7c5..a2b7f5f2 100644 --- a/tests/test_alias_upto313.py +++ b/tests/test_alias_upto313.py @@ -71,7 +71,6 @@ def test_optional(union_aliases): def test_double_registration(union_aliases): # We can register with the same alias, but not with a different alias. - set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 set_union_alias(int | str, alias="IntStr") # This is OK. with pytest.raises(RuntimeError, match=r"already has alias"): From 9e9178cbe8bd112435e11e2d280ed57dfe196a48 Mon Sep 17 00:00:00 2001 From: nstarman Date: Thu, 12 Feb 2026 21:08:30 -0500 Subject: [PATCH 4/6] feat: union alias in python 3.14 Signed-off-by: nstarman --- docs/union_aliases.md | 3 +- src/plum/_alias.py | 138 ++++++++++------- src/plum/repr.py | 17 ++- tests/conftest.py | 13 ++ tests/test_alias_314plus.py | 289 ++++++++++++++++++++++++++++++++---- 5 files changed, 371 insertions(+), 89 deletions(-) diff --git a/docs/union_aliases.md b/docs/union_aliases.md index f3de8bc6..aaf77c20 100644 --- a/docs/union_aliases.md +++ b/docs/union_aliases.md @@ -145,7 +145,8 @@ You can deactivate union aliases with `deactivate_union_aliases`: >>> import warnings >>> from plum import deactivate_union_aliases ->>> with warnings.catch_warnings(action="ignore"): +>>> with warnings.catch_warnings(): +... warnings.simplefilter("ignore") ... deactivate_union_aliases() % skip: next "Result depends on NumPy version." diff --git a/src/plum/_alias.py b/src/plum/_alias.py index 60042f30..18942195 100644 --- a/src/plum/_alias.py +++ b/src/plum/_alias.py @@ -34,18 +34,19 @@ import sys from functools import wraps -from typing import TypeVar, Union, _type_repr, get_args -from typing_extensions import deprecated +from typing import Any, TypeVar, Union, _type_repr, get_args +from typing_extensions import TypeAliasType, deprecated UnionT = TypeVar("UnionT") -_ALIASED_UNIONS: list = [] +_union_type = type(Union[int, float]) # noqa: UP007 -if sys.version_info < (3, 14): - _union_type = type(Union[int, float]) # noqa: UP007 +if sys.version_info < (3, 14): # pragma: specific no cover 3.14 _original_repr = _union_type.__repr__ _original_str = _union_type.__str__ + _ALIASED_UNIONS: dict[tuple[Any, ...], str] = {} + @wraps(_original_repr) def _new_repr(self: object) -> str: """Print a `typing.Union`, replacing all aliased unions by their aliased names. @@ -60,7 +61,7 @@ def _new_repr(self: object) -> str: found_unions = [] found_positions = [] found_aliases = [] - for union, alias in reversed(_ALIASED_UNIONS): + for union, alias in reversed(_ALIASED_UNIONS.items()): union_set = set(union) if union_set <= args_set: found = False @@ -77,40 +78,30 @@ def _new_repr(self: object) -> str: "Could not identify union. This should never happen." ) - # Delete any unions that are contained in strictly bigger unions. We check for - # strictly inequality because any union includes itself. + # Delete any unions that are contained in strictly bigger unions. We + # check for strictly inequality because any union includes itself. for i in range(len(found_unions) - 1, -1, -1): - for union in found_unions: - if found_unions[i] < union: + for union_ in found_unions: + if found_unions[i] < set(union_): del found_unions[i] del found_positions[i] del found_aliases[i] break # Create a set with all arguments of all found unions. - found_args = set() - for union in found_unions: - found_args |= union - - # Insert the aliases right before the first found argument. When we insert an - # element, the positions of following insertions need to be appropriately - # incremented. - args = list(args) - # Sort by insertion position to ensure that all following insertions are - # at higher indices. This makes the bookkeeping simple. - for delta, (i, alias) in enumerate( - sorted( - zip(found_positions, found_aliases, strict=False), key=lambda x: x[0] - ) - ): - args.insert(i + delta, alias) + found_args = set().union(*found_unions) if found_unions else set() + + # Build a mapping from original position to aliases to insert before it. + inserts: dict[int, list[str]] = {} + for pos, alias in zip(found_positions, found_aliases, strict=False): + inserts.setdefault(pos, []).append(alias) + # Interleave aliases at the appropriate positions. + args = tuple( + v for i, arg in enumerate(args) for v in (*inserts.pop(i, []), arg) + ) # Filter all elements of unions that are aliased. - new_args = () - for arg in args: - if arg not in found_args: - new_args += (arg,) - args = new_args + args = tuple(arg for arg in args if arg not in found_args) # Generate a string representation. args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args] @@ -140,8 +131,8 @@ def _new_str(self: object) -> str: def activate_union_aliases() -> None: """When printing `typing.Union`s, replace aliased unions by the aliased names. This monkey patches `__repr__` and `__str__` for `typing.Union`.""" - _union_type.__repr__ = _new_repr - _union_type.__str__ = _new_str + _union_type.__repr__ = _new_repr # type: ignore[method-assign] + _union_type.__str__ = _new_str # type: ignore[method-assign] @deprecated( "`deactivate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 @@ -150,13 +141,9 @@ def activate_union_aliases() -> None: def deactivate_union_aliases() -> None: """Undo what :func:`.alias.activate` did. This restores the original `__repr__` and `__str__` for `typing.Union`.""" - _union_type.__repr__ = _original_repr - _union_type.__str__ = _original_str + _union_type.__repr__ = _original_repr # type: ignore[method-assign] + _union_type.__str__ = _original_str # type: ignore[method-assign] - @deprecated( - "`set_union_alias` is deprecated and will be removed in a future version.", # noqa: E501 - stacklevel=2, - ) def set_union_alias(union: UnionT, alias: str) -> UnionT: """Change how a `typing.Union` is printed. This does not modify `union`. @@ -168,7 +155,7 @@ def set_union_alias(union: UnionT, alias: str) -> UnionT: type or type hint: `union`. """ args = get_args(union) if isinstance(union, _union_type) else (union,) - for existing_union, existing_alias in _ALIASED_UNIONS: + for existing_union, existing_alias in _ALIASED_UNIONS.items(): if set(existing_union) == set(args) and alias != existing_alias: if isinstance(union, _union_type): union_str = _original_str(union) @@ -177,11 +164,11 @@ def set_union_alias(union: UnionT, alias: str) -> UnionT: raise RuntimeError( f"`{union_str}` already has alias `{existing_alias}`." ) - _ALIASED_UNIONS.append((args, alias)) + _ALIASED_UNIONS[args] = alias return union - -else: +else: # pragma: specific no cover 3.13 3.12 3.11 3.10 + _ALIASED_UNIONS: dict[tuple[Any, ...], TypeAliasType] = {} @deprecated( "`activate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 @@ -200,23 +187,60 @@ def activate_union_aliases() -> None: def deactivate_union_aliases() -> None: """Undo what :func:`.alias.activate` did. This restores the original `__repr__` and `__str__` for `typing.Union`.""" - if sys.version_info < (3, 14): - _union_type.__repr__ = _original_repr - _union_type.__str__ = _original_str - @deprecated( - "`set_union_alias` is deprecated and will be removed in a future version.", # noqa: E501 - category=RuntimeWarning, - stacklevel=2, - ) - def set_union_alias(union: UnionT, alias: str) -> UnionT: - """Change how a `typing.Union` is printed. This does not modify `union`. + def set_union_alias(union: UnionT, /, alias: str) -> UnionT: + """Register a union alias for use in plum's dispatch system. + + When used with plum's dispatch system, the union will be automatically + transformed into a `TypeAliasType` during signature extraction, allowing + dispatch to key off the alias name instead of the union structure. Args: - union (type or type hint): A union. - alias (str): How to print `union`. + union (type or type hint): A union type or a single type. + alias (str): Alias name for the union. - Returns: - type or type hint: `union`. """ + # Handle both union types and single types, matching < 3.14 behaviour. + args = get_args(union) if isinstance(union, _union_type) else (union,) + + # Check for conflicting aliases + for existing_union, existing_alias in _ALIASED_UNIONS.items(): + if set(existing_union) == set(args) and alias != repr(existing_alias): + union_str = repr(union) + raise RuntimeError( + f"`{union_str}` already has alias `{existing_alias!r}`." + ) + + new_alias = TypeAliasType(alias, union, type_params=()) # type: ignore[misc] + + _ALIASED_UNIONS[args] = new_alias + return union + + +def _transform_union_alias(x: object, /) -> object: + """Transform a Union type hint to a TypeAliasType if it's registered in the alias + registry. This is used by plum's dispatch machinery to use aliased names for unions. + + Args: + x (type or type hint): Type hint, potentially a Union. + + Returns: + type or type hint: If `x` is a Union registered in `_ALIASED_UNIONS`, returns + the TypeAliasType. Otherwise returns `x` unchanged. + """ + # TypeAliasType instances are already transformed, return as-is + if isinstance(x, TypeAliasType): + return x + + # Get the union args to check if it's registered + args = get_args(x) if isinstance(x, _union_type) else None + if args: + args_set = set(args) + # Look for a matching alias in the registry + for union_args, type_alias in _ALIASED_UNIONS.items(): + if set(union_args) == args_set: + return type_alias + + # Not a union or not aliased, return as-is + return x diff --git a/src/plum/repr.py b/src/plum/repr.py index ebd9282c..10f6597c 100644 --- a/src/plum/repr.py +++ b/src/plum/repr.py @@ -5,7 +5,6 @@ "repr_pyfunction", "rich_repr", ] - import inspect import os import sys @@ -14,12 +13,15 @@ from collections.abc import Callable, Iterable from functools import partial from typing import Any, TypeVar, overload +from typing_extensions import TypeAliasType import rich from rich.color import Color from rich.style import Style from rich.text import Text +from ._alias import _transform_union_alias + T = TypeVar("T") path_style = Style(color=Color.from_ansi(7)) @@ -41,6 +43,9 @@ def repr_type(x: object, /) -> Text: Returns: :class:`rich.Text`: Representation. """ + # Apply union aliasing if `x` is a union. This allows us to have the correct + # syntax highlighting for aliased unions. + x = _transform_union_alias(x) if isinstance(x, type): if x.__module__ in ["builtins", "typing", "typing_extensions"]: @@ -60,14 +65,20 @@ def repr_short(x: object, /) -> str: """Representation as a string, but in shorter form. This just calls :func:`typing._type_repr`. + If the type is a union registered in plum's alias registry, the alias name + is used instead. + Args: x (object): Object. Returns: str: Shorter representation of `x`. """ - # :func:`typing._type_repr` is an internal function, but it should be available in - # Python versions 3.9 through 3.13. + if isinstance(transformed := _transform_union_alias(x), TypeAliasType): + # It's an aliased union — use the alias name + return str(transformed.__name__) + # :func:`typing._type_repr` is an internal function, but it should be + # available in Python versions 3.9 through 3.14. return typing._type_repr(x) diff --git a/tests/conftest.py b/tests/conftest.py index 6a49bcd9..d840ee3e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,22 @@ +"""Fixtures for testing.""" + +from unittest.mock import patch + import pytest import plum from plum._promotion import _convert, _promotion_rule +@pytest.fixture(autouse=True) +def _clean_union_aliases(): + """Give each test its own empty alias registry, restored automatically.""" + from plum._alias import _ALIASED_UNIONS + + with patch.dict(_ALIASED_UNIONS, clear=True): + yield + + @pytest.fixture def dispatch() -> plum.Dispatcher: """Provide a fresh Dispatcher for testing.""" diff --git a/tests/test_alias_314plus.py b/tests/test_alias_314plus.py index 3a1e4a07..e2788168 100644 --- a/tests/test_alias_314plus.py +++ b/tests/test_alias_314plus.py @@ -1,47 +1,105 @@ +import functools as ft import sys from typing import Union import pytest -from plum import set_union_alias +import beartype +import beartype.door + +import plum # These tests are for Python >= 3.14 only. -pytestmark = [ - pytest.mark.skipif( - sys.version_info < (3, 14), - reason="Union aliasing tests for Python >= 3.14", - ), -] +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 14), + reason="Union aliasing tests for Python >= 3.14", +) + + +def test_activate_union_aliases_deprecated() -> None: + """Test that activate_union_aliases raises a deprecation warning.""" + with pytest.warns(RuntimeWarning, match="`activate_union_aliases` is deprecated"): + plum.activate_union_aliases() + + +def test_deactivate_union_aliases_deprecated() -> None: + """Test that deactivate_union_aliases raises a deprecation warning.""" + with pytest.warns(RuntimeWarning, match="`deactivate_union_aliases` is deprecated"): + plum.deactivate_union_aliases() + + +@pytest.mark.parametrize("union", [int | str, Union[int, str]]) # noqa: UP007 +def test_repr_short_uses_alias(union) -> None: + """Test that repr_short substitutes registered union aliases.""" + plum.set_union_alias(union, alias="IntStr") + + # Aliased union should use the alias name + assert plum.repr.repr_short(int | str) == "IntStr" + assert plum.repr.repr_short(Union[int, str]) == "IntStr" # noqa: UP007 + # Non-aliased unions should be unchanged + assert "IntStr" not in plum.repr.repr_short(int | float) + assert "IntStr" not in plum.repr.repr_short(Union[int, float]) # noqa: UP007 -@pytest.mark.parametrize("display", [str, repr]) + # Plain types should be unchanged + assert plum.repr.repr_short(int) == "int" + assert plum.repr.repr_short(float) == "float" + + # Signature printing should use the alias + sig = plum.Signature(int | str, float) + assert "IntStr" in repr(sig) + assert repr(sig) == "Signature(IntStr, float)" + + # Signature printing should use the alias + sig = plum.Signature(Union[int, str], float) # noqa: UP007 + assert "IntStr" in repr(sig) + assert repr(sig) == "Signature(IntStr, float)" + + +@pytest.mark.parametrize("display", [plum.repr.repr_short]) def test_union_alias(display): + plum.set_union_alias(int | str, alias="IntStr") + # Check that printing is normal before registering any aliases. - assert display(Union[int, str]) == "int | str" # noqa: UP007 + assert display(Union[int, str]) == "IntStr" # noqa: UP007 # Register a simple alias and check that it prints correctly. - IntStr = set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - assert display(IntStr) == "int | str" # noqa: UP007 - assert display(Union[int, str]) == "int | str" # noqa: UP007 + IntStr = plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + assert display(IntStr) == "IntStr" # noqa: UP007 + assert display(Union[int, str]) == "IntStr" # noqa: UP007 # Register a bigger alias. - set_union_alias(Union[int, str, float], alias="IntStrFloat") # noqa: UP007 - assert display(Union[int, str, float]) == "int | str | float" # noqa: UP007 + plum.set_union_alias(Union[int, str, float], alias="IntStrFloat") # noqa: UP007 + assert display(Union[int, str, float]) == "IntStrFloat" # noqa: UP007 -@pytest.mark.parametrize("display", [str, repr]) +@pytest.mark.parametrize("display", [plum.repr.repr_short]) def test_uniontype_alias(display): + plum.set_union_alias(int | str, alias="IntStr") + # Check that printing is normal before registering any aliases. - assert display(int | str) == "int | str" + assert display(int | str) == "IntStr" # Register a simple alias and check that it prints correctly. - IntStr = set_union_alias(int | str, alias="IntStr") # noqa: UP007 - assert display(IntStr) == "int | str" # noqa: UP007 - assert display(int | str) == "int | str" # noqa: UP007 + IntStr = plum.set_union_alias(int | str, alias="IntStr") # noqa: UP007 + assert display(IntStr) == "IntStr" # noqa: UP007 + assert display(int | str) == "IntStr" # noqa: UP007 # Register a bigger alias. - set_union_alias(int | str | float, alias="IntStrFloat") # noqa: UP007 - assert display(int | str | float) == "int | str | float" # noqa: UP007 + plum.set_union_alias(int | str | float, alias="IntStrFloat") # noqa: UP007 + assert display(int | str | float) == "IntStrFloat" # noqa: UP007 + + +def test_repr_short_with_type_alias_type_passthrough(): + """Test that repr_short handles a TypeAliasType passed directly (not as a union). + + This exercises the early-return path in _transform_union_alias where the + input is already a TypeAliasType instance. + """ + from typing_extensions import TypeAliasType + + alias = TypeAliasType("MyAlias", int | str) + assert plum.repr.repr_short(alias) == "MyAlias" def test_optional(): @@ -50,15 +108,190 @@ def test_optional(): assert repr(int | None) == "int | None" +def test_double_registration_union_same_alias() -> None: + """Test that registering the same union with the same alias is OK.""" + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + plum.set_union_alias(Union[int, str], alias="IntStr") # This is OK # noqa: UP007 + + +def test_double_registration_uniontype_same_alias() -> None: + """Test that registering the same union with the same alias is OK.""" + plum.set_union_alias(int | str, alias="IntStr") + plum.set_union_alias(int | str, alias="IntStr") # This is OK + + +def test_double_registration_different_alias() -> None: + """Test that registering the same union with a different alias raises an error.""" + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(Union[int, str], alias="OtherIntStr") # noqa: UP007 + + def test_double_registration(): # We can register with the same alias, but not with a different alias. + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(Union[int, str], alias="OtherIntStr") # noqa: UP007 + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(int | str, alias="OtherIntStr") + + # The same applies for Union types. + plum.set_union_alias(str | None, alias="OptStr") + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(str | None, alias="OtherOptStr") # noqa: UP007 + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(Union[str | None], alias="OtherIntStr") # noqa: UP007 + + # We can also register plain types, but the same rules apply. + plum.set_union_alias(int, alias="MyInt") + plum.set_union_alias(int, alias="MyInt") # This is OK. + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(int, alias="MyOtherInt") + + +def test_set_union_alias_generated_type_alias() -> None: + """Test that set_union_alias generates a TypeAliasType for unions.""" + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + IntStr = plum._alias._ALIASED_UNIONS[(int, str)] + + # The returned value should be a TypeAliasType + assert hasattr(IntStr, "__name__") + assert IntStr.__name__ == "IntStr" + assert hasattr(IntStr, "__value__") + # The underlying value should be the union + from typing import get_args + + assert set(get_args(IntStr.__value__)) == {int, str} + + +def test_dispatch_with_union_alias(dispatch: plum.Dispatcher) -> None: + """Test that dispatch works correctly with union aliases.""" + # Register an alias for Union[int, str] + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + # Define a function using the alias + @dispatch + def process(x: int | str) -> str: + return "int or str" + + @dispatch + def process(x: float) -> str: + return "float" + + # Test dispatch + assert process(42) == "int or str" + assert process("hello") == "int or str" + assert process(3.14) == "float" + + +def test_dispatch_with_union_directly(dispatch: plum.Dispatcher) -> None: + """Test that dispatch works when using Union directly if registered.""" + # Register the alias + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + # Define a function using Union directly (should still work) + @dispatch + def process(x: int | str) -> str: + return "int or str" + + @dispatch + def process(x: float) -> str: + return "float" + + # Test dispatch + assert process(42) == "int or str" + assert process("hello") == "int or str" + assert process(3.14) == "float" + + +def test_signature_printing_with_alias(dispatch: plum.Dispatcher) -> None: + """Test that function signatures are nicely printed with TypeAliasType names.""" + # Register an alias + plum.set_union_alias(int | str, alias="IntStr") + + @dispatch + def example(x: int | str, y: float) -> str: + return "test" + + # Check that the signature contains the alias name + # The signature should show "IntStr" rather than "Union[int, str]" + sig_str = str(example.methods[0].signature) + assert "IntStr" in sig_str + + +def test_beartype_strict_mode_compatibility(dispatch: plum.Dispatcher) -> None: + """Test that strict beartype works with plum dispatch on aliased unions.""" + original_is_bearable = plum._is_bearable + + # Temporarily set strict mode for this test + plum._is_bearable = ft.partial( + beartype.door.is_bearable, + conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On), + ) + + try: + # Register an alias + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + # Define a function using the alias with plum dispatch + @dispatch + def strict_process(x: int | str) -> str: + return f"processed: {x}" + + # These should work + assert strict_process(42) == "processed: 42" + assert strict_process("hello") == "processed: hello" + + # This should not match the signature (float is not in IntStr) + with pytest.raises(plum.NotFoundLookupError): + strict_process(3.14) + finally: + # Restore original + plum._is_bearable = original_is_bearable + + +def test_multiple_aliases_in_signature(dispatch: plum.Dispatcher) -> None: + """Test that multiple aliased unions in the same signature work correctly.""" + # Register multiple aliases + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + plum.set_union_alias(Union[float, bool], alias="FloatBool") # noqa: UP007 + + @dispatch + def multi(x: int | str, y: float | bool) -> str: + return f"{x}, {y}" + + # Test various combinations + assert multi(42, 3.14) == "42, 3.14" + assert multi("hello", True) == "hello, True" + assert multi(100, False) == "100, False" + assert multi("test", 2.5) == "test, 2.5" + + +def test_alias_in_method_repr(dispatch: plum.Dispatcher) -> None: + """Test that aliased union names appear in `method.methods` repr.""" + plum.set_union_alias(int | str, alias="IntOrStr") + + @dispatch + def method(x: int | str) -> str: + return f"Integer: {x}" + + assert "IntOrStr" in repr(method.methods) + + +def test_alias_priority_in_dispatch(dispatch: plum.Dispatcher) -> None: + """Test that aliased unions are treated like a union in dispatch.""" + # Register an alias + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - set_union_alias(Union[int, str], alias="OtherIntStr") # noqa: UP007 + @dispatch + def handle(x: int | str) -> str: + return "alias" - set_union_alias(int | str, alias="IntStr") # This is OK. - set_union_alias(int | str, alias="OtherIntStr") + @dispatch + def handle(x: int) -> str: + return "int" - set_union_alias(int, alias="MyInt") - set_union_alias(int, alias="MyInt") # This is OK. - set_union_alias(int, alias="MyOtherInt") + # The more specific 'int' should match first + assert handle(42) == "int" + assert handle("hello") == "alias" From e606ed22c78ee12c79ccbfc01a61bb04695293d1 Mon Sep 17 00:00:00 2001 From: nstarman Date: Fri, 6 Mar 2026 12:06:17 -0500 Subject: [PATCH 5/6] feat: de-deprecate (de)activate_union_aliases Signed-off-by: nstarman --- .github/workflows/ci.yml | 4 ++-- docs/union_aliases.md | 4 ++-- src/plum/_alias.py | 44 ++++++++++++++++--------------------- tests/conftest.py | 6 ++++- tests/test_alias_314plus.py | 22 ++++++++++--------- 5 files changed, 40 insertions(+), 40 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d18a0ba5..7736f3c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,13 +30,13 @@ jobs: extra-install: "" - name: "3.13-pre-beartype" python-version: "3.13" - extra-install: "pip install --upgrade --pre beartype" + extra-install: "uv pip install --upgrade --pre beartype" - name: "3.14" python-version: "3.14" extra-install: "" - name: "3.14-pre-beartype" python-version: "3.14" - extra-install: "pip install --upgrade --pre beartype" + extra-install: "uv pip install --upgrade --pre beartype" name: Test ${{ matrix.value.name }} steps: diff --git a/docs/union_aliases.md b/docs/union_aliases.md index aaf77c20..22c9ec8e 100644 --- a/docs/union_aliases.md +++ b/docs/union_aliases.md @@ -118,10 +118,10 @@ typing.Union[Scalar] typing.Union[Scalar] >>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple - typing.Union[Scalar, tuple] +typing.Union[Scalar, tuple] >>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list - typing.Union[Scalar, tuple, list] +typing.Union[Scalar, tuple, list] ``` % skip: end diff --git a/src/plum/_alias.py b/src/plum/_alias.py index 18942195..847745fb 100644 --- a/src/plum/_alias.py +++ b/src/plum/_alias.py @@ -35,12 +35,14 @@ import sys from functools import wraps from typing import Any, TypeVar, Union, _type_repr, get_args -from typing_extensions import TypeAliasType, deprecated +from typing_extensions import TypeAliasType UnionT = TypeVar("UnionT") _union_type = type(Union[int, float]) # noqa: UP007 +_ALIASES_ARE_ACTIVE: bool = True + if sys.version_info < (3, 14): # pragma: specific no cover 3.14 _original_repr = _union_type.__repr__ _original_str = _union_type.__str__ @@ -124,25 +126,21 @@ def _new_str(self: object) -> str: """ return _new_repr(self) - @deprecated( - "`activate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 - stacklevel=2, - ) def activate_union_aliases() -> None: """When printing `typing.Union`s, replace aliased unions by the aliased names. This monkey patches `__repr__` and `__str__` for `typing.Union`.""" + global _ALIASES_ARE_ACTIVE _union_type.__repr__ = _new_repr # type: ignore[method-assign] _union_type.__str__ = _new_str # type: ignore[method-assign] + _ALIASES_ARE_ACTIVE = True - @deprecated( - "`deactivate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 - stacklevel=2, - ) def deactivate_union_aliases() -> None: """Undo what :func:`.alias.activate` did. This restores the original `__repr__` and `__str__` for `typing.Union`.""" + global _ALIASES_ARE_ACTIVE _union_type.__repr__ = _original_repr # type: ignore[method-assign] _union_type.__str__ = _original_str # type: ignore[method-assign] + _ALIASES_ARE_ACTIVE = False def set_union_alias(union: UnionT, alias: str) -> UnionT: """Change how a `typing.Union` is printed. This does not modify `union`. @@ -170,23 +168,15 @@ def set_union_alias(union: UnionT, alias: str) -> UnionT: else: # pragma: specific no cover 3.13 3.12 3.11 3.10 _ALIASED_UNIONS: dict[tuple[Any, ...], TypeAliasType] = {} - @deprecated( - "`activate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 - category=RuntimeWarning, - stacklevel=2, - ) def activate_union_aliases() -> None: - """When printing `typing.Union`s, replace aliased unions by the aliased names. - This monkey patches `__repr__` and `__str__` for `typing.Union`.""" + """When printing `typing.Union`, replace aliased unions by the aliased names.""" + global _ALIASES_ARE_ACTIVE + _ALIASES_ARE_ACTIVE = True - @deprecated( - "`deactivate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501 - category=RuntimeWarning, - stacklevel=2, - ) def deactivate_union_aliases() -> None: - """Undo what :func:`.alias.activate` did. This restores the original `__repr__` - and `__str__` for `typing.Union`.""" + """When printing `typing.Union`s, print as normal.""" + global _ALIASES_ARE_ACTIVE + _ALIASES_ARE_ACTIVE = False def set_union_alias(union: UnionT, /, alias: str) -> UnionT: """Register a union alias for use in plum's dispatch system. @@ -205,10 +195,10 @@ def set_union_alias(union: UnionT, /, alias: str) -> UnionT: # Check for conflicting aliases for existing_union, existing_alias in _ALIASED_UNIONS.items(): - if set(existing_union) == set(args) and alias != repr(existing_alias): + if set(existing_union) == set(args) and alias != existing_alias.__name__: union_str = repr(union) raise RuntimeError( - f"`{union_str}` already has alias `{existing_alias!r}`." + f"`{union_str}` already has alias `{existing_alias.__name__}`." ) new_alias = TypeAliasType(alias, union, type_params=()) # type: ignore[misc] @@ -229,6 +219,10 @@ def _transform_union_alias(x: object, /) -> object: type or type hint: If `x` is a Union registered in `_ALIASED_UNIONS`, returns the TypeAliasType. Otherwise returns `x` unchanged. """ + # Fast path: if aliases are not active, return `x` immediately. + if not _ALIASES_ARE_ACTIVE: + return x + # TypeAliasType instances are already transformed, return as-is if isinstance(x, TypeAliasType): return x diff --git a/tests/conftest.py b/tests/conftest.py index d840ee3e..decc75dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,13 @@ @pytest.fixture(autouse=True) def _clean_union_aliases(): """Give each test its own empty alias registry, restored automatically.""" + import plum._alias as _alias_mod from plum._alias import _ALIASED_UNIONS - with patch.dict(_ALIASED_UNIONS, clear=True): + with ( + patch.dict(_ALIASED_UNIONS, clear=True), + patch.object(_alias_mod, "_ALIASES_ARE_ACTIVE", True), + ): yield diff --git a/tests/test_alias_314plus.py b/tests/test_alias_314plus.py index e2788168..dc1d6110 100644 --- a/tests/test_alias_314plus.py +++ b/tests/test_alias_314plus.py @@ -16,16 +16,18 @@ ) -def test_activate_union_aliases_deprecated() -> None: - """Test that activate_union_aliases raises a deprecation warning.""" - with pytest.warns(RuntimeWarning, match="`activate_union_aliases` is deprecated"): - plum.activate_union_aliases() +def test_activate_union_aliases() -> None: + """Test that activate_union_aliases sets _ALIASES_ARE_ACTIVE to True.""" + plum._alias._ALIASES_ARE_ACTIVE = False + plum.activate_union_aliases() + assert plum._alias._ALIASES_ARE_ACTIVE is True -def test_deactivate_union_aliases_deprecated() -> None: - """Test that deactivate_union_aliases raises a deprecation warning.""" - with pytest.warns(RuntimeWarning, match="`deactivate_union_aliases` is deprecated"): - plum.deactivate_union_aliases() +def test_deactivate_union_aliases() -> None: + """Test that deactivate_union_aliases sets _ALIASES_ARE_ACTIVE to False.""" + plum._alias._ALIASES_ARE_ACTIVE = True + plum.deactivate_union_aliases() + assert plum._alias._ALIASES_ARE_ACTIVE is False @pytest.mark.parametrize("union", [int | str, Union[int, str]]) # noqa: UP007 @@ -60,7 +62,7 @@ def test_repr_short_uses_alias(union) -> None: def test_union_alias(display): plum.set_union_alias(int | str, alias="IntStr") - # Check that printing is normal before registering any aliases. + # Check that the alias is used after registration. assert display(Union[int, str]) == "IntStr" # noqa: UP007 # Register a simple alias and check that it prints correctly. @@ -77,7 +79,7 @@ def test_union_alias(display): def test_uniontype_alias(display): plum.set_union_alias(int | str, alias="IntStr") - # Check that printing is normal before registering any aliases. + # Check that the alias is used after registration. assert display(int | str) == "IntStr" # Register a simple alias and check that it prints correctly. From 0d2ed52aeed577026d1f64e2101c9beb097a5a54 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Fri, 13 Mar 2026 17:12:42 -0400 Subject: [PATCH 6/6] Apply suggestions from code review Co-authored-by: Wessel Co-authored-by: Nathaniel Starkman Signed-off-by: nstarman --- docs/union_aliases.md | 312 +++++++++---------- src/plum/_alias.py | 482 ++++++++++++++--------------- src/plum/repr.py | 4 +- tests/test_alias_314plus.py | 598 ++++++++++++++++++------------------ tests/test_alias_upto313.py | 12 +- 5 files changed, 705 insertions(+), 703 deletions(-) diff --git a/docs/union_aliases.md b/docs/union_aliases.md index 22c9ec8e..6fc8fc2f 100644 --- a/docs/union_aliases.md +++ b/docs/union_aliases.md @@ -1,155 +1,157 @@ -(union-aliases)= -# Union Aliases - -To understand what union aliases are and what problem they solve, consider the -following example. -Suppose that we would want to implement a special addition function, and we would -want to implement it for all NumPy scalar types: - -```python -import numpy as np - -from typing import Union -from plum import dispatch - - -scalar_types = tuple(np.sctypeDict.values()) # All NumPy scalar types -Scalar = Union[scalar_types] # Union of all NumPy scalar types - - -@dispatch -def add(x: Scalar, y: Scalar): - return x + y -``` - -This looks all fine, until you look at the documentation. -In particular, `help(add)` prints - - -``` -Help on Function in module __main__: - -add(x: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void], y: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void]) -``` - -While the documentation is accurate, it is not at all helpful to expand the union in -its many elements, because it obscures the key message: `add(x, y)` is implemented -for all _scalars_. -A better option would be to print `add(x: Scalar, y: Scalar)`. -This is precisely what union aliases do: -by aliasing a union, you change the way it is displayed. -Union aliases must be activated explicitly, because the feature -monkeypatches `Union.__str__` and `Union.__repr__`. - -% invisible-code-block: python -% -% import sys - -% skip: start if(sys.version_info < (3, 14), reason="Union repr changed in Python 3.14+") - -```python ->>> from plum import set_union_alias ->>> set_union_alias(Scalar, alias="Scalar") -numpy.bool | numpy.float16 | ... -``` - -% skip: end - -% skip: start if(sys.version_info >= (3, 14), reason="Union repr changed in Python 3.14+") - -```python ->>> from plum import activate_union_aliases, set_union_alias - ->>> activate_union_aliases() - ->>> set_union_alias(Scalar, alias="Scalar") -typing.Union[Scalar] -``` - -% skip: end - -After this, `help(add)` now prints the following: - -% skip: next "Example" - -```python -Help on Function in module __main__: - -add(x: Union[Scalar], y: Union[Scalar]) -``` - -Hurray! -Note that the documentation prints `Union[Scalar]` rather than just `Scalar`. -This is intentional: it is to prevent breaking code that depends on how unions -print. -For example, printing just `Scalar` would omit the type parameter(s). - -Let's see with a few more examples how this works: - -% invisible-code-block: python -% -% import sys - -% skip: start if(sys.version_info < (3, 14), reason="Union repr changed in Python 3.14+") - -```python ->>> Scalar -numpy.bool | numpy.float16 | ... - ->>> Union[tuple(scalar_types)] -numpy.bool | numpy.float16 | ... - ->>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple -numpy.bool | numpy.float16 | ... | tuple - ->>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list -numpy.bool | numpy.float16 | ... | tuple | list -``` - -% skip: end - -% skip: start if(sys.version_info >= (3, 14), reason="Union repr changed in Python 3.14+") - -```python ->>> Scalar -typing.Union[Scalar] - ->>> Union[tuple(scalar_types)] -typing.Union[Scalar] - ->>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple -typing.Union[Scalar, tuple] - ->>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list -typing.Union[Scalar, tuple, list] -``` - -% skip: end - -If we don't include all of `scalar_types`, we won't see `Scalar`, as desired: - -% invisible-code-block: python -% -% import sys - -% skip: next "Result depends on NumPy version." - -```python ->>> Union[tuple(scalar_types[:-1])] -typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool] -``` - -You can deactivate union aliases with `deactivate_union_aliases`: - -```python ->>> import warnings ->>> from plum import deactivate_union_aliases - ->>> with warnings.catch_warnings(): -... warnings.simplefilter("ignore") -... deactivate_union_aliases() - -% skip: next "Result depends on NumPy version." ->>> Scalar -typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool, numpy.object_] -``` +(union-aliases)= +# Union Aliases + +To understand what union aliases are and what problem they solve, consider the +following example. +Suppose that we would want to implement a special addition function, and we would +want to implement it for all NumPy scalar types: + +```python +import numpy as np + +from typing import Union +from plum import dispatch + + +scalar_types = tuple(np.sctypeDict.values()) # All NumPy scalar types +Scalar = Union[scalar_types] # Union of all NumPy scalar types + + +@dispatch +def add(x: Scalar, y: Scalar): + return x + y +``` + +This looks all fine, until you look at the documentation. +In particular, `help(add)` prints + + +``` +Help on Function in module __main__: + +add(x: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void], y: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void]) +``` + +While the documentation is accurate, it is not at all helpful to expand the +union in its many elements, because it obscures the key message: `add(x, y)` is +implemented for all _scalars_. A better option would be to print `add(x: +Scalar, y: Scalar)`. This is precisely what union aliases do: by aliasing a +union, you change the way it is displayed. On Python 3.13 and earlier, union +aliases work by monkeypatching `typing.Union.__str__` and +`typing.Union.__repr__`, and therefore must be activated explicitly. On Python +3.14 and later, `typing.Union`'s representation can no longer be monkeypatched; +union aliases instead only affect how Plum formats unions in its own printed +output. + +% invisible-code-block: python +% +% import sys + +% skip: start if(sys.version_info < (3, 14), reason="Union repr changed in Python 3.14+") + +```python +>>> from plum import set_union_alias + +>>> set_union_alias(Scalar, alias="Scalar") +numpy.bool | numpy.float16 | ... +``` + +% skip: end + +% skip: start if(sys.version_info >= (3, 14), reason="Representation of unions changed in Python 3.14.") + +```python +>>> from plum import activate_union_aliases, set_union_alias + +>>> activate_union_aliases() + +>>> set_union_alias(Scalar, alias="Scalar") +typing.Union[Scalar] +``` + +% skip: end + +After this, `help(add)` now prints the following: + +% skip: next "Example" + +```python +Help on Function in module __main__: + +add(x: Union[Scalar], y: Union[Scalar]) +``` + +Hurray! +Note that the documentation prints `Union[Scalar]` rather than just `Scalar`. +This is intentional: it is to prevent breaking code that depends on how unions +print. +For example, printing just `Scalar` would omit the type parameter(s). + +Let's see with a few more examples how this works: + +% invisible-code-block: python +% +% import sys + +% skip: start if(sys.version_info < (3, 14), reason="Representation of unions changed in Python 3.14.") + +```python +>>> Scalar +numpy.bool | numpy.float16 | ... + +>>> Union[tuple(scalar_types)] +numpy.bool | numpy.float16 | ... + +>>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple +numpy.bool | numpy.float16 | ... | tuple + +>>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list +numpy.bool | numpy.float16 | ... | tuple | list +``` + +% skip: end + +% skip: start if(sys.version_info >= (3, 14), reason="Representation of unions changed in Python 3.14.") + +```python +>>> Scalar +typing.Union[Scalar] + +>>> Union[tuple(scalar_types)] +typing.Union[Scalar] + +>>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple +typing.Union[Scalar, tuple] + +>>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list +typing.Union[Scalar, tuple, list] +``` + +% skip: end + +If we don't include all of `scalar_types`, we won't see `Scalar`, as desired: + +% invisible-code-block: python +% +% import sys + +% skip: next "Result depends on NumPy version." + +```python +>>> Union[tuple(scalar_types[:-1])] +typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool] +``` + +You can deactivate union aliases with `deactivate_union_aliases`: + +```python +>>> import warnings + +>>> from plum import deactivate_union_aliases + +>>> deactivate_union_aliases() + +% skip: next "Result depends on NumPy version." +>>> Scalar +typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool, numpy.object_] +``` diff --git a/src/plum/_alias.py b/src/plum/_alias.py index 847745fb..3c7d25f0 100644 --- a/src/plum/_alias.py +++ b/src/plum/_alias.py @@ -1,240 +1,242 @@ -"""This module monkey patches `__repr__` and `__str__` of `typing.Union` to control how -`typing.Unions` are displayed. - -Example:: - - >> plum.activate_union_aliases() - - >> IntOrFloat = typing.Union[int, float] - - >> IntOrFloat - Union[int, float] - - >> plum.set_union_alias(IntOrFloat, "IntOrFloat") - - >> IntOrFloat - typing.Union[IntOrFloat] - - >> typing.Union[int, float] - typing.Union[IntOrFloat] - - >> typing.Union[int, float, str] - typing.Union[IntOrFloat, str] - -Note that `IntOrFloat` prints to `typing.Union[IntOrFloat]` rather than just -`IntOrFloat`. This is deliberate, with the goal of not breaking code that relies on -parsing how unions print. -""" - -__all__ = ( - "activate_union_aliases", - "deactivate_union_aliases", - "set_union_alias", -) - -import sys -from functools import wraps -from typing import Any, TypeVar, Union, _type_repr, get_args -from typing_extensions import TypeAliasType - -UnionT = TypeVar("UnionT") - -_union_type = type(Union[int, float]) # noqa: UP007 - -_ALIASES_ARE_ACTIVE: bool = True - -if sys.version_info < (3, 14): # pragma: specific no cover 3.14 - _original_repr = _union_type.__repr__ - _original_str = _union_type.__str__ - - _ALIASED_UNIONS: dict[tuple[Any, ...], str] = {} - - @wraps(_original_repr) - def _new_repr(self: object) -> str: - """Print a `typing.Union`, replacing all aliased unions by their aliased names. - - Returns: - str: Representation of a `typing.Union` taking into account union aliases. - """ - args = get_args(self) - args_set = set(args) - - # Find all aliased unions contained in this union. - found_unions = [] - found_positions = [] - found_aliases = [] - for union, alias in reversed(_ALIASED_UNIONS.items()): - union_set = set(union) - if union_set <= args_set: - found = False - for i, arg in enumerate(args): - if arg in union_set: - found_unions.append(union_set) - found_positions.append(i) - found_aliases.append(alias) - found = True - break - if not found: # pragma: no cover - # This branch should never be reached. - raise AssertionError( - "Could not identify union. This should never happen." - ) - - # Delete any unions that are contained in strictly bigger unions. We - # check for strictly inequality because any union includes itself. - for i in range(len(found_unions) - 1, -1, -1): - for union_ in found_unions: - if found_unions[i] < set(union_): - del found_unions[i] - del found_positions[i] - del found_aliases[i] - break - - # Create a set with all arguments of all found unions. - found_args = set().union(*found_unions) if found_unions else set() - - # Build a mapping from original position to aliases to insert before it. - inserts: dict[int, list[str]] = {} - for pos, alias in zip(found_positions, found_aliases, strict=False): - inserts.setdefault(pos, []).append(alias) - # Interleave aliases at the appropriate positions. - args = tuple( - v for i, arg in enumerate(args) for v in (*inserts.pop(i, []), arg) - ) - - # Filter all elements of unions that are aliased. - args = tuple(arg for arg in args if arg not in found_args) - - # Generate a string representation. - args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args] - # Like `typing` does, print `Optional` whenever possible. - if len(args) == 2: - if args[0] is type(None): # noqa: E721 - return f"typing.Optional[{args_repr[1]}]" - elif args[1] is type(None): # noqa: E721 - return f"typing.Optional[{args_repr[0]}]" - # We would like to just print `args_repr[0]` whenever `len(args) == 1`, but - # this might break code that parses how unions print. - return "typing.Union[" + ", ".join(args_repr) + "]" - - @wraps(_original_str) - def _new_str(self: object) -> str: - """Does the same as :func:`_new_repr`. - - Returns: - str: Representation of the `typing.Union` taking into account union aliases. - """ - return _new_repr(self) - - def activate_union_aliases() -> None: - """When printing `typing.Union`s, replace aliased unions by the aliased names. - This monkey patches `__repr__` and `__str__` for `typing.Union`.""" - global _ALIASES_ARE_ACTIVE - _union_type.__repr__ = _new_repr # type: ignore[method-assign] - _union_type.__str__ = _new_str # type: ignore[method-assign] - _ALIASES_ARE_ACTIVE = True - - def deactivate_union_aliases() -> None: - """Undo what :func:`.alias.activate` did. This restores the original `__repr__` - and `__str__` for `typing.Union`.""" - global _ALIASES_ARE_ACTIVE - _union_type.__repr__ = _original_repr # type: ignore[method-assign] - _union_type.__str__ = _original_str # type: ignore[method-assign] - _ALIASES_ARE_ACTIVE = False - - def set_union_alias(union: UnionT, alias: str) -> UnionT: - """Change how a `typing.Union` is printed. This does not modify `union`. - - Args: - union (type or type hint): A union. - alias (str): How to print `union`. - - Returns: - type or type hint: `union`. - """ - args = get_args(union) if isinstance(union, _union_type) else (union,) - for existing_union, existing_alias in _ALIASED_UNIONS.items(): - if set(existing_union) == set(args) and alias != existing_alias: - if isinstance(union, _union_type): - union_str = _original_str(union) - else: - union_str = repr(union) - raise RuntimeError( - f"`{union_str}` already has alias `{existing_alias}`." - ) - _ALIASED_UNIONS[args] = alias - return union - -else: # pragma: specific no cover 3.13 3.12 3.11 3.10 - _ALIASED_UNIONS: dict[tuple[Any, ...], TypeAliasType] = {} - - def activate_union_aliases() -> None: - """When printing `typing.Union`, replace aliased unions by the aliased names.""" - global _ALIASES_ARE_ACTIVE - _ALIASES_ARE_ACTIVE = True - - def deactivate_union_aliases() -> None: - """When printing `typing.Union`s, print as normal.""" - global _ALIASES_ARE_ACTIVE - _ALIASES_ARE_ACTIVE = False - - def set_union_alias(union: UnionT, /, alias: str) -> UnionT: - """Register a union alias for use in plum's dispatch system. - - When used with plum's dispatch system, the union will be automatically - transformed into a `TypeAliasType` during signature extraction, allowing - dispatch to key off the alias name instead of the union structure. - - Args: - union (type or type hint): A union type or a single type. - alias (str): Alias name for the union. - - """ - # Handle both union types and single types, matching < 3.14 behaviour. - args = get_args(union) if isinstance(union, _union_type) else (union,) - - # Check for conflicting aliases - for existing_union, existing_alias in _ALIASED_UNIONS.items(): - if set(existing_union) == set(args) and alias != existing_alias.__name__: - union_str = repr(union) - raise RuntimeError( - f"`{union_str}` already has alias `{existing_alias.__name__}`." - ) - - new_alias = TypeAliasType(alias, union, type_params=()) # type: ignore[misc] - - _ALIASED_UNIONS[args] = new_alias - - return union - - -def _transform_union_alias(x: object, /) -> object: - """Transform a Union type hint to a TypeAliasType if it's registered in the alias - registry. This is used by plum's dispatch machinery to use aliased names for unions. - - Args: - x (type or type hint): Type hint, potentially a Union. - - Returns: - type or type hint: If `x` is a Union registered in `_ALIASED_UNIONS`, returns - the TypeAliasType. Otherwise returns `x` unchanged. - """ - # Fast path: if aliases are not active, return `x` immediately. - if not _ALIASES_ARE_ACTIVE: - return x - - # TypeAliasType instances are already transformed, return as-is - if isinstance(x, TypeAliasType): - return x - - # Get the union args to check if it's registered - args = get_args(x) if isinstance(x, _union_type) else None - if args: - args_set = set(args) - # Look for a matching alias in the registry - for union_args, type_alias in _ALIASED_UNIONS.items(): - if set(union_args) == args_set: - return type_alias - - # Not a union or not aliased, return as-is - return x +"""This module monkey patches `__repr__` and `__str__` of `typing.Union` to control how +`typing.Unions` are displayed. + +Example:: + + >> plum.activate_union_aliases() + + >> IntOrFloat = typing.Union[int, float] + + >> IntOrFloat + Union[int, float] + + >> plum.set_union_alias(IntOrFloat, "IntOrFloat") + + >> IntOrFloat + typing.Union[IntOrFloat] + + >> typing.Union[int, float] + typing.Union[IntOrFloat] + + >> typing.Union[int, float, str] + typing.Union[IntOrFloat, str] + +Note that `IntOrFloat` prints to `typing.Union[IntOrFloat]` rather than just +`IntOrFloat`. This is deliberate, with the goal of not breaking code that relies on +parsing how unions print. +""" + +__all__ = ( + "activate_union_aliases", + "deactivate_union_aliases", + "set_union_alias", +) + +import sys +from functools import wraps +from typing import Any, TypeVar, Union, _type_repr, get_args +from typing_extensions import TypeAliasType + +UnionT = TypeVar("UnionT") + +_union_type = type(Union[int, float]) # noqa: UP007 + +_ALIASES_ARE_ACTIVE: bool = True + +if sys.version_info < (3, 14): # pragma: specific no cover 3.14 + _original_repr = _union_type.__repr__ + _original_str = _union_type.__str__ + + _ALIASED_UNIONS: dict[tuple[Any, ...], str] = {} + + @wraps(_original_repr) + def _new_repr(self: object) -> str: + """Print a `typing.Union`, replacing all aliased unions by their aliased names. + + Returns: + str: Representation of a `typing.Union` taking into account union aliases. + """ + args = get_args(self) + args_set = set(args) + + # Find all aliased unions contained in this union. + found_unions = [] + found_positions = [] + found_aliases = [] + for union, alias in reversed(_ALIASED_UNIONS.items()): + union_set = set(union) + if union_set <= args_set: + found = False + for i, arg in enumerate(args): + if arg in union_set: + found_unions.append(union_set) + found_positions.append(i) + found_aliases.append(alias) + found = True + break + if not found: # pragma: no cover + # This branch should never be reached. + raise AssertionError( + "Could not identify union. This should never happen." + ) + + # Delete any unions that are contained in strictly bigger unions. We + # check for strictly inequality because any union includes itself. + for i in range(len(found_unions) - 1, -1, -1): + for union_ in found_unions: + if found_unions[i] < set(union_): + del found_unions[i] + del found_positions[i] + del found_aliases[i] + break + + # Create a set with all arguments of all found unions. + found_args = set().union(*found_unions) if found_unions else set() + + # Build a mapping from original position to aliases to insert before it. + inserts: dict[int, list[str]] = {} + for pos, alias in zip(found_positions, found_aliases, strict=False): + inserts.setdefault(pos, []).append(alias) + # Interleave aliases at the appropriate positions. + args = tuple( + v for i, arg in enumerate(args) for v in (*inserts.pop(i, []), arg) + ) + + # Filter all elements of unions that are aliased. + args = tuple(arg for arg in args if arg not in found_args) + + # Generate a string representation. + args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args] + # Like `typing` does, print `Optional` whenever possible. + if len(args) == 2: + if args[0] is type(None): # noqa: E721 + return f"typing.Optional[{args_repr[1]}]" + elif args[1] is type(None): # noqa: E721 + return f"typing.Optional[{args_repr[0]}]" + # We would like to just print `args_repr[0]` whenever `len(args) == 1`, but + # this might break code that parses how unions print. + return "typing.Union[" + ", ".join(args_repr) + "]" + + @wraps(_original_str) + def _new_str(self: object) -> str: + """Does the same as :func:`_new_repr`. + + Returns: + str: Representation of the `typing.Union` taking into account union aliases. + """ + return _new_repr(self) + + def activate_union_aliases() -> None: + """When printing `typing.Union`s, replace aliased unions by the aliased names. + This monkey patches `__repr__` and `__str__` for `typing.Union`.""" + global _ALIASES_ARE_ACTIVE + _union_type.__repr__ = _new_repr # type: ignore[method-assign] + _union_type.__str__ = _new_str # type: ignore[method-assign] + _ALIASES_ARE_ACTIVE = True + + def deactivate_union_aliases() -> None: + """Undo what :func:`.alias.activate` did. This restores the original `__repr__` + and `__str__` for `typing.Union`.""" + global _ALIASES_ARE_ACTIVE + _union_type.__repr__ = _original_repr # type: ignore[method-assign] + _union_type.__str__ = _original_str # type: ignore[method-assign] + _ALIASES_ARE_ACTIVE = False + + def set_union_alias(union: UnionT, alias: str) -> UnionT: + """Change how a `typing.Union` is printed. This does not modify `union`. + + Args: + union (type or type hint): A union. + alias (str): How to print `union`. + + Returns: + type or type hint: `union`. + """ + args = get_args(union) if isinstance(union, _union_type) else (union,) + for existing_union, existing_alias in _ALIASED_UNIONS.items(): + if set(existing_union) == set(args) and alias != existing_alias: + if isinstance(union, _union_type): + union_str = _original_str(union) + else: + union_str = repr(union) + raise RuntimeError( + f"`{union_str}` already has alias `{existing_alias}`." + ) + _ALIASED_UNIONS[args] = alias + return union + +else: # pragma: specific no cover 3.13 3.12 3.11 3.10 + _ALIASED_UNIONS: dict[tuple[Any, ...], TypeAliasType] = {} + + def activate_union_aliases() -> None: + """When printing `typing.Union`, replace aliased unions by the aliased names.""" + global _ALIASES_ARE_ACTIVE + _ALIASES_ARE_ACTIVE = True + + def deactivate_union_aliases() -> None: + """When printing `typing.Union`s, print as normal.""" + global _ALIASES_ARE_ACTIVE + _ALIASES_ARE_ACTIVE = False + + def set_union_alias(union: UnionT, /, alias: str) -> UnionT: + """Register a union alias for use in plum's printing of dispatch signatures. + + This does not modify the given `union` in any way. It only controls how + the union is printed when it is registered as a union alias. + + Args: + union (type or type hint): A union type or a single type. + alias (str): Alias name for the union. + + Returns: + type or type hint: The given union. + + """ + # Handle both union types and single types, matching < 3.14 behaviour. + args = get_args(union) if isinstance(union, _union_type) else (union,) + + # Check for conflicting aliases + for existing_union, existing_alias in _ALIASED_UNIONS.items(): + if set(existing_union) == set(args) and alias != existing_alias.__name__: + union_str = repr(union) + raise RuntimeError( + f"`{union_str}` already has alias `{existing_alias.__name__}`." + ) + + new_alias = TypeAliasType(alias, union, type_params=()) # type: ignore[misc] + + _ALIASED_UNIONS[args] = new_alias + + return union + + +def _transform_union_alias(x: object, /) -> object: + """Transform a Union type hint to a TypeAliasType if it's registered in the alias + registry. This is used by plum's dispatch machinery to use aliased names for unions. + + Args: + x (type or type hint): Type hint, potentially a Union. + + Returns: + type or type hint: If `x` is a Union registered in `_ALIASED_UNIONS`, returns + the TypeAliasType. Otherwise returns `x` unchanged. + """ + # Fast path: if aliases are not active, return `x` immediately. + if not _ALIASES_ARE_ACTIVE: + return x + + # TypeAliasType instances are already transformed, return as-is + if isinstance(x, TypeAliasType): + return x + + # Get the union args to check if it's registered + args = get_args(x) if isinstance(x, _union_type) else None + if args: + args_set = set(args) + # Look for a matching alias in the registry + for union_args, type_alias in _ALIASED_UNIONS.items(): + if set(union_args) == args_set: + return type_alias + + # Not a union or not aliased, return as-is + return x diff --git a/src/plum/repr.py b/src/plum/repr.py index 10f6597c..fe53aedb 100644 --- a/src/plum/repr.py +++ b/src/plum/repr.py @@ -65,7 +65,7 @@ def repr_short(x: object, /) -> str: """Representation as a string, but in shorter form. This just calls :func:`typing._type_repr`. - If the type is a union registered in plum's alias registry, the alias name + If the type is a union registered in Plum's alias registry, the alias name is used instead. Args: @@ -75,7 +75,7 @@ def repr_short(x: object, /) -> str: str: Shorter representation of `x`. """ if isinstance(transformed := _transform_union_alias(x), TypeAliasType): - # It's an aliased union — use the alias name + # It's an aliased union — use the alias name. return str(transformed.__name__) # :func:`typing._type_repr` is an internal function, but it should be # available in Python versions 3.9 through 3.14. diff --git a/tests/test_alias_314plus.py b/tests/test_alias_314plus.py index dc1d6110..23b047d4 100644 --- a/tests/test_alias_314plus.py +++ b/tests/test_alias_314plus.py @@ -1,299 +1,299 @@ -import functools as ft -import sys -from typing import Union - -import pytest - -import beartype -import beartype.door - -import plum - -# These tests are for Python >= 3.14 only. -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 14), - reason="Union aliasing tests for Python >= 3.14", -) - - -def test_activate_union_aliases() -> None: - """Test that activate_union_aliases sets _ALIASES_ARE_ACTIVE to True.""" - plum._alias._ALIASES_ARE_ACTIVE = False - plum.activate_union_aliases() - assert plum._alias._ALIASES_ARE_ACTIVE is True - - -def test_deactivate_union_aliases() -> None: - """Test that deactivate_union_aliases sets _ALIASES_ARE_ACTIVE to False.""" - plum._alias._ALIASES_ARE_ACTIVE = True - plum.deactivate_union_aliases() - assert plum._alias._ALIASES_ARE_ACTIVE is False - - -@pytest.mark.parametrize("union", [int | str, Union[int, str]]) # noqa: UP007 -def test_repr_short_uses_alias(union) -> None: - """Test that repr_short substitutes registered union aliases.""" - plum.set_union_alias(union, alias="IntStr") - - # Aliased union should use the alias name - assert plum.repr.repr_short(int | str) == "IntStr" - assert plum.repr.repr_short(Union[int, str]) == "IntStr" # noqa: UP007 - - # Non-aliased unions should be unchanged - assert "IntStr" not in plum.repr.repr_short(int | float) - assert "IntStr" not in plum.repr.repr_short(Union[int, float]) # noqa: UP007 - - # Plain types should be unchanged - assert plum.repr.repr_short(int) == "int" - assert plum.repr.repr_short(float) == "float" - - # Signature printing should use the alias - sig = plum.Signature(int | str, float) - assert "IntStr" in repr(sig) - assert repr(sig) == "Signature(IntStr, float)" - - # Signature printing should use the alias - sig = plum.Signature(Union[int, str], float) # noqa: UP007 - assert "IntStr" in repr(sig) - assert repr(sig) == "Signature(IntStr, float)" - - -@pytest.mark.parametrize("display", [plum.repr.repr_short]) -def test_union_alias(display): - plum.set_union_alias(int | str, alias="IntStr") - - # Check that the alias is used after registration. - assert display(Union[int, str]) == "IntStr" # noqa: UP007 - - # Register a simple alias and check that it prints correctly. - IntStr = plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - assert display(IntStr) == "IntStr" # noqa: UP007 - assert display(Union[int, str]) == "IntStr" # noqa: UP007 - - # Register a bigger alias. - plum.set_union_alias(Union[int, str, float], alias="IntStrFloat") # noqa: UP007 - assert display(Union[int, str, float]) == "IntStrFloat" # noqa: UP007 - - -@pytest.mark.parametrize("display", [plum.repr.repr_short]) -def test_uniontype_alias(display): - plum.set_union_alias(int | str, alias="IntStr") - - # Check that the alias is used after registration. - assert display(int | str) == "IntStr" - - # Register a simple alias and check that it prints correctly. - IntStr = plum.set_union_alias(int | str, alias="IntStr") # noqa: UP007 - assert display(IntStr) == "IntStr" # noqa: UP007 - assert display(int | str) == "IntStr" # noqa: UP007 - - # Register a bigger alias. - plum.set_union_alias(int | str | float, alias="IntStrFloat") # noqa: UP007 - assert display(int | str | float) == "IntStrFloat" # noqa: UP007 - - -def test_repr_short_with_type_alias_type_passthrough(): - """Test that repr_short handles a TypeAliasType passed directly (not as a union). - - This exercises the early-return path in _transform_union_alias where the - input is already a TypeAliasType instance. - """ - from typing_extensions import TypeAliasType - - alias = TypeAliasType("MyAlias", int | str) - assert plum.repr.repr_short(alias) == "MyAlias" - - -def test_optional(): - assert repr(Union[int, None]) == "int | None" # noqa: UP007 - assert repr(Union[None, int]) == "None | int" # noqa: UP007 - assert repr(int | None) == "int | None" - - -def test_double_registration_union_same_alias() -> None: - """Test that registering the same union with the same alias is OK.""" - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - plum.set_union_alias(Union[int, str], alias="IntStr") # This is OK # noqa: UP007 - - -def test_double_registration_uniontype_same_alias() -> None: - """Test that registering the same union with the same alias is OK.""" - plum.set_union_alias(int | str, alias="IntStr") - plum.set_union_alias(int | str, alias="IntStr") # This is OK - - -def test_double_registration_different_alias() -> None: - """Test that registering the same union with a different alias raises an error.""" - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - with pytest.raises(RuntimeError, match=r"already has alias"): - plum.set_union_alias(Union[int, str], alias="OtherIntStr") # noqa: UP007 - - -def test_double_registration(): - # We can register with the same alias, but not with a different alias. - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - with pytest.raises(RuntimeError, match=r"already has alias"): - plum.set_union_alias(Union[int, str], alias="OtherIntStr") # noqa: UP007 - with pytest.raises(RuntimeError, match=r"already has alias"): - plum.set_union_alias(int | str, alias="OtherIntStr") - - # The same applies for Union types. - plum.set_union_alias(str | None, alias="OptStr") - with pytest.raises(RuntimeError, match=r"already has alias"): - plum.set_union_alias(str | None, alias="OtherOptStr") # noqa: UP007 - with pytest.raises(RuntimeError, match=r"already has alias"): - plum.set_union_alias(Union[str | None], alias="OtherIntStr") # noqa: UP007 - - # We can also register plain types, but the same rules apply. - plum.set_union_alias(int, alias="MyInt") - plum.set_union_alias(int, alias="MyInt") # This is OK. - with pytest.raises(RuntimeError, match=r"already has alias"): - plum.set_union_alias(int, alias="MyOtherInt") - - -def test_set_union_alias_generated_type_alias() -> None: - """Test that set_union_alias generates a TypeAliasType for unions.""" - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - - IntStr = plum._alias._ALIASED_UNIONS[(int, str)] - - # The returned value should be a TypeAliasType - assert hasattr(IntStr, "__name__") - assert IntStr.__name__ == "IntStr" - assert hasattr(IntStr, "__value__") - # The underlying value should be the union - from typing import get_args - - assert set(get_args(IntStr.__value__)) == {int, str} - - -def test_dispatch_with_union_alias(dispatch: plum.Dispatcher) -> None: - """Test that dispatch works correctly with union aliases.""" - # Register an alias for Union[int, str] - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - - # Define a function using the alias - @dispatch - def process(x: int | str) -> str: - return "int or str" - - @dispatch - def process(x: float) -> str: - return "float" - - # Test dispatch - assert process(42) == "int or str" - assert process("hello") == "int or str" - assert process(3.14) == "float" - - -def test_dispatch_with_union_directly(dispatch: plum.Dispatcher) -> None: - """Test that dispatch works when using Union directly if registered.""" - # Register the alias - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - - # Define a function using Union directly (should still work) - @dispatch - def process(x: int | str) -> str: - return "int or str" - - @dispatch - def process(x: float) -> str: - return "float" - - # Test dispatch - assert process(42) == "int or str" - assert process("hello") == "int or str" - assert process(3.14) == "float" - - -def test_signature_printing_with_alias(dispatch: plum.Dispatcher) -> None: - """Test that function signatures are nicely printed with TypeAliasType names.""" - # Register an alias - plum.set_union_alias(int | str, alias="IntStr") - - @dispatch - def example(x: int | str, y: float) -> str: - return "test" - - # Check that the signature contains the alias name - # The signature should show "IntStr" rather than "Union[int, str]" - sig_str = str(example.methods[0].signature) - assert "IntStr" in sig_str - - -def test_beartype_strict_mode_compatibility(dispatch: plum.Dispatcher) -> None: - """Test that strict beartype works with plum dispatch on aliased unions.""" - original_is_bearable = plum._is_bearable - - # Temporarily set strict mode for this test - plum._is_bearable = ft.partial( - beartype.door.is_bearable, - conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On), - ) - - try: - # Register an alias - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - - # Define a function using the alias with plum dispatch - @dispatch - def strict_process(x: int | str) -> str: - return f"processed: {x}" - - # These should work - assert strict_process(42) == "processed: 42" - assert strict_process("hello") == "processed: hello" - - # This should not match the signature (float is not in IntStr) - with pytest.raises(plum.NotFoundLookupError): - strict_process(3.14) - finally: - # Restore original - plum._is_bearable = original_is_bearable - - -def test_multiple_aliases_in_signature(dispatch: plum.Dispatcher) -> None: - """Test that multiple aliased unions in the same signature work correctly.""" - # Register multiple aliases - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - plum.set_union_alias(Union[float, bool], alias="FloatBool") # noqa: UP007 - - @dispatch - def multi(x: int | str, y: float | bool) -> str: - return f"{x}, {y}" - - # Test various combinations - assert multi(42, 3.14) == "42, 3.14" - assert multi("hello", True) == "hello, True" - assert multi(100, False) == "100, False" - assert multi("test", 2.5) == "test, 2.5" - - -def test_alias_in_method_repr(dispatch: plum.Dispatcher) -> None: - """Test that aliased union names appear in `method.methods` repr.""" - plum.set_union_alias(int | str, alias="IntOrStr") - - @dispatch - def method(x: int | str) -> str: - return f"Integer: {x}" - - assert "IntOrStr" in repr(method.methods) - - -def test_alias_priority_in_dispatch(dispatch: plum.Dispatcher) -> None: - """Test that aliased unions are treated like a union in dispatch.""" - # Register an alias - plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 - - @dispatch - def handle(x: int | str) -> str: - return "alias" - - @dispatch - def handle(x: int) -> str: - return "int" - - # The more specific 'int' should match first - assert handle(42) == "int" - assert handle("hello") == "alias" +import functools as ft +import sys +from typing import Union + +import pytest + +import beartype +import beartype.door + +import plum + +# These tests are for Python >= 3.14 only. +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 14), + reason="Union aliasing tests for Python >= 3.14", +) + + +def test_activate_union_aliases() -> None: + """Test that activate_union_aliases sets _ALIASES_ARE_ACTIVE to True.""" + plum._alias._ALIASES_ARE_ACTIVE = False + plum.activate_union_aliases() + assert plum._alias._ALIASES_ARE_ACTIVE is True + + +def test_deactivate_union_aliases() -> None: + """Test that deactivate_union_aliases sets _ALIASES_ARE_ACTIVE to False.""" + plum._alias._ALIASES_ARE_ACTIVE = True + plum.deactivate_union_aliases() + assert plum._alias._ALIASES_ARE_ACTIVE is False + + +@pytest.mark.parametrize("union", [int | str, Union[int, str]]) # noqa: UP007 +def test_repr_short_uses_alias(union) -> None: + """Test that repr_short substitutes registered union aliases.""" + plum.set_union_alias(union, alias="IntStr") + + # Aliased union should use the alias name + assert plum.repr.repr_short(int | str) == "IntStr" + assert plum.repr.repr_short(Union[int, str]) == "IntStr" # noqa: UP007 + + # Non-aliased unions should be unchanged + assert "IntStr" not in plum.repr.repr_short(int | float) + assert "IntStr" not in plum.repr.repr_short(Union[int, float]) # noqa: UP007 + + # Plain types should be unchanged + assert plum.repr.repr_short(int) == "int" + assert plum.repr.repr_short(float) == "float" + + # Signature printing should use the alias + sig = plum.Signature(int | str, float) + assert "IntStr" in repr(sig) + assert repr(sig) == "Signature(IntStr, float)" + + # Signature printing should use the alias + sig = plum.Signature(Union[int, str], float) # noqa: UP007 + assert "IntStr" in repr(sig) + assert repr(sig) == "Signature(IntStr, float)" + + +@pytest.mark.parametrize("display", [plum.repr.repr_short]) +def test_union_alias(display): + plum.set_union_alias(int | str, alias="IntStr") + + # Check that the alias is used after registration. + assert display(Union[int, str]) == "IntStr" # noqa: UP007 + + # Register a simple alias and check that it prints correctly. + IntStr = plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + assert display(IntStr) == "IntStr" # noqa: UP007 + assert display(Union[int, str]) == "IntStr" # noqa: UP007 + + # Register a bigger alias. + plum.set_union_alias(Union[int, str, float], alias="IntStrFloat") # noqa: UP007 + assert display(Union[int, str, float]) == "IntStrFloat" # noqa: UP007 + + +@pytest.mark.parametrize("display", [plum.repr.repr_short]) +def test_uniontype_alias(display): + plum.set_union_alias(int | str, alias="IntStr") + + # Check that the alias is used after registration. + assert display(int | str) == "IntStr" + + # Register a simple alias and check that it prints correctly. + IntStr = plum.set_union_alias(int | str, alias="IntStr") # noqa: UP007 + assert display(IntStr) == "IntStr" # noqa: UP007 + assert display(int | str) == "IntStr" # noqa: UP007 + + # Register a bigger alias. + plum.set_union_alias(int | str | float, alias="IntStrFloat") # noqa: UP007 + assert display(int | str | float) == "IntStrFloat" # noqa: UP007 + + +def test_repr_short_with_type_alias_type_passthrough(): + """Test that repr_short handles a TypeAliasType passed directly (not as a union). + + This exercises the early-return path in _transform_union_alias where the + input is already a TypeAliasType instance. + """ + from typing_extensions import TypeAliasType + + alias = TypeAliasType("MyAlias", int | str) + assert plum.repr.repr_short(alias) == "MyAlias" + + +def test_optional(): + assert repr(Union[int, None]) == "int | None" # noqa: UP007 + assert repr(Union[None, int]) == "None | int" # noqa: UP007 + assert repr(int | None) == "int | None" + + +def test_double_registration_union_same_alias() -> None: + """Test that registering the same union with the same alias is OK.""" + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + plum.set_union_alias(Union[int, str], alias="IntStr") # This is OK # noqa: UP007 + + +def test_double_registration_uniontype_same_alias() -> None: + """Test that registering the same union with the same alias is OK.""" + plum.set_union_alias(int | str, alias="IntStr") + plum.set_union_alias(int | str, alias="IntStr") # This is OK + + +def test_double_registration_different_alias() -> None: + """Test that registering the same union with a different alias raises an error.""" + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(Union[int, str], alias="OtherIntStr") # noqa: UP007 + + +def test_double_registration(): + # We can register with the same alias, but not with a different alias. + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(Union[int, str], alias="OtherIntStr") # noqa: UP007 + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(int | str, alias="OtherIntStr") + + # The same applies for Union types. + plum.set_union_alias(str | None, alias="OptStr") + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(str | None, alias="OtherOptStr") # noqa: UP007 + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(Union[str | None], alias="OtherIntStr") # noqa: UP007 + + # We can also register plain types, but the same rules apply. + plum.set_union_alias(int, alias="MyInt") + plum.set_union_alias(int, alias="MyInt") # This is OK. + with pytest.raises(RuntimeError, match=r"already has alias"): + plum.set_union_alias(int, alias="MyOtherInt") + + +def test_set_union_alias_generated_type_alias() -> None: + """Test that set_union_alias generates a TypeAliasType for unions.""" + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + IntStr = plum._alias._ALIASED_UNIONS[(int, str)] + + # The returned value should be a TypeAliasType + assert hasattr(IntStr, "__name__") + assert IntStr.__name__ == "IntStr" + assert hasattr(IntStr, "__value__") + # The underlying value should be the union + from typing import get_args + + assert set(get_args(IntStr.__value__)) == {int, str} + + +def test_dispatch_with_union_alias(dispatch: plum.Dispatcher) -> None: + """Test that dispatch works correctly with union aliases.""" + # Register an alias for Union[int, str] + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + # Define a function using the alias + @dispatch + def process(x: int | str) -> str: + return "int or str" + + @dispatch + def process(x: float) -> str: + return "float" + + # Test dispatch + assert process(42) == "int or str" + assert process("hello") == "int or str" + assert process(3.14) == "float" + + +def test_dispatch_with_union_directly(dispatch: plum.Dispatcher) -> None: + """Test that dispatch works when using Union directly if registered.""" + # Register the alias + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + # Define a function using Union directly (should still work) + @dispatch + def process(x: int | str) -> str: + return "int or str" + + @dispatch + def process(x: float) -> str: + return "float" + + # Test dispatch + assert process(42) == "int or str" + assert process("hello") == "int or str" + assert process(3.14) == "float" + + +def test_signature_printing_with_alias(dispatch: plum.Dispatcher) -> None: + """Test that function signatures are nicely printed with TypeAliasType names.""" + # Register an alias + plum.set_union_alias(int | str, alias="IntStr") + + @dispatch + def example(x: int | str, y: float) -> str: + return "test" + + # Check that the signature contains the alias name + # The signature should show "IntStr" rather than "Union[int, str]" + sig_str = str(example.methods[0].signature) + assert "IntStr" in sig_str + + +def test_beartype_strict_mode_compatibility(dispatch: plum.Dispatcher) -> None: + """Test that strict beartype works with plum dispatch on aliased unions.""" + original_is_bearable = plum._is_bearable + + # Temporarily set strict mode for this test + plum._is_bearable = ft.partial( + beartype.door.is_bearable, + conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On), + ) + + try: + # Register an alias + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + # Define a function using the alias with plum dispatch + @dispatch + def strict_process(x: int | str) -> str: + return f"processed: {x}" + + # These should work + assert strict_process(42) == "processed: 42" + assert strict_process("hello") == "processed: hello" + + # This should not match the signature (float is not in IntStr) + with pytest.raises(plum.NotFoundLookupError): + strict_process(3.14) + finally: + # Restore original + plum._is_bearable = original_is_bearable + + +def test_multiple_aliases_in_signature(dispatch: plum.Dispatcher) -> None: + """Test that multiple aliased unions in the same signature work correctly.""" + # Register multiple aliases + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + plum.set_union_alias(Union[float, bool], alias="FloatBool") # noqa: UP007 + + @dispatch + def multi(x: int | str, y: float | bool) -> str: + return f"{x}, {y}" + + # Test various combinations + assert multi(42, 3.14) == "42, 3.14" + assert multi("hello", True) == "hello, True" + assert multi(100, False) == "100, False" + assert multi("test", 2.5) == "test, 2.5" + + +def test_alias_in_method_repr(dispatch: plum.Dispatcher) -> None: + """Test that aliased union names appear in `method.methods` repr.""" + plum.set_union_alias(int | str, alias="IntOrStr") + + @dispatch + def method(x: int | str) -> str: + return f"Integer: {x}" + + assert "IntOrStr" in repr(method.methods) + + +def test_alias_priority_in_dispatch(dispatch: plum.Dispatcher) -> None: + """Test that aliased unions are treated like a union in dispatch.""" + # Register an alias + plum.set_union_alias(Union[int, str], alias="IntStr") # noqa: UP007 + + @dispatch + def handle(x: int | str) -> str: + return "alias" + + @dispatch + def handle(x: int) -> str: + return "int" + + # The more specific 'int' should match first + assert handle(42) == "int" + assert handle("hello") == "alias" diff --git a/tests/test_alias_upto313.py b/tests/test_alias_upto313.py index a2b7f5f2..4f5ed82a 100644 --- a/tests/test_alias_upto313.py +++ b/tests/test_alias_upto313.py @@ -7,13 +7,11 @@ from plum import set_union_alias from plum._alias import _ALIASED_UNIONS -# These tests are for Python <= 3.13 only. -pytestmark = [ - pytest.mark.skipif( - sys.version_info >= (3, 14), - reason="Union aliasing tests for Python <= 3.13", - ), -] +# These tests are for Python 3.13 and earlier only. +pytestmark = pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="Union aliasing tests for Python 3.13 and earlier.", +) @pytest.fixture()