diff --git a/check_linter_assertions.py b/check_linter_assertions.py index f563553c..3c2008f3 100644 --- a/check_linter_assertions.py +++ b/check_linter_assertions.py @@ -158,7 +158,7 @@ def run_linter(linter: str) -> str: str: `stdout`. """ p = subprocess.Popen( - [linter, source_dir], + [sys.executable, "-m", linter, source_dir], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) diff --git a/plum/dispatcher.py b/plum/dispatcher.py index bf018c39..9d489e2c 100644 --- a/plum/dispatcher.py +++ b/plum/dispatcher.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union, overload from .function import Function from .overload import get_overloads @@ -23,7 +23,21 @@ def __init__(self): self.functions: Dict[str, Function] = {} self.classes: Dict[str, Dict[str, Function]] = {} - def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T: + @overload + def __call__(self, method: Callable[..., Any], precedence: int = 0) -> Function: + ... + + @overload + def __call__( + self, method: None, precedence: int = 0 + ) -> Callable[[Callable[..., Any]], Function]: + ... + + def __call__( + self, + method: Optional[Callable[..., Any]] = None, + precedence: int = 0, + ) -> Union[Function, Callable[[Callable[..., Any]], Function]]: """Decorator to register for a particular signature. Args: diff --git a/plum/function.py b/plum/function.py index 7b97c006..fe9ba2d1 100644 --- a/plum/function.py +++ b/plum/function.py @@ -1,11 +1,12 @@ +import pydoc +import sys import textwrap from functools import wraps from types import MethodType -from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union from .resolver import AmbiguousLookupError, NotFoundLookupError, Resolver from .signature import Signature, append_default_args, extract_signature -from .type import resolve_type_hint from .util import TypeHint, repr_short __all__ = ["Function"] @@ -23,6 +24,49 @@ SomeExceptionType = TypeVar("SomeExceptionType", bound=Exception) +def _document(f: Callable) -> str: + """Generate documentation for a function `f`. + + The generated documentation contains both the function definition and the + docstring. The docstring is on the same level of indentation of the function + definition. There will be no trailing newlines. + + If the package :mod:`sphinx` is not imported, then the function definition will + be preceded by the string ``. + + If the package :mod:`sphinx` is imported, then the function definition will include + a Sphinx directive to displays the function definition in a nice way. + + Args: + f (function): Function. + + Returns: + str: Documentation for `f`. + """ + # :class:`pydoc._PlainTextDoc` removes styling. This styling will display + # erroneously in Sphinx. + parts = pydoc._PlainTextDoc().document(f).rstrip().split("\n") + + # Separate out the function definition and the lines corresponding to the body. + title = parts[0] + body = parts[1:] + + # Remove indentation from every line of the body. This indentation defaults to + # four spaces. + body = [line[4:] for line in body] + + # If `sphinx` is imported, assume that we're building the documentation. In that + # case, display the function definition in a nice way. + if "sphinx" in sys.modules: + title = ".. py:function:: " + title + "\n :noindex:" + else: + title = "\n\n" + title + title += "\n" # Add a newline to separate the title from the body. + + # Ensure that there are no trailing newlines. This can happen if the body is empty. + return "\n".join([title] + body).rstrip() + + def _convert(obj: Any, target_type: TypeHint) -> Any: """Convert an object to a particular type. Only converts if `target_type` is set. @@ -64,6 +108,99 @@ def f_renamed(*args, **kw_args): a function (see :meth:`Function.owner`), make the corresponding value the owner.""" +class MethodsRegistry: + def __init__(self, function_name: str): + self._all_methods: List[Tuple[Callable, Optional[Signature], int]] = [] + self._resolver: Optional[Resolver] = None + self._function_name: str = function_name + + def add_method( + self, method: Callable, signature: Optional[Signature], precedence: int + ): + self._all_methods.append((method, signature, precedence)) + # since the list of methods has changed, the resolver and cache are invalidated + self.invalidate_resolver_and_cache() + + @property + def methods(self) -> List[Tuple[Callable, Optional[Signature], int]]: + return self._all_methods + + def invalidate_resolver_and_cache(self): + self._resolver = None + + @property + def resolver(self) -> Resolver: + if self._resolver is None: + self._resolver = Resolver(self.get_all_subsignatures()) + return self._resolver + + def get_all_subsignatures(self, strict: bool = True) -> Iterator[Signature]: + # Perform any pending registrations. + for f, signature, precedence in self._all_methods: + + # Obtain the signature if it is not available. + if signature is None: + try: + signature = extract_signature(f, precedence=precedence) + except NameError: + if strict: + raise + else: # pragma: specific no cover 3.8 3.9 + # in case we are using from __future__ import annotations + continue + else: + # Ensure that the implementation is `f`, but make a copy before + # mutating. + signature = signature.__copy__() + signature.implementation = f + + # Ensure that the implementation has the right name, because this name + # will show up in the docstring. + if ( + getattr(signature.implementation, "__name__", None) + != self._function_name + ): + signature.implementation = _change_function_name( + signature.implementation, + self._function_name, + ) + + # Process default values. + yield from append_default_args(signature, f) + + def doc(self, exclude: Union[Callable, None] = None) -> str: + """Concatenate the docstrings of all methods of this function. Remove duplicate + docstrings before concatenating. + + Args: + exclude (function, optional): Exclude this implementation from the + concatenation. + + Returns: + str: Concatenation of all docstrings. + """ + # Generate all docstrings, possibly excluding `exclude`. + if sys.version_info < (3, 10): + strict = True + else: + strict = False + docs = [ + _document(sig.implementation) + for sig in self.get_all_subsignatures(strict=strict) + if not (exclude and sig.implementation == exclude) + ] + # This can yield duplicates, because of extra methods automatically generated by + # :func:`.signature.append_default_args`. We remove these by simply only + # keeping unique docstrings. + unique_docs = [] + for d in docs: + if d not in unique_docs: + unique_docs.append(d) + # The unique documentations have no trailing newlines, so separate them with + # a newline. + return "\n\n".join(unique_docs) + + class _FunctionMeta(type): """:class:`Function` implements `__doc__`, which overrides the docstring of the class. This simple metaclass ensures that `Function.__doc__` still prints as the @@ -92,7 +229,6 @@ def __init__(self, f: Callable, owner: Optional[str] = None) -> None: Function._instances.append(self) self._f: Callable = f - self._cache = {} wraps(f)(self) # Sets `self._doc`. # `owner` is the name of the owner. We will later attempt to resolve to @@ -101,9 +237,7 @@ def __init__(self, f: Callable, owner: Optional[str] = None) -> None: self._owner: Optional[type] = None # Initialise pending and resolved methods. - self._pending: List[Tuple[Callable, Optional[Signature], int]] = [] - self._resolver = Resolver() - self._resolved: List[Tuple[Callable, Signature, int]] = [] + self._methods_registry: MethodsRegistry = MethodsRegistry(self.__name__) @property def owner(self): @@ -125,21 +259,6 @@ def __doc__(self) -> Optional[str]: Upon instantiation, this property is available through `obj.__doc__`. """ - try: - self._resolve_pending_registrations() - except NameError: # pragma: specific no cover 3.7 3.8 3.9 - # When `staticmethod` is combined with - # `from __future__ import annotations`, in Python 3.10 and higher - # `staticmethod` will attempt to inherit `__doc__` (see - # https://docs.python.org/3/library/functions.html#staticmethod). Since - # we are still in class construction, forward references are not yet - # defined, so attempting to resolve all pending methods might fail with a - # `NameError`. This is fine, because later calling `__doc__` on the - # `staticmethod` will again call this `__doc__`, at which point all methods - # will resolve properly. For now, we just ignore the error and undo the - # partially completed :meth:`Function._resolve_pending_registrations` by - # clearing the cache. - self.clear_cache(reregister=False) # Derive the basis of the docstring from `self._f`, removing any indentation. doc = self._doc.strip() @@ -153,7 +272,7 @@ def __doc__(self) -> Optional[str]: # Append the docstrings of all other implementations to it. Exclude the # docstring from `self._f`, because that one forms the basis (see boave). - resolver_doc = self._resolver.doc(exclude=self._f) + resolver_doc = self._methods_registry.doc(exclude=self._f) if resolver_doc: # Add a newline if the documentation is non-empty. if doc: @@ -175,8 +294,14 @@ def __doc__(self, value: str) -> None: @property def methods(self) -> List[Signature]: """list[:class:`.signature.Signature`]: All available methods.""" - self._resolve_pending_registrations() - return self._resolver.signatures + return self._methods_registry.resolver.signatures + + @property + def _resolver(self) -> Resolver: + return self._methods_registry.resolver + + def _clear_cache_dict(self): + self._methods_registry.invalidate_resolver_and_cache() def dispatch( self: Self, method: Optional[Callable] = None, precedence=0 @@ -227,22 +352,9 @@ def decorator(method): return decorator - def clear_cache(self, reregister: bool = True) -> None: - """Clear cache. - - Args: - reregister (bool, optional): Also reregister all methods. Defaults to - `True`. - """ - self._cache.clear() - - if reregister: - # Add all resolved to pending. - self._pending.extend(self._resolved) - - # Clear resolved. - self._resolved = [] - self._resolver = Resolver() + def clear_cache(self) -> None: + """Clear cache.""" + self._methods_registry.invalidate_resolver_and_cache() def register( self, f: Callable, signature: Optional[Signature] = None, precedence=0 @@ -258,44 +370,7 @@ def register( precedence (int, optional): Precedence of the function. If `signature` is given, then this argument will not be used. Defaults to `0`. """ - self._pending.append((f, signature, precedence)) - - def _resolve_pending_registrations(self) -> None: - # Keep track of whether anything registered. - registered = False - - # Perform any pending registrations. - for f, signature, precedence in self._pending: - # Add to resolved registrations. - self._resolved.append((f, signature, precedence)) - - # Obtain the signature if it is not available. - if signature is None: - signature = extract_signature(f, precedence=precedence) - else: - # Ensure that the implementation is `f`, but make a copy before - # mutating. - signature = signature.__copy__() - signature.implementation = f - - # Ensure that the implementation has the right name, because this name - # will show up in the docstring. - if getattr(signature.implementation, "__name__", None) != self.__name__: - signature.implementation = _change_function_name( - signature.implementation, - self.__name__, - ) - - # Process default values. - for subsignature in append_default_args(signature, f): - self._resolver.register(subsignature) - registered = True - - if registered: - self._pending = [] - - # Clear cache. - self.clear_cache(reregister=False) + self._methods_registry.add_method(f, signature, precedence) def _enhance_exception(self, e: SomeExceptionType) -> SomeExceptionType: """Enchance an exception by prepending a prefix to the message of the exception @@ -316,35 +391,6 @@ def _enhance_exception(self, e: SomeExceptionType) -> SomeExceptionType: message = str(e) return type(e)(prefix + message[0].lower() + message[1:]) - def resolve_method( - self, target: Union[Tuple[object, ...], Signature] - ) -> Tuple[Callable, TypeHint]: - """Find the method and return type for arguments. - - Args: - target (object): Target. - - Returns: - function: Method. - type: Return type. - """ - self._resolve_pending_registrations() - - try: - # Attempt to find the method using the resolver. - signature = self._resolver.resolve(target) - method = signature.implementation - return_type = signature.return_type - - except AmbiguousLookupError as e: - raise self._enhance_exception(e) # Specify this function. - - except NotFoundLookupError as e: - e = self._enhance_exception(e) # Specify this function. - method, return_type = self._handle_not_found_lookup_error(e) - - return method, return_type - def _handle_not_found_lookup_error( self, ex: NotFoundLookupError ) -> Tuple[Callable, TypeHint]: @@ -388,7 +434,7 @@ def _handle_not_found_lookup_error( raise ex return method, return_type - def __call__(self, *args, **kw_args): + def __call__(self, *args: object, **kw_args: object) -> object: method, return_type = self._resolve_method_with_cache(args=args) return _convert(method(*args, **kw_args), return_type) @@ -397,33 +443,14 @@ def _resolve_method_with_cache( args: Union[Tuple[object, ...], Signature, None] = None, types: Optional[Tuple[TypeHint, ...]] = None, ) -> Tuple[Callable, TypeHint]: - if args is None and types is None: - raise ValueError( - "Arguments `args` and `types` cannot both be `None`. " - "This should never happen!" - ) - - # Before attempting to use the cache, resolve any unresolved registrations. Use - # an `if`-statement to speed up the common case. - if self._pending: - self._resolve_pending_registrations() - - if types is None: - # Attempt to use the cache based on the types of the arguments. - types = tuple(map(type, args)) try: - return self._cache[types] - except KeyError: - if args is None: - args = Signature(*(resolve_type_hint(t) for t in types)) - - # Cache miss. Run the resolver based on the arguments. - method, return_type = self.resolve_method(args) - # If the resolver is faithful, then we can perform caching using the types - # of the arguments. If the resolver is not faithful, then we cannot. - if self._resolver.is_faithful: - self._cache[types] = method, return_type - return method, return_type + return self._resolver.resolve_method_with_cache(args=args, types=types) + except AmbiguousLookupError as e: + raise self._enhance_exception(e) # Specify this function. + + except NotFoundLookupError as e: + e = self._enhance_exception(e) # Specify this function. + return self._handle_not_found_lookup_error(e) def invoke(self, *types: TypeHint) -> Callable: """Invoke a particular method. @@ -450,8 +477,7 @@ def __get__(self, instance, owner): def __repr__(self) -> str: return ( - f"" + f"" ) diff --git a/plum/promotion.py b/plum/promotion.py index c34f8ae3..d086ca14 100644 --- a/plum/promotion.py +++ b/plum/promotion.py @@ -151,7 +151,6 @@ def _promote_types(t0, t1): return resolve_type_hint(_promotion_rule.invoke(t0, t1)(t0, t1)) # Find the common type. - _promotion_rule._resolve_pending_registrations() common_type = _promote_types(types[0], types[1]) for t in types[2:]: common_type = _promote_types(common_type, t) diff --git a/plum/resolver.py b/plum/resolver.py index dc99fa05..16b96ee2 100644 --- a/plum/resolver.py +++ b/plum/resolver.py @@ -1,8 +1,8 @@ -import pydoc -import sys -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterable, List, Optional, Tuple, Union from plum.signature import Signature +from plum.type import resolve_type_hint +from plum.util import TypeHint __all__ = ["AmbiguousLookupError", "NotFoundLookupError"] @@ -15,49 +15,6 @@ class NotFoundLookupError(LookupError): """A signature cannot be resolved because no applicable method can be found.""" -def _document(f: Callable) -> str: - """Generate documentation for a function `f`. - - The generated documentation contains both the function definition and the - docstring. The docstring is on the same level of indentation of the function - definition. There will be no trailing newlines. - - If the package :mod:`sphinx` is not imported, then the function definition will - be preceded by the string ``. - - If the package :mod:`sphinx` is imported, then the function definition will include - a Sphinx directive to displays the function definition in a nice way. - - Args: - f (function): Function. - - Returns: - str: Documentation for `f`. - """ - # :class:`pydoc._PlainTextDoc` removes styling. This styling will display - # erroneously in Sphinx. - parts = pydoc._PlainTextDoc().document(f).rstrip().split("\n") - - # Separate out the function definition and the lines corresponding to the body. - title = parts[0] - body = parts[1:] - - # Remove indentation from every line of the body. This indentation defaults to - # four spaces. - body = [line[4:] for line in body] - - # If `sphinx` is imported, assume that we're building the documentation. In that - # case, display the function definition in a nice way. - if "sphinx" in sys.modules: - title = ".. py:function:: " + title + "\n :noindex:" - else: - title = "\n\n" + title - title += "\n" # Add a newline to separate the title from the body. - - # Ensure that there are no trailing newlines. This can happen if the body is empty. - return "\n".join([title] + body).rstrip() - - class Resolver: """Method resolver. @@ -66,62 +23,62 @@ class Resolver: is_faithful (bool): Whether all signatures are faithful or not. """ - def __init__(self): - self.signatures: List[Signature] = [] - self.is_faithful: bool = True + def __init__(self, signatures: Iterable[Signature]): + signatures_dict = {hash(s): s for s in signatures} + self.signatures: List[Signature] = list(signatures_dict.values()) + self.is_faithful: bool = all(s.is_faithful for s in self.signatures) + self._cache = {} - def doc(self, exclude: Union[Callable, None] = None) -> str: - """Concatenate the docstrings of all methods of this function. Remove duplicate - docstrings before concatenating. + def __len__(self) -> int: + return len(self.signatures) + + def clear_cache(self): + self._cache = {} + + def resolve_method_with_cache( + self, + args: Union[Tuple[object, ...], Signature, None] = None, + types: Optional[Tuple[TypeHint, ...]] = None, + ) -> Tuple[Callable, TypeHint]: + if args is None and types is None: + raise ValueError( + "Arguments `args` and `types` cannot both be `None`. " + "This should never happen!" + ) + + if types is None: + # Attempt to use the cache based on the types of the arguments. + types = tuple(map(type, args)) + try: + return self._cache[types] + except KeyError: + if args is None: + args = Signature(*(resolve_type_hint(t) for t in types)) + + # Cache miss. Run the resolver based on the arguments. + method, return_type = self._resolve_method(args) + # If the resolver is faithful, then we can perform caching using the types + # of the arguments. If the resolver is not faithful, then we cannot. + if self.is_faithful: + self._cache[types] = method, return_type + return method, return_type + + def _resolve_method( + self, target: Union[Tuple[object, ...], Signature] + ) -> Tuple[Callable, TypeHint]: + """Find the method and return type for arguments. Args: - exclude (function, optional): Exclude this implementation from the - concatenation. + target (object): Target. Returns: - str: Concatenation of all docstrings. - """ - # Generate all docstrings, possibly excluding `exclude`. - docs = [ - _document(sig.implementation) - for sig in self.signatures - if not (exclude and sig.implementation == exclude) - ] - # This can yield duplicates, because of extra methods automatically generated by - # :func:`.signature.append_default_args`. We remove these by simply only - # keeping unique docstrings. - unique_docs = [] - for d in docs: - if d not in unique_docs: - unique_docs.append(d) - # The unique documentations have no trailing newlines, so separate them with - # a newline. - return "\n\n".join(unique_docs) - - def register(self, signature: Signature) -> None: - """Register a new signature. - - Args: - signature (:class:`.signature.Signature`): Signature to add. + function: Method. + type: Return type. """ - existing = [s == signature for s in self.signatures] - if any(existing): - if sum(existing) != 1: - raise AssertionError( - f"The added signature `{signature}` is equal to {sum(existing)} " - f"existing signatures. This should never happen." - ) - self.signatures[existing.index(True)] = signature - else: - self.signatures.append(signature) - - # Use a double negation for slightly better performance. - self.is_faithful = not any(not s.is_faithful for s in self.signatures) - - def __len__(self) -> int: - return len(self.signatures) + signature = self._resolve(target) + return signature.implementation, signature.return_type - def resolve(self, target: Union[Tuple[object, ...], Signature]) -> Signature: + def _resolve(self, target: Union[Tuple[object, ...], Signature]) -> Signature: """Find the most specific signature that satisfies a target. Args: diff --git a/tests/advanced/test_advanced.py b/tests/advanced/test_advanced.py index 6780cb61..d10e7bf1 100644 --- a/tests/advanced/test_advanced.py +++ b/tests/advanced/test_advanced.py @@ -58,7 +58,7 @@ def f(x: float, y: int = y_default, *, option=None): def f_wrong_default(x: int, y: float = y_default): return y - f_wrong_default._resolve_pending_registrations() + f_wrong_default._resolver() # Remove this function from global tracking. Otherwise, it might interfere with # other tests. diff --git a/tests/conftest.py b/tests/conftest.py index 5c936029..3a10e6f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,27 +7,22 @@ @pytest.fixture def convert(): # Save methods. - _convert._resolve_pending_registrations() - resolved = list(_convert._resolved) + all_methods = _convert._methods_registry._all_methods.copy() yield plum.convert # Clear methods after use. - _convert._resolve_pending_registrations() - _convert._pending = [] - _convert._resolved = resolved - _convert.clear_cache(reregister=True) + _convert._methods_registry._all_methods = all_methods + _convert._methods_registry.invalidate_resolver_and_cache() @pytest.fixture def promote(): # Save methods. - _promotion_rule._resolve_pending_registrations() - resolved = list(_promotion_rule._resolved) + all_methods = _promotion_rule._methods_registry._all_methods.copy() yield plum.promote # Clear methods after use. - _promotion_rule._pending = [] - _promotion_rule._resolved = resolved - _promotion_rule.clear_cache(reregister=True) + _promotion_rule._methods_registry._all_methods = all_methods + _promotion_rule._methods_registry.invalidate_resolver_and_cache() diff --git a/tests/test_autoreload.py b/tests/test_autoreload.py index ea4fc29f..70a1621b 100644 --- a/tests/test_autoreload.py +++ b/tests/test_autoreload.py @@ -1,3 +1,5 @@ +import contextlib + import pytest from plum import Dispatcher @@ -30,6 +32,15 @@ def test_autoreload_activate_deactivate(): assert iar.update_instances == ar._update_instances_original +@contextlib.contextmanager +def autoreload_context_manager(): + ar.activate_autoreload() + try: + yield + finally: + ar.deactivate_autoreload() + + def test_autoreload_correctness(): dispatch = Dispatcher() @@ -57,24 +68,25 @@ def test(x: A1): with pytest.raises(NotFoundLookupError): test(A3()) - ar._update_instances(A1, A2) + with autoreload_context_manager(): + ar._update_instances(A1, A2) - assert isinstance(a, A2) - assert test(a) == 1 + assert isinstance(a, A2) + assert test(a) == 1 - with pytest.raises(NotFoundLookupError): - test(A1()) - assert test(A2()) == 1 - with pytest.raises(NotFoundLookupError): - test(A3()) + with pytest.raises(NotFoundLookupError): + test(A1()) + assert test(A2()) == 1 + with pytest.raises(NotFoundLookupError): + test(A3()) - ar._update_instances(A2, A3) + ar._update_instances(A2, A3) - assert isinstance(a, A3) - assert test(a) == 1 + assert isinstance(a, A3) + assert test(a) == 1 - with pytest.raises(NotFoundLookupError): - test(A1()) - with pytest.raises(NotFoundLookupError): - test(A2()) - assert test(A3()) == 1 + with pytest.raises(NotFoundLookupError): + test(A1()) + with pytest.raises(NotFoundLookupError): + test(A2()) + assert test(A3()) == 1 diff --git a/tests/test_cache.py b/tests/test_cache.py index 76ae41d5..247f5391 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -11,7 +11,7 @@ def assert_cache_performance(f, f_native): def resolve_registrations(): for f in Function._instances: - f._resolve_pending_registrations() + f._resolver def setup_no_cache(): clear_all_cache() @@ -25,14 +25,17 @@ def setup_no_cache(): resolve_registrations() dur = benchmark(f, (1,), n=250, burn=10) - # A cached call should not be more than 50 times slower than a native call. - assert dur <= 50 * dur_native + assert ( + dur <= 50 * dur_native + ), "A cached call should not be more than 50 times slower than a native call" - # A first call should not be more than 2000 times slower than a cached call. - assert dur_first <= 2000 * dur + assert ( + dur_first <= 2000 * dur + ), "A first call should not be more than 2000 times slower than a cached call" - # The cached call should be at least 5 times faster than a first call. - assert dur <= dur_first / 5 + assert ( + dur <= dur_first / 4 + ), "The cached call should be at least 4 times faster than a first call" def test_cache_function(): @@ -136,33 +139,30 @@ def f(x: int): def f(x: float): return 2 - assert len(f._cache) == 0 - assert len(f._resolver) == 0 + assert f._methods_registry._resolver is None assert f(1) == 1 # Check that cache is used. - assert len(f._cache) == 1 - assert len(f._resolver) == 2 + assert len(f._methods_registry._resolver._cache) == 1 + assert len(f._methods_registry._resolver) == 2 # Clear via the dispatcher. dispatch.clear_cache() - assert len(f._cache) == 0 - assert len(f._resolver) == 0 + assert f._methods_registry._resolver is None # Run the function again. assert f(1) == 1 - assert len(f._cache) == 1 - assert len(f._resolver) == 2 + assert len(f._methods_registry._resolver._cache) == 1 + assert len(f._methods_registry._resolver) == 2 # Clear via `clear_all_cache`. clear_all_cache() - assert len(f._cache) == 0 - assert len(f._resolver) == 0 + assert f._methods_registry._resolver is None # Run the function one last time. assert f(1) == 1 - assert len(f._cache) == 1 - assert len(f._resolver) == 2 + assert len(f._methods_registry._resolver._cache) == 1 + assert len(f._methods_registry._resolver) == 2 def test_cache_unfaithful(): @@ -179,4 +179,4 @@ def f(x: List[int]): # Since `f` is not faithful, no cache should be accumulated. assert f(1) == 1 assert f([1]) == 2 - assert len(f._cache) == 0 + assert len(f._methods_registry._resolver._cache) == 0 diff --git a/tests/test_function.py b/tests/test_function.py index c4f8229b..d0b0e993 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,11 +1,21 @@ import abc +import sys import textwrap import typing +from unittest.mock import MagicMock import pytest +import plum.resolver from plum import Dispatcher -from plum.function import Function, _change_function_name, _convert, _owner_transfer +from plum.function import ( + Function, + MethodsRegistry, + _change_function_name, + _convert, + _document, + _owner_transfer, +) from plum.resolver import AmbiguousLookupError, NotFoundLookupError from plum.signature import Signature @@ -53,23 +63,13 @@ def f(x: int): def f(x: str): return "str" - assert repr(f) == f"" - - # Register all methods. - assert f(1) == "int" - - assert repr(f) == f"" + assert repr(f) == f"" @dispatch def f(x: float): return "float" - assert repr(f) == f"" - - # Again register all methods. - assert f(1) == "int" - - assert repr(f) == f"" + assert repr(f) == f"" # `A` needs to be in the global scope for owner resolution to work. @@ -253,7 +253,7 @@ def other_implementation(x: str): assert f(1) == "int" assert f(1.0) == "float" assert f("1") == "str" - assert f._resolver.resolve(("1",)).precedence == 1 + assert f._resolver._resolve(("1",)).precedence == 1 def test_function_multi_dispatch(): @@ -270,7 +270,7 @@ def implementation(x): assert f(1) == "int" assert f(1.0) == "float or str" assert f("1") == "float or str" - assert f._resolver.resolve(("1",)).precedence == 1 + assert f._resolver._resolve(("1",)).precedence == 1 # Check that arguments to `f.dispatch_multi` must be tuples or signatures. with pytest.raises(ValueError): @@ -284,9 +284,8 @@ def f(x: int): g = Function(f) g.register(f) - assert g._pending == [(f, None, 0)] - assert g._resolved == [] - assert len(g._resolver) == 0 + assert len(g.methods) == 1 + assert g.methods[0].implementation == f def test_resolve_pending_registrations(): @@ -301,19 +300,16 @@ def f(x: int): # At this point, there should be nothing to register, so a call should not clear # the cache. - assert f._pending == [] - f._resolve_pending_registrations() - assert len(f._cache) == 1 + assert f._resolver + assert len(f._methods_registry._resolver._cache) == 1 @f.dispatch def f(x: str): pass # Now there is something to register. A call should clear the cache. - assert len(f._pending) == 1 - f._resolve_pending_registrations() - assert len(f._pending) == 0 - assert len(f._cache) == 0 + f._resolver + assert f._methods_registry._resolver._cache == {} # Register in two ways using multi and the wrong name. @f.dispatch_multi((float,), Signature(complex)) @@ -577,3 +573,105 @@ def do(self, x: int): # Also test that `invoke` is wrapped, like above. assert A().do.invoke(int).__doc__ == "Docs" assert A.do.invoke(A, int).__doc__ == "Docs" + + +def test_document_nosphinx(): + """Test the following: + (1) remove trailing newlines, + (2) appropriately remove trailing newlines, + (3) appropriately remove indentation, ignoring the first line, + (4) separate the title from the body. + """ + + def f(x): + """Title. + + Hello. + + Args: + x (object): Input. + + """ + + expected_doc = """ + + + f(x) + + Title. + + Hello. + + Args: + x (object): Input. + """ + assert _document(f) == textwrap.dedent(expected_doc).strip() + + +def test_document_sphinx(monkeypatch): + """Like :func:`test_document_nosphinx`, but when :mod:`sphinx` + is imported.""" + # Fake import :mod:`sphinx`. + monkeypatch.setitem(sys.modules, "sphinx", None) + + def f(x): + """Title. + + Hello. + + Args: + x (object): Input. + + """ + + expected_doc = """ + .. py:function:: f(x) + :noindex: + + Title. + + Hello. + + Args: + x (object): Input. + """ + assert _document(f) == textwrap.dedent(expected_doc).strip() + + +def test_doc_in_resolver(monkeypatch): + # Let the `pydoc` documenter simply return the docstring. This makes testing + # simpler. + monkeypatch.setattr(plum.function, "_document", lambda x: x.__doc__) + + r = MethodsRegistry(function_name="something") + + class _MockFunction: + def __init__(self, doc): + self.__doc__ = doc + + class _MockSignature: + def __init__(self, doc): + self.implementation = _MockFunction(doc) + + # Circumvent the use of :meth:`.resolver.Resolver.register`. + r.get_all_subsignatures = MagicMock( + return_value=[ + _MockSignature("first"), + _MockSignature("second"), + _MockSignature("third"), + ] + ) + assert r.doc() == "first\n\nsecond\n\nthird" + + # Test that duplicates are excluded. + all_subsignatures = [ + _MockSignature("first"), + _MockSignature("second"), + _MockSignature("second"), + _MockSignature("third"), + ] + r.get_all_subsignatures = MagicMock(return_value=all_subsignatures) + assert r.doc() == "first\n\nsecond\n\nthird" + + # Test that the explicit exclusion mechanism also works. + assert r.doc(exclude=all_subsignatures[3].implementation) == "first\n\nsecond" diff --git a/tests/test_resolver.py b/tests/test_resolver.py index fcc32433..45a54fd8 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,154 +1,41 @@ -import sys -import textwrap -import typing +from typing import Tuple import pytest -import plum.resolver -from plum.resolver import AmbiguousLookupError, NotFoundLookupError, Resolver, _document +from plum.resolver import AmbiguousLookupError, NotFoundLookupError, Resolver from plum.signature import Signature def test_initialisation(): - r = Resolver() + r = Resolver([]) # Without any registered signatures, the resolver should be faithful. assert r.is_faithful -def test_document_nosphinx(): - """Test the following: - (1) remove trailing newlines, - (2) appropriately remove trailing newlines, - (3) appropriately remove indentation, ignoring the first line, - (4) separate the title from the body. - """ - - def f(x): - """Title. - - Hello. - - Args: - x (object): Input. - - """ - - expected_doc = """ - - - f(x) - - Title. - - Hello. - - Args: - x (object): Input. - """ - assert _document(f) == textwrap.dedent(expected_doc).strip() - - -def test_document_sphinx(monkeypatch): - """Like :func:`test_document_nosphinx`, but when :mod:`sphinx` - is imported.""" - # Fake import :mod:`sphinx`. - monkeypatch.setitem(sys.modules, "sphinx", None) - - def f(x): - """Title. - - Hello. - - Args: - x (object): Input. - - """ - - expected_doc = """ - .. py:function:: f(x) - :noindex: - - Title. - - Hello. - - Args: - x (object): Input. - """ - assert _document(f) == textwrap.dedent(expected_doc).strip() - - -def test_doc(monkeypatch): - # Let the `pydoc` documenter simply return the docstring. This makes testing - # simpler. - monkeypatch.setattr(plum.resolver, "_document", lambda x: x.__doc__) - - r = Resolver() - - class _MockFunction: - def __init__(self, doc): - self.__doc__ = doc - - class _MockSignature: - def __init__(self, doc): - self.implementation = _MockFunction(doc) - - # Circumvent the use of :meth:`.resolver.Resolver.register`. - r.signatures = [ - _MockSignature("first"), - _MockSignature("second"), - _MockSignature("third"), - ] - assert r.doc() == "first\n\nsecond\n\nthird" - - # Test that duplicates are excluded. - r.signatures = [ - _MockSignature("first"), - _MockSignature("second"), - _MockSignature("second"), - _MockSignature("third"), - ] - assert r.doc() == "first\n\nsecond\n\nthird" - - # Test that the explicit exclusion mechanism also works. - assert r.doc(exclude=r.signatures[3].implementation) == "first\n\nsecond" - - def test_register(): - r = Resolver() - # Test that faithfulness is tracked correctly. - r.register(Signature(int)) - r.register(Signature(float)) + r = Resolver([Signature(int), Signature(float)]) assert r.is_faithful - r.register(Signature(typing.Tuple[int])) + r = Resolver([Signature(int), Signature(float), Signature(Tuple[int])]) assert not r.is_faithful # Test that signatures can be replaced. - new_s = Signature(float) assert len(r) == 3 + new_s = Signature(float) assert r.signatures[1] is not new_s - r.register(new_s) + r = Resolver([Signature(int), Signature(float), Signature(Tuple[int]), new_s]) assert len(r) == 3 assert r.signatures[1] is new_s - # Test the edge case that should never happen. - r.signatures[2] = Signature(float) - with pytest.raises( - AssertionError, - match=r"(?i)the added signature `(.*)` is equal to 2 existing signatures", - ): - r.register(Signature(float)) - def test_len(): - r = Resolver() + r = Resolver([]) assert len(r) == 0 - r.register(Signature(int)) + r = Resolver([Signature(int)]) assert len(r) == 1 - r.register(Signature(float)) + r = Resolver([Signature(int), Signature(float)]) assert len(r) == 2 - r.register(Signature(float)) + r = Resolver([Signature(int), Signature(float), Signature(float)]) assert len(r) == 2 @@ -182,40 +69,42 @@ class Missing: s_u = Signature(Unrelated) s_m = Signature(Missing) - r = Resolver() - r.register(s_b1) - # Import this after `s_b1` to test all branches. - r.register(s_a) - r.register(s_b2) - # Do not register `s_c1`. - r.register(s_c2) - r.register(s_u) - # Also do not register `s_m`. - + r = Resolver( + [ + s_b1, + # Import this after `s_b1` to test all branches. + s_a, + s_b2, + # Do not register `s_c1`. + s_c2, + s_u, + # Also do not register `s_m`. + ] + ) # Resolve by signature. - assert r.resolve(s_a) == s_a - assert r.resolve(s_b1) == s_b1 - assert r.resolve(s_b2) == s_b2 + assert r._resolve(s_a) == s_a + assert r._resolve(s_b1) == s_b1 + assert r._resolve(s_b2) == s_b2 with pytest.raises(AmbiguousLookupError): - r.resolve(s_c1) - assert r.resolve(s_c2) == s_c2 - assert r.resolve(s_u) == s_u + r._resolve(s_c1) + assert r._resolve(s_c2) == s_c2 + assert r._resolve(s_u) == s_u with pytest.raises(NotFoundLookupError): - r.resolve(s_m) + r._resolve(s_m) # Resolve by type. - assert r.resolve((A(),)) == s_a - assert r.resolve((B1(),)) == s_b1 - assert r.resolve((B2(),)) == s_b2 + assert r._resolve((A(),)) == s_a + assert r._resolve((B1(),)) == s_b1 + assert r._resolve((B2(),)) == s_b2 with pytest.raises(AmbiguousLookupError): - r.resolve((C1(),)) - assert r.resolve((C2(),)) == s_c2 - assert r.resolve((Unrelated(),)) == s_u + r._resolve((C1(),)) + assert r._resolve((C2(),)) == s_c2 + assert r._resolve((Unrelated(),)) == s_u with pytest.raises(NotFoundLookupError): - r.resolve((Missing(),)) + r._resolve((Missing(),)) # Test that precedence can correctly break the ambiguity. s_b1.precedence = 1 - assert r.resolve(s_c1) == s_b1 + assert r._resolve(s_c1) == s_b1 s_b2.precedence = 2 - assert r.resolve(s_c1) == s_b2 + assert r._resolve(s_c1) == s_b2 diff --git a/tests/typechecked/test_overload.py b/tests/typechecked/test_overload.py index 47f4795b..cee7ad83 100644 --- a/tests/typechecked/test_overload.py +++ b/tests/typechecked/test_overload.py @@ -6,24 +6,24 @@ @overload -def f(x: int) -> int: # E: pyright(marked as overload) +def f(x: int, y: int) -> int: # E: pyright(marked as overload) return x @overload -def f(x: str) -> str: # E: pyright(marked as overload) +def f(x: str, y: str) -> str: # E: pyright(marked as overload) return x @dispatch -def f(x): # E: pyright(overloaded implementation is not consistent) +def f(x, y): # E: pyright(overloaded implementation is not consistent) pass def test_overload() -> None: - assert f(1) == 1 - assert f("1") == "1" + assert f(1, 2) == 1 + assert f("1", "2") == "1" with pytest.raises(NotFoundLookupError): # E: pyright(argument of type "float" cannot be assigned to parameter "x") # E: mypy(no overload variant of "f" matches argument type "float") - f(1.0) + f(1.0, 2.0)