diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9215b35f..4bb05ee0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,17 +18,12 @@ repos: - id: name-tests-test args: ["--django"] - - repo: https://github.com/pycqa/isort - rev: 8.0.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.13.1 hooks: - - id: isort - name: isort - - - repo: https://github.com/psf/black-pre-commit-mirror - rev: 26.3.1 - hooks: - - id: black - language_version: python3 + - id: ruff-format + - id: ruff-check + args: [--fix] - repo: https://github.com/asottile/blacken-docs rev: 1.20.0 @@ -36,23 +31,6 @@ repos: - id: blacken-docs additional_dependencies: [black==22.3.0] - - repo: https://github.com/pycqa/flake8 - rev: 7.3.0 - hooks: - - id: flake8 - exclude: docs/source/conf.py, __pycache__ - additional_dependencies: - [ - flake8-bugbear, - flake8-builtins, - flake8-quotes>=3.3.2, - flake8-comprehensions, - pandas-vet, - flake8-print, - pep8-naming, - doc8, - ] - - repo: https://github.com/pycqa/pydocstyle rev: 6.3.0 hooks: @@ -62,15 +40,8 @@ repos: - repo: https://github.com/nbQA-dev/nbQA rev: 1.9.1 hooks: - - id: nbqa-isort - args: [--nbqa-mutate, --nbqa-dont-skip-bad-cells] - additional_dependencies: [isort==5.6.4] - - id: nbqa-black - args: [--nbqa-mutate, --nbqa-dont-skip-bad-cells] - additional_dependencies: [black>=22.3.0] - - id: nbqa-flake8 - args: [--nbqa-dont-skip-bad-cells, "--extend-ignore=E402,E203"] - additional_dependencies: [flake8==3.8.3] + - id: nbqa-ruff + args: [--nbqa-mutate] - repo: https://github.com/PyCQA/bandit rev: 1.9.4 diff --git a/docs/source/conf.py b/docs/source/conf.py index c611f98e..d3f154af 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file) """Configure skbase Sphinx documentation.""" @@ -298,7 +297,7 @@ def find_source(): # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - "python": ("https://docs.python.org/{.major}".format(sys.version_info), None), + "python": (f"https://docs.python.org/{sys.version_info.major}", None), "scikit-learn": ("https://scikit-learn.org/stable/", None), "sktime": ("https://www.sktime.net/en/stable/", None), } diff --git a/pyproject.toml b/pyproject.toml index 17c90e45..c5c6d5ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,18 +47,9 @@ dev = [ linters = [ "mypy", - "isort", - "flake8", - "black", + "ruff", "pydocstyle", "nbqa", - "flake8-bugbear", - "flake8-builtins", - "flake8-quotes", - "flake8-comprehensions", - "pandas-vet", - "flake8-print", - "pep8-naming", "doc8", ] @@ -112,16 +103,6 @@ addopts = [ "--cov-report=html", ] -[tool.isort] -profile = "black" -src_paths = ["skbase/*"] -multi_line_output = 3 -known_first_party = ["skbase"] - -[tool.black] -line-length = 88 -extend-exclude = "^/setup.py docs/conf.py" - [tool.pydocstyle] convention = "numpy" @@ -137,6 +118,79 @@ max-line-length = 88 ignore = ["D004"] ignore_path = ["docs/_build", "docs/source/api_reference/auto_generated"] +[tool.ruff] +line-length = 88 +extend-exclude = ["setup.py", "docs/source/conf.py"] +target-version = "py310" +extend-include = ["*.ipynb"] + +[tool.ruff.lint] +select = [ + "D", # pydocstyle + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "PIE", # flake8-pie + "T20", # flake8-print + "RET", # flake8-return + "SIM", # flake8-simplify + "TID", # flake8-tidy-imports + "TCH", # flake8-type-checking + "RUF", # Ruff-specific rules +] +ignore = [ + "E203", # Whitespace before punctuation + "E402", # Module import not at top of file + "E501", # Line too long + "E731", # Do not assign a lambda expression, use a def + "RET504", # Unnecessary variable assignment before `return` statement + "S101", # Use of `assert` detected + "C408", # Unnecessary dict call - rewrite as a literal + "UP031", # Use format specifier instead of % + "UP009", # UTF-8 encoding declaration is unnecessary + "S102", # Use of exec + "C414", # Unnecessary `list` call within `sorted()` + "S301", # pickle and modules that wrap it can be unsafe + "C416", # Unnecessary list comprehension - rewrite as a generator + "S310", # Audit URL open for permitted schemes + "S202", # Uses of `tarfile.extractall()` + "S307", # Use of possibly insecure function + "C417", # Unnecessary `map` usage (rewrite using a generator expression) + "S605", # Starting a process with a shell, possible injection detected + "E741", # Ambiguous variable name + "S107", # Possible hardcoded password + "S105", # Possible hardcoded password + "PT018", # Checks for assertions that combine multiple independent condition + "S602", # sub process call with shell=True unsafe + "C419", # Unnecessary list comprehension, some are flagged yet are not + "C409", # Unnecessary `list` literal passed to `tuple()` (rewrite as a `tuple` literal) + "S113", # Probable use of httpx call without timeout +] +allowed-confusables = ["σ"] + +[tool.ruff.lint.per-file-ignores] +"setup.py" = ["S101"] +"**/__init__.py" = [ + "F401", # unused import +] +"**/tests/**" = [ + "D", # docstring + "S605", # Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` + "S607", # Starting a process with a partial executable path + "RET504", # Unnecessary variable assignment before `return` statement + "PT004", # Fixture `tmpdir_unittest_fixture` does not return anything, add leading underscore + "PT011", # `pytest.raises(ValueError)` is too broad, set the `match` parameter or use a more specific exception + "PT012", # `pytest.raises()` block should contain a single simple statement + "PT019", # Fixture `_` without value is injected as parameter, use `@pytest.mark.usefixtures` instead +] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + [tool.bandit] exclude_dirs = ["*/tests/*", "*/testing/*"] diff --git a/skbase/_exceptions.py b/skbase/_exceptions.py index 50953aaa..d4de229d 100644 --- a/skbase/_exceptions.py +++ b/skbase/_exceptions.py @@ -5,20 +5,18 @@ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING """Custom exceptions used in ``skbase``.""" -from typing import List - -__author__: List[str] = ["fkiraly", "mloning", "rnkuhns"] -__all__: List[str] = ["FixtureGenerationError", "NotFittedError"] +__author__: list[str] = ["fkiraly", "mloning", "rnkuhns"] +__all__: list[str] = ["FixtureGenerationError", "NotFittedError"] class FixtureGenerationError(Exception): """Raised when a fixture fails to generate.""" - def __init__(self, fixture_name="", err=None): # noqa: B042 + def __init__(self, fixture_name="", err=None): self.fixture_name = fixture_name self.err = err msg = f"fixture {fixture_name} failed to generate. {err}" - super().__init__(msg) # noqa: B042 + super().__init__(msg) class NotFittedError(ValueError, AttributeError): diff --git a/skbase/base/__init__.py b/skbase/base/__init__.py index 6c80e99f..18b2444e 100644 --- a/skbase/base/__init__.py +++ b/skbase/base/__init__.py @@ -7,8 +7,6 @@ sktime design principles in your project. """ -from typing import List - from skbase.base._base import BaseEstimator, BaseObject from skbase.base._meta import ( BaseMetaEstimator, @@ -17,12 +15,12 @@ BaseMetaObjectMixin, ) -__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] -__all__: List[str] = [ - "BaseObject", +__author__: list[str] = ["mloning", "RNKuhns", "fkiraly"] +__all__: list[str] = [ "BaseEstimator", "BaseMetaEstimator", - "BaseMetaObject", "BaseMetaEstimatorMixin", + "BaseMetaObject", "BaseMetaObjectMixin", + "BaseObject", ] diff --git a/skbase/base/_base.py b/skbase/base/_base.py index c6a0d16a..3f714708 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -58,15 +58,15 @@ class name: BaseEstimator import warnings from collections import defaultdict from copy import deepcopy -from typing import List +from typing import ClassVar from skbase._exceptions import NotFittedError from skbase.base._clone_base import _check_clone, _clone from skbase.base._pretty_printing._object_html_repr import _object_html_repr from skbase.base._tagmanager import _FlagManager -__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"] -__all__: List[str] = ["BaseEstimator", "BaseObject"] +__author__: list[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"] +__all__: list[str] = ["BaseEstimator", "BaseObject"] class BaseObject(_FlagManager): @@ -75,7 +75,7 @@ class BaseObject(_FlagManager): Extends scikit-learn's BaseEstimator to include sktime style interface for tags. """ - _config = { + _config: ClassVar[dict] = { "display": "diagram", "print_changed_only": True, "check_clone": False, # whether to execute validity checks in clone @@ -86,7 +86,7 @@ def __init__(self): """Construct BaseObject.""" self._init_flags(flag_attr_name="_tags") self._init_flags(flag_attr_name="_config") - super(BaseObject, self).__init__() + super().__init__() def __eq__(self, other): """Equality dunder. Checks equal class and parameters. @@ -210,7 +210,7 @@ def _get_clone_plugins(cls): in ``skbase.base._clone_plugins``, and implement the methods ``_check`` and ``_clone``. """ - return None + return @classmethod def _get_init_signature(cls): @@ -449,10 +449,10 @@ def _is_suffix(x, y): def _get_alias(x, d): """Return alias of x in d.""" # if key is in valid_params, key is replaced by key (itself) - if any(x == y for y in d.keys()): + if x in d: return x - suff_list = [y for y in d.keys() if _is_suffix(x, y)] + suff_list = [y for y in d if _is_suffix(x, y)] # if key is a __ suffix of exactly one key in valid_params, # it is replaced by that key @@ -468,7 +468,7 @@ def _get_alias(x, d): # if ns == 1 return suff_list[0] - alias_dict = {_get_alias(x, valid_params): d[x] for x in d.keys()} + alias_dict = {_get_alias(x, valid_params): d[x] for x in d} return alias_dict @@ -1222,18 +1222,18 @@ class TagAliaserMixin: # dictionary of aliases # key = old tag; value = new tag, aliased by old tag # override this in a child class - alias_dict = {"old_tag": "new_tag", "tag_to_remove": ""} + alias_dict: ClassVar[dict] = {"old_tag": "new_tag", "tag_to_remove": ""} # dictionary of removal version # key = old tag; value = version in which tag will be removed, as string - deprecate_dict = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"} + deprecate_dict: ClassVar[dict] = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"} # package name used for deprecation warnings _package_name = "" def __init__(self): """Construct TagAliaserMixin.""" - super(TagAliaserMixin, self).__init__() + super().__init__() @classmethod def get_class_tags(cls): @@ -1277,7 +1277,7 @@ def get_class_tags(cls): class attribute via nested inheritance. NOT overridden by dynamic tags set by ``set_tags`` or ``clone_tags``. """ - collected_tags = super(TagAliaserMixin, cls).get_class_tags() + collected_tags = super().get_class_tags() cls._deprecate_tag_warn(collected_tags) collected_tags = cls._complete_dict(collected_tags) return collected_tags @@ -1333,7 +1333,7 @@ def get_class_tag(cls, tag_name, tag_value_default=None): old_tag_name = tag_name new_tag_name = alias_dict[old_tag_name] if tag_name in alias_dict.values(): - old_tag_name = [k for k, v in alias_dict.items() if v == tag_name][0] + old_tag_name = next(k for k, v in alias_dict.items() if v == tag_name) new_tag_name = tag_name tag_changed = new_tag_name != old_tag_name @@ -1353,7 +1353,7 @@ def get_class_tag(cls, tag_name, tag_value_default=None): return old_tag_val # case 2: old tag was queried, but old tag not present # then: return value of new tag - elif old_tag_queried: + if old_tag_queried: return cls._get_class_flag( new_tag_name, tag_value_default, @@ -1400,7 +1400,7 @@ def get_tags(self): class attribute via nested inheritance and then any overrides and new tags from ``_tags_dynamic`` object attribute. """ - collected_tags = super(TagAliaserMixin, self).get_tags() + collected_tags = super().get_tags() self._deprecate_tag_warn(collected_tags) collected_tags = self._complete_dict(collected_tags) return collected_tags @@ -1458,7 +1458,7 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True): old_tag_name = tag_name new_tag_name = alias_dict[old_tag_name] if tag_name in alias_dict.values(): - old_tag_name = [k for k, v in alias_dict.items() if v == tag_name][0] + old_tag_name = next(k for k, v in alias_dict.items() if v == tag_name) new_tag_name = tag_name tag_changed = new_tag_name != old_tag_name @@ -1479,7 +1479,7 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True): return old_tag_val # case 2: old tag was queried, but old tag not present # then: return value of new tag - elif old_tag_queried: + if old_tag_queried: return self._get_flag( new_tag_name, tag_value_default, @@ -1562,8 +1562,7 @@ def _complete_dict(cls, tag_dict, direction="both"): for old_tag in alias_dict: cls._translate_tags(new_tag_dict, tag_dict, old_tag, direction) return new_tag_dict - else: - return tag_dict + return tag_dict @classmethod def _deprecate_tag_warn(cls, tags): @@ -1578,7 +1577,7 @@ def _deprecate_tag_warn(cls, tags): DeprecationWarning for each tag in tags that is aliased by cls.alias_dict """ for tag_name in tags: - if tag_name in cls.alias_dict.keys(): + if tag_name in cls.alias_dict: version = cls.deprecate_dict[tag_name] new_tag = cls.alias_dict[tag_name] pkg_name = cls._package_name @@ -1656,8 +1655,7 @@ def is_fitted(self): """ if hasattr(self, "_is_fitted"): return self._is_fitted - else: - return False + return False def check_is_fitted(self, method_name=None): """Check if the estimator has been fitted. @@ -1818,8 +1816,7 @@ def getattr_safe(obj, attr): if hasattr(obj, attr): attr = getattr(obj, attr) return attr, True - else: - return None, False + return None, False except Exception: return None, False diff --git a/skbase/base/_clone_base.py b/skbase/base/_clone_base.py index d3ed05f3..5cd759a2 100644 --- a/skbase/base/_clone_base.py +++ b/skbase/base/_clone_base.py @@ -20,7 +20,7 @@ * clone(obj) -> type(obj) - method to clone obj """ -__all__ = ["_clone", "_check_clone"] +__all__ = ["_check_clone", "_clone"] from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS @@ -110,7 +110,7 @@ def _check_clone(original, clone): self_params = original.get_params(deep=False) # check that all attributes are written to the clone - for attrname in self_params.keys(): + for attrname in self_params: if not hasattr(clone, attrname): raise RuntimeError( f"error in {original}.clone, __init__ must write all arguments " @@ -118,7 +118,7 @@ def _check_clone(original, clone): f"Please check __init__ of {original}." ) - clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()} + clone_attrs = {attr: getattr(clone, attr) for attr in self_params} # check equality of parameters post-clone and pre-clone clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True) diff --git a/skbase/base/_clone_plugins.py b/skbase/base/_clone_plugins.py index c4165ae9..3885433b 100644 --- a/skbase/base/_clone_plugins.py +++ b/skbase/base/_clone_plugins.py @@ -15,13 +15,13 @@ * clone(obj) -> type(obj) - method to clone obj """ -from functools import lru_cache +from functools import cache from inspect import isclass # imports wrapped in functions to avoid exceptions on skbase init # wrapped in _safe_import to avoid exceptions on skbase init -@lru_cache(maxsize=None) +@cache def _is_sklearn_present(): """Check whether scikit-learn is present.""" from skbase.utils.dependencies import _check_soft_dependencies @@ -29,7 +29,7 @@ def _is_sklearn_present(): return _check_soft_dependencies("scikit-learn") -@lru_cache(maxsize=None) +@cache def _get_sklearn_clone(): """Get sklearn's clone function.""" from skbase.utils.dependencies._import import _safe_import @@ -196,13 +196,12 @@ def _clone(self, obj): if not self.safe: return deepcopy(obj) - else: - raise TypeError( - "Cannot clone object '%s' (type %s): " - "it does not seem to be a scikit-base object or scikit-learn " - "estimator, as it does not implement a " - "'get_params' method." % (repr(obj), type(obj)) - ) + raise TypeError( + "Cannot clone object '%s' (type %s): " + "it does not seem to be a scikit-base object or scikit-learn " + "estimator, as it does not implement a " + "'get_params' method." % (repr(obj), type(obj)) + ) DEFAULT_CLONE_PLUGINS = [ diff --git a/skbase/base/_meta.py b/skbase/base/_meta.py index 3a4e70db..dfef2493 100644 --- a/skbase/base/_meta.py +++ b/skbase/base/_meta.py @@ -10,6 +10,7 @@ """Implements functionality for meta objects composed of other objects.""" from inspect import isclass +from typing import ClassVar from skbase.base._base import BaseEstimator, BaseObject from skbase.base._pretty_printing._object_html_repr import _VisualBlock @@ -39,7 +40,7 @@ class has values that follow the named object specification. For example, # _steps_attr points to the attribute of self # which contains the heterogeneous set of estimators # this must be an iterable of (name: str, estimator) pairs for the default - _tags = {"named_object_parameters": "steps"} + _tags: ClassVar[dict] = {"named_object_parameters": "steps"} def is_composite(self): """Check if the object is composite. @@ -139,9 +140,7 @@ def _get_fitted_params(self): """ fitted_params = self._get_fitted_params_default() - fitted_named_object_attr = self.get_tag( - "fitted_named_object_parameters" - ) # type: ignore + fitted_named_object_attr = self.get_tag("fitted_named_object_parameters") # type: ignore named_objects_fitted_params = self._get_params( fitted_named_object_attr, fitted=True @@ -253,7 +252,7 @@ def _set_params(self, attr: str, **params): items = getattr(self, attr) names = [] if items and isinstance(items, (list, tuple)): - names = list(zip(*items))[0] + names = next(zip(*items, strict=False)) for name in list(params.keys()): if "__" not in name and name in names: self._replace_object(attr, name, params.pop(name)) @@ -323,14 +322,14 @@ def _check_names(self, names, make_unique=True): A sequence of unique string names that follow named object API rules. """ if len(set(names)) != len(names): - raise ValueError("Names provided are not unique: {0!r}".format(list(names))) + raise ValueError(f"Names provided are not unique: {list(names)!r}") # Get names that match direct parameter invalid_names = set(names).intersection(self.get_params(deep=False)) invalid_names = invalid_names.union({name for name in names if "__" in name}) if invalid_names: raise ValueError( "Object names conflict with constructor argument or " - "contain '__': {0!r}".format(sorted(invalid_names)) + f"contain '__': {sorted(invalid_names)!r}" ) if make_unique: names = make_strings_unique(names) @@ -363,10 +362,7 @@ def _coerce_object_tuple(self, obj, clone=False): name = obj[0] else: - if isinstance(obj, tuple) and len(obj) == 1: - _obj = obj[0] - else: - _obj = obj + _obj = obj[0] if isinstance(obj, tuple) and len(obj) == 1 else obj name = type(_obj).__name__ if clone: @@ -460,22 +456,25 @@ def is_obj_is_tuple(obj): # We've already guarded against objs being dict when allow_dict is False # So here we can just check dictionary elements - if isinstance(objs, dict) and not all( - isinstance(name, str) and isinstance(obj, cls_type) - for name, obj in objs.items() - ): - raise TypeError(msg) - - elif not all(any(is_obj_is_tuple(x)) for x in objs): + if ( + isinstance(objs, dict) + and not all( + isinstance(name, str) and isinstance(obj, cls_type) + for name, obj in objs.items() + ) + ) or not all(any(is_obj_is_tuple(x)) for x in objs): raise TypeError(msg) msg_no_mix = ( f"Elements of {attr_name} must either all be objects, " f"or all (str, objects) tuples. A mix of the two is not allowed." ) - if not allow_mix and not all(is_obj_is_tuple(x)[0] for x in objs): - if not all(is_obj_is_tuple(x)[1] for x in objs): - raise TypeError(msg_no_mix) + if ( + not allow_mix + and not all(is_obj_is_tuple(x)[0] for x in objs) + and not all(is_obj_is_tuple(x)[1] for x in objs) + ): + raise TypeError(msg_no_mix) return self._coerce_to_named_object_tuples(objs, clone=clone, make_unique=True) @@ -500,9 +499,11 @@ def _get_names_and_objects(self, named_objects, make_unique=False): The """ if isinstance(named_objects, dict): - names, objs = zip(*named_objects.items()) + names, objs = zip(*named_objects.items(), strict=False) else: - names, objs = zip(*[self._coerce_object_tuple(x) for x in named_objects]) + names, objs = zip( + *[self._coerce_object_tuple(x) for x in named_objects], strict=False + ) # Optionally make names unique if make_unique: @@ -555,7 +556,7 @@ def _coerce_to_named_object_tuples(self, objs, clone=False, make_unique=True): named_objects, make_unique=make_unique ) # Repack the objects - named_objects = list(zip(names, objs)) + named_objects = list(zip(names, objs, strict=False)) return named_objects def _dunder_concat( @@ -632,8 +633,7 @@ def _dunder_concat( def concat(x, y): if concat_order == "left": return x + y - else: - return y + x + return y + x # get attr_name from self and other # can be list of ests, list of (str, est) tuples, or list of mixture of these @@ -658,16 +658,15 @@ def concat(x, y): new_objs = concat(self_objs, other_objs) # create the "steps" param for the composite # if all the names are equal to class names, we eat them away - if all(type(x[1]).__name__ == x[0] for x in zip(new_names, new_objs)): + if all( + type(x[1]).__name__ == x[0] for x in zip(new_names, new_objs, strict=False) + ): step_param = {attr_name: list(new_objs)} else: - step_param = {attr_name: list(zip(new_names, new_objs))} + step_param = {attr_name: list(zip(new_names, new_objs, strict=False))} # retrieve other parameters, from composite_params attribute - if composite_params is None: - composite_params = {} - else: - composite_params = composite_params.copy() + composite_params = {} if composite_params is None else composite_params.copy() # construct the composite with both step and additional params composite_params.update(step_param) @@ -806,7 +805,7 @@ def _tagchain_is_linked( for _, est in estimators: if est.get_tag(mid_tag_name) == mid_tag_val: return True, True - if not est.get_tag(left_tag_name) == left_tag_val: + if est.get_tag(left_tag_name) != left_tag_val: return False, False return True, False diff --git a/skbase/base/_pretty_printing/__init__.py b/skbase/base/_pretty_printing/__init__.py index c94c1888..1d46b64c 100644 --- a/skbase/base/_pretty_printing/__init__.py +++ b/skbase/base/_pretty_printing/__init__.py @@ -6,7 +6,5 @@ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING """Functionality for pretty printing BaseObjects.""" -from typing import List - -__author__: List[str] = ["RNKuhns"] -__all__: List[str] = [] +__author__: list[str] = ["RNKuhns"] +__all__: list[str] = [] diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py index 1c4fc017..d882d3cc 100644 --- a/skbase/base/_pretty_printing/_object_html_repr.py +++ b/skbase/base/_pretty_printing/_object_html_repr.py @@ -100,7 +100,7 @@ def _get_visual_block(base_object): return _VisualBlock( "single", base_object, names=base_object, name_details=base_object ) - elif base_object is None: + if base_object is None: return _VisualBlock("single", base_object, names="None", name_details="None") # check if base_object looks like a meta base_object wraps base_object @@ -137,7 +137,9 @@ def _write_base_object_html( kind = est_block.kind out.write(f'
') - est_infos = zip(est_block.estimators, est_block.names, est_block.name_details) + est_infos = zip( + est_block.estimators, est_block.names, est_block.name_details, strict=False + ) for est, name, name_details in est_infos: if kind == "serial": @@ -335,7 +337,7 @@ def _write_base_object_html( #$id div.sk-text-repr-fallback { display: none; } -""".replace(" ", "").replace("\n", "") # noqa +""".replace(" ", "").replace("\n", "") def _object_html_repr(base_object): diff --git a/skbase/base/_pretty_printing/_pprint.py b/skbase/base/_pretty_printing/_pprint.py index b1a575bc..66a2889b 100644 --- a/skbase/base/_pretty_printing/_pprint.py +++ b/skbase/base/_pretty_printing/_pprint.py @@ -27,8 +27,6 @@ def __repr__(self): class KeyValTupleParam(KeyValTuple): """Dummy class for correctly rendering key-value tuples from parameters.""" - pass - def _changed_params(base_object): """Return dict (param_name: value) of parameters with non-default values.""" @@ -48,11 +46,9 @@ def has_changed(k, v): if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__: return True # Use repr as a last resort. It may be expensive. - if repr(v) != repr(init_params[k]) and not ( + return repr(v) != repr(init_params[k]) and not ( _is_scalar_nan(init_params[k]) and _is_scalar_nan(v) - ): - return True - return False + ) return {k: v for k, v in params.items() if has_changed(k, v)} @@ -131,7 +127,7 @@ def __init__( # (they are treated as dicts) self.n_max_elements_to_show = n_max_elements_to_show - def format(self, obj, context, maxlevels, level): # noqa + def format(self, obj, context, maxlevels, level): return _safe_repr( obj, context, maxlevels, level, changed_only=self.changed_only ) @@ -386,10 +382,7 @@ def _safe_repr(obj, context, maxlevels, level, changed_only=False): context[objid] = 1 readable = True recursive = False - if changed_only: - params = _changed_params(obj) - else: - params = obj.get_params(deep=False) + params = _changed_params(obj) if changed_only else obj.get_params(deep=False) components = [] append = components.append level += 1 diff --git a/skbase/base/_pretty_printing/tests/test_pprint.py b/skbase/base/_pretty_printing/tests/test_pprint.py index a7322c18..7c922916 100644 --- a/skbase/base/_pretty_printing/tests/test_pprint.py +++ b/skbase/base/_pretty_printing/tests/test_pprint.py @@ -15,7 +15,7 @@ def __init__(self, foo, bar=84): self.foo = foo self.bar = bar - super(CompositionDummy, self).__init__() + super().__init__() @pytest.mark.skipif( diff --git a/skbase/base/_tagmanager.py b/skbase/base/_tagmanager.py index 87c2af23..757772f0 100644 --- a/skbase/base/_tagmanager.py +++ b/skbase/base/_tagmanager.py @@ -144,7 +144,7 @@ def _get_flag( flag_value = collected_flags.get(flag_name, flag_value_default) - if raise_error and flag_name not in collected_flags.keys(): + if raise_error and flag_name not in collected_flags: raise ValueError(f"Tag with name {flag_name} could not be found.") return flag_value @@ -209,7 +209,7 @@ def _clone_flags(self, estimator, flag_names=None, flag_attr_name="_flags"): # if flag_set is passed, intersect keys with flags in estimator if not isinstance(flag_names, list): flag_names = [flag_names] - flag_names = [key for key in flag_names if key in flags_est.keys()] + flag_names = [key for key in flag_names if key in flags_est] update_dict = {key: flags_est[key] for key in flag_names} diff --git a/skbase/lookup/__init__.py b/skbase/lookup/__init__.py index 726cd29b..b4b335fc 100644 --- a/skbase/lookup/__init__.py +++ b/skbase/lookup/__init__.py @@ -17,12 +17,11 @@ # is based on the sklearn estimator retrieval utility of the same name # See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and # https://github.com/sktime/sktime/blob/main/LICENSE -from typing import List from skbase.lookup._lookup import all_objects, get_package_metadata -__all__: List[str] = ["all_objects", "get_package_metadata"] -__author__: List[str] = [ +__all__: list[str] = ["all_objects", "get_package_metadata"] +__author__: list[str] = [ "fkiraly", "mloning", "katiebuc", diff --git a/skbase/lookup/_lookup.py b/skbase/lookup/_lookup.py index 20175fdd..d697b1a2 100644 --- a/skbase/lookup/_lookup.py +++ b/skbase/lookup/_lookup.py @@ -22,19 +22,19 @@ import pkgutil import re import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Mapping, MutableMapping, Sequence from copy import deepcopy from functools import lru_cache from operator import itemgetter from types import ModuleType -from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union +from typing import Any from skbase.base import BaseObject from skbase.utils.stdout_mute import StdoutMute from skbase.validate import check_sequence -__all__: List[str] = ["all_objects", "get_package_metadata"] -__author__: List[str] = [ +__all__: list[str] = ["all_objects", "get_package_metadata"] +__author__: list[str] = [ "fkiraly", "mloning", "katiebuc", @@ -106,7 +106,7 @@ def _is_non_public_module(module_name: str) -> bool: def _is_ignored_module( - module_name: str, modules_to_ignore: Union[str, List[str], Tuple[str]] = None + module_name: str, modules_to_ignore: str | list[str] | tuple[str] | None = None ) -> bool: """Determine if module is one of the ignored modules. @@ -141,7 +141,7 @@ def _is_ignored_module( def _filter_by_class( - klass: type, class_filter: Optional[Union[type, Sequence[type]]] = None + klass: type, class_filter: type | Sequence[type] | None = None ) -> bool: """Determine if a class is a subclass of the supplied classes. @@ -217,7 +217,7 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True): # case: tag_filter is dict # check that all keys are str - if not all(isinstance(t, str) for t in tag_filter.keys()): + if not all(isinstance(t, str) for t in tag_filter): raise ValueError(f"{type_msg} {tag_filter}") cond_sat = True @@ -279,7 +279,7 @@ def _walk(root, exclude=None, prefix=""): def _import_module( - module: Union[str, importlib.machinery.SourceFileLoader], + module: str | importlib.machinery.SourceFileLoader, suppress_import_stdout: bool = True, ) -> ModuleType: """Import a module, while optionally suppressing import standard out. @@ -319,8 +319,8 @@ def _import_module( def _determine_module_path( - package_name: str, path: Optional[Union[str, pathlib.Path]] = None -) -> Tuple[ModuleType, str, importlib.machinery.SourceFileLoader]: + package_name: str, path: str | pathlib.Path | None = None +) -> tuple[ModuleType, str, importlib.machinery.SourceFileLoader]: """Determine a package's path information. Parameters @@ -402,11 +402,11 @@ def _get_module_info( module: ModuleType, is_pkg: bool, path: str, - package_base_classes: Union[type, Tuple[type, ...]], + package_base_classes: type | tuple[type, ...], exclude_non_public_items: bool = True, - class_filter: Optional[Union[type, Sequence[type]]] = None, - tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None, - classes_to_exclude: Optional[Union[type, Sequence[type]]] = None, + class_filter: type | Sequence[type] | None = None, + tag_filter: str | Sequence[str] | Mapping[str, Any] | None = None, + classes_to_exclude: type | Sequence[type] | None = None, ) -> dict: # of ModuleInfo type # Make package_base_classes a tuple if it was supplied as a class base_classes_none = False @@ -417,16 +417,14 @@ def _get_module_info( base_classes_none = True package_base_classes = (package_base_classes,) - exclude_classes: Optional[Sequence[type]] - if classes_to_exclude is None: - exclude_classes = classes_to_exclude - elif isinstance(classes_to_exclude, Sequence): + exclude_classes: Sequence[type] | None + if classes_to_exclude is None or isinstance(classes_to_exclude, Sequence): exclude_classes = classes_to_exclude elif inspect.isclass(classes_to_exclude): exclude_classes = (classes_to_exclude,) - designed_imports: List[str] = getattr(module, "__all__", []) - authors: Union[str, List[str]] = getattr(module, "__author__", []) + designed_imports: list[str] = getattr(module, "__all__", []) + authors: str | list[str] = getattr(module, "__author__", []) if isinstance(authors, (list, tuple)): authors = ", ".join(authors) # Compile information on classes in the module @@ -575,15 +573,15 @@ def _get_members_uw(module, predicate=None): def get_package_metadata( package_name: str, - path: Optional[str] = None, + path: str | None = None, recursive: bool = True, exclude_non_public_items: bool = True, exclude_non_public_modules: bool = True, - modules_to_ignore: Union[str, List[str], Tuple[str]] = ("tests",), - package_base_classes: Union[type, Tuple[type, ...]] = (BaseObject,), - class_filter: Optional[Union[type, Sequence[type]]] = None, - tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None, - classes_to_exclude: Optional[Union[type, Sequence[type]]] = None, + modules_to_ignore: str | list[str] | tuple[str] = ("tests",), + package_base_classes: type | tuple[type, ...] = (BaseObject,), + class_filter: type | Sequence[type] | None = None, + tag_filter: str | Sequence[str] | Mapping[str, Any] | None = None, + classes_to_exclude: type | Sequence[type] | None = None, suppress_import_stdout: bool = True, ) -> Mapping: # of ModuleInfo type """Return a dictionary mapping all package modules to their metadata. @@ -721,10 +719,7 @@ def get_package_metadata( continue if recursive and is_pkg: - if "." in name: - name_ending = name[len(package_name) + 1 :] - else: - name_ending = name + name_ending = name[len(package_name) + 1 :] if "." in name else name updated_path: str if "." in name_ending: @@ -759,7 +754,7 @@ def all_objects( return_tags=None, suppress_import_stdout=True, package_name="skbase", - path: Optional[str] = None, + path: str | None = None, modules_to_ignore=None, class_lookup=None, ): @@ -963,13 +958,12 @@ class name if ``return_names=False`` and ``return_tags is not None``. if all_estimators: if isinstance(all_estimators[0], tuple): all_estimators = [ - (name, est) + _get_return_tags(est, return_tags) + (name, est, *_get_return_tags(est, return_tags)) for (name, est) in all_estimators ] else: all_estimators = [ - (est,) + _get_return_tags(est, return_tags) - for est in all_estimators + (est, *_get_return_tags(est, return_tags)) for est in all_estimators ] columns = columns + return_tags @@ -1027,20 +1021,19 @@ def _get_err_msg(estimator_type): if class_lookup is None or not isinstance(class_lookup, dict): return ( f"Parameter `estimator_type` must be None, a class, or a list of " - f"class, but found: {repr(estimator_type)}" - ) - else: - return ( - f"Parameter `estimator_type` must be None, a string, a class, or a list" - f" of [string or class]. Valid string values are: " - f"{tuple(class_lookup.keys())}, but found: " - f"{repr(estimator_type)}" + f"class, but found: {estimator_type!r}" ) + return ( + f"Parameter `estimator_type` must be None, a string, a class, or a list" + f" of [string or class]. Valid string values are: " + f"{tuple(class_lookup.keys())}, but found: " + f"{estimator_type!r}" + ) for i, estimator_type in enumerate(object_types): if isinstance(estimator_type, str): if not isinstance(class_lookup, dict) or ( - estimator_type not in class_lookup.keys() + estimator_type not in class_lookup ): raise ValueError(_get_err_msg(estimator_type)) estimator_type = class_lookup[estimator_type] @@ -1079,7 +1072,7 @@ class StdoutMuteNCatchMNF(StdoutMute): except catch and suppress ModuleNotFoundError. """ - def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002 + def _handle_exit_exceptions(self, type, value, traceback): """Handle exceptions raised during __exit__. Parameters @@ -1104,12 +1097,11 @@ def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002 def _coerce_to_tuple(x): if x is None: return () - elif isinstance(x, tuple): + if isinstance(x, tuple): return x - elif isinstance(x, list): + if isinstance(x, list): return tuple(x) - else: - return (x,) + return (x,) @lru_cache(maxsize=100) @@ -1144,7 +1136,7 @@ def _walk_and_retrieve_all_objs(root, package_name, modules_to_ignore): prefix = f"{package_name}." def _is_base_class(name): - return name.startswith("_") or name.startswith("Base") + return name.startswith(("_", "Base")) all_estimators = [] diff --git a/skbase/lookup/tests/test_lookup.py b/skbase/lookup/tests/test_lookup.py index fc029316..5508665e 100644 --- a/skbase/lookup/tests/test_lookup.py +++ b/skbase/lookup/tests/test_lookup.py @@ -11,7 +11,6 @@ import sys from copy import deepcopy from types import ModuleType -from typing import List import pandas as pd import pytest @@ -46,8 +45,8 @@ NotABaseObject, ) -__author__: List[str] = ["RNKuhns", "fkiraly"] -__all__: List[str] = [] +__author__: list[str] = ["RNKuhns", "fkiraly"] +__all__: list[str] = [] MODULE_METADATA_EXPECTED_KEYS = ( @@ -186,10 +185,7 @@ def _check_package_metadata_result(results): isinstance(k, str) and isinstance(v, dict) for k, v in mod_metadata["classes"].items() ) - ): - return False - # Then verify sub-dict values for each class have required keys - elif not all( + ) or not all( k in c_meta for c_meta in mod_metadata["classes"].values() for k in REQUIRED_CLASS_METADATA_KEYS @@ -202,10 +198,7 @@ def _check_package_metadata_result(results): isinstance(k, str) and isinstance(v, dict) for k, v in mod_metadata["functions"].items() ) - ): - return False - # Then verify sub-dict values for each function have required keys - elif not all( + ) or not all( k in f_meta for f_meta in mod_metadata["functions"].values() for k in REQUIRED_FUNCTION_METADATA_KEYS @@ -415,9 +408,9 @@ def test_walk_returns_expected_format(fixture_skbase_root_path): """Check walk function returns expected format.""" def _test_walk_return(p): - assert ( - isinstance(p, tuple) and len(p) == 3 - ), "_walk should return tuple of length 3" + assert isinstance(p, tuple) and len(p) == 3, ( + "_walk should return tuple of length 3" + ) assert ( isinstance(p[0], str) and isinstance(p[1], bool) @@ -790,13 +783,13 @@ def test_get_package_metadata_returns_expected_results( # Verify class metadata attributes correct for klass, klass_metadata in results[module]["classes"].items(): if klass_metadata["klass"] in SKBASE_BASE_CLASSES: - assert ( - klass_metadata["is_base_class"] is True - ), f"{klass} should be base class." + assert klass_metadata["is_base_class"] is True, ( + f"{klass} should be base class." + ) else: - assert ( - klass_metadata["is_base_class"] is False - ), f"{klass} should not be base class." + assert klass_metadata["is_base_class"] is False, ( + f"{klass} should not be base class." + ) if issubclass(klass_metadata["klass"], BaseObject): assert klass_metadata["is_base_object"] is True @@ -982,13 +975,12 @@ def test_all_objects_returns_expected_types( if isinstance(modules_to_ignore, str): modules_to_ignore = (modules_to_ignore,) if ( - modules_to_ignore is not None - and "tests" in modules_to_ignore + modules_to_ignore is not None and "tests" in modules_to_ignore # and "mock_package" in modules_to_ignore ): - assert ( - len(objs) == 0 - ), "Search of `skbase` should only return objects from tests module." + assert len(objs) == 0, ( + "Search of `skbase` should only return objects from tests module." + ) else: # We expect at least one object to be returned so we verify output type/format _check_all_object_output_types( diff --git a/skbase/testing/__init__.py b/skbase/testing/__init__.py index 012948fb..d9193026 100644 --- a/skbase/testing/__init__.py +++ b/skbase/testing/__init__.py @@ -1,13 +1,11 @@ # -*- coding: utf-8 -*- """:mod:`skbase.testing` provides a framework to test ``BaseObject``-s.""" -from typing import List - from skbase.testing.test_all_objects import ( BaseFixtureGenerator, QuickTester, TestAllObjects, ) -__all__: List[str] = ["BaseFixtureGenerator", "QuickTester", "TestAllObjects"] -__author__: List[str] = ["fkiraly"] +__all__: list[str] = ["BaseFixtureGenerator", "QuickTester", "TestAllObjects"] +__author__: list[str] = ["fkiraly"] diff --git a/skbase/testing/test_all_objects.py b/skbase/testing/test_all_objects.py index a1bdeea9..021eebbf 100644 --- a/skbase/testing/test_all_objects.py +++ b/skbase/testing/test_all_objects.py @@ -9,7 +9,7 @@ import types from copy import deepcopy from inspect import getfullargspec, isclass, signature -from typing import List +from typing import ClassVar import numpy as np import pytest @@ -23,7 +23,7 @@ from skbase.utils.deep_equals import deep_equals from skbase.utils.dependencies import _check_soft_dependencies -__author__: List[str] = ["fkiraly"] +__author__: list[str] = ["fkiraly"] class BaseFixtureGenerator: @@ -106,12 +106,12 @@ class BaseFixtureGenerator: valid_base_types = None # which sequence the conditional fixtures are generated in - fixture_sequence = ["object_class", "object_instance"] + fixture_sequence: ClassVar[list[str]] = ["object_class", "object_instance"] # which fixtures are indirect, e.g., have an additional pytest.fixture block # to generate an indirect fixture at runtime. Example: object_instance # warning: direct fixtures retain state changes within the same test - indirect_fixtures = ["object_instance"] + indirect_fixtures: ClassVar[list[str]] = ["object_instance"] def pytest_generate_tests(self, metafunc): """Test parameterization routine for pytest. @@ -176,7 +176,7 @@ def generator_dict(self): fixts = [gen.replace("_generate_", "") for gen in gens] generator_dict = {} - for var, gen in zip(fixts, gens): + for var, gen in zip(fixts, gens, strict=False): generator_dict[var] = getattr(self, gen) return generator_dict @@ -185,8 +185,7 @@ def is_excluded(self, test_name, est): """Shorthand to check whether test test_name is excluded for object est.""" if self.excluded_tests is None: return False - else: - return test_name in self.excluded_tests.get(est.__name__, []) + return test_name in self.excluded_tests.get(est.__name__, []) # the following functions define fixture generation logic for pytest_generate_tests # each function is of signature (test_name:str, **kwargs) -> List of fixtures @@ -419,36 +418,38 @@ def run_tests( # if function is decorated with mark.parametrize, add variable settings # NOTE: currently this works only with single-variable mark.parametrize - if hasattr(test_fun, "pytestmark"): - if len([x for x in test_fun.pytestmark if x.name == "parametrize"]) > 0: - # get the three lists from pytest - ( - pytest_fixture_vars, - pytest_fixture_prod, - pytest_fixture_names, - ) = self._get_pytest_mark_args(test_fun) - # add them to the three lists from conditional fixtures - fixture_vars, fixture_prod, fixture_names = self._product_fixtures( - fixture_vars, - fixture_prod, - fixture_names, - pytest_fixture_vars, - pytest_fixture_prod, - pytest_fixture_names, - ) + if ( + hasattr(test_fun, "pytestmark") + and len([x for x in test_fun.pytestmark if x.name == "parametrize"]) > 0 + ): + # get the three lists from pytest + ( + pytest_fixture_vars, + pytest_fixture_prod, + pytest_fixture_names, + ) = self._get_pytest_mark_args(test_fun) + # add them to the three lists from conditional fixtures + fixture_vars, fixture_prod, fixture_names = self._product_fixtures( + fixture_vars, + fixture_prod, + fixture_names, + pytest_fixture_vars, + pytest_fixture_prod, + pytest_fixture_names, + ) def print_if_verbose(msg): if int(verbose) > 0: - print(msg) # noqa: T001, T201 + print(msg) # noqa: T201 # loop B: for each test, we loop over all fixtures - for params, fixt_name in zip(fixture_prod, fixture_names): + for params, fixt_name in zip(fixture_prod, fixture_names, strict=False): # this is needed because pytest unwraps 1-tuples automatically # but subsequent code assumes params is k-tuple, no matter what k is if len(fixture_vars) == 1: params = (params,) key = f"{test_name}[{fixt_name}]" - args = dict(zip(fixture_vars, params)) + args = dict(zip(fixture_vars, params, strict=False)) for f in test_fun_vars: if f not in args: @@ -509,10 +510,7 @@ def _subset_generator_dict(obj, generator_dict): """ obj_generator_dict = generator_dict - if isclass(obj): - object_class = obj - else: - object_class = type(obj) + object_class = obj if isclass(obj) else type(obj) def _generate_object_class(test_name, **kwargs): return [object_class], [object_class.__name__] @@ -573,10 +571,9 @@ def to_str(obj): return [str(x) for x in obj] def get_id(mark): - if "ids" in mark.kwargs.keys(): + if "ids" in mark.kwargs: return mark.kwargs["ids"] - else: - return to_str(range(len(mark.args[1]))) + return to_str(range(len(mark.args[1]))) pytest_fixture_vars = [x.args[0] for x in marks] pytest_fixt_raw = [x.args[1] for x in marks] @@ -618,16 +615,16 @@ def _product_fixtures( return fixture_vars_return, fixture_prod_return, fixture_names_return def _make_builtin_fixture_equivalents(self, name): - """Utility for QuickTester, creates equivalent fixtures for pytest runs.""" + """Create equivalent fixtures for pytest runs.""" import io import logging import tempfile from pathlib import Path values = {} - if "tmp_path" == name: + if name == "tmp_path": return Path(tempfile.mkdtemp()) - if "capsys" == name: + if name == "capsys": # crude emulation using StringIO return type( "Capsys", @@ -639,12 +636,12 @@ def _make_builtin_fixture_equivalents(self, name): }, )() - if "monkeypatch" == name: + if name == "monkeypatch": from _pytest.monkeypatch import MonkeyPatch return MonkeyPatch() - if "caplog" == name: + if name == "caplog": class Caplog: def __init__(self): @@ -735,7 +732,7 @@ def test_object_tags(self, object_class): assert hasattr(object_class, "get_class_tags") all_tags = object_class.get_class_tags() assert isinstance(all_tags, dict) - assert all(isinstance(key, str) for key in all_tags.keys()) + assert all(isinstance(key, str) for key in all_tags) if hasattr(object_class, "_tags"): tags = object_class._tags msg = ( @@ -747,9 +744,7 @@ def test_object_tags(self, object_class): if self.valid_tags is None: invalid_tags = tags else: - invalid_tags = [ - tag for tag in tags.keys() if tag not in self.valid_tags - ] + invalid_tags = [tag for tag in tags if tag not in self.valid_tags] assert len(invalid_tags) == 0, ( f"_tags of {object_class} contains invalid tags: {invalid_tags}. " f"For a list of valid tags, see {self.__class__.__name__}.valid_tags." @@ -766,7 +761,7 @@ def test_object_tags(self, object_class): def test_inheritance(self, object_class): """Check that object inherits from BaseObject.""" assert issubclass(object_class, BaseObject), ( - f"object: {object_class} " f"is not a sub-class of " f"BaseObject." + f"object: {object_class} is not a sub-class of BaseObject." ) # Usually should inherit only from one BaseObject type if self.valid_base_types is not None: @@ -988,13 +983,13 @@ def param_filter(p): def test_valid_object_class_tags(self, object_class): """Check that object class tags are in self.valid_tags.""" if self.valid_tags is None: - return None - for tag in object_class.get_class_tags().keys(): + return + for tag in object_class.get_class_tags(): assert tag in self.valid_tags def test_valid_object_tags(self, object_instance): """Check that object tags are in self.valid_tags.""" if self.valid_tags is None: - return None - for tag in object_instance.get_tags().keys(): + return + for tag in object_instance.get_tags(): assert tag in self.valid_tags diff --git a/skbase/testing/utils/__init__.py b/skbase/testing/utils/__init__.py index 550c8c89..6fcd6756 100644 --- a/skbase/testing/utils/__init__.py +++ b/skbase/testing/utils/__init__.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- """Utilities for the test framework.""" -from typing import List - -__all__: List[str] = [] +__all__: list[str] = [] diff --git a/skbase/testing/utils/_conditional_fixtures.py b/skbase/testing/utils/_conditional_fixtures.py index 01689311..53b21112 100644 --- a/skbase/testing/utils/_conditional_fixtures.py +++ b/skbase/testing/utils/_conditional_fixtures.py @@ -4,22 +4,22 @@ Exports create_conditional_fixtures_and_names utility """ +from collections.abc import Callable from copy import deepcopy -from typing import Callable, Dict, List from skbase._exceptions import FixtureGenerationError from skbase.utils._nested_iter import _remove_single from skbase.validate import check_sequence -__author__: List[str] = ["fkiraly"] -__all__: List[str] = ["create_conditional_fixtures_and_names"] +__author__: list[str] = ["fkiraly"] +__all__: list[str] = ["create_conditional_fixtures_and_names"] def create_conditional_fixtures_and_names( test_name: str, - fixture_vars: List[str], - generator_dict: Dict[str, Callable], - fixture_sequence: List[str] = None, + fixture_vars: list[str], + generator_dict: dict[str, Callable], + fixture_sequence: list[str] | None = None, raise_exceptions: bool = False, deepcopy_fixtures: bool = False, ): @@ -101,7 +101,7 @@ def create_conditional_fixtures_and_names( fixture_vars = check_sequence( fixture_vars, sequence_type=list, element_type=str, sequence_name="fixture_vars" ) - fixture_vars = [var for var in fixture_vars if var in generator_dict.keys()] + fixture_vars = [var for var in fixture_vars if var in generator_dict] # order fixture_vars according to fixture_sequence if provided if fixture_sequence is not None: @@ -153,7 +153,7 @@ def get_fixtures(fixture_var, **kwargs): except Exception as err: error = FixtureGenerationError(fixture_name=fixture_var, err=err) if raise_exceptions: - raise error + raise error from err fixture_prod = [error] fixture_names = [f"Error:{fixture_var}"] @@ -176,12 +176,12 @@ def get_fixtures(fixture_var, **kwargs): if i == 0: kwargs = {} else: - kwargs = dict(zip(old_fixture_vars, fixture)) + kwargs = dict(zip(old_fixture_vars, fixture, strict=False)) # retrieve conditional fixtures, conditional on fixture values in kwargs new_fixtures, new_fixture_names_r = get_fixtures(fixture_var, **kwargs) # new fixture values are concatenation/product of old values plus new new_fixture_prod += [ - fixture + (new_fixture,) for new_fixture in new_fixtures + (*fixture, new_fixture) for new_fixture in new_fixtures ] # new fixture name is concatenation of name so far and "dash-new name" # if the new name is empty string, don't add a dash diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index dec1b79b..6f253196 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -1,20 +1,20 @@ # -*- coding: utf-8 -*- """Common functionality for skbase unit tests.""" -from typing import List +from typing import ClassVar from skbase.base import BaseEstimator, BaseObject -__all__: List[str] = [ +__all__: list[str] = [ "SKBASE_BASE_CLASSES", + "SKBASE_CLASSES_BY_MODULE", + "SKBASE_FUNCTIONS_BY_MODULE", "SKBASE_MODULES", - "SKBASE_PUBLIC_MODULES", "SKBASE_PUBLIC_CLASSES_BY_MODULE", - "SKBASE_CLASSES_BY_MODULE", "SKBASE_PUBLIC_FUNCTIONS_BY_MODULE", - "SKBASE_FUNCTIONS_BY_MODULE", + "SKBASE_PUBLIC_MODULES", ] -__author__: List[str] = ["fkiraly", "RNKuhns"] +__author__: list[str] = ["fkiraly", "RNKuhns"] # bug 442 fixed: metaclasses now discovered correctly on all Python versions IMPORT_CLS = ("CommonMagicMeta", "MagicAttribute") @@ -332,7 +332,7 @@ class Parent(BaseObject): """Parent class to test BaseObject's usage.""" - _tags = {"A": "1", "B": 2, "C": 1234, "3": "D"} + _tags: ClassVar[dict] = {"A": "1", "B": 2, "C": 1234, "3": "D"} def __init__(self, a="something", b=7, c=None): """Initialize the class.""" @@ -343,36 +343,31 @@ def __init__(self, a="something", b=7, c=None): def some_method(self): """To be implemented by child class.""" - pass # Fixture class for testing tag system, child overrides tags class Child(Parent): """Child class that is child of FixtureClassParent.""" - _tags = {"A": 42, "3": "E"} - __author__ = ["fkiraly", "RNKuhns"] + _tags: ClassVar[dict] = {"A": 42, "3": "E"} + __author__: ClassVar[list[str]] = ["fkiraly", "RNKuhns"] def some_method(self): """Child class' implementation.""" - pass def some_other_method(self): """To be implemented in the child class.""" - pass # Fixture class for testing tag system, child overrides tags class ClassWithABTrue(Parent): """Child class that sets A, B tags to True.""" - _tags = {"A": True, "B": True} - __author__ = ["fkiraly", "RNKuhns"] + _tags: ClassVar[dict] = {"A": True, "B": True} + __author__: ClassVar[list[str]] = ["fkiraly", "RNKuhns"] def some_method(self): """Child class' implementation.""" - pass def some_other_method(self): """To be implemented in the child class.""" - pass diff --git a/skbase/tests/mock_package/__init__.py b/skbase/tests/mock_package/__init__.py index ec467ed8..cc6f4a25 100644 --- a/skbase/tests/mock_package/__init__.py +++ b/skbase/tests/mock_package/__init__.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- """Mock package for skbase testing.""" -from typing import List - -__author__: List[str] = ["fkiraly", "RNKuhns"] +__author__: list[str] = ["fkiraly", "RNKuhns"] diff --git a/skbase/tests/mock_package/test_mock_package.py b/skbase/tests/mock_package/test_mock_package.py index 2ecbf6b9..ebe859ef 100644 --- a/skbase/tests/mock_package/test_mock_package.py +++ b/skbase/tests/mock_package/test_mock_package.py @@ -2,17 +2,16 @@ """Mock package for testing skbase functionality.""" from copy import deepcopy -from typing import List from skbase.base import BaseObject -__all__: List[str] = [ +__all__: list[str] = [ + "AnotherClass", "CompositionDummy", "InheritsFromBaseObject", - "AnotherClass", "NotABaseObject", ] -__author__: List[str] = ["fkiraly", "RNKuhns"] +__author__: list[str] = ["fkiraly", "RNKuhns"] class CompositionDummy(BaseObject): @@ -23,7 +22,7 @@ def __init__(self, foo, bar=84): self.foo_ = deepcopy(foo) self.bar = bar - super(CompositionDummy, self).__init__() + super().__init__() @classmethod def get_test_params(cls, parameter_set="default"): diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 23b4582c..0880e5cc 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -21,55 +21,55 @@ __author__ = ["fkiraly", "RNKuhns"] __all__ = [ - "test_get_class_tags", - "test_get_class_tag", - "test_get_tags", - "test_get_tag", - "test_get_tag_raises", - "test_set_tags", - "test_set_tags_works_with_missing_tags_dynamic_attribute", + "test_baseobject_repr", + "test_baseobject_repr_mimebundle_", + "test_baseobject_str", + "test_clone", + "test_clone_2", + "test_clone_class_rather_than_instance_raises_error", + "test_clone_estimator_types", + "test_clone_none_and_empty_array_nan_sparse_matrix", + "test_clone_raises_error_for_nonconforming_objects", + "test_clone_sklearn_composite", "test_clone_tags", - "test_is_composite", "test_components", - "test_components_raises_error_base_class_is_not_class", "test_components_raises_error_base_class_is_not_baseobject_subclass", - "test_param_alias", - "test_nested_set_params_and_alias", - "test_reset", - "test_reset_composite", + "test_components_raises_error_base_class_is_not_class", + "test_create_test_instance", + "test_create_test_instances_and_names", + "test_eq_dunder", + "test_get_class_tag", + "test_get_class_tags", "test_get_init_signature", "test_get_init_signature_raises_error_for_invalid_signature", "test_get_param_names", "test_get_params", - "test_get_params_invariance", "test_get_params_after_set_params", + "test_get_params_invariance", + "test_get_tag", + "test_get_tag_raises", + "test_get_tags", + "test_get_test_params", + "test_get_test_params_raises_error_when_params_required", + "test_has_implementation_of", + "test_is_composite", + "test_nested_set_params_and_alias", + "test_param_alias", + "test_raises_on_get_params_for_param_arg_not_assigned_to_attribute", + "test_repr_html_wraps", + "test_reset", + "test_reset_composite", "test_set_params", "test_set_params_raises_error_non_existent_param", "test_set_params_raises_error_non_interface_composite", - "test_raises_on_get_params_for_param_arg_not_assigned_to_attribute", "test_set_params_with_no_param_to_set_returns_object", - "test_clone", - "test_clone_2", - "test_clone_raises_error_for_nonconforming_objects", - "test_clone_none_and_empty_array_nan_sparse_matrix", - "test_clone_estimator_types", - "test_clone_class_rather_than_instance_raises_error", - "test_clone_sklearn_composite", - "test_baseobject_repr", - "test_baseobject_str", - "test_baseobject_repr_mimebundle_", - "test_repr_html_wraps", - "test_get_test_params", - "test_get_test_params_raises_error_when_params_required", - "test_create_test_instance", - "test_create_test_instances_and_names", - "test_has_implementation_of", - "test_eq_dunder", + "test_set_tags", + "test_set_tags_works_with_missing_tags_dynamic_attribute", ] import inspect from copy import deepcopy -from typing import Any, Dict, Type +from typing import Any, ClassVar import numpy as np import pytest @@ -112,7 +112,7 @@ def __init__(self, a, *args): class RequiredParam(BaseObject): """BaseObject class with _required_parameters.""" - _required_parameters = ["a"] + _required_parameters: ClassVar[list[str]] = ["a"] def __init__(self, a, b=7): self.a = a @@ -198,7 +198,7 @@ def fixture_reset_tester(): @pytest.fixture -def fixture_class_child_tags(fixture_class_child: Type[Child]): +def fixture_class_child_tags(fixture_class_child: type[Child]): """Pytest fixture for tags of Child.""" return fixture_class_child.get_class_tags() @@ -265,7 +265,7 @@ def fixture_class_instance_no_param_interface(): def test_get_class_tags( - fixture_class_child: Type[Child], fixture_class_child_tags: Any + fixture_class_child: type[Child], fixture_class_child_tags: Any ): """Test get_class_tags class method of BaseObject for correctness. @@ -280,7 +280,7 @@ def test_get_class_tags( assert child_tags == fixture_class_child_tags, msg -def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any): +def test_get_class_tag(fixture_class_child: type[Child], fixture_class_child_tags: Any): """Test get_class_tag class method of BaseObject for correctness. Raises @@ -307,7 +307,7 @@ def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tag assert child_tag_default_none is None, msg -def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): +def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: dict[str, Any]): """Test get_tags method of BaseObject for correctness. Raises @@ -321,7 +321,7 @@ def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str assert object_tags == fixture_object_tags, msg -def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): +def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: dict[str, Any]): """Test get_tag method of BaseObject for correctness. Raises @@ -364,8 +364,8 @@ def test_get_tag_raises(fixture_tag_class_object: Child): def test_set_tags( fixture_object_instance_set_tags: Any, - fixture_object_set_tags: Dict[str, Any], - fixture_object_dynamic_tags: Dict[str, int], + fixture_object_set_tags: dict[str, Any], + fixture_object_dynamic_tags: dict[str, int], ): """Test set_tags method of BaseObject for correctness. @@ -387,7 +387,7 @@ def test_set_tags_works_with_missing_tags_dynamic_attribute( """Test set_tags will still work if _tags_dynamic is missing.""" base_obj = deepcopy(fixture_tag_class_object) attr_name = "_tags_dynamic" - delattr(base_obj, attr_name) # noqa + delattr(base_obj, attr_name) assert not hasattr(base_obj, "_tags_dynamic") base_obj.set_tags(some_tag="something") tags = base_obj.get_tags() @@ -399,7 +399,7 @@ def test_clone_tags(): """Test clone_tags works as expected.""" class TestClass(BaseObject): - _tags = {"some_tag": True, "another_tag": 37} + _tags: ClassVar[dict] = {"some_tag": True, "another_tag": 37} class AnotherTestClass(BaseObject): pass @@ -463,7 +463,7 @@ class AnotherTestClass(BaseObject): assert test_obj_tags.get(tag) == another_base_obj_tags[tag] -def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]): +def test_is_composite(fixture_composition_dummy: type[CompositionDummy]): """Test is_composite tag for correctness. Raises @@ -478,9 +478,9 @@ def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]): def test_components( - fixture_object: Type[BaseObject], - fixture_class_parent: Type[Parent], - fixture_composition_dummy: Type[CompositionDummy], + fixture_object: type[BaseObject], + fixture_class_parent: type[Parent], + fixture_composition_dummy: type[CompositionDummy], ): """Test component retrieval. @@ -514,7 +514,7 @@ def test_components( def test_components_raises_error_base_class_is_not_class( - fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy] + fixture_object: type[BaseObject], fixture_composition_dummy: type[CompositionDummy] ): """Test _component method raises error if base_class param is not class.""" non_composite = fixture_composition_dummy(foo=42) @@ -533,7 +533,7 @@ def test_components_raises_error_base_class_is_not_class( def test_components_raises_error_base_class_is_not_baseobject_subclass( - fixture_composition_dummy: Type[CompositionDummy], + fixture_composition_dummy: type[CompositionDummy], ): """Test _component method raises error if base_class is not BaseObject subclass.""" @@ -563,18 +563,18 @@ def test_param_alias(): composite = CompositionDummy(foo=non_composite) # this should write to a of foo, because there is only one suffix called a - composite.set_params(**{"a": 424242}) + composite.set_params(a=424242) assert composite.get_params()["foo__a"] == 424242 # this should write to bar of composite, because "bar" is a full parameter string # there is a suffix in foo, but if the full string is there, it writes to that - composite.set_params(**{"bar": 424243}) + composite.set_params(bar=424243) assert composite.get_params()["bar"] == 424243 # trying to write to bad_param should raise an exception # since bad_param is neither a suffix nor a full parameter string with pytest.raises(ValueError, match=r"Invalid parameter keys provided to"): - composite.set_params(**{"bad_param": 424242}) + composite.set_params(bad_param=424242) # new example: highly nested composite with identical suffixes non_composite1 = composite @@ -584,10 +584,10 @@ def test_param_alias(): # trying to write to a should raise an exception # since there are two suffix a, and a is not a full parameter string with pytest.raises(ValueError, match=r"does not uniquely determine parameter key"): - uber_composite.set_params(**{"a": 424242}) + uber_composite.set_params(a=424242) # same as above, should overwrite "bar" of uber_composite - uber_composite.set_params(**{"bar": 424243}) + uber_composite.set_params(bar=424243) assert uber_composite.get_params()["bar"] == 424243 @@ -611,14 +611,14 @@ def test_nested_set_params_and_alias(): # this should write to a of foo # potential error here is that composite does not have foo__a to start with # so error catching or writing foo__a to early could cause an exception - composite.set_params(**{"foo": non_composite, "foo__a": 424242}) + composite.set_params(foo=non_composite, foo__a=424242) assert composite.get_params()["foo__a"] == 424242 non_composite = AliasTester(a=42, bar=4242) composite = CompositionDummy(foo=0) # same, and recognizing that foo__a is the only matching suffix in the end state - composite.set_params(**{"foo": non_composite, "a": 424242}) + composite.set_params(foo=non_composite, a=424242) assert composite.get_params()["foo__a"] == 424242 # new example: highly nested composite with identical suffixes @@ -629,20 +629,18 @@ def test_nested_set_params_and_alias(): # trying to write to a should raise an exception # since there are two suffix a, and a is not a full parameter string with pytest.raises(ValueError, match=r"does not uniquely determine parameter key"): - uber_composite.set_params( - **{"a": 424242, "foo": non_composite1, "bar": non_composite2} - ) + uber_composite.set_params(a=424242, foo=non_composite1, bar=non_composite2) uber_composite = CompositionDummy(foo=non_composite1, bar=42) # same as above, should overwrite "bar" of uber_composite - uber_composite.set_params(**{"bar": 424243}) + uber_composite.set_params(bar=424243) assert uber_composite.get_params()["bar"] == 424243 # Test parameter interface (get_params, set_params, reset and related methods) # Some tests of get_params and set_params are adapted from sklearn tests -def test_reset(fixture_reset_tester: Type[ResetTester]): +def test_reset(fixture_reset_tester: type[ResetTester]): """Test reset method for correct behaviour, on a simple estimator. Raises @@ -669,7 +667,7 @@ def test_reset(fixture_reset_tester: Type[ResetTester]): assert hasattr(x, "foo") -def test_reset_composite(fixture_reset_tester: Type[ResetTester]): +def test_reset_composite(fixture_reset_tester: type[ResetTester]): """Test reset method for correct behaviour, on a composite estimator.""" y = fixture_reset_tester(42) x = fixture_reset_tester(a=y) @@ -684,20 +682,20 @@ def test_reset_composite(fixture_reset_tester: Type[ResetTester]): assert not hasattr(x.a, "d") -def test_get_init_signature(fixture_class_parent: Type[Parent]): +def test_get_init_signature(fixture_class_parent: type[Parent]): """Test error is raised when invalid init signature is used.""" init_sig = fixture_class_parent._get_init_signature() init_sig_is_list = isinstance(init_sig, list) init_sig_elements_are_params = all( isinstance(p, inspect.Parameter) for p in init_sig ) - assert ( - init_sig_is_list and init_sig_elements_are_params - ), "`_get_init_signature` is not returning expected result." + assert init_sig_is_list and init_sig_elements_are_params, ( + "`_get_init_signature` is not returning expected result." + ) def test_get_init_signature_raises_error_for_invalid_signature( - fixture_invalid_init: Type[InvalidInitSignatureTester], + fixture_invalid_init: type[InvalidInitSignatureTester], ): """Test error is raised when invalid init signature is used.""" with pytest.raises(RuntimeError): @@ -706,9 +704,9 @@ def test_get_init_signature_raises_error_for_invalid_signature( @pytest.mark.parametrize("sort", [True, False]) def test_get_param_names( - fixture_object: Type[BaseObject], - fixture_class_parent: Type[Parent], - fixture_class_parent_expected_params: Dict[str, Any], + fixture_object: type[BaseObject], + fixture_class_parent: type[Parent], + fixture_class_parent_expected_params: dict[str, Any], sort: bool, ): """Test that get_param_names returns list of string parameter names.""" @@ -723,10 +721,10 @@ def test_get_param_names( def test_get_params( - fixture_class_parent: Type[Parent], - fixture_class_parent_expected_params: Dict[str, Any], + fixture_class_parent: type[Parent], + fixture_class_parent_expected_params: dict[str, Any], fixture_class_instance_no_param_interface: NoParamInterface, - fixture_composition_dummy: Type[CompositionDummy], + fixture_composition_dummy: type[CompositionDummy], ): """Test get_params returns expected parameters.""" # Simple test of returned params @@ -750,8 +748,8 @@ def test_get_params( def test_get_params_invariance( - fixture_class_parent: Type[Parent], - fixture_composition_dummy: Type[CompositionDummy], + fixture_class_parent: type[Parent], + fixture_composition_dummy: type[CompositionDummy], ): """Test that get_params(deep=False) is subset of get_params(deep=True).""" composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84) @@ -760,7 +758,7 @@ def test_get_params_invariance( assert all(item in deep_params.items() for item in shallow_params.items()) -def test_get_params_after_set_params(fixture_class_parent: Type[Parent]): +def test_get_params_after_set_params(fixture_class_parent: type[Parent]): """Test that get_params returns the same thing before and after set_params. Based on scikit-learn check in check_estimator. @@ -780,7 +778,7 @@ def test_get_params_after_set_params(fixture_class_parent: Type[Parent]): test_values = [-np.inf, np.inf, None] test_params = deepcopy(orig_params) - for param_name in orig_params.keys(): + for param_name in orig_params: default_value = orig_params[param_name] for value in test_values: test_params[param_name] = value @@ -801,9 +799,9 @@ def test_get_params_after_set_params(fixture_class_parent: Type[Parent]): def test_set_params( - fixture_class_parent: Type[Parent], - fixture_class_parent_expected_params: Dict[str, Any], - fixture_composition_dummy: Type[CompositionDummy], + fixture_class_parent: type[Parent], + fixture_class_parent_expected_params: dict[str, Any], + fixture_composition_dummy: type[CompositionDummy], ): """Test set_params works as expected.""" # Simple case of setting a parameter @@ -826,7 +824,7 @@ def test_set_params( def test_set_params_raises_error_non_existent_param( fixture_class_parent_instance: Parent, - fixture_composition_dummy: Type[CompositionDummy], + fixture_composition_dummy: type[CompositionDummy], ): """Test set_params raises an error when passed a non-existent parameter name.""" # non-existing parameter in svc @@ -843,7 +841,7 @@ def test_set_params_raises_error_non_existent_param( def test_set_params_raises_error_non_interface_composite( fixture_class_instance_no_param_interface: NoParamInterface, - fixture_composition_dummy: Type[CompositionDummy], + fixture_composition_dummy: type[CompositionDummy], ): """Test set_params raises error when setting param of non-conforming composite.""" # When a composite is made up of a class that doesn't have the BaseObject @@ -870,7 +868,7 @@ def __init__(self, param=5): def test_set_params_with_no_param_to_set_returns_object( - fixture_class_parent: Type[Parent], + fixture_class_parent: type[Parent], ): """Test set_params correctly returns self when no parameters are set.""" base_obj = fixture_class_parent() @@ -908,19 +906,19 @@ def test_clone_2(fixture_class_parent_instance: Parent): def test_clone_raises_error_for_nonconforming_objects( - fixture_invalid_init: Type[InvalidInitSignatureTester], - fixture_buggy: Type[Buggy], - fixture_modify_param: Type[ModifyParam], + fixture_invalid_init: type[InvalidInitSignatureTester], + fixture_buggy: type[Buggy], + fixture_modify_param: type[ModifyParam], ): """Test that clone raises an error on nonconforming BaseObjects.""" buggy = fixture_buggy() - buggy.set_config(**{"check_clone": True}) + buggy.set_config(check_clone=True) buggy.a = 2 with pytest.raises(RuntimeError): buggy.clone() varg_obj = fixture_invalid_init(a=7) - varg_obj.set_config(**{"check_clone": True}) + varg_obj.set_config(check_clone=True) with pytest.raises(RuntimeError): varg_obj.clone() @@ -939,17 +937,17 @@ def test_config_after_clone_tags(clone_config): """Test clone also clones config works as expected.""" class TestClass(BaseObject): - _tags = {"some_tag": True, "another_tag": 37} - _config = {"check_clone": 0} + _tags: ClassVar[dict] = {"some_tag": True, "another_tag": 37} + _config: ClassVar[dict] = {"check_clone": 0} test_obj = TestClass() - test_obj.set_config(**{"check_clone": 42, "foo": "bar"}) + test_obj.set_config(check_clone=42, foo="bar") if not clone_config: # if clone_config config is set to False: # config key check_clone should be default, 0 # the new config key foo should not be present - test_obj.set_config(**{"clone_config": False}) + test_obj.set_config(clone_config=False) expected = 0 else: # if clone_config config is set to True: @@ -959,14 +957,14 @@ class TestClass(BaseObject): test_obj_clone = test_obj.clone() - assert "check_clone" in test_obj_clone.get_config().keys() + assert "check_clone" in test_obj_clone.get_config() assert test_obj_clone.get_config()["check_clone"] == expected if clone_config: - assert "foo" in test_obj_clone.get_config().keys() + assert "foo" in test_obj_clone.get_config() assert test_obj_clone.get_config()["foo"] == "bar" else: - assert "foo" not in test_obj_clone.get_config().keys() + assert "foo" not in test_obj_clone.get_config() @pytest.mark.parametrize("clone_config", [True, False]) @@ -974,7 +972,7 @@ def test_nested_config_after_clone_tags(clone_config): """Test clone also clones config of nested objects.""" class TestClass(BaseObject): - _config = {"check_clone": 0} + _config: ClassVar[dict] = {"check_clone": 0} class TestNestedClass(BaseObject): def __init__(self, obj, obj_iterable): @@ -982,16 +980,16 @@ def __init__(self, obj, obj_iterable): self.obj_iterable = obj_iterable test_obj = TestNestedClass( - obj=TestClass().set_config(**{"check_clone": 1, "foo": "bar"}), - obj_iterable=[TestClass().set_config(**{"check_clone": 2, "foo": "barz"})], + obj=TestClass().set_config(check_clone=1, foo="bar"), + obj_iterable=[TestClass().set_config(check_clone=2, foo="barz")], ) if not clone_config: # if clone_config config is set to False: # config key check_clone should be default, 0 # the new config key foo should not be present - test_obj.obj.set_config(**{"clone_config": False}) - test_obj.obj_iterable[0].set_config(**{"clone_config": False}) + test_obj.obj.set_config(clone_config=False) + test_obj.obj_iterable[0].set_config(clone_config=False) expected_obj = 0 expected_obj_iterable = 0 else: @@ -1004,19 +1002,19 @@ def __init__(self, obj, obj_iterable): test_obj_clone = test_obj.clone().obj test_obj_iterable_clone = test_obj.clone().obj_iterable[0] - assert "check_clone" in test_obj_clone.get_config().keys() - assert "check_clone" in test_obj_iterable_clone.get_config().keys() + assert "check_clone" in test_obj_clone.get_config() + assert "check_clone" in test_obj_iterable_clone.get_config() assert test_obj_clone.get_config()["check_clone"] == expected_obj assert test_obj_iterable_clone.get_config()["check_clone"] == expected_obj_iterable if clone_config: - assert "foo" in test_obj_clone.get_config().keys() - assert "foo" in test_obj_iterable_clone.get_config().keys() + assert "foo" in test_obj_clone.get_config() + assert "foo" in test_obj_iterable_clone.get_config() assert test_obj_clone.get_config()["foo"] == "bar" assert test_obj_iterable_clone.get_config()["foo"] == "barz" else: - assert "foo" not in test_obj_clone.get_config().keys() - assert "foo" not in test_obj_iterable_clone.get_config().keys() + assert "foo" not in test_obj_clone.get_config() + assert "foo" not in test_obj_iterable_clone.get_config() @pytest.mark.skipif( @@ -1033,7 +1031,7 @@ def __init__(self, obj, obj_iterable): ], ) def test_clone_none_and_empty_array_nan_sparse_matrix( - fixture_class_parent: Type[Parent], c_value + fixture_class_parent: type[Parent], c_value ): from sklearn.base import clone @@ -1052,7 +1050,7 @@ def test_clone_none_and_empty_array_nan_sparse_matrix( assert base_obj.c is new_base_obj2.c -def test_clone_estimator_types(fixture_class_parent: Type[Parent]): +def test_clone_estimator_types(fixture_class_parent: type[Parent]): """Test clone works for parameters that are types rather than instances.""" base_obj = fixture_class_parent(c=fixture_class_parent) new_base_obj = base_obj.clone() @@ -1065,7 +1063,7 @@ def test_clone_estimator_types(fixture_class_parent: Type[Parent]): reason="skip test if sklearn is not available", ) # sklearn is part of the dev dependency set, test should be executed with that def test_clone_class_rather_than_instance_raises_error( - fixture_class_parent: Type[Parent], + fixture_class_parent: type[Parent], ): """Test clone raises expected error when cloning a class instead of an instance.""" from sklearn.base import clone @@ -1109,8 +1107,8 @@ def test_clone_sklearn_composite_retains_config(): # Tests of BaseObject pretty printing representation inspired by sklearn def test_baseobject_repr( - fixture_class_parent: Type[Parent], - fixture_composition_dummy: Type[CompositionDummy], + fixture_class_parent: type[Parent], + fixture_composition_dummy: type[CompositionDummy], ): """Test BaseObject repr works as expected.""" # Simple test where all parameters are left at defaults @@ -1149,16 +1147,16 @@ def test_baseobject_repr( long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000) assert len(repr(long_base_obj_repr)) == 535 - named_objs = [(f"Step {i+1}", Child()) for i in range(25)] + named_objs = [(f"Step {i + 1}", Child()) for i in range(25)] base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs))) assert len(repr(base_comp)) == 1362 def test_baseobject_str(fixture_class_parent_instance: Parent): """Test BaseObject string representation works.""" - assert ( - str(fixture_class_parent_instance) == "Parent()" - ), "String representation of instance not working." + assert str(fixture_class_parent_instance) == "Parent()", ( + "String representation of instance not working." + ) # Check that local config works as expected fixture_class_parent_instance.set_config(print_changed_only=False) @@ -1200,7 +1198,7 @@ def test_get_test_params(fixture_class_parent_instance: Parent): def test_get_test_params_raises_error_when_params_required( - fixture_required_param: Type[RequiredParam], + fixture_required_param: type[RequiredParam], ): """Test get_test_params raises an error when parameters are required.""" with pytest.raises(ValueError): @@ -1208,7 +1206,7 @@ def test_get_test_params_raises_error_when_params_required( def test_create_test_instance( - fixture_class_parent: Type[Parent], fixture_class_parent_instance: Parent + fixture_class_parent: type[Parent], fixture_class_parent_instance: Parent ): """Test first that create_test_instance logic works.""" base_obj = fixture_class_parent.create_test_instance() @@ -1275,9 +1273,9 @@ def test_has_implementation_of( class ConfigTester(BaseObject): - _config = {"foo_config": 42, "bar": "a"} + _config: ClassVar[dict] = {"foo_config": 42, "bar": "a"} - clsvar = 210 + clsvar: ClassVar[int] = 210 def __init__(self, a, b=42): self.a = a @@ -1286,9 +1284,9 @@ def __init__(self, a, b=42): class AnotherConfigTester(BaseObject): - _config = {"print_changed_only": False, "bar": "a"} + _config: ClassVar[dict] = {"print_changed_only": False, "bar": "a"} - clsvar = 210 + clsvar: ClassVar[int] = 210 def __init__(self, a, b=42): self.a = a @@ -1382,9 +1380,9 @@ def test_get_set_config(): """Tests get_config and set_config methods.""" class _TestConfig(BaseObject): - _config = {"foo_config": 42, "bar": "a"} + _config: ClassVar[dict] = {"foo_config": 42, "bar": "a"} - clsvar = 210 + clsvar: ClassVar[int] = 210 def __init__(self, a, b=42): self.a = a diff --git a/skbase/tests/test_baseestimator.py b/skbase/tests/test_baseestimator.py index 606b71bc..8851978b 100644 --- a/skbase/tests/test_baseestimator.py +++ b/skbase/tests/test_baseestimator.py @@ -49,18 +49,18 @@ def test_has_is_fitted(fixture_estimator_instance): """Test BaseEstimator has `is_fitted` property.""" has_private_is_fitted = hasattr(fixture_estimator_instance, "_is_fitted") has_is_fitted = hasattr(fixture_estimator_instance, "is_fitted") - assert ( - has_private_is_fitted and has_is_fitted - ), "BaseEstimator does not have `is_fitted` property;" + assert has_private_is_fitted and has_is_fitted, ( + "BaseEstimator does not have `is_fitted` property;" + ) def test_has_check_is_fitted(fixture_estimator_instance): """Test BaseEstimator has `check_is_fitted` method.""" has_check_is_fitted = hasattr(fixture_estimator_instance, "check_is_fitted") is_method = inspect.ismethod(fixture_estimator_instance.check_is_fitted) - assert ( - has_check_is_fitted and is_method - ), "`BaseEstimator` does not have `check_is_fitted` method." + assert has_check_is_fitted and is_method, ( + "`BaseEstimator` does not have `check_is_fitted` method." + ) def test_is_fitted(fixture_estimator_instance): @@ -68,9 +68,9 @@ def test_is_fitted(fixture_estimator_instance): expected_value_unfitted = ( fixture_estimator_instance.is_fitted == fixture_estimator_instance._is_fitted ) - assert ( - expected_value_unfitted - ), "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value." + assert expected_value_unfitted, ( + "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value." + ) def test_check_is_fitted_raises_error_when_unfitted(fixture_estimator_instance): diff --git a/skbase/tests/test_exceptions.py b/skbase/tests/test_exceptions.py index fb83a509..db706a1c 100644 --- a/skbase/tests/test_exceptions.py +++ b/skbase/tests/test_exceptions.py @@ -6,13 +6,11 @@ test_exceptions_raise_error - Test that skbase exceptions raise expected error. """ -from typing import List - import pytest from skbase._exceptions import FixtureGenerationError, NotFittedError -__author__: List[str] = ["RNKuhns"] +__author__: list[str] = ["RNKuhns"] ALL_EXCEPTIONS = [FixtureGenerationError, NotFittedError] diff --git a/skbase/tests/test_lookup.py b/skbase/tests/test_lookup.py index bb985b10..c3de3a35 100644 --- a/skbase/tests/test_lookup.py +++ b/skbase/tests/test_lookup.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """Tests for skbase.lookup utilities.""" import importlib @@ -14,7 +15,7 @@ def test_all_objects_returns_class_name_for_alias(tmp_path, monkeypatch): # create a tmp module to test all_objects behaviour (root / "__init__.py").write_text( - "from .module import AliasName\n" "__all__ = ['AliasName']\n" + "from .module import AliasName\n__all__ = ['AliasName']\n" ) (root / "module.py").write_text( "from skbase.base import BaseObject\n\n" diff --git a/skbase/tests/test_lookup_metaclasses.py b/skbase/tests/test_lookup_metaclasses.py index 8608de48..9d4adb39 100644 --- a/skbase/tests/test_lookup_metaclasses.py +++ b/skbase/tests/test_lookup_metaclasses.py @@ -3,13 +3,12 @@ import importlib import inspect -from typing import List from skbase.lookup import get_package_metadata from skbase.lookup._lookup import _get_members_uw from skbase.utils.dependencies._import import CommonMagicMeta, MagicAttribute -__author__: List[str] = ["SimonBlanke"] +__author__: list[str] = ["SimonBlanke"] def test_get_members_uw_discovers_metaclass_classes(): diff --git a/skbase/tests/test_meta.py b/skbase/tests/test_meta.py index e5d1a631..19934807 100644 --- a/skbase/tests/test_meta.py +++ b/skbase/tests/test_meta.py @@ -100,18 +100,18 @@ def test_basemetaestimator_has_is_fitted(fixture_metaestimator_instance): """Test BaseEstimator has `is_fitted` property.""" has_private_is_fitted = hasattr(fixture_metaestimator_instance, "_is_fitted") has_is_fitted = hasattr(fixture_metaestimator_instance, "is_fitted") - assert ( - has_private_is_fitted and has_is_fitted - ), "`BaseMetaEstimator` does not have `is_fitted` property or `_is_fitted` attr." + assert has_private_is_fitted and has_is_fitted, ( + "`BaseMetaEstimator` does not have `is_fitted` property or `_is_fitted` attr." + ) def test_basemetaestimator_has_check_is_fitted(fixture_metaestimator_instance): """Test BaseEstimator has `check_is_fitted` method.""" has_check_is_fitted = hasattr(fixture_metaestimator_instance, "check_is_fitted") is_method = inspect.ismethod(fixture_metaestimator_instance.check_is_fitted) - assert ( - has_check_is_fitted and is_method - ), "`BaseMetaEstimator` does not have `check_is_fitted` method." + assert has_check_is_fitted and is_method, ( + "`BaseMetaEstimator` does not have `check_is_fitted` method." + ) @pytest.mark.parametrize("is_fitted_value", (True, False)) @@ -122,9 +122,9 @@ def test_basemetaestimator_is_fitted(fixture_metaestimator_instance, is_fitted_v fixture_metaestimator_instance.is_fitted == fixture_metaestimator_instance._is_fitted ) - assert ( - expected_value_unfitted - ), "`BaseMetaEstimator` property `is_fitted` does not return `_is_fitted` value." + assert expected_value_unfitted, ( + "`BaseMetaEstimator` property `is_fitted` does not return `_is_fitted` value." + ) def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted( @@ -189,12 +189,12 @@ def test_set_params_resets_fitted_state(): meta_obj.set_params(steps=new_steps) # Fitted state should be gone after set_params - assert not hasattr( - meta_obj, "fitted_attr_" - ), "fitted_attr_ should be removed by reset() during set_params(steps=...)" - assert not hasattr( - meta_obj, "another_fitted_" - ), "another_fitted_ should be removed by reset() during set_params(steps=...)" + assert not hasattr(meta_obj, "fitted_attr_"), ( + "fitted_attr_ should be removed by reset() during set_params(steps=...)" + ) + assert not hasattr(meta_obj, "another_fitted_"), ( + "another_fitted_ should be removed by reset() during set_params(steps=...)" + ) # Test 2: Replacing individual step should also trigger reset meta_obj = MetaObjectTester(steps=steps) @@ -202,6 +202,6 @@ def test_set_params_resets_fitted_state(): meta_obj.set_params(foo=ComponentDummy(77)) - assert not hasattr( - meta_obj, "fitted_attr_" - ), "fitted_attr_ should be removed by reset() during set_params(foo=...)" + assert not hasattr(meta_obj, "fitted_attr_"), ( + "fitted_attr_ should be removed by reset() during set_params(foo=...)" + ) diff --git a/skbase/tests/test_tagaliaser.py b/skbase/tests/test_tagaliaser.py index ce3378a2..53145c03 100644 --- a/skbase/tests/test_tagaliaser.py +++ b/skbase/tests/test_tagaliaser.py @@ -2,6 +2,7 @@ """Tests the aliasing logic in the Tag Aliaser.""" import re +from typing import ClassVar import pytest @@ -12,19 +13,19 @@ class AliaserTestClass(_TagAliaserMixin, _BaseObject): """Class for testing tag aliasing logic.""" - _tags = { + _tags: ClassVar[dict] = { "new_tag_1": "new_tag_1_value", "old_tag_1": "old_tag_1_value", "new_tag_2": "new_tag_2_value", "old_tag_3": "old_tag_3_value", } - alias_dict = { + alias_dict: ClassVar[dict] = { "old_tag_1": "new_tag_1", "old_tag_2": "new_tag_2", "old_tag_3": "new_tag_3", } - deprecate_dict = { + deprecate_dict: ClassVar[dict] = { "old_tag_1": "42.0.0", "old_tag_2": "84.0.0", "old_tag_3": "126.0.0", diff --git a/skbase/utils/__init__.py b/skbase/utils/__init__.py index 529f53a5..5788deec 100644 --- a/skbase/utils/__init__.py +++ b/skbase/utils/__init__.py @@ -3,8 +3,6 @@ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file) """Utility functionality used through `skbase`.""" -from typing import List - from skbase.utils._iter import make_strings_unique from skbase.utils._nested_iter import flatten, is_flat, unflat_len, unflatten from skbase.utils._utils import subset_dict_keys @@ -16,8 +14,8 @@ set_random_state, ) -__author__: List[str] = ["RNKuhns", "fkiraly"] -__all__: List[str] = [ +__author__: list[str] = ["RNKuhns", "fkiraly"] +__all__: list[str] = [ "check_random_state", "deep_equals", "flatten", diff --git a/skbase/utils/_iter.py b/skbase/utils/_iter.py index 5d31f459..23704080 100644 --- a/skbase/utils/_iter.py +++ b/skbase/utils/_iter.py @@ -11,9 +11,9 @@ __author__ = ["fkiraly", "RNKuhns"] __all__ = [ - "_scalar_to_seq", - "_remove_type_text", "_format_seq_to_str", + "_remove_type_text", + "_scalar_to_seq", "make_strings_unique", ] @@ -61,15 +61,12 @@ def _scalar_to_seq(scalar, sequence_type=None): # We'll treat str like regular scalar and not a sequence if isinstance(scalar, Sequence) and not isinstance(scalar, str): return scalar - elif sequence_type is None: + if sequence_type is None: return (scalar,) - elif issubclass(sequence_type, Sequence) and sequence_type != Sequence: + if issubclass(sequence_type, Sequence) and sequence_type != Sequence: # Note calling (scalar,) is done to avoid str unpacking return sequence_type((scalar,)) # type: ignore - else: - raise ValueError( - "`sequence_type` must be a subclass of collections.abc.Sequence." - ) + raise ValueError("`sequence_type` must be a subclass of collections.abc.Sequence.") def _remove_type_text(input_): @@ -80,8 +77,7 @@ def _remove_type_text(input_): m = re.match("^$", input_) if m: return m[1] - else: - return input_ + return input_ def _format_seq_to_str(seq, sep=", ", last_sep=None, remove_type_text=True): @@ -134,9 +130,9 @@ def _format_seq_to_str(seq, sep=", ", last_sep=None, remove_type_text=True): if isinstance(seq, str): return seq # Allow casting of scalars to strings - elif isinstance(seq, (int, float, bool, type)): + if isinstance(seq, (int, float, bool, type)): return _remove_type_text(seq) - elif not isinstance(seq, Sequence): + if not isinstance(seq, Sequence): raise TypeError( "`seq` must be a sequence or scalar str, int, float, bool or class." ) diff --git a/skbase/utils/_nested_iter.py b/skbase/utils/_nested_iter.py index 055e539f..7fb94c6d 100644 --- a/skbase/utils/_nested_iter.py +++ b/skbase/utils/_nested_iter.py @@ -4,13 +4,12 @@ """Functionality for working with nested sequences.""" import collections -from typing import List -__author__: List[str] = ["RNKuhns", "fkiraly"] -__all__: List[str] = [ +__author__: list[str] = ["RNKuhns", "fkiraly"] +__all__: list[str] = [ + "_remove_single", "flatten", "is_flat", - "_remove_single", "unflat_len", "unflatten", ] @@ -42,8 +41,7 @@ def _remove_single(x): """ if len(x) == 1: return x[0] - else: - return x + return x def flatten(obj): @@ -73,8 +71,7 @@ def flatten(obj): obj, (collections.abc.Iterable, collections.abc.Sequence) ) or isinstance(obj, str): return [obj] - else: - return type(obj)([y for x in obj for y in flatten(x)]) + return type(obj)([y for x in obj for y in flatten(x)]) def unflatten(obj, template): @@ -109,7 +106,7 @@ def unflatten(obj, template): ls = [unflat_len(x) for x in template] for i in range(1, len(ls)): ls[i] += ls[i - 1] - ls = [0] + ls + ls = [0, *ls] res = [unflatten(obj[ls[i] : ls[i + 1]], template[i]) for i in range(len(ls) - 1)] @@ -146,8 +143,7 @@ def unflat_len(obj): obj, (collections.abc.Iterable, collections.abc.Sequence) ) or isinstance(obj, str): return 1 - else: - return sum([unflat_len(x) for x in obj]) + return sum([unflat_len(x) for x in obj]) def is_flat(obj): diff --git a/skbase/utils/_utils.py b/skbase/utils/_utils.py index ba4eca26..9b402bd9 100644 --- a/skbase/utils/_utils.py +++ b/skbase/utils/_utils.py @@ -3,16 +3,17 @@ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file) """Functionality for working with sequences.""" -from typing import Any, Iterable, List, MutableMapping, Optional, Union +from collections.abc import Iterable, MutableMapping +from typing import Any -__author__: List[str] = ["RNKuhns"] -__all__: List[str] = ["subset_dict_keys"] +__author__: list[str] = ["RNKuhns"] +__all__: list[str] = ["subset_dict_keys"] def subset_dict_keys( input_dict: MutableMapping[Any, Any], - keys: Union[Iterable, int, float, bool, str, type], - prefix: Optional[str] = None, + keys: Iterable | int | float | bool | str | type, + prefix: str | None = None, remove_prefix: bool = True, ): """Subset dictionary so it only contains specified keys. @@ -76,17 +77,13 @@ def rem_prefix(x): return x[len(prefix__) :] # The way this is used below, this else shouldn't really execute # But its here for completeness in case something goes wrong - else: - return x # pragma: no cover + return x # pragma: no cover # Handle passage of certain scalar values if isinstance(keys, (str, float, int, bool, type)): keys = [keys] - if prefix is not None: - keys = [f"{prefix}__{key}" for key in keys] - else: - keys = list(keys) + keys = [f"{prefix}__{key}" for key in keys] if prefix is not None else list(keys) subsetted_dict = {rem_prefix(k): v for k, v in input_dict.items() if k in keys} return subsetted_dict diff --git a/skbase/utils/deep_equals/_common.py b/skbase/utils/deep_equals/_common.py index bc486f23..ba6d1551 100644 --- a/skbase/utils/deep_equals/_common.py +++ b/skbase/utils/deep_equals/_common.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- """Common utility functions for skbase.utils.deep_equals.""" +from typing import Any -def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False): + +def _ret(is_equal, msg="", string_arguments: list[Any] | None = None, return_msg=False): """Return is_equal and msg, formatted with string_arguments if return_msg=True. Parameters @@ -28,8 +30,7 @@ def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False): elif isinstance(string_arguments, (list, tuple)) and len(string_arguments) > 0: msg = msg.format(*string_arguments) return is_equal, msg - else: - return is_equal + return is_equal def _make_ret(return_msg): diff --git a/skbase/utils/deep_equals/_deep_equals.py b/skbase/utils/deep_equals/_deep_equals.py index cf73a5a7..8b98c922 100644 --- a/skbase/utils/deep_equals/_deep_equals.py +++ b/skbase/utils/deep_equals/_deep_equals.py @@ -112,8 +112,7 @@ def _is_npnan(x): return isinstance(x, float) and np.isnan(x) - else: - return False + return False def _coerce_list(x): @@ -131,8 +130,7 @@ def _numpy_equals_plugin(x, y, return_msg=False, deep_equals=None): if not numpy_available or not _is_npndarray(x): return None - else: - import numpy as np + import numpy as np ret = _make_ret(return_msg) @@ -144,15 +142,14 @@ def _numpy_equals_plugin(x, y, return_msg=False, deep_equals=None): return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}") if x.dtype == "str": return ret(np.array_equal(x, y), ".values") - elif x.dtype == "object": + if x.dtype == "object": x_flat = x.flatten() y_flat = y.flatten() for i in range(len(x_flat)): is_equal, msg = deep_equals(x_flat[i], y_flat[i], return_msg=True) return ret(is_equal, f"[{i}]" + msg) return ret(True, "") # catches len(x_flat) == 0 - else: - return ret(np.array_equal(x, y, equal_nan=True), ".values") + return ret(np.array_equal(x, y, equal_nan=True), ".values") def _pandas_equals_plugin(x, y, return_msg=False, deep_equals=None): @@ -189,9 +186,8 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None): else: msg = "" return ret(index_equal and values_equal, msg) - else: - return ret(x.equals(y), ".series_equals, x = {} != y = {}", [x, y]) - elif isinstance(x, pd.DataFrame): + return ret(x.equals(y), ".series_equals, x = {} != y = {}", [x, y]) + if isinstance(x, pd.DataFrame): # check column names for equality if not x.columns.equals(y.columns): return ret( @@ -209,48 +205,56 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None): # and would upset the type check, e.g., RangeIndex(2) vs Index([0, 1]) xix = x.index yix = y.index - if hasattr(xix, "dtype") and hasattr(xix, "dtype"): - if not xix.dtype == yix.dtype: + if hasattr(xix, "dtype") and hasattr(xix, "dtype") and xix.dtype != yix.dtype: + return ret( + False, + ".index.dtype, x.index.dtype = {} != y.index.dtype = {}", + [xix.dtype, yix.dtype], + ) + if ( + hasattr(xix, "dtypes") + and hasattr(yix, "dtypes") + and not x.dtypes.equals(y.dtypes) + ): + return ret( + False, + ".index.dtypes, x.dtypes = {} != y.index.dtypes = {}", + [xix.dtypes, yix.dtypes], + ) + ix_eq = xix.equals(yix) + if not ix_eq: + if len(xix) != len(yix): return ret( False, - ".index.dtype, x.index.dtype = {} != y.index.dtype = {}", - [xix.dtype, yix.dtype], + ".index.len, x.index.len = {} != y.index.len = {}", + [len(xix), len(yix)], ) - if hasattr(xix, "dtypes") and hasattr(yix, "dtypes"): - if not x.dtypes.equals(y.dtypes): + if hasattr(xix, "name") and hasattr(yix, "name") and xix.name != yix.name: return ret( False, - ".index.dtypes, x.dtypes = {} != y.index.dtypes = {}", - [xix.dtypes, yix.dtypes], + ".index.name, x.index.name = {} != y.index.name = {}", + [xix.name, yix.name], ) - ix_eq = xix.equals(yix) - if not ix_eq: - if not len(xix) == len(yix): + if ( + hasattr(xix, "names") + and hasattr(yix, "names") + and len(xix.names) != len(yix.names) + ): return ret( False, - ".index.len, x.index.len = {} != y.index.len = {}", - [len(xix), len(yix)], + ".index.names, x.index.names = {} != y.index.name = {}", + [xix.names, yix.names], + ) + if ( + hasattr(xix, "names") + and hasattr(yix, "names") + and not np.all(xix.names == yix.names) + ): + return ret( + False, + ".index.names, x.index.names = {} != y.index.name = {}", + [xix.names, yix.names], ) - if hasattr(xix, "name") and hasattr(yix, "name"): - if not xix.name == yix.name: - return ret( - False, - ".index.name, x.index.name = {} != y.index.name = {}", - [xix.name, yix.name], - ) - if hasattr(xix, "names") and hasattr(yix, "names"): - if not len(xix.names) == len(yix.names): - return ret( - False, - ".index.names, x.index.names = {} != y.index.name = {}", - [xix.names, yix.names], - ) - if not np.all(xix.names == yix.names): - return ret( - False, - ".index.names, x.index.names = {} != y.index.name = {}", - [xix.names, yix.names], - ) elts_eq = np.all(xix == yix) return ret(elts_eq, ".index.equals, x = {} != y = {}", [xix, yix]) # if columns, dtypes are equal and at least one is object, recurse over Series @@ -260,24 +264,26 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None): if not is_equal: return ret(False, f"[{c!r}]" + msg) return ret(True, "") - else: - return ret(x.equals(y), ".df_equals, x = {} != y = {}", [x, y]) - elif isinstance(x, pd.Index): - if hasattr(x, "dtype") and hasattr(y, "dtype"): - if not x.dtype == y.dtype: - return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}") - if hasattr(x, "dtypes") and hasattr(y, "dtypes"): - if not x.dtypes.equals(y.dtypes): - return ret( - False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}" - ) + return ret(x.equals(y), ".df_equals, x = {} != y = {}", [x, y]) + if isinstance(x, pd.Index): + if hasattr(x, "dtype") and hasattr(y, "dtype") and x.dtype != y.dtype: + return ret( + False, + f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}", + ) + if ( + hasattr(x, "dtypes") + and hasattr(y, "dtypes") + and not x.dtypes.equals(y.dtypes) + ): + msg = f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}" + return ret(False, msg) return ret(x.equals(y), "index.equals, x = {} != y = {}", [x, y]) - else: - raise RuntimeError( - f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)}," - f" type(y)={type(y)}, both should be one of " - "pd.Series, pd.DataFrame, pd.Index" - ) + raise RuntimeError( + f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)}," + f" type(y)={type(y)}, both should be one of " + "pd.Series, pd.DataFrame, pd.Index" + ) def _tuple_equals(x, y, return_msg=False, deep_equals=None): @@ -522,12 +528,12 @@ def deep_equals_curried(x, y, return_msg=False): if isinstance(x, (list, tuple)): dec = deep_equals_curried return ret(*_tuple_equals(x, y, return_msg=True, deep_equals=dec)) - elif isinstance(x, dict): + if isinstance(x, dict): dec = deep_equals_curried return ret(*_dict_equals(x, y, return_msg=True, deep_equals=dec)) - elif _is_npnan(x): + if _is_npnan(x): return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}") - elif isclass(x): + if isclass(x): return ret(x == y, f".class, x={x.__name__} != y={y.__name__}") if plugins is not None: @@ -589,8 +595,7 @@ def _safe_any_unequal(x, y): any_un = any(unequal) if isinstance(any_un, bool): return any_un - else: - return False + return False except Exception: return False @@ -600,8 +605,7 @@ def _safe_any_unequal(x, y): any_un = np.any(x != y) or np.any(_coerce_list(x != y)) if isinstance(any_un, bool) or any_un.dtype == "bool": return any_un - else: - return False + return False except Exception: return False diff --git a/skbase/utils/dependencies/__init__.py b/skbase/utils/dependencies/__init__.py index 11bafbdc..0cfc3eed 100644 --- a/skbase/utils/dependencies/__init__.py +++ b/skbase/utils/dependencies/__init__.py @@ -11,8 +11,8 @@ from skbase.utils.dependencies._import import _safe_import __all__ = [ + "_check_estimator_deps", "_check_python_version", "_check_soft_dependencies", - "_check_estimator_deps", "_safe_import", ] diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index ee29200a..6e85e8e0 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -221,8 +221,7 @@ def _is_version_req_satisfied(pkg_env_version, pkg_version_req): return False if pkg_version_req != SpecifierSet(""): return pkg_env_version in pkg_version_req - else: - return True + return True pkg_version_reqs = [] pkg_env_versions = [] @@ -275,7 +274,7 @@ def _is_version_req_satisfied(pkg_env_version, pkg_version_req): # now we check compatibility with the version specifier if non-empty if not any(req_sat): - zp = zip(package_req, pkg_names, pkg_env_versions, req_sat) + zp = zip(package_req, pkg_names, pkg_env_versions, req_sat, strict=False) reqs_not_satisfied = [x for x in zp if x[3] is False] actual_vers = [f"{x[1]} {x[2]}" for x in reqs_not_satisfied] pkg_env_version_str = ", ".join(actual_vers) @@ -465,10 +464,7 @@ def _check_python_version( return True # now we know that est_version is not compatible with sys_version - if isclass(obj): - class_name = obj.__name__ - else: - class_name = type(obj).__name__ + class_name = obj.__name__ if isclass(obj) else type(obj).__name__ if not isinstance(msg, str): msg = ( @@ -542,10 +538,7 @@ def _check_env_marker(obj, package=None, msg=None, severity="error"): return True # now we know that est_marker is not compatible with the environment - if isclass(obj): - class_name = obj.__name__ - else: - class_name = type(obj).__name__ + class_name = obj.__name__ if isclass(obj) else type(obj).__name__ if not isinstance(msg, str): msg = ( @@ -671,10 +664,7 @@ def _normalize_version(version): """ if version is None: return None - if not isinstance(version, Version): - version_obj = Version(version) - else: - version_obj = version + version_obj = Version(version) if not isinstance(version, Version) else version normalized_version = f"{version_obj.major}.{version_obj.minor}.{version_obj.micro}" return normalized_version @@ -727,13 +717,13 @@ def _raise_at_severity( if severity == "error": raise exception_type(msg) - elif severity == "warning": + if severity == "warning": warnings.warn(msg, category=warning_type, stacklevel=stacklevel) elif severity == "none": - return None + return else: raise ValueError( f"Error in calling {caller}, severity " f'argument must be "error", "warning", or "none", found {severity!r}.' ) - return None + return diff --git a/skbase/utils/dependencies/_import.py b/skbase/utils/dependencies/_import.py index 9ba41f2b..7162a449 100644 --- a/skbase/utils/dependencies/_import.py +++ b/skbase/utils/dependencies/_import.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """Import a module/class, return a Mock object or None if import fails.""" import importlib @@ -118,19 +119,18 @@ def _safe_import(import_path, pkg_name=None, condition=True, return_object="Magi "other than ImportError or AttributeError:" f": {e}. ", ImportWarning, + stacklevel=2, ) - pass if return_object == "MagicMock": mock_obj = _create_mock_class(obj_name) return mock_obj - elif return_object == "None": + if return_object == "None": return None - else: - raise RuntimeError( - "Error in skbase _safe_import, return_object argument must be " - f"'MagicMock' or 'None', but found {return_object}" - ) + raise RuntimeError( + "Error in skbase _safe_import, return_object argument must be " + f"'MagicMock' or 'None', but found {return_object}" + ) class CommonMagicMeta(type): diff --git a/skbase/utils/dependencies/tests/test_check_dependencies.py b/skbase/utils/dependencies/tests/test_check_dependencies.py index fd0d78cf..f384bdd8 100644 --- a/skbase/utils/dependencies/tests/test_check_dependencies.py +++ b/skbase/utils/dependencies/tests/test_check_dependencies.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Tests for _check_soft_dependencies utility.""" +from typing import ClassVar from unittest.mock import patch import pytest @@ -57,12 +58,12 @@ def test_check_soft_deps(): def test_check_soft_dependencies_nested(): """Test check_soft_dependencies with .""" - ALWAYS_INSTALLED = "pytest" # noqa: N806 - ALWAYS_INSTALLED2 = "numpy" # noqa: N806 - ALWAYS_INSTALLED_W_V = "pytest>=0.5.0" # noqa: N806 - ALWAYS_INSTALLED_W_V2 = "numpy>=0.1.0" # noqa: N806 - NEVER_INSTALLED = "nonexistent__package_foo_bar" # noqa: N806 - NEVER_INSTALLED_W_V = "pytest<0.1.0" # noqa: N806 + ALWAYS_INSTALLED = "pytest" + ALWAYS_INSTALLED2 = "numpy" + ALWAYS_INSTALLED_W_V = "pytest>=0.5.0" + ALWAYS_INSTALLED_W_V2 = "numpy>=0.1.0" + NEVER_INSTALLED = "nonexistent__package_foo_bar" + NEVER_INSTALLED_W_V = "pytest<0.1.0" # Test that the function does not raise an error when all dependencies are installed _check_soft_dependencies(ALWAYS_INSTALLED) @@ -155,7 +156,7 @@ def test_check_python_version( mock_sys.version = "3.8.1" class DummyObjectClass(BaseObject): - _tags = { + _tags: ClassVar[dict] = { "python_version": ">=3.7.1", # PEP 440 version specifier, e.g., ">=3.7" "python_dependencies": None, # PEP 440 dependency strs, e.g., "pandas>=1.0" "env_marker": None, # PEP 508 environment marker, e.g., "os_name=='posix'" diff --git a/skbase/utils/dependencies/tests/test_safe_import.py b/skbase/utils/dependencies/tests/test_safe_import.py index 152fc2d6..7004f4bf 100644 --- a/skbase/utils/dependencies/tests/test_safe_import.py +++ b/skbase/utils/dependencies/tests/test_safe_import.py @@ -64,8 +64,8 @@ def test_import_existing_object(): def test_multiple_inheritance_from_mock(): """Test multiple inheritance from dynamic MagicMock.""" - Class1 = _safe_import("foobar.foo.FooBar") # noqa: N806 - Class2 = _safe_import("barfoobar.BarFooBar") # noqa: N806 + Class1 = _safe_import("foobar.foo.FooBar") + Class2 = _safe_import("barfoobar.BarFooBar") class NewClass(Class1, Class2): """This should not trigger an error. @@ -75,8 +75,6 @@ class NewClass(Class1, Class2): identical to MagicMock. """ - pass - def test_soft_dependency_chains(): """Test soft dependency chains. diff --git a/skbase/utils/git_diff.py b/skbase/utils/git_diff.py index f7ab1538..8b28c7bc 100644 --- a/skbase/utils/git_diff.py +++ b/skbase/utils/git_diff.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """Git related utilities to identify changed modules.""" __author__ = ["fkiraly"] @@ -6,7 +7,6 @@ import importlib.util import inspect from functools import lru_cache -from typing import List @lru_cache @@ -53,7 +53,7 @@ def _get_path_from_module(module_str): raise ImportError(f"Error finding module {module_str!r}") from e -def _run_git_diff(cmd: List[str]) -> str: +def _run_git_diff(cmd: list[str]) -> str: # Safety note: cmd is always a hard-coded list constructed in this module only. # No user input is ever injected → safe from shell injection. result = __import__("subprocess").run( # nosec B404, B603 @@ -165,9 +165,9 @@ def _get_packages_with_changed_specs(): packages = [] for line in changed_lines: - if line.find("'") > line.find('"') and line.find('"') != -1: - sep = '"' - elif line.find("'") == -1: + if (line.find("'") > line.find('"') and line.find('"') != -1) or line.find( + "'" + ) == -1: sep = '"' else: sep = "'" diff --git a/skbase/utils/random_state.py b/skbase/utils/random_state.py index c655eb50..3faa7335 100644 --- a/skbase/utils/random_state.py +++ b/skbase/utils/random_state.py @@ -58,7 +58,7 @@ def set_random_state(estimator, random_state=None, deep=True, root_policy="copy" keys.append(key) seeds = sample_dependent_seed(random_state, n_seeds=len(keys)) - to_set = dict(zip(keys, seeds)) + to_set = dict(zip(keys, seeds, strict=False)) if root_policy == "copy" and "random_state" in to_set: to_set["random_state"] = random_state_orig diff --git a/skbase/utils/stderr_mute.py b/skbase/utils/stderr_mute.py index bab4b119..5626809d 100644 --- a/skbase/utils/stderr_mute.py +++ b/skbase/utils/stderr_mute.py @@ -35,7 +35,7 @@ def __enter__(self): self._stderr = sys.stderr sys.stderr = io.StringIO() - def __exit__(self, type, value, traceback): # noqa: A002 + def __exit__(self, type, value, traceback): """Context manager exit point.""" # restore stderr if active # if not active, nothing needs to be done, since stderr was not replaced @@ -49,7 +49,7 @@ def __exit__(self, type, value, traceback): # noqa: A002 # return statement not needed as type was None, but included for clarity return True - def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002 + def _handle_exit_exceptions(self, type, value, traceback): """Handle exceptions raised during __exit__. Parameters diff --git a/skbase/utils/stdout_mute.py b/skbase/utils/stdout_mute.py index 6783b22e..02254d62 100644 --- a/skbase/utils/stdout_mute.py +++ b/skbase/utils/stdout_mute.py @@ -35,7 +35,7 @@ def __enter__(self): self._stdout = sys.stdout sys.stdout = io.StringIO() - def __exit__(self, type, value, traceback): # noqa: A002 + def __exit__(self, type, value, traceback): """Context manager exit point.""" # restore stdout if active # if not active, nothing needs to be done, since stdout was not replaced @@ -49,7 +49,7 @@ def __exit__(self, type, value, traceback): # noqa: A002 # return statement not needed as type was None, but included for clarity return True - def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002 + def _handle_exit_exceptions(self, type, value, traceback): """Handle exceptions raised during __exit__. Parameters diff --git a/skbase/utils/tests/test_deep_equals.py b/skbase/utils/tests/test_deep_equals.py index 9ab3d168..7d3f339c 100644 --- a/skbase/utils/tests/test_deep_equals.py +++ b/skbase/utils/tests/test_deep_equals.py @@ -12,7 +12,7 @@ EXAMPLES = [ 42, [], - ((((())))), + (()), [[[[()]]]], 3.5, 4.2, @@ -122,13 +122,11 @@ def copy_except_if_sklearn(obj): """ if not _check_soft_dependencies("scikit-learn", severity="none"): return deepcopy(obj) - else: - from sklearn.base import BaseEstimator + from sklearn.base import BaseEstimator - if isinstance(obj, BaseEstimator): - return obj - else: - return deepcopy(obj) + if isinstance(obj, BaseEstimator): + return obj + return deepcopy(obj) # Add JAX examples diff --git a/skbase/utils/tests/test_iter.py b/skbase/utils/tests/test_iter.py index 329d914e..850dace7 100644 --- a/skbase/utils/tests/test_iter.py +++ b/skbase/utils/tests/test_iter.py @@ -75,8 +75,8 @@ def test_format_seq_to_str(): def test_format_seq_to_str_raises(): """Test _format_seq_to_str raises error when input is unexpected type.""" - with pytest.raises(TypeError, match="`seq` must be a sequence or scalar.*"): - _format_seq_to_str((c for c in [1, 2, 3])) + with pytest.raises(TypeError, match=r"`seq` must be a sequence or scalar.*"): + _format_seq_to_str(c for c in [1, 2, 3]) def test_scalar_to_seq_expected_output(): @@ -97,13 +97,13 @@ def test_scalar_to_seq_raises(): """Test scalar_to_seq raises error when `sequence_type` is unexpected type.""" with pytest.raises( ValueError, - match="`sequence_type` must be a subclass of collections.abc.Sequence.", + match=r"`sequence_type` must be a subclass of collections.abc.Sequence.", ): _scalar_to_seq(7, sequence_type=int) with pytest.raises( ValueError, - match="`sequence_type` must be a subclass of collections.abc.Sequence.", + match=r"`sequence_type` must be a subclass of collections.abc.Sequence.", ): _scalar_to_seq(7, sequence_type=dict) diff --git a/skbase/utils/tests/test_nested_iter.py b/skbase/utils/tests/test_nested_iter.py index 4ff89875..91058530 100644 --- a/skbase/utils/tests/test_nested_iter.py +++ b/skbase/utils/tests/test_nested_iter.py @@ -70,7 +70,7 @@ def test_unflat_len(): assert unflat_len((1, 2)) == 2 assert unflat_len([1, (2, 3), 4, 5]) == 5 assert unflat_len([1, 2, (c for c in (2, 3, 4))]) == 5 - assert unflat_len((c for c in [1, 2, (c for c in (2, 3, 4))])) == 5 + assert unflat_len(c for c in [1, 2, (c for c in (2, 3, 4))]) == 5 def test_is_flat(): @@ -78,8 +78,8 @@ def test_is_flat(): assert is_flat([1, 2, 3, 4, 5]) is True assert is_flat([1, (2, 3), 4, 5]) is False # Check with flat generator - assert is_flat((c for c in [1, 2, 3])) is True + assert is_flat(c for c in [1, 2, 3]) is True # Check with nested generator assert is_flat([1, 2, (c for c in (2, 3, 4))]) is False # Check with generator nested in a generator - assert is_flat((c for c in [1, 2, (c for c in (2, 3, 4))])) is False + assert is_flat(c for c in [1, 2, (c for c in (2, 3, 4))]) is False diff --git a/skbase/utils/tests/test_random_state.py b/skbase/utils/tests/test_random_state.py index 844bc6de..f1adc7be 100644 --- a/skbase/utils/tests/test_random_state.py +++ b/skbase/utils/tests/test_random_state.py @@ -36,10 +36,7 @@ def set_seed(obj): return set_random_state( obj, random_state=42, deep=deep, root_policy=root_policy ) - else: - return obj.set_random_state( - random_state=42, deep=deep, self_policy=root_policy - ) + return obj.set_random_state(random_state=42, deep=deep, self_policy=root_policy) class DummyDummy(BaseObject): """Has no random_state attribute.""" @@ -47,7 +44,7 @@ class DummyDummy(BaseObject): def __init__(self, foo): self.foo = foo - super(DummyDummy, self).__init__() + super().__init__() class SeedCompositionDummy(BaseObject): """Potentially composite object, for testing.""" @@ -56,7 +53,7 @@ def __init__(self, foo, random_state=None): self.foo = foo self.random_state = random_state - super(SeedCompositionDummy, self).__init__() + super().__init__() simple = SeedCompositionDummy(foo=1, random_state=41) seedless = DummyDummy(foo=42) diff --git a/skbase/utils/tests/test_std_mute.py b/skbase/utils/tests/test_std_mute.py index b248a35b..5665f7f3 100644 --- a/skbase/utils/tests/test_std_mute.py +++ b/skbase/utils/tests/test_std_mute.py @@ -23,10 +23,14 @@ def test_std_mute(mute, expected): stdout_io = io.StringIO() try: - with redirect_stderr(stderr_io), redirect_stdout(stdout_io): - with StderrMute(mute), StdoutMute(mute): - sys.stdout.write("test stdout") - sys.stderr.write("test sterr") - 1 / 0 + with ( + redirect_stderr(stderr_io), + redirect_stdout(stdout_io), + StderrMute(mute), + StdoutMute(mute), + ): + sys.stdout.write("test stdout") + sys.stderr.write("test sterr") + _ = 1 / 0 except ZeroDivisionError: assert expected == [stdout_io.getvalue(), stderr_io.getvalue()] diff --git a/skbase/validate/__init__.py b/skbase/validate/__init__.py index 97d5b7a5..d5989610 100644 --- a/skbase/validate/__init__.py +++ b/skbase/validate/__init__.py @@ -3,8 +3,6 @@ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file) """Tools for validating and comparing BaseObjects and collections of BaseObjects.""" -from typing import List - from skbase.validate._named_objects import ( check_sequence_named_objects, is_named_object_tuple, @@ -12,8 +10,8 @@ ) from skbase.validate._types import check_sequence, check_type, is_sequence -__author__: List[str] = ["RNKuhns", "fkiraly"] -__all__: List[str] = [ +__author__: list[str] = ["RNKuhns", "fkiraly"] +__all__: list[str] = [ "check_sequence", "check_sequence_named_objects", "check_type", diff --git a/skbase/validate/_named_objects.py b/skbase/validate/_named_objects.py index 488a832f..3ed76031 100644 --- a/skbase/validate/_named_objects.py +++ b/skbase/validate/_named_objects.py @@ -84,9 +84,7 @@ def is_named_object_tuple(obj, object_type=None): object_type = BaseObject if not isinstance(obj, tuple) or len(obj) != 2: return False - if not isinstance(obj[0], str) or not isinstance(obj[1], object_type): - return False - return True + return isinstance(obj[0], str) and isinstance(obj[1], object_type) def is_sequence_named_objects( diff --git a/skbase/validate/_types.py b/skbase/validate/_types.py index 21df51b5..a62568aa 100644 --- a/skbase/validate/_types.py +++ b/skbase/validate/_types.py @@ -5,19 +5,20 @@ import collections import inspect -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any from skbase.utils._iter import _format_seq_to_str, _remove_type_text, _scalar_to_seq -__author__: List[str] = ["RNKuhns", "fkiraly"] -__all__: List[str] = ["check_sequence", "check_type", "is_sequence"] +__author__: list[str] = ["RNKuhns", "fkiraly"] +__all__: list[str] = ["check_sequence", "check_type", "is_sequence"] def check_type( input_: Any, expected_type: type, allow_none: bool = False, - input_name: Optional[str] = None, + input_name: str | None = None, use_subclass: bool = False, ) -> Any: """Check the input is the expected type. @@ -90,30 +91,26 @@ def check_type( type_check = issubclass if use_subclass else isinstance if (allow_none and input_ is None) or type_check(input_, expected_type): return input_ - else: - chk_msg = "subclass type" if use_subclass else "be type" - expected_type_str = _remove_type_text(expected_type) - input_type_str = _remove_type_text(type(input_)) - if allow_none: - type_msg = f"{expected_type_str} or None" - else: - type_msg = f"{expected_type_str}" - raise TypeError( - f"`{input_name}` should {chk_msg} {type_msg}, but found {input_type_str}." - ) + chk_msg = "subclass type" if use_subclass else "be type" + expected_type_str = _remove_type_text(expected_type) + input_type_str = _remove_type_text(type(input_)) + type_msg = f"{expected_type_str} or None" if allow_none else f"{expected_type_str}" + raise TypeError( + f"`{input_name}` should {chk_msg} {type_msg}, but found {input_type_str}." + ) def _convert_scalar_seq_type_input_to_tuple( - type_input: Optional[Union[type, Tuple[type, ...]]], - none_default: Optional[type] = None, - type_input_subclass: Optional[type] = None, - input_name: str = None, -) -> Tuple[type, ...]: + type_input: type | tuple[type, ...] | None, + none_default: type | None = None, + type_input_subclass: type | None = None, + input_name: str | None = None, +) -> tuple[type, ...]: """Convert input that is scalar or sequence of types to always be a tuple.""" if none_default is None: none_default = collections.abc.Sequence - seq_output: Tuple[type, ...] + seq_output: tuple[type, ...] if type_input is None: seq_output = (none_default,) # if a sequence of types received as sequence_type, convert to tuple of types @@ -134,8 +131,8 @@ def _convert_scalar_seq_type_input_to_tuple( def is_sequence( input_seq: Any, - sequence_type: Optional[Union[type, Tuple[type, ...]]] = None, - element_type: Optional[Union[type, Tuple[type, ...]]] = None, + sequence_type: type | tuple[type, ...] | None = None, + element_type: type | tuple[type, ...] | None = None, ) -> bool: """Indicate if an object is a sequence with optional check of element types. @@ -223,11 +220,11 @@ def is_sequence( def check_sequence( input_seq: Sequence[Any], - sequence_type: Optional[Union[type, Tuple[type, ...]]] = None, - element_type: Optional[Union[type, Tuple[type, ...]]] = None, - coerce_output_type_to: type = None, + sequence_type: type | tuple[type, ...] | None = None, + element_type: type | tuple[type, ...] | None = None, + coerce_output_type_to: type | None = None, coerce_scalar_input: bool = False, - sequence_name: str = None, + sequence_name: str | None = None, ) -> Sequence[Any]: """Check whether an object is a sequence with optional check of element types. diff --git a/skbase/validate/tests/test_type_validations.py b/skbase/validate/tests/test_type_validations.py index 0e35b464..260346d8 100644 --- a/skbase/validate/tests/test_type_validations.py +++ b/skbase/validate/tests/test_type_validations.py @@ -65,7 +65,7 @@ def test_check_type_output(fixture_estimator_instance, fixture_object_instance): with pytest.raises(TypeError, match=r"`input` should be type.*"): check_type(BaseEstimator, expected_type=BaseObject) - with pytest.raises(TypeError, match="^`input` should be.*"): + with pytest.raises(TypeError, match=r"^`input` should be.*"): check_type("something", expected_type=int, allow_none=True) # Verify optional use of issubclass instead of isinstance @@ -80,13 +80,13 @@ def test_check_type_raises_error_if_expected_type_is_wrong_format(): `expected_type` should be a type or tuple of types. """ - with pytest.raises(TypeError, match="^`expected_type` should be.*"): + with pytest.raises(TypeError, match=r"^`expected_type` should be.*"): check_type(7, expected_type=11) - with pytest.raises(TypeError, match="^`expected_type` should be.*"): + with pytest.raises(TypeError, match=r"^`expected_type` should be.*"): check_type(7, expected_type=[int]) - with pytest.raises(TypeError, match="^`expected_type` should be.*"): + with pytest.raises(TypeError, match=r"^`expected_type` should be.*"): check_type(None, expected_type=[int]) @@ -100,7 +100,7 @@ def test_is_sequence_output(): # True for any sequence assert is_sequence([1, 2, 3]) is True # But false for generators, since they are iterable but not sequences - assert is_sequence((c for c in [1, 2, 3])) is False + assert is_sequence(c for c in [1, 2, 3]) is False # Test use of sequence_type restriction assert is_sequence([1, 2, 3, 4], sequence_type=list) is True @@ -171,20 +171,20 @@ def test_check_sequence_output(): # But false for generators, since they are iterable but not sequences with pytest.raises( TypeError, - match="Invalid sequence: Input sequence expected to be a a sequence.", + match=r"Invalid sequence: Input sequence expected to be a a sequence.", ): - assert check_sequence((c for c in [1, 2, 3])) + assert check_sequence(c for c in [1, 2, 3]) # Test use of sequence_type restriction assert check_sequence([1, 2, 3, 4], sequence_type=list) == [1, 2, 3, 4] with pytest.raises( TypeError, - match="Invalid sequence: Input sequence expected to be a tuple.", + match=r"Invalid sequence: Input sequence expected to be a tuple.", ): check_sequence([1, 2, 3, 4], sequence_type=tuple) with pytest.raises( TypeError, - match="Invalid sequence: Input sequence expected to be a list.", + match=r"Invalid sequence: Input sequence expected to be a list.", ): check_sequence((1, 2, 3, 4), sequence_type=list) assert check_sequence((1, 2, 3, 4), sequence_type=tuple) == (1, 2, 3, 4) @@ -203,22 +203,22 @@ def test_check_sequence_output(): with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence([1, 2, 3], element_type=float) with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence([1, 2, 3, 4], sequence_type=tuple, element_type=int) with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence([1, 2, 3, 4], sequence_type=list, element_type=float) with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence([1, 2, 3, 4], sequence_type=tuple, element_type=float) @@ -231,7 +231,7 @@ def test_check_sequence_output(): assert check_sequence([1, "something", 4.5]) == [1, "something", 4.5] with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence([1, "something", 4.5], element_type=float) @@ -242,7 +242,7 @@ def test_check_sequence_output(): # Test with 3rd party types works in default way via exact type with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence([1.2, 4.7], element_type=np.float64) input_seq = [np.float64(1.2), np.float64(4.7)] @@ -257,7 +257,7 @@ def test_check_sequence_output(): ] with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence([np.nan, 4], element_type=int) @@ -300,7 +300,7 @@ def test_check_sequence_scalar_input_coercion(): # Still raise an error if element type is not expected with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence( 7, @@ -323,12 +323,12 @@ def test_check_sequence_with_seq_of_class_and_instance_input( with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): check_sequence(list(input_seq), sequence_type=tuple, element_type=BaseObject) with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): # Verify we detect when list elements are not instances of valid class type check_sequence([1, 2, 3], element_type=BaseObject) @@ -341,7 +341,7 @@ def test_check_sequence_with_seq_of_class_and_instance_input( ) == list(input_seq) with pytest.raises( TypeError, - match="Invalid sequence: .*", + match=r"Invalid sequence: .*", ): # Verify we detect when list elements are not instances of valid types check_sequence([1, 2, 3], element_type=BaseObject)