Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ jobs:
- name: "3.12"
python-version: "3.12"
extra-install: ""
- name: "3.12-pre-beartype"
python-version: "3.12"
extra-install: "uv pip install --upgrade --pre beartype"
- name: "3.13"
python-version: "3.13"
extra-install: ""
- name: "3.13-pre-beartype"
python-version: "3.13"
extra-install: "uv pip install --upgrade --pre beartype"
- name: "3.14"
python-version: "3.14"
extra-install: ""
- name: "3.14-pre-beartype"
python-version: "3.14"
extra-install: "uv pip install --upgrade --pre beartype"

name: Test ${{ matrix.value.name }}
steps:
Expand Down
3 changes: 0 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ ci:
autoupdate_commit_msg: "chore: update pre-commit hooks"
autofix_commit_msg: "style: pre-commit fixes"

default_language_version:
python: "3.10"

repos:
- repo: meta
hooks:
Expand Down
23 changes: 23 additions & 0 deletions docs/comparison.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,27 @@ def f(x: int, y: Number):
return "second"
```

% invisible-code-block: python
%
% import sys

% skip: start if(sys.version_info < (3, 14), reason="Union repr changed in Python 3.14+")

```python
>>> try: f(1, 1)
... except Exception as e: print(f"{type(e).__name__}: {e}")
AmbiguousLookupError: `f(1, 1)` is ambiguous.
Candidates:
f(x: int | numbers.Number, y: int)
<function f at ...> @ ...
f(x: int, y: numbers.Number)
<function f at ...> @ ...
```

% skip: end

% skip: start if(sys.version_info >= (3, 14), reason="Union repr changed in Python 3.14+")

```python
>>> try: f(1, 1)
... except Exception as e: print(f"{type(e).__name__}: {e}")
Expand All @@ -96,6 +117,8 @@ Candidates:
<function f at ...> @ ...
```
Comment thread
nstarman marked this conversation as resolved.

% skip: end

Just to sanity check that things are indeed working correctly:

```python
Expand Down
265 changes: 157 additions & 108 deletions docs/union_aliases.md
Original file line number Diff line number Diff line change
@@ -1,108 +1,157 @@
(union-aliases)=
# Union Aliases

To understand what union aliases are and what problem they solve, consider the
following example.
Suppose that we would want to implement a special addition function, and we would
want to implement it for all NumPy scalar types:

```python
import numpy as np

from typing import Union
from plum import dispatch


scalar_types = tuple(np.sctypeDict.values()) # All NumPy scalar types
Scalar = Union[scalar_types] # Union of all NumPy scalar types


@dispatch
def add(x: Scalar, y: Scalar):
return x + y
```

This looks all fine, until you look at the documentation.
In particular, `help(add)` prints


```
Help on Function in module __main__:

add(x: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void], y: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void])
```

While the documentation is accurate, it is not at all helpful to expand the union in
its many elements, because it obscures the key message: `add(x, y)` is implemented
for all _scalars_.
A better option would be to print `add(x: Scalar, y: Scalar)`.
This is precisely what union aliases do:
by aliasing a union, you change the way it is displayed.
Union aliases must be activated explicitly, because the feature
monkeypatches `Union.__str__` and `Union.__repr__`.

```python
>>> from plum import activate_union_aliases, set_union_alias

>>> activate_union_aliases()

>>> set_union_alias(Scalar, alias="Scalar")
typing.Union[Scalar]
```

After this, `help(add)` now prints the following:

% skip: next "Example"

```python
Help on Function in module __main__:

add(x: Union[Scalar], y: Union[Scalar])
```

Hurray!
Note that the documentation prints `Union[Scalar]` rather than just `Scalar`.
This is intentional: it is to prevent breaking code that depends on how unions
print.
For example, printing just `Scalar` would omit the type parameter(s).

Let's see with a few more examples how this works:

```python
>>> Scalar
typing.Union[Scalar]

>>> Union[tuple(scalar_types)]
typing.Union[Scalar]

>>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple
typing.Union[Scalar, tuple]

>>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list
typing.Union[Scalar, tuple, list]
```

If we don't include all of `scalar_types`, we won't see `Scalar`, as desired:

% invisible-code-block: python
%
% import sys

% skip: next "Result depends on NumPy version."

```python
>>> Union[tuple(scalar_types[:-1])]
typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool]
```

You can deactivate union aliases with `deactivate_union_aliases`:

```python
>>> from plum import deactivate_union_aliases

>>> deactivate_union_aliases()

% skip: next "Result depends on NumPy version."
>>> Scalar
typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool, numpy.object_]
```
(union-aliases)=
# Union Aliases

To understand what union aliases are and what problem they solve, consider the
following example.
Suppose that we would want to implement a special addition function, and we would
want to implement it for all NumPy scalar types:

```python
import numpy as np

from typing import Union
from plum import dispatch


scalar_types = tuple(np.sctypeDict.values()) # All NumPy scalar types
Scalar = Union[scalar_types] # Union of all NumPy scalar types


@dispatch
def add(x: Scalar, y: Scalar):
return x + y
```

This looks all fine, until you look at the documentation.
In particular, `help(add)` prints


```
Help on Function in module __main__:

add(x: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void], y: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void])
```

While the documentation is accurate, it is not at all helpful to expand the
union in its many elements, because it obscures the key message: `add(x, y)` is
implemented for all _scalars_. A better option would be to print `add(x:
Scalar, y: Scalar)`. This is precisely what union aliases do: by aliasing a
union, you change the way it is displayed. On Python 3.13 and earlier, union
aliases work by monkeypatching `typing.Union.__str__` and
`typing.Union.__repr__`, and therefore must be activated explicitly. On Python
3.14 and later, `typing.Union`'s representation can no longer be monkeypatched;
union aliases instead only affect how Plum formats unions in its own printed
output.

% invisible-code-block: python
%
% import sys

% skip: start if(sys.version_info < (3, 14), reason="Union repr changed in Python 3.14+")

```python
>>> from plum import set_union_alias

>>> set_union_alias(Scalar, alias="Scalar")
numpy.bool | numpy.float16 | ...
```

% skip: end

% skip: start if(sys.version_info >= (3, 14), reason="Representation of unions changed in Python 3.14.")

```python
>>> from plum import activate_union_aliases, set_union_alias

>>> activate_union_aliases()

>>> set_union_alias(Scalar, alias="Scalar")
typing.Union[Scalar]
```

% skip: end

After this, `help(add)` now prints the following:

% skip: next "Example"

```python
Help on Function in module __main__:

add(x: Union[Scalar], y: Union[Scalar])
```

Hurray!
Note that the documentation prints `Union[Scalar]` rather than just `Scalar`.
This is intentional: it is to prevent breaking code that depends on how unions
print.
For example, printing just `Scalar` would omit the type parameter(s).

Let's see with a few more examples how this works:

% invisible-code-block: python
%
% import sys

% skip: start if(sys.version_info < (3, 14), reason="Representation of unions changed in Python 3.14.")

```python
>>> Scalar
numpy.bool | numpy.float16 | ...

>>> Union[tuple(scalar_types)]
numpy.bool | numpy.float16 | ...

>>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple
numpy.bool | numpy.float16 | ... | tuple

>>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list
numpy.bool | numpy.float16 | ... | tuple | list
```

% skip: end

% skip: start if(sys.version_info >= (3, 14), reason="Representation of unions changed in Python 3.14.")

```python
>>> Scalar
typing.Union[Scalar]

>>> Union[tuple(scalar_types)]
typing.Union[Scalar]

>>> Union[tuple(scalar_types) + (tuple,)] # Scalar or tuple
typing.Union[Scalar, tuple]

>>> Union[tuple(scalar_types) + (tuple, list)] # Scalar or tuple or list
typing.Union[Scalar, tuple, list]
```

% skip: end

If we don't include all of `scalar_types`, we won't see `Scalar`, as desired:

% invisible-code-block: python
%
% import sys

% skip: next "Result depends on NumPy version."

```python
>>> Union[tuple(scalar_types[:-1])]
typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool]
```

You can deactivate union aliases with `deactivate_union_aliases`:

```python
>>> import warnings

>>> from plum import deactivate_union_aliases

>>> deactivate_union_aliases()

% skip: next "Result depends on NumPy version."
>>> Scalar
typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool, numpy.object_]
```
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ dynamic = ["version"]

requires-python = ">=3.10"
dependencies = [
"beartype>=0.16.2",
"beartype>=0.22.2; python_version>='3.14'",
"beartype>=0.16.2; python_version<'3.14'",
"typing-extensions>=4.9.0",
"rich>=10.0"
]
Expand Down
Loading
Loading