Skip to content
2 changes: 1 addition & 1 deletion check_linter_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def run_linter(linter: str) -> str:
str: `stdout`.
"""
p = subprocess.Popen(
[linter, source_dir],
[sys.executable, "-m", linter, source_dir],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
Expand Down
18 changes: 16 additions & 2 deletions plum/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union, overload

from .function import Function
from .overload import get_overloads
Expand All @@ -23,7 +23,21 @@ def __init__(self):
self.functions: Dict[str, Function] = {}
self.classes: Dict[str, Dict[str, Function]] = {}

def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T:
@overload
def __call__(self, method: Callable[..., Any], precedence: int = 0) -> Function:
...

@overload
def __call__(
self, method: None, precedence: int = 0
) -> Callable[[Callable[..., Any]], Function]:
...

def __call__(
self,
method: Optional[Callable[..., Any]] = None,
precedence: int = 0,
) -> Union[Function, Callable[[Callable[..., Any]], Function]]:
"""Decorator to register for a particular signature.

Args:
Expand Down
2 changes: 1 addition & 1 deletion plum/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _handle_not_found_lookup_error(
raise ex
return method, return_type

def __call__(self, *args, **kw_args):
def __call__(self, *args: object, **kw_args: object) -> object:
method, return_type = self._resolve_method_with_cache(args=args)
return _convert(method(*args, **kw_args), return_type)

Expand Down
12 changes: 6 additions & 6 deletions tests/typechecked/test_overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@


@overload
def f(x: int) -> int: # E: pyright(marked as overload)
def f(x: int, y: int) -> int: # E: pyright(marked as overload)
return x


@overload
def f(x: str) -> str: # E: pyright(marked as overload)
def f(x: str, y: str) -> str: # E: pyright(marked as overload)
return x


@dispatch
def f(x): # E: pyright(overloaded implementation is not consistent)
def f(x, y): # E: pyright(overloaded implementation is not consistent)
pass


def test_overload() -> None:
assert f(1) == 1
assert f("1") == "1"
assert f(1, 2) == 1
assert f("1", "2") == "1"
with pytest.raises(NotFoundLookupError):
# E: pyright(argument of type "float" cannot be assigned to parameter "x")
# E: mypy(no overload variant of "f" matches argument type "float")
f(1.0)
f(1.0, 2.0)