diff --git a/check_linter_assertions.py b/check_linter_assertions.py index f563553c..3c2008f3 100644 --- a/check_linter_assertions.py +++ b/check_linter_assertions.py @@ -158,7 +158,7 @@ def run_linter(linter: str) -> str: str: `stdout`. """ p = subprocess.Popen( - [linter, source_dir], + [sys.executable, "-m", linter, source_dir], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) diff --git a/plum/dispatcher.py b/plum/dispatcher.py index bf018c39..9d489e2c 100644 --- a/plum/dispatcher.py +++ b/plum/dispatcher.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union, overload from .function import Function from .overload import get_overloads @@ -23,7 +23,21 @@ def __init__(self): self.functions: Dict[str, Function] = {} self.classes: Dict[str, Dict[str, Function]] = {} - def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T: + @overload + def __call__(self, method: Callable[..., Any], precedence: int = 0) -> Function: + ... + + @overload + def __call__( + self, method: None, precedence: int = 0 + ) -> Callable[[Callable[..., Any]], Function]: + ... + + def __call__( + self, + method: Optional[Callable[..., Any]] = None, + precedence: int = 0, + ) -> Union[Function, Callable[[Callable[..., Any]], Function]]: """Decorator to register for a particular signature. Args: diff --git a/plum/function.py b/plum/function.py index 7b97c006..d255b82e 100644 --- a/plum/function.py +++ b/plum/function.py @@ -388,7 +388,7 @@ def _handle_not_found_lookup_error( raise ex return method, return_type - def __call__(self, *args, **kw_args): + def __call__(self, *args: object, **kw_args: object) -> object: method, return_type = self._resolve_method_with_cache(args=args) return _convert(method(*args, **kw_args), return_type) diff --git a/tests/typechecked/test_overload.py b/tests/typechecked/test_overload.py index 47f4795b..cee7ad83 100644 --- a/tests/typechecked/test_overload.py +++ b/tests/typechecked/test_overload.py @@ -6,24 +6,24 @@ @overload -def f(x: int) -> int: # E: pyright(marked as overload) +def f(x: int, y: int) -> int: # E: pyright(marked as overload) return x @overload -def f(x: str) -> str: # E: pyright(marked as overload) +def f(x: str, y: str) -> str: # E: pyright(marked as overload) return x @dispatch -def f(x): # E: pyright(overloaded implementation is not consistent) +def f(x, y): # E: pyright(overloaded implementation is not consistent) pass def test_overload() -> None: - assert f(1) == 1 - assert f("1") == "1" + assert f(1, 2) == 1 + assert f("1", "2") == "1" with pytest.raises(NotFoundLookupError): # E: pyright(argument of type "float" cannot be assigned to parameter "x") # E: mypy(no overload variant of "f" matches argument type "float") - f(1.0) + f(1.0, 2.0)