diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9215b35f..4bb05ee0 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -18,17 +18,12 @@ repos:
- id: name-tests-test
args: ["--django"]
- - repo: https://github.com/pycqa/isort
- rev: 8.0.1
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.13.1
hooks:
- - id: isort
- name: isort
-
- - repo: https://github.com/psf/black-pre-commit-mirror
- rev: 26.3.1
- hooks:
- - id: black
- language_version: python3
+ - id: ruff-format
+ - id: ruff-check
+ args: [--fix]
- repo: https://github.com/asottile/blacken-docs
rev: 1.20.0
@@ -36,23 +31,6 @@ repos:
- id: blacken-docs
additional_dependencies: [black==22.3.0]
- - repo: https://github.com/pycqa/flake8
- rev: 7.3.0
- hooks:
- - id: flake8
- exclude: docs/source/conf.py, __pycache__
- additional_dependencies:
- [
- flake8-bugbear,
- flake8-builtins,
- flake8-quotes>=3.3.2,
- flake8-comprehensions,
- pandas-vet,
- flake8-print,
- pep8-naming,
- doc8,
- ]
-
- repo: https://github.com/pycqa/pydocstyle
rev: 6.3.0
hooks:
@@ -62,15 +40,8 @@ repos:
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.9.1
hooks:
- - id: nbqa-isort
- args: [--nbqa-mutate, --nbqa-dont-skip-bad-cells]
- additional_dependencies: [isort==5.6.4]
- - id: nbqa-black
- args: [--nbqa-mutate, --nbqa-dont-skip-bad-cells]
- additional_dependencies: [black>=22.3.0]
- - id: nbqa-flake8
- args: [--nbqa-dont-skip-bad-cells, "--extend-ignore=E402,E203"]
- additional_dependencies: [flake8==3.8.3]
+ - id: nbqa-ruff
+ args: [--nbqa-mutate]
- repo: https://github.com/PyCQA/bandit
rev: 1.9.4
diff --git a/docs/source/conf.py b/docs/source/conf.py
index c611f98e..d3f154af 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Configure skbase Sphinx documentation."""
@@ -298,7 +297,7 @@ def find_source():
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
- "python": ("https://docs.python.org/{.major}".format(sys.version_info), None),
+ "python": (f"https://docs.python.org/{sys.version_info.major}", None),
"scikit-learn": ("https://scikit-learn.org/stable/", None),
"sktime": ("https://www.sktime.net/en/stable/", None),
}
diff --git a/pyproject.toml b/pyproject.toml
index 17c90e45..c5c6d5ff 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,18 +47,9 @@ dev = [
linters = [
"mypy",
- "isort",
- "flake8",
- "black",
+ "ruff",
"pydocstyle",
"nbqa",
- "flake8-bugbear",
- "flake8-builtins",
- "flake8-quotes",
- "flake8-comprehensions",
- "pandas-vet",
- "flake8-print",
- "pep8-naming",
"doc8",
]
@@ -112,16 +103,6 @@ addopts = [
"--cov-report=html",
]
-[tool.isort]
-profile = "black"
-src_paths = ["skbase/*"]
-multi_line_output = 3
-known_first_party = ["skbase"]
-
-[tool.black]
-line-length = 88
-extend-exclude = "^/setup.py docs/conf.py"
-
[tool.pydocstyle]
convention = "numpy"
@@ -137,6 +118,79 @@ max-line-length = 88
ignore = ["D004"]
ignore_path = ["docs/_build", "docs/source/api_reference/auto_generated"]
+[tool.ruff]
+line-length = 88
+extend-exclude = ["setup.py", "docs/source/conf.py"]
+target-version = "py310"
+extend-include = ["*.ipynb"]
+
+[tool.ruff.lint]
+select = [
+ "D", # pydocstyle
+ "E", # pycodestyle errors
+ "W", # pycodestyle warnings
+ "F", # pyflakes
+ "I", # isort
+ "UP", # pyupgrade
+ "B", # flake8-bugbear
+ "C4", # flake8-comprehensions
+ "PIE", # flake8-pie
+ "T20", # flake8-print
+ "RET", # flake8-return
+ "SIM", # flake8-simplify
+ "TID", # flake8-tidy-imports
+ "TCH", # flake8-type-checking
+ "RUF", # Ruff-specific rules
+]
+ignore = [
+ "E203", # Whitespace before punctuation
+ "E402", # Module import not at top of file
+ "E501", # Line too long
+ "E731", # Do not assign a lambda expression, use a def
+ "RET504", # Unnecessary variable assignment before `return` statement
+ "S101", # Use of `assert` detected
+ "C408", # Unnecessary dict call - rewrite as a literal
+ "UP031", # Use format specifier instead of %
+ "UP009", # UTF-8 encoding declaration is unnecessary
+ "S102", # Use of exec
+ "C414", # Unnecessary `list` call within `sorted()`
+ "S301", # pickle and modules that wrap it can be unsafe
+ "C416", # Unnecessary list comprehension - rewrite as a generator
+ "S310", # Audit URL open for permitted schemes
+ "S202", # Uses of `tarfile.extractall()`
+ "S307", # Use of possibly insecure function
+ "C417", # Unnecessary `map` usage (rewrite using a generator expression)
+ "S605", # Starting a process with a shell, possible injection detected
+ "E741", # Ambiguous variable name
+ "S107", # Possible hardcoded password
+ "S105", # Possible hardcoded password
+ "PT018", # Checks for assertions that combine multiple independent condition
+ "S602", # sub process call with shell=True unsafe
+ "C419", # Unnecessary list comprehension, some are flagged yet are not
+ "C409", # Unnecessary `list` literal passed to `tuple()` (rewrite as a `tuple` literal)
+ "S113", # Probable use of httpx call without timeout
+]
+allowed-confusables = ["σ"]
+
+[tool.ruff.lint.per-file-ignores]
+"setup.py" = ["S101"]
+"**/__init__.py" = [
+ "F401", # unused import
+]
+"**/tests/**" = [
+ "D", # docstring
+ "S605", # Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
+ "S607", # Starting a process with a partial executable path
+ "RET504", # Unnecessary variable assignment before `return` statement
+ "PT004", # Fixture `tmpdir_unittest_fixture` does not return anything, add leading underscore
+ "PT011", # `pytest.raises(ValueError)` is too broad, set the `match` parameter or use a more specific exception
+ "PT012", # `pytest.raises()` block should contain a single simple statement
+ "PT019", # Fixture `_` without value is injected as parameter, use `@pytest.mark.usefixtures` instead
+]
+
+[tool.ruff.lint.pydocstyle]
+convention = "numpy"
+
[tool.bandit]
exclude_dirs = ["*/tests/*", "*/testing/*"]
diff --git a/skbase/_exceptions.py b/skbase/_exceptions.py
index 50953aaa..d4de229d 100644
--- a/skbase/_exceptions.py
+++ b/skbase/_exceptions.py
@@ -5,20 +5,18 @@
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
"""Custom exceptions used in ``skbase``."""
-from typing import List
-
-__author__: List[str] = ["fkiraly", "mloning", "rnkuhns"]
-__all__: List[str] = ["FixtureGenerationError", "NotFittedError"]
+__author__: list[str] = ["fkiraly", "mloning", "rnkuhns"]
+__all__: list[str] = ["FixtureGenerationError", "NotFittedError"]
class FixtureGenerationError(Exception):
"""Raised when a fixture fails to generate."""
- def __init__(self, fixture_name="", err=None): # noqa: B042
+ def __init__(self, fixture_name="", err=None):
self.fixture_name = fixture_name
self.err = err
msg = f"fixture {fixture_name} failed to generate. {err}"
- super().__init__(msg) # noqa: B042
+ super().__init__(msg)
class NotFittedError(ValueError, AttributeError):
diff --git a/skbase/base/__init__.py b/skbase/base/__init__.py
index 6c80e99f..18b2444e 100644
--- a/skbase/base/__init__.py
+++ b/skbase/base/__init__.py
@@ -7,8 +7,6 @@
sktime design principles in your project.
"""
-from typing import List
-
from skbase.base._base import BaseEstimator, BaseObject
from skbase.base._meta import (
BaseMetaEstimator,
@@ -17,12 +15,12 @@
BaseMetaObjectMixin,
)
-__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
-__all__: List[str] = [
- "BaseObject",
+__author__: list[str] = ["mloning", "RNKuhns", "fkiraly"]
+__all__: list[str] = [
"BaseEstimator",
"BaseMetaEstimator",
- "BaseMetaObject",
"BaseMetaEstimatorMixin",
+ "BaseMetaObject",
"BaseMetaObjectMixin",
+ "BaseObject",
]
diff --git a/skbase/base/_base.py b/skbase/base/_base.py
index c6a0d16a..3f714708 100644
--- a/skbase/base/_base.py
+++ b/skbase/base/_base.py
@@ -58,15 +58,15 @@ class name: BaseEstimator
import warnings
from collections import defaultdict
from copy import deepcopy
-from typing import List
+from typing import ClassVar
from skbase._exceptions import NotFittedError
from skbase.base._clone_base import _check_clone, _clone
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager
-__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
-__all__: List[str] = ["BaseEstimator", "BaseObject"]
+__author__: list[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
+__all__: list[str] = ["BaseEstimator", "BaseObject"]
class BaseObject(_FlagManager):
@@ -75,7 +75,7 @@ class BaseObject(_FlagManager):
Extends scikit-learn's BaseEstimator to include sktime style interface for tags.
"""
- _config = {
+ _config: ClassVar[dict] = {
"display": "diagram",
"print_changed_only": True,
"check_clone": False, # whether to execute validity checks in clone
@@ -86,7 +86,7 @@ def __init__(self):
"""Construct BaseObject."""
self._init_flags(flag_attr_name="_tags")
self._init_flags(flag_attr_name="_config")
- super(BaseObject, self).__init__()
+ super().__init__()
def __eq__(self, other):
"""Equality dunder. Checks equal class and parameters.
@@ -210,7 +210,7 @@ def _get_clone_plugins(cls):
in ``skbase.base._clone_plugins``, and implement
the methods ``_check`` and ``_clone``.
"""
- return None
+ return
@classmethod
def _get_init_signature(cls):
@@ -449,10 +449,10 @@ def _is_suffix(x, y):
def _get_alias(x, d):
"""Return alias of x in d."""
# if key is in valid_params, key is replaced by key (itself)
- if any(x == y for y in d.keys()):
+ if x in d:
return x
- suff_list = [y for y in d.keys() if _is_suffix(x, y)]
+ suff_list = [y for y in d if _is_suffix(x, y)]
# if key is a __ suffix of exactly one key in valid_params,
# it is replaced by that key
@@ -468,7 +468,7 @@ def _get_alias(x, d):
# if ns == 1
return suff_list[0]
- alias_dict = {_get_alias(x, valid_params): d[x] for x in d.keys()}
+ alias_dict = {_get_alias(x, valid_params): d[x] for x in d}
return alias_dict
@@ -1222,18 +1222,18 @@ class TagAliaserMixin:
# dictionary of aliases
# key = old tag; value = new tag, aliased by old tag
# override this in a child class
- alias_dict = {"old_tag": "new_tag", "tag_to_remove": ""}
+ alias_dict: ClassVar[dict] = {"old_tag": "new_tag", "tag_to_remove": ""}
# dictionary of removal version
# key = old tag; value = version in which tag will be removed, as string
- deprecate_dict = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"}
+ deprecate_dict: ClassVar[dict] = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"}
# package name used for deprecation warnings
_package_name = ""
def __init__(self):
"""Construct TagAliaserMixin."""
- super(TagAliaserMixin, self).__init__()
+ super().__init__()
@classmethod
def get_class_tags(cls):
@@ -1277,7 +1277,7 @@ def get_class_tags(cls):
class attribute via nested inheritance. NOT overridden by dynamic
tags set by ``set_tags`` or ``clone_tags``.
"""
- collected_tags = super(TagAliaserMixin, cls).get_class_tags()
+ collected_tags = super().get_class_tags()
cls._deprecate_tag_warn(collected_tags)
collected_tags = cls._complete_dict(collected_tags)
return collected_tags
@@ -1333,7 +1333,7 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
old_tag_name = tag_name
new_tag_name = alias_dict[old_tag_name]
if tag_name in alias_dict.values():
- old_tag_name = [k for k, v in alias_dict.items() if v == tag_name][0]
+ old_tag_name = next(k for k, v in alias_dict.items() if v == tag_name)
new_tag_name = tag_name
tag_changed = new_tag_name != old_tag_name
@@ -1353,7 +1353,7 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
return old_tag_val
# case 2: old tag was queried, but old tag not present
# then: return value of new tag
- elif old_tag_queried:
+ if old_tag_queried:
return cls._get_class_flag(
new_tag_name,
tag_value_default,
@@ -1400,7 +1400,7 @@ def get_tags(self):
class attribute via nested inheritance and then any overrides
and new tags from ``_tags_dynamic`` object attribute.
"""
- collected_tags = super(TagAliaserMixin, self).get_tags()
+ collected_tags = super().get_tags()
self._deprecate_tag_warn(collected_tags)
collected_tags = self._complete_dict(collected_tags)
return collected_tags
@@ -1458,7 +1458,7 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
old_tag_name = tag_name
new_tag_name = alias_dict[old_tag_name]
if tag_name in alias_dict.values():
- old_tag_name = [k for k, v in alias_dict.items() if v == tag_name][0]
+ old_tag_name = next(k for k, v in alias_dict.items() if v == tag_name)
new_tag_name = tag_name
tag_changed = new_tag_name != old_tag_name
@@ -1479,7 +1479,7 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
return old_tag_val
# case 2: old tag was queried, but old tag not present
# then: return value of new tag
- elif old_tag_queried:
+ if old_tag_queried:
return self._get_flag(
new_tag_name,
tag_value_default,
@@ -1562,8 +1562,7 @@ def _complete_dict(cls, tag_dict, direction="both"):
for old_tag in alias_dict:
cls._translate_tags(new_tag_dict, tag_dict, old_tag, direction)
return new_tag_dict
- else:
- return tag_dict
+ return tag_dict
@classmethod
def _deprecate_tag_warn(cls, tags):
@@ -1578,7 +1577,7 @@ def _deprecate_tag_warn(cls, tags):
DeprecationWarning for each tag in tags that is aliased by cls.alias_dict
"""
for tag_name in tags:
- if tag_name in cls.alias_dict.keys():
+ if tag_name in cls.alias_dict:
version = cls.deprecate_dict[tag_name]
new_tag = cls.alias_dict[tag_name]
pkg_name = cls._package_name
@@ -1656,8 +1655,7 @@ def is_fitted(self):
"""
if hasattr(self, "_is_fitted"):
return self._is_fitted
- else:
- return False
+ return False
def check_is_fitted(self, method_name=None):
"""Check if the estimator has been fitted.
@@ -1818,8 +1816,7 @@ def getattr_safe(obj, attr):
if hasattr(obj, attr):
attr = getattr(obj, attr)
return attr, True
- else:
- return None, False
+ return None, False
except Exception:
return None, False
diff --git a/skbase/base/_clone_base.py b/skbase/base/_clone_base.py
index d3ed05f3..5cd759a2 100644
--- a/skbase/base/_clone_base.py
+++ b/skbase/base/_clone_base.py
@@ -20,7 +20,7 @@
* clone(obj) -> type(obj) - method to clone obj
"""
-__all__ = ["_clone", "_check_clone"]
+__all__ = ["_check_clone", "_clone"]
from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS
@@ -110,7 +110,7 @@ def _check_clone(original, clone):
self_params = original.get_params(deep=False)
# check that all attributes are written to the clone
- for attrname in self_params.keys():
+ for attrname in self_params:
if not hasattr(clone, attrname):
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
@@ -118,7 +118,7 @@ def _check_clone(original, clone):
f"Please check __init__ of {original}."
)
- clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
+ clone_attrs = {attr: getattr(clone, attr) for attr in self_params}
# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
diff --git a/skbase/base/_clone_plugins.py b/skbase/base/_clone_plugins.py
index c4165ae9..3885433b 100644
--- a/skbase/base/_clone_plugins.py
+++ b/skbase/base/_clone_plugins.py
@@ -15,13 +15,13 @@
* clone(obj) -> type(obj) - method to clone obj
"""
-from functools import lru_cache
+from functools import cache
from inspect import isclass
# imports wrapped in functions to avoid exceptions on skbase init
# wrapped in _safe_import to avoid exceptions on skbase init
-@lru_cache(maxsize=None)
+@cache
def _is_sklearn_present():
"""Check whether scikit-learn is present."""
from skbase.utils.dependencies import _check_soft_dependencies
@@ -29,7 +29,7 @@ def _is_sklearn_present():
return _check_soft_dependencies("scikit-learn")
-@lru_cache(maxsize=None)
+@cache
def _get_sklearn_clone():
"""Get sklearn's clone function."""
from skbase.utils.dependencies._import import _safe_import
@@ -196,13 +196,12 @@ def _clone(self, obj):
if not self.safe:
return deepcopy(obj)
- else:
- raise TypeError(
- "Cannot clone object '%s' (type %s): "
- "it does not seem to be a scikit-base object or scikit-learn "
- "estimator, as it does not implement a "
- "'get_params' method." % (repr(obj), type(obj))
- )
+ raise TypeError(
+ "Cannot clone object '%s' (type %s): "
+ "it does not seem to be a scikit-base object or scikit-learn "
+ "estimator, as it does not implement a "
+ "'get_params' method." % (repr(obj), type(obj))
+ )
DEFAULT_CLONE_PLUGINS = [
diff --git a/skbase/base/_meta.py b/skbase/base/_meta.py
index 3a4e70db..dfef2493 100644
--- a/skbase/base/_meta.py
+++ b/skbase/base/_meta.py
@@ -10,6 +10,7 @@
"""Implements functionality for meta objects composed of other objects."""
from inspect import isclass
+from typing import ClassVar
from skbase.base._base import BaseEstimator, BaseObject
from skbase.base._pretty_printing._object_html_repr import _VisualBlock
@@ -39,7 +40,7 @@ class has values that follow the named object specification. For example,
# _steps_attr points to the attribute of self
# which contains the heterogeneous set of estimators
# this must be an iterable of (name: str, estimator) pairs for the default
- _tags = {"named_object_parameters": "steps"}
+ _tags: ClassVar[dict] = {"named_object_parameters": "steps"}
def is_composite(self):
"""Check if the object is composite.
@@ -139,9 +140,7 @@ def _get_fitted_params(self):
"""
fitted_params = self._get_fitted_params_default()
- fitted_named_object_attr = self.get_tag(
- "fitted_named_object_parameters"
- ) # type: ignore
+ fitted_named_object_attr = self.get_tag("fitted_named_object_parameters") # type: ignore
named_objects_fitted_params = self._get_params(
fitted_named_object_attr, fitted=True
@@ -253,7 +252,7 @@ def _set_params(self, attr: str, **params):
items = getattr(self, attr)
names = []
if items and isinstance(items, (list, tuple)):
- names = list(zip(*items))[0]
+ names = next(zip(*items, strict=False))
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_object(attr, name, params.pop(name))
@@ -323,14 +322,14 @@ def _check_names(self, names, make_unique=True):
A sequence of unique string names that follow named object API rules.
"""
if len(set(names)) != len(names):
- raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
+ raise ValueError(f"Names provided are not unique: {list(names)!r}")
# Get names that match direct parameter
invalid_names = set(names).intersection(self.get_params(deep=False))
invalid_names = invalid_names.union({name for name in names if "__" in name})
if invalid_names:
raise ValueError(
"Object names conflict with constructor argument or "
- "contain '__': {0!r}".format(sorted(invalid_names))
+ f"contain '__': {sorted(invalid_names)!r}"
)
if make_unique:
names = make_strings_unique(names)
@@ -363,10 +362,7 @@ def _coerce_object_tuple(self, obj, clone=False):
name = obj[0]
else:
- if isinstance(obj, tuple) and len(obj) == 1:
- _obj = obj[0]
- else:
- _obj = obj
+ _obj = obj[0] if isinstance(obj, tuple) and len(obj) == 1 else obj
name = type(_obj).__name__
if clone:
@@ -460,22 +456,25 @@ def is_obj_is_tuple(obj):
# We've already guarded against objs being dict when allow_dict is False
# So here we can just check dictionary elements
- if isinstance(objs, dict) and not all(
- isinstance(name, str) and isinstance(obj, cls_type)
- for name, obj in objs.items()
- ):
- raise TypeError(msg)
-
- elif not all(any(is_obj_is_tuple(x)) for x in objs):
+ if (
+ isinstance(objs, dict)
+ and not all(
+ isinstance(name, str) and isinstance(obj, cls_type)
+ for name, obj in objs.items()
+ )
+ ) or not all(any(is_obj_is_tuple(x)) for x in objs):
raise TypeError(msg)
msg_no_mix = (
f"Elements of {attr_name} must either all be objects, "
f"or all (str, objects) tuples. A mix of the two is not allowed."
)
- if not allow_mix and not all(is_obj_is_tuple(x)[0] for x in objs):
- if not all(is_obj_is_tuple(x)[1] for x in objs):
- raise TypeError(msg_no_mix)
+ if (
+ not allow_mix
+ and not all(is_obj_is_tuple(x)[0] for x in objs)
+ and not all(is_obj_is_tuple(x)[1] for x in objs)
+ ):
+ raise TypeError(msg_no_mix)
return self._coerce_to_named_object_tuples(objs, clone=clone, make_unique=True)
@@ -500,9 +499,11 @@ def _get_names_and_objects(self, named_objects, make_unique=False):
The
"""
if isinstance(named_objects, dict):
- names, objs = zip(*named_objects.items())
+ names, objs = zip(*named_objects.items(), strict=False)
else:
- names, objs = zip(*[self._coerce_object_tuple(x) for x in named_objects])
+ names, objs = zip(
+ *[self._coerce_object_tuple(x) for x in named_objects], strict=False
+ )
# Optionally make names unique
if make_unique:
@@ -555,7 +556,7 @@ def _coerce_to_named_object_tuples(self, objs, clone=False, make_unique=True):
named_objects, make_unique=make_unique
)
# Repack the objects
- named_objects = list(zip(names, objs))
+ named_objects = list(zip(names, objs, strict=False))
return named_objects
def _dunder_concat(
@@ -632,8 +633,7 @@ def _dunder_concat(
def concat(x, y):
if concat_order == "left":
return x + y
- else:
- return y + x
+ return y + x
# get attr_name from self and other
# can be list of ests, list of (str, est) tuples, or list of mixture of these
@@ -658,16 +658,15 @@ def concat(x, y):
new_objs = concat(self_objs, other_objs)
# create the "steps" param for the composite
# if all the names are equal to class names, we eat them away
- if all(type(x[1]).__name__ == x[0] for x in zip(new_names, new_objs)):
+ if all(
+ type(x[1]).__name__ == x[0] for x in zip(new_names, new_objs, strict=False)
+ ):
step_param = {attr_name: list(new_objs)}
else:
- step_param = {attr_name: list(zip(new_names, new_objs))}
+ step_param = {attr_name: list(zip(new_names, new_objs, strict=False))}
# retrieve other parameters, from composite_params attribute
- if composite_params is None:
- composite_params = {}
- else:
- composite_params = composite_params.copy()
+ composite_params = {} if composite_params is None else composite_params.copy()
# construct the composite with both step and additional params
composite_params.update(step_param)
@@ -806,7 +805,7 @@ def _tagchain_is_linked(
for _, est in estimators:
if est.get_tag(mid_tag_name) == mid_tag_val:
return True, True
- if not est.get_tag(left_tag_name) == left_tag_val:
+ if est.get_tag(left_tag_name) != left_tag_val:
return False, False
return True, False
diff --git a/skbase/base/_pretty_printing/__init__.py b/skbase/base/_pretty_printing/__init__.py
index c94c1888..1d46b64c 100644
--- a/skbase/base/_pretty_printing/__init__.py
+++ b/skbase/base/_pretty_printing/__init__.py
@@ -6,7 +6,5 @@
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
"""Functionality for pretty printing BaseObjects."""
-from typing import List
-
-__author__: List[str] = ["RNKuhns"]
-__all__: List[str] = []
+__author__: list[str] = ["RNKuhns"]
+__all__: list[str] = []
diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py
index 1c4fc017..d882d3cc 100644
--- a/skbase/base/_pretty_printing/_object_html_repr.py
+++ b/skbase/base/_pretty_printing/_object_html_repr.py
@@ -100,7 +100,7 @@ def _get_visual_block(base_object):
return _VisualBlock(
"single", base_object, names=base_object, name_details=base_object
)
- elif base_object is None:
+ if base_object is None:
return _VisualBlock("single", base_object, names="None", name_details="None")
# check if base_object looks like a meta base_object wraps base_object
@@ -137,7 +137,9 @@ def _write_base_object_html(
kind = est_block.kind
out.write(f'
')
- est_infos = zip(est_block.estimators, est_block.names, est_block.name_details)
+ est_infos = zip(
+ est_block.estimators, est_block.names, est_block.name_details, strict=False
+ )
for est, name, name_details in est_infos:
if kind == "serial":
@@ -335,7 +337,7 @@ def _write_base_object_html(
#$id div.sk-text-repr-fallback {
display: none;
}
-""".replace(" ", "").replace("\n", "") # noqa
+""".replace(" ", "").replace("\n", "")
def _object_html_repr(base_object):
diff --git a/skbase/base/_pretty_printing/_pprint.py b/skbase/base/_pretty_printing/_pprint.py
index b1a575bc..66a2889b 100644
--- a/skbase/base/_pretty_printing/_pprint.py
+++ b/skbase/base/_pretty_printing/_pprint.py
@@ -27,8 +27,6 @@ def __repr__(self):
class KeyValTupleParam(KeyValTuple):
"""Dummy class for correctly rendering key-value tuples from parameters."""
- pass
-
def _changed_params(base_object):
"""Return dict (param_name: value) of parameters with non-default values."""
@@ -48,11 +46,9 @@ def has_changed(k, v):
if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__:
return True
# Use repr as a last resort. It may be expensive.
- if repr(v) != repr(init_params[k]) and not (
+ return repr(v) != repr(init_params[k]) and not (
_is_scalar_nan(init_params[k]) and _is_scalar_nan(v)
- ):
- return True
- return False
+ )
return {k: v for k, v in params.items() if has_changed(k, v)}
@@ -131,7 +127,7 @@ def __init__(
# (they are treated as dicts)
self.n_max_elements_to_show = n_max_elements_to_show
- def format(self, obj, context, maxlevels, level): # noqa
+ def format(self, obj, context, maxlevels, level):
return _safe_repr(
obj, context, maxlevels, level, changed_only=self.changed_only
)
@@ -386,10 +382,7 @@ def _safe_repr(obj, context, maxlevels, level, changed_only=False):
context[objid] = 1
readable = True
recursive = False
- if changed_only:
- params = _changed_params(obj)
- else:
- params = obj.get_params(deep=False)
+ params = _changed_params(obj) if changed_only else obj.get_params(deep=False)
components = []
append = components.append
level += 1
diff --git a/skbase/base/_pretty_printing/tests/test_pprint.py b/skbase/base/_pretty_printing/tests/test_pprint.py
index a7322c18..7c922916 100644
--- a/skbase/base/_pretty_printing/tests/test_pprint.py
+++ b/skbase/base/_pretty_printing/tests/test_pprint.py
@@ -15,7 +15,7 @@ def __init__(self, foo, bar=84):
self.foo = foo
self.bar = bar
- super(CompositionDummy, self).__init__()
+ super().__init__()
@pytest.mark.skipif(
diff --git a/skbase/base/_tagmanager.py b/skbase/base/_tagmanager.py
index 87c2af23..757772f0 100644
--- a/skbase/base/_tagmanager.py
+++ b/skbase/base/_tagmanager.py
@@ -144,7 +144,7 @@ def _get_flag(
flag_value = collected_flags.get(flag_name, flag_value_default)
- if raise_error and flag_name not in collected_flags.keys():
+ if raise_error and flag_name not in collected_flags:
raise ValueError(f"Tag with name {flag_name} could not be found.")
return flag_value
@@ -209,7 +209,7 @@ def _clone_flags(self, estimator, flag_names=None, flag_attr_name="_flags"):
# if flag_set is passed, intersect keys with flags in estimator
if not isinstance(flag_names, list):
flag_names = [flag_names]
- flag_names = [key for key in flag_names if key in flags_est.keys()]
+ flag_names = [key for key in flag_names if key in flags_est]
update_dict = {key: flags_est[key] for key in flag_names}
diff --git a/skbase/lookup/__init__.py b/skbase/lookup/__init__.py
index 726cd29b..b4b335fc 100644
--- a/skbase/lookup/__init__.py
+++ b/skbase/lookup/__init__.py
@@ -17,12 +17,11 @@
# is based on the sklearn estimator retrieval utility of the same name
# See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
# https://github.com/sktime/sktime/blob/main/LICENSE
-from typing import List
from skbase.lookup._lookup import all_objects, get_package_metadata
-__all__: List[str] = ["all_objects", "get_package_metadata"]
-__author__: List[str] = [
+__all__: list[str] = ["all_objects", "get_package_metadata"]
+__author__: list[str] = [
"fkiraly",
"mloning",
"katiebuc",
diff --git a/skbase/lookup/_lookup.py b/skbase/lookup/_lookup.py
index 20175fdd..d697b1a2 100644
--- a/skbase/lookup/_lookup.py
+++ b/skbase/lookup/_lookup.py
@@ -22,19 +22,19 @@
import pkgutil
import re
import warnings
-from collections.abc import Iterable
+from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from copy import deepcopy
from functools import lru_cache
from operator import itemgetter
from types import ModuleType
-from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
+from typing import Any
from skbase.base import BaseObject
from skbase.utils.stdout_mute import StdoutMute
from skbase.validate import check_sequence
-__all__: List[str] = ["all_objects", "get_package_metadata"]
-__author__: List[str] = [
+__all__: list[str] = ["all_objects", "get_package_metadata"]
+__author__: list[str] = [
"fkiraly",
"mloning",
"katiebuc",
@@ -106,7 +106,7 @@ def _is_non_public_module(module_name: str) -> bool:
def _is_ignored_module(
- module_name: str, modules_to_ignore: Union[str, List[str], Tuple[str]] = None
+ module_name: str, modules_to_ignore: str | list[str] | tuple[str] | None = None
) -> bool:
"""Determine if module is one of the ignored modules.
@@ -141,7 +141,7 @@ def _is_ignored_module(
def _filter_by_class(
- klass: type, class_filter: Optional[Union[type, Sequence[type]]] = None
+ klass: type, class_filter: type | Sequence[type] | None = None
) -> bool:
"""Determine if a class is a subclass of the supplied classes.
@@ -217,7 +217,7 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
# case: tag_filter is dict
# check that all keys are str
- if not all(isinstance(t, str) for t in tag_filter.keys()):
+ if not all(isinstance(t, str) for t in tag_filter):
raise ValueError(f"{type_msg} {tag_filter}")
cond_sat = True
@@ -279,7 +279,7 @@ def _walk(root, exclude=None, prefix=""):
def _import_module(
- module: Union[str, importlib.machinery.SourceFileLoader],
+ module: str | importlib.machinery.SourceFileLoader,
suppress_import_stdout: bool = True,
) -> ModuleType:
"""Import a module, while optionally suppressing import standard out.
@@ -319,8 +319,8 @@ def _import_module(
def _determine_module_path(
- package_name: str, path: Optional[Union[str, pathlib.Path]] = None
-) -> Tuple[ModuleType, str, importlib.machinery.SourceFileLoader]:
+ package_name: str, path: str | pathlib.Path | None = None
+) -> tuple[ModuleType, str, importlib.machinery.SourceFileLoader]:
"""Determine a package's path information.
Parameters
@@ -402,11 +402,11 @@ def _get_module_info(
module: ModuleType,
is_pkg: bool,
path: str,
- package_base_classes: Union[type, Tuple[type, ...]],
+ package_base_classes: type | tuple[type, ...],
exclude_non_public_items: bool = True,
- class_filter: Optional[Union[type, Sequence[type]]] = None,
- tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
- classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
+ class_filter: type | Sequence[type] | None = None,
+ tag_filter: str | Sequence[str] | Mapping[str, Any] | None = None,
+ classes_to_exclude: type | Sequence[type] | None = None,
) -> dict: # of ModuleInfo type
# Make package_base_classes a tuple if it was supplied as a class
base_classes_none = False
@@ -417,16 +417,14 @@ def _get_module_info(
base_classes_none = True
package_base_classes = (package_base_classes,)
- exclude_classes: Optional[Sequence[type]]
- if classes_to_exclude is None:
- exclude_classes = classes_to_exclude
- elif isinstance(classes_to_exclude, Sequence):
+ exclude_classes: Sequence[type] | None
+ if classes_to_exclude is None or isinstance(classes_to_exclude, Sequence):
exclude_classes = classes_to_exclude
elif inspect.isclass(classes_to_exclude):
exclude_classes = (classes_to_exclude,)
- designed_imports: List[str] = getattr(module, "__all__", [])
- authors: Union[str, List[str]] = getattr(module, "__author__", [])
+ designed_imports: list[str] = getattr(module, "__all__", [])
+ authors: str | list[str] = getattr(module, "__author__", [])
if isinstance(authors, (list, tuple)):
authors = ", ".join(authors)
# Compile information on classes in the module
@@ -575,15 +573,15 @@ def _get_members_uw(module, predicate=None):
def get_package_metadata(
package_name: str,
- path: Optional[str] = None,
+ path: str | None = None,
recursive: bool = True,
exclude_non_public_items: bool = True,
exclude_non_public_modules: bool = True,
- modules_to_ignore: Union[str, List[str], Tuple[str]] = ("tests",),
- package_base_classes: Union[type, Tuple[type, ...]] = (BaseObject,),
- class_filter: Optional[Union[type, Sequence[type]]] = None,
- tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
- classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
+ modules_to_ignore: str | list[str] | tuple[str] = ("tests",),
+ package_base_classes: type | tuple[type, ...] = (BaseObject,),
+ class_filter: type | Sequence[type] | None = None,
+ tag_filter: str | Sequence[str] | Mapping[str, Any] | None = None,
+ classes_to_exclude: type | Sequence[type] | None = None,
suppress_import_stdout: bool = True,
) -> Mapping: # of ModuleInfo type
"""Return a dictionary mapping all package modules to their metadata.
@@ -721,10 +719,7 @@ def get_package_metadata(
continue
if recursive and is_pkg:
- if "." in name:
- name_ending = name[len(package_name) + 1 :]
- else:
- name_ending = name
+ name_ending = name[len(package_name) + 1 :] if "." in name else name
updated_path: str
if "." in name_ending:
@@ -759,7 +754,7 @@ def all_objects(
return_tags=None,
suppress_import_stdout=True,
package_name="skbase",
- path: Optional[str] = None,
+ path: str | None = None,
modules_to_ignore=None,
class_lookup=None,
):
@@ -963,13 +958,12 @@ class name if ``return_names=False`` and ``return_tags is not None``.
if all_estimators:
if isinstance(all_estimators[0], tuple):
all_estimators = [
- (name, est) + _get_return_tags(est, return_tags)
+ (name, est, *_get_return_tags(est, return_tags))
for (name, est) in all_estimators
]
else:
all_estimators = [
- (est,) + _get_return_tags(est, return_tags)
- for est in all_estimators
+ (est, *_get_return_tags(est, return_tags)) for est in all_estimators
]
columns = columns + return_tags
@@ -1027,20 +1021,19 @@ def _get_err_msg(estimator_type):
if class_lookup is None or not isinstance(class_lookup, dict):
return (
f"Parameter `estimator_type` must be None, a class, or a list of "
- f"class, but found: {repr(estimator_type)}"
- )
- else:
- return (
- f"Parameter `estimator_type` must be None, a string, a class, or a list"
- f" of [string or class]. Valid string values are: "
- f"{tuple(class_lookup.keys())}, but found: "
- f"{repr(estimator_type)}"
+ f"class, but found: {estimator_type!r}"
)
+ return (
+ f"Parameter `estimator_type` must be None, a string, a class, or a list"
+ f" of [string or class]. Valid string values are: "
+ f"{tuple(class_lookup.keys())}, but found: "
+ f"{estimator_type!r}"
+ )
for i, estimator_type in enumerate(object_types):
if isinstance(estimator_type, str):
if not isinstance(class_lookup, dict) or (
- estimator_type not in class_lookup.keys()
+ estimator_type not in class_lookup
):
raise ValueError(_get_err_msg(estimator_type))
estimator_type = class_lookup[estimator_type]
@@ -1079,7 +1072,7 @@ class StdoutMuteNCatchMNF(StdoutMute):
except catch and suppress ModuleNotFoundError.
"""
- def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
+ def _handle_exit_exceptions(self, type, value, traceback):
"""Handle exceptions raised during __exit__.
Parameters
@@ -1104,12 +1097,11 @@ def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
def _coerce_to_tuple(x):
if x is None:
return ()
- elif isinstance(x, tuple):
+ if isinstance(x, tuple):
return x
- elif isinstance(x, list):
+ if isinstance(x, list):
return tuple(x)
- else:
- return (x,)
+ return (x,)
@lru_cache(maxsize=100)
@@ -1144,7 +1136,7 @@ def _walk_and_retrieve_all_objs(root, package_name, modules_to_ignore):
prefix = f"{package_name}."
def _is_base_class(name):
- return name.startswith("_") or name.startswith("Base")
+ return name.startswith(("_", "Base"))
all_estimators = []
diff --git a/skbase/lookup/tests/test_lookup.py b/skbase/lookup/tests/test_lookup.py
index fc029316..5508665e 100644
--- a/skbase/lookup/tests/test_lookup.py
+++ b/skbase/lookup/tests/test_lookup.py
@@ -11,7 +11,6 @@
import sys
from copy import deepcopy
from types import ModuleType
-from typing import List
import pandas as pd
import pytest
@@ -46,8 +45,8 @@
NotABaseObject,
)
-__author__: List[str] = ["RNKuhns", "fkiraly"]
-__all__: List[str] = []
+__author__: list[str] = ["RNKuhns", "fkiraly"]
+__all__: list[str] = []
MODULE_METADATA_EXPECTED_KEYS = (
@@ -186,10 +185,7 @@ def _check_package_metadata_result(results):
isinstance(k, str) and isinstance(v, dict)
for k, v in mod_metadata["classes"].items()
)
- ):
- return False
- # Then verify sub-dict values for each class have required keys
- elif not all(
+ ) or not all(
k in c_meta
for c_meta in mod_metadata["classes"].values()
for k in REQUIRED_CLASS_METADATA_KEYS
@@ -202,10 +198,7 @@ def _check_package_metadata_result(results):
isinstance(k, str) and isinstance(v, dict)
for k, v in mod_metadata["functions"].items()
)
- ):
- return False
- # Then verify sub-dict values for each function have required keys
- elif not all(
+ ) or not all(
k in f_meta
for f_meta in mod_metadata["functions"].values()
for k in REQUIRED_FUNCTION_METADATA_KEYS
@@ -415,9 +408,9 @@ def test_walk_returns_expected_format(fixture_skbase_root_path):
"""Check walk function returns expected format."""
def _test_walk_return(p):
- assert (
- isinstance(p, tuple) and len(p) == 3
- ), "_walk should return tuple of length 3"
+ assert isinstance(p, tuple) and len(p) == 3, (
+ "_walk should return tuple of length 3"
+ )
assert (
isinstance(p[0], str)
and isinstance(p[1], bool)
@@ -790,13 +783,13 @@ def test_get_package_metadata_returns_expected_results(
# Verify class metadata attributes correct
for klass, klass_metadata in results[module]["classes"].items():
if klass_metadata["klass"] in SKBASE_BASE_CLASSES:
- assert (
- klass_metadata["is_base_class"] is True
- ), f"{klass} should be base class."
+ assert klass_metadata["is_base_class"] is True, (
+ f"{klass} should be base class."
+ )
else:
- assert (
- klass_metadata["is_base_class"] is False
- ), f"{klass} should not be base class."
+ assert klass_metadata["is_base_class"] is False, (
+ f"{klass} should not be base class."
+ )
if issubclass(klass_metadata["klass"], BaseObject):
assert klass_metadata["is_base_object"] is True
@@ -982,13 +975,12 @@ def test_all_objects_returns_expected_types(
if isinstance(modules_to_ignore, str):
modules_to_ignore = (modules_to_ignore,)
if (
- modules_to_ignore is not None
- and "tests" in modules_to_ignore
+ modules_to_ignore is not None and "tests" in modules_to_ignore
# and "mock_package" in modules_to_ignore
):
- assert (
- len(objs) == 0
- ), "Search of `skbase` should only return objects from tests module."
+ assert len(objs) == 0, (
+ "Search of `skbase` should only return objects from tests module."
+ )
else:
# We expect at least one object to be returned so we verify output type/format
_check_all_object_output_types(
diff --git a/skbase/testing/__init__.py b/skbase/testing/__init__.py
index 012948fb..d9193026 100644
--- a/skbase/testing/__init__.py
+++ b/skbase/testing/__init__.py
@@ -1,13 +1,11 @@
# -*- coding: utf-8 -*-
""":mod:`skbase.testing` provides a framework to test ``BaseObject``-s."""
-from typing import List
-
from skbase.testing.test_all_objects import (
BaseFixtureGenerator,
QuickTester,
TestAllObjects,
)
-__all__: List[str] = ["BaseFixtureGenerator", "QuickTester", "TestAllObjects"]
-__author__: List[str] = ["fkiraly"]
+__all__: list[str] = ["BaseFixtureGenerator", "QuickTester", "TestAllObjects"]
+__author__: list[str] = ["fkiraly"]
diff --git a/skbase/testing/test_all_objects.py b/skbase/testing/test_all_objects.py
index a1bdeea9..021eebbf 100644
--- a/skbase/testing/test_all_objects.py
+++ b/skbase/testing/test_all_objects.py
@@ -9,7 +9,7 @@
import types
from copy import deepcopy
from inspect import getfullargspec, isclass, signature
-from typing import List
+from typing import ClassVar
import numpy as np
import pytest
@@ -23,7 +23,7 @@
from skbase.utils.deep_equals import deep_equals
from skbase.utils.dependencies import _check_soft_dependencies
-__author__: List[str] = ["fkiraly"]
+__author__: list[str] = ["fkiraly"]
class BaseFixtureGenerator:
@@ -106,12 +106,12 @@ class BaseFixtureGenerator:
valid_base_types = None
# which sequence the conditional fixtures are generated in
- fixture_sequence = ["object_class", "object_instance"]
+ fixture_sequence: ClassVar[list[str]] = ["object_class", "object_instance"]
# which fixtures are indirect, e.g., have an additional pytest.fixture block
# to generate an indirect fixture at runtime. Example: object_instance
# warning: direct fixtures retain state changes within the same test
- indirect_fixtures = ["object_instance"]
+ indirect_fixtures: ClassVar[list[str]] = ["object_instance"]
def pytest_generate_tests(self, metafunc):
"""Test parameterization routine for pytest.
@@ -176,7 +176,7 @@ def generator_dict(self):
fixts = [gen.replace("_generate_", "") for gen in gens]
generator_dict = {}
- for var, gen in zip(fixts, gens):
+ for var, gen in zip(fixts, gens, strict=False):
generator_dict[var] = getattr(self, gen)
return generator_dict
@@ -185,8 +185,7 @@ def is_excluded(self, test_name, est):
"""Shorthand to check whether test test_name is excluded for object est."""
if self.excluded_tests is None:
return False
- else:
- return test_name in self.excluded_tests.get(est.__name__, [])
+ return test_name in self.excluded_tests.get(est.__name__, [])
# the following functions define fixture generation logic for pytest_generate_tests
# each function is of signature (test_name:str, **kwargs) -> List of fixtures
@@ -419,36 +418,38 @@ def run_tests(
# if function is decorated with mark.parametrize, add variable settings
# NOTE: currently this works only with single-variable mark.parametrize
- if hasattr(test_fun, "pytestmark"):
- if len([x for x in test_fun.pytestmark if x.name == "parametrize"]) > 0:
- # get the three lists from pytest
- (
- pytest_fixture_vars,
- pytest_fixture_prod,
- pytest_fixture_names,
- ) = self._get_pytest_mark_args(test_fun)
- # add them to the three lists from conditional fixtures
- fixture_vars, fixture_prod, fixture_names = self._product_fixtures(
- fixture_vars,
- fixture_prod,
- fixture_names,
- pytest_fixture_vars,
- pytest_fixture_prod,
- pytest_fixture_names,
- )
+ if (
+ hasattr(test_fun, "pytestmark")
+ and len([x for x in test_fun.pytestmark if x.name == "parametrize"]) > 0
+ ):
+ # get the three lists from pytest
+ (
+ pytest_fixture_vars,
+ pytest_fixture_prod,
+ pytest_fixture_names,
+ ) = self._get_pytest_mark_args(test_fun)
+ # add them to the three lists from conditional fixtures
+ fixture_vars, fixture_prod, fixture_names = self._product_fixtures(
+ fixture_vars,
+ fixture_prod,
+ fixture_names,
+ pytest_fixture_vars,
+ pytest_fixture_prod,
+ pytest_fixture_names,
+ )
def print_if_verbose(msg):
if int(verbose) > 0:
- print(msg) # noqa: T001, T201
+ print(msg) # noqa: T201
# loop B: for each test, we loop over all fixtures
- for params, fixt_name in zip(fixture_prod, fixture_names):
+ for params, fixt_name in zip(fixture_prod, fixture_names, strict=False):
# this is needed because pytest unwraps 1-tuples automatically
# but subsequent code assumes params is k-tuple, no matter what k is
if len(fixture_vars) == 1:
params = (params,)
key = f"{test_name}[{fixt_name}]"
- args = dict(zip(fixture_vars, params))
+ args = dict(zip(fixture_vars, params, strict=False))
for f in test_fun_vars:
if f not in args:
@@ -509,10 +510,7 @@ def _subset_generator_dict(obj, generator_dict):
"""
obj_generator_dict = generator_dict
- if isclass(obj):
- object_class = obj
- else:
- object_class = type(obj)
+ object_class = obj if isclass(obj) else type(obj)
def _generate_object_class(test_name, **kwargs):
return [object_class], [object_class.__name__]
@@ -573,10 +571,9 @@ def to_str(obj):
return [str(x) for x in obj]
def get_id(mark):
- if "ids" in mark.kwargs.keys():
+ if "ids" in mark.kwargs:
return mark.kwargs["ids"]
- else:
- return to_str(range(len(mark.args[1])))
+ return to_str(range(len(mark.args[1])))
pytest_fixture_vars = [x.args[0] for x in marks]
pytest_fixt_raw = [x.args[1] for x in marks]
@@ -618,16 +615,16 @@ def _product_fixtures(
return fixture_vars_return, fixture_prod_return, fixture_names_return
def _make_builtin_fixture_equivalents(self, name):
- """Utility for QuickTester, creates equivalent fixtures for pytest runs."""
+ """Create equivalent fixtures for pytest runs."""
import io
import logging
import tempfile
from pathlib import Path
values = {}
- if "tmp_path" == name:
+ if name == "tmp_path":
return Path(tempfile.mkdtemp())
- if "capsys" == name:
+ if name == "capsys":
# crude emulation using StringIO
return type(
"Capsys",
@@ -639,12 +636,12 @@ def _make_builtin_fixture_equivalents(self, name):
},
)()
- if "monkeypatch" == name:
+ if name == "monkeypatch":
from _pytest.monkeypatch import MonkeyPatch
return MonkeyPatch()
- if "caplog" == name:
+ if name == "caplog":
class Caplog:
def __init__(self):
@@ -735,7 +732,7 @@ def test_object_tags(self, object_class):
assert hasattr(object_class, "get_class_tags")
all_tags = object_class.get_class_tags()
assert isinstance(all_tags, dict)
- assert all(isinstance(key, str) for key in all_tags.keys())
+ assert all(isinstance(key, str) for key in all_tags)
if hasattr(object_class, "_tags"):
tags = object_class._tags
msg = (
@@ -747,9 +744,7 @@ def test_object_tags(self, object_class):
if self.valid_tags is None:
invalid_tags = tags
else:
- invalid_tags = [
- tag for tag in tags.keys() if tag not in self.valid_tags
- ]
+ invalid_tags = [tag for tag in tags if tag not in self.valid_tags]
assert len(invalid_tags) == 0, (
f"_tags of {object_class} contains invalid tags: {invalid_tags}. "
f"For a list of valid tags, see {self.__class__.__name__}.valid_tags."
@@ -766,7 +761,7 @@ def test_object_tags(self, object_class):
def test_inheritance(self, object_class):
"""Check that object inherits from BaseObject."""
assert issubclass(object_class, BaseObject), (
- f"object: {object_class} " f"is not a sub-class of " f"BaseObject."
+ f"object: {object_class} is not a sub-class of BaseObject."
)
# Usually should inherit only from one BaseObject type
if self.valid_base_types is not None:
@@ -988,13 +983,13 @@ def param_filter(p):
def test_valid_object_class_tags(self, object_class):
"""Check that object class tags are in self.valid_tags."""
if self.valid_tags is None:
- return None
- for tag in object_class.get_class_tags().keys():
+ return
+ for tag in object_class.get_class_tags():
assert tag in self.valid_tags
def test_valid_object_tags(self, object_instance):
"""Check that object tags are in self.valid_tags."""
if self.valid_tags is None:
- return None
- for tag in object_instance.get_tags().keys():
+ return
+ for tag in object_instance.get_tags():
assert tag in self.valid_tags
diff --git a/skbase/testing/utils/__init__.py b/skbase/testing/utils/__init__.py
index 550c8c89..6fcd6756 100644
--- a/skbase/testing/utils/__init__.py
+++ b/skbase/testing/utils/__init__.py
@@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
"""Utilities for the test framework."""
-from typing import List
-
-__all__: List[str] = []
+__all__: list[str] = []
diff --git a/skbase/testing/utils/_conditional_fixtures.py b/skbase/testing/utils/_conditional_fixtures.py
index 01689311..53b21112 100644
--- a/skbase/testing/utils/_conditional_fixtures.py
+++ b/skbase/testing/utils/_conditional_fixtures.py
@@ -4,22 +4,22 @@
Exports create_conditional_fixtures_and_names utility
"""
+from collections.abc import Callable
from copy import deepcopy
-from typing import Callable, Dict, List
from skbase._exceptions import FixtureGenerationError
from skbase.utils._nested_iter import _remove_single
from skbase.validate import check_sequence
-__author__: List[str] = ["fkiraly"]
-__all__: List[str] = ["create_conditional_fixtures_and_names"]
+__author__: list[str] = ["fkiraly"]
+__all__: list[str] = ["create_conditional_fixtures_and_names"]
def create_conditional_fixtures_and_names(
test_name: str,
- fixture_vars: List[str],
- generator_dict: Dict[str, Callable],
- fixture_sequence: List[str] = None,
+ fixture_vars: list[str],
+ generator_dict: dict[str, Callable],
+ fixture_sequence: list[str] | None = None,
raise_exceptions: bool = False,
deepcopy_fixtures: bool = False,
):
@@ -101,7 +101,7 @@ def create_conditional_fixtures_and_names(
fixture_vars = check_sequence(
fixture_vars, sequence_type=list, element_type=str, sequence_name="fixture_vars"
)
- fixture_vars = [var for var in fixture_vars if var in generator_dict.keys()]
+ fixture_vars = [var for var in fixture_vars if var in generator_dict]
# order fixture_vars according to fixture_sequence if provided
if fixture_sequence is not None:
@@ -153,7 +153,7 @@ def get_fixtures(fixture_var, **kwargs):
except Exception as err:
error = FixtureGenerationError(fixture_name=fixture_var, err=err)
if raise_exceptions:
- raise error
+ raise error from err
fixture_prod = [error]
fixture_names = [f"Error:{fixture_var}"]
@@ -176,12 +176,12 @@ def get_fixtures(fixture_var, **kwargs):
if i == 0:
kwargs = {}
else:
- kwargs = dict(zip(old_fixture_vars, fixture))
+ kwargs = dict(zip(old_fixture_vars, fixture, strict=False))
# retrieve conditional fixtures, conditional on fixture values in kwargs
new_fixtures, new_fixture_names_r = get_fixtures(fixture_var, **kwargs)
# new fixture values are concatenation/product of old values plus new
new_fixture_prod += [
- fixture + (new_fixture,) for new_fixture in new_fixtures
+ (*fixture, new_fixture) for new_fixture in new_fixtures
]
# new fixture name is concatenation of name so far and "dash-new name"
# if the new name is empty string, don't add a dash
diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py
index dec1b79b..6f253196 100644
--- a/skbase/tests/conftest.py
+++ b/skbase/tests/conftest.py
@@ -1,20 +1,20 @@
# -*- coding: utf-8 -*-
"""Common functionality for skbase unit tests."""
-from typing import List
+from typing import ClassVar
from skbase.base import BaseEstimator, BaseObject
-__all__: List[str] = [
+__all__: list[str] = [
"SKBASE_BASE_CLASSES",
+ "SKBASE_CLASSES_BY_MODULE",
+ "SKBASE_FUNCTIONS_BY_MODULE",
"SKBASE_MODULES",
- "SKBASE_PUBLIC_MODULES",
"SKBASE_PUBLIC_CLASSES_BY_MODULE",
- "SKBASE_CLASSES_BY_MODULE",
"SKBASE_PUBLIC_FUNCTIONS_BY_MODULE",
- "SKBASE_FUNCTIONS_BY_MODULE",
+ "SKBASE_PUBLIC_MODULES",
]
-__author__: List[str] = ["fkiraly", "RNKuhns"]
+__author__: list[str] = ["fkiraly", "RNKuhns"]
# bug 442 fixed: metaclasses now discovered correctly on all Python versions
IMPORT_CLS = ("CommonMagicMeta", "MagicAttribute")
@@ -332,7 +332,7 @@
class Parent(BaseObject):
"""Parent class to test BaseObject's usage."""
- _tags = {"A": "1", "B": 2, "C": 1234, "3": "D"}
+ _tags: ClassVar[dict] = {"A": "1", "B": 2, "C": 1234, "3": "D"}
def __init__(self, a="something", b=7, c=None):
"""Initialize the class."""
@@ -343,36 +343,31 @@ def __init__(self, a="something", b=7, c=None):
def some_method(self):
"""To be implemented by child class."""
- pass
# Fixture class for testing tag system, child overrides tags
class Child(Parent):
"""Child class that is child of FixtureClassParent."""
- _tags = {"A": 42, "3": "E"}
- __author__ = ["fkiraly", "RNKuhns"]
+ _tags: ClassVar[dict] = {"A": 42, "3": "E"}
+ __author__: ClassVar[list[str]] = ["fkiraly", "RNKuhns"]
def some_method(self):
"""Child class' implementation."""
- pass
def some_other_method(self):
"""To be implemented in the child class."""
- pass
# Fixture class for testing tag system, child overrides tags
class ClassWithABTrue(Parent):
"""Child class that sets A, B tags to True."""
- _tags = {"A": True, "B": True}
- __author__ = ["fkiraly", "RNKuhns"]
+ _tags: ClassVar[dict] = {"A": True, "B": True}
+ __author__: ClassVar[list[str]] = ["fkiraly", "RNKuhns"]
def some_method(self):
"""Child class' implementation."""
- pass
def some_other_method(self):
"""To be implemented in the child class."""
- pass
diff --git a/skbase/tests/mock_package/__init__.py b/skbase/tests/mock_package/__init__.py
index ec467ed8..cc6f4a25 100644
--- a/skbase/tests/mock_package/__init__.py
+++ b/skbase/tests/mock_package/__init__.py
@@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
"""Mock package for skbase testing."""
-from typing import List
-
-__author__: List[str] = ["fkiraly", "RNKuhns"]
+__author__: list[str] = ["fkiraly", "RNKuhns"]
diff --git a/skbase/tests/mock_package/test_mock_package.py b/skbase/tests/mock_package/test_mock_package.py
index 2ecbf6b9..ebe859ef 100644
--- a/skbase/tests/mock_package/test_mock_package.py
+++ b/skbase/tests/mock_package/test_mock_package.py
@@ -2,17 +2,16 @@
"""Mock package for testing skbase functionality."""
from copy import deepcopy
-from typing import List
from skbase.base import BaseObject
-__all__: List[str] = [
+__all__: list[str] = [
+ "AnotherClass",
"CompositionDummy",
"InheritsFromBaseObject",
- "AnotherClass",
"NotABaseObject",
]
-__author__: List[str] = ["fkiraly", "RNKuhns"]
+__author__: list[str] = ["fkiraly", "RNKuhns"]
class CompositionDummy(BaseObject):
@@ -23,7 +22,7 @@ def __init__(self, foo, bar=84):
self.foo_ = deepcopy(foo)
self.bar = bar
- super(CompositionDummy, self).__init__()
+ super().__init__()
@classmethod
def get_test_params(cls, parameter_set="default"):
diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py
index 23b4582c..0880e5cc 100644
--- a/skbase/tests/test_base.py
+++ b/skbase/tests/test_base.py
@@ -21,55 +21,55 @@
__author__ = ["fkiraly", "RNKuhns"]
__all__ = [
- "test_get_class_tags",
- "test_get_class_tag",
- "test_get_tags",
- "test_get_tag",
- "test_get_tag_raises",
- "test_set_tags",
- "test_set_tags_works_with_missing_tags_dynamic_attribute",
+ "test_baseobject_repr",
+ "test_baseobject_repr_mimebundle_",
+ "test_baseobject_str",
+ "test_clone",
+ "test_clone_2",
+ "test_clone_class_rather_than_instance_raises_error",
+ "test_clone_estimator_types",
+ "test_clone_none_and_empty_array_nan_sparse_matrix",
+ "test_clone_raises_error_for_nonconforming_objects",
+ "test_clone_sklearn_composite",
"test_clone_tags",
- "test_is_composite",
"test_components",
- "test_components_raises_error_base_class_is_not_class",
"test_components_raises_error_base_class_is_not_baseobject_subclass",
- "test_param_alias",
- "test_nested_set_params_and_alias",
- "test_reset",
- "test_reset_composite",
+ "test_components_raises_error_base_class_is_not_class",
+ "test_create_test_instance",
+ "test_create_test_instances_and_names",
+ "test_eq_dunder",
+ "test_get_class_tag",
+ "test_get_class_tags",
"test_get_init_signature",
"test_get_init_signature_raises_error_for_invalid_signature",
"test_get_param_names",
"test_get_params",
- "test_get_params_invariance",
"test_get_params_after_set_params",
+ "test_get_params_invariance",
+ "test_get_tag",
+ "test_get_tag_raises",
+ "test_get_tags",
+ "test_get_test_params",
+ "test_get_test_params_raises_error_when_params_required",
+ "test_has_implementation_of",
+ "test_is_composite",
+ "test_nested_set_params_and_alias",
+ "test_param_alias",
+ "test_raises_on_get_params_for_param_arg_not_assigned_to_attribute",
+ "test_repr_html_wraps",
+ "test_reset",
+ "test_reset_composite",
"test_set_params",
"test_set_params_raises_error_non_existent_param",
"test_set_params_raises_error_non_interface_composite",
- "test_raises_on_get_params_for_param_arg_not_assigned_to_attribute",
"test_set_params_with_no_param_to_set_returns_object",
- "test_clone",
- "test_clone_2",
- "test_clone_raises_error_for_nonconforming_objects",
- "test_clone_none_and_empty_array_nan_sparse_matrix",
- "test_clone_estimator_types",
- "test_clone_class_rather_than_instance_raises_error",
- "test_clone_sklearn_composite",
- "test_baseobject_repr",
- "test_baseobject_str",
- "test_baseobject_repr_mimebundle_",
- "test_repr_html_wraps",
- "test_get_test_params",
- "test_get_test_params_raises_error_when_params_required",
- "test_create_test_instance",
- "test_create_test_instances_and_names",
- "test_has_implementation_of",
- "test_eq_dunder",
+ "test_set_tags",
+ "test_set_tags_works_with_missing_tags_dynamic_attribute",
]
import inspect
from copy import deepcopy
-from typing import Any, Dict, Type
+from typing import Any, ClassVar
import numpy as np
import pytest
@@ -112,7 +112,7 @@ def __init__(self, a, *args):
class RequiredParam(BaseObject):
"""BaseObject class with _required_parameters."""
- _required_parameters = ["a"]
+ _required_parameters: ClassVar[list[str]] = ["a"]
def __init__(self, a, b=7):
self.a = a
@@ -198,7 +198,7 @@ def fixture_reset_tester():
@pytest.fixture
-def fixture_class_child_tags(fixture_class_child: Type[Child]):
+def fixture_class_child_tags(fixture_class_child: type[Child]):
"""Pytest fixture for tags of Child."""
return fixture_class_child.get_class_tags()
@@ -265,7 +265,7 @@ def fixture_class_instance_no_param_interface():
def test_get_class_tags(
- fixture_class_child: Type[Child], fixture_class_child_tags: Any
+ fixture_class_child: type[Child], fixture_class_child_tags: Any
):
"""Test get_class_tags class method of BaseObject for correctness.
@@ -280,7 +280,7 @@ def test_get_class_tags(
assert child_tags == fixture_class_child_tags, msg
-def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any):
+def test_get_class_tag(fixture_class_child: type[Child], fixture_class_child_tags: Any):
"""Test get_class_tag class method of BaseObject for correctness.
Raises
@@ -307,7 +307,7 @@ def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tag
assert child_tag_default_none is None, msg
-def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
+def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: dict[str, Any]):
"""Test get_tags method of BaseObject for correctness.
Raises
@@ -321,7 +321,7 @@ def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str
assert object_tags == fixture_object_tags, msg
-def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
+def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: dict[str, Any]):
"""Test get_tag method of BaseObject for correctness.
Raises
@@ -364,8 +364,8 @@ def test_get_tag_raises(fixture_tag_class_object: Child):
def test_set_tags(
fixture_object_instance_set_tags: Any,
- fixture_object_set_tags: Dict[str, Any],
- fixture_object_dynamic_tags: Dict[str, int],
+ fixture_object_set_tags: dict[str, Any],
+ fixture_object_dynamic_tags: dict[str, int],
):
"""Test set_tags method of BaseObject for correctness.
@@ -387,7 +387,7 @@ def test_set_tags_works_with_missing_tags_dynamic_attribute(
"""Test set_tags will still work if _tags_dynamic is missing."""
base_obj = deepcopy(fixture_tag_class_object)
attr_name = "_tags_dynamic"
- delattr(base_obj, attr_name) # noqa
+ delattr(base_obj, attr_name)
assert not hasattr(base_obj, "_tags_dynamic")
base_obj.set_tags(some_tag="something")
tags = base_obj.get_tags()
@@ -399,7 +399,7 @@ def test_clone_tags():
"""Test clone_tags works as expected."""
class TestClass(BaseObject):
- _tags = {"some_tag": True, "another_tag": 37}
+ _tags: ClassVar[dict] = {"some_tag": True, "another_tag": 37}
class AnotherTestClass(BaseObject):
pass
@@ -463,7 +463,7 @@ class AnotherTestClass(BaseObject):
assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
-def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]):
+def test_is_composite(fixture_composition_dummy: type[CompositionDummy]):
"""Test is_composite tag for correctness.
Raises
@@ -478,9 +478,9 @@ def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]):
def test_components(
- fixture_object: Type[BaseObject],
- fixture_class_parent: Type[Parent],
- fixture_composition_dummy: Type[CompositionDummy],
+ fixture_object: type[BaseObject],
+ fixture_class_parent: type[Parent],
+ fixture_composition_dummy: type[CompositionDummy],
):
"""Test component retrieval.
@@ -514,7 +514,7 @@ def test_components(
def test_components_raises_error_base_class_is_not_class(
- fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy]
+ fixture_object: type[BaseObject], fixture_composition_dummy: type[CompositionDummy]
):
"""Test _component method raises error if base_class param is not class."""
non_composite = fixture_composition_dummy(foo=42)
@@ -533,7 +533,7 @@ def test_components_raises_error_base_class_is_not_class(
def test_components_raises_error_base_class_is_not_baseobject_subclass(
- fixture_composition_dummy: Type[CompositionDummy],
+ fixture_composition_dummy: type[CompositionDummy],
):
"""Test _component method raises error if base_class is not BaseObject subclass."""
@@ -563,18 +563,18 @@ def test_param_alias():
composite = CompositionDummy(foo=non_composite)
# this should write to a of foo, because there is only one suffix called a
- composite.set_params(**{"a": 424242})
+ composite.set_params(a=424242)
assert composite.get_params()["foo__a"] == 424242
# this should write to bar of composite, because "bar" is a full parameter string
# there is a suffix in foo, but if the full string is there, it writes to that
- composite.set_params(**{"bar": 424243})
+ composite.set_params(bar=424243)
assert composite.get_params()["bar"] == 424243
# trying to write to bad_param should raise an exception
# since bad_param is neither a suffix nor a full parameter string
with pytest.raises(ValueError, match=r"Invalid parameter keys provided to"):
- composite.set_params(**{"bad_param": 424242})
+ composite.set_params(bad_param=424242)
# new example: highly nested composite with identical suffixes
non_composite1 = composite
@@ -584,10 +584,10 @@ def test_param_alias():
# trying to write to a should raise an exception
# since there are two suffix a, and a is not a full parameter string
with pytest.raises(ValueError, match=r"does not uniquely determine parameter key"):
- uber_composite.set_params(**{"a": 424242})
+ uber_composite.set_params(a=424242)
# same as above, should overwrite "bar" of uber_composite
- uber_composite.set_params(**{"bar": 424243})
+ uber_composite.set_params(bar=424243)
assert uber_composite.get_params()["bar"] == 424243
@@ -611,14 +611,14 @@ def test_nested_set_params_and_alias():
# this should write to a of foo
# potential error here is that composite does not have foo__a to start with
# so error catching or writing foo__a to early could cause an exception
- composite.set_params(**{"foo": non_composite, "foo__a": 424242})
+ composite.set_params(foo=non_composite, foo__a=424242)
assert composite.get_params()["foo__a"] == 424242
non_composite = AliasTester(a=42, bar=4242)
composite = CompositionDummy(foo=0)
# same, and recognizing that foo__a is the only matching suffix in the end state
- composite.set_params(**{"foo": non_composite, "a": 424242})
+ composite.set_params(foo=non_composite, a=424242)
assert composite.get_params()["foo__a"] == 424242
# new example: highly nested composite with identical suffixes
@@ -629,20 +629,18 @@ def test_nested_set_params_and_alias():
# trying to write to a should raise an exception
# since there are two suffix a, and a is not a full parameter string
with pytest.raises(ValueError, match=r"does not uniquely determine parameter key"):
- uber_composite.set_params(
- **{"a": 424242, "foo": non_composite1, "bar": non_composite2}
- )
+ uber_composite.set_params(a=424242, foo=non_composite1, bar=non_composite2)
uber_composite = CompositionDummy(foo=non_composite1, bar=42)
# same as above, should overwrite "bar" of uber_composite
- uber_composite.set_params(**{"bar": 424243})
+ uber_composite.set_params(bar=424243)
assert uber_composite.get_params()["bar"] == 424243
# Test parameter interface (get_params, set_params, reset and related methods)
# Some tests of get_params and set_params are adapted from sklearn tests
-def test_reset(fixture_reset_tester: Type[ResetTester]):
+def test_reset(fixture_reset_tester: type[ResetTester]):
"""Test reset method for correct behaviour, on a simple estimator.
Raises
@@ -669,7 +667,7 @@ def test_reset(fixture_reset_tester: Type[ResetTester]):
assert hasattr(x, "foo")
-def test_reset_composite(fixture_reset_tester: Type[ResetTester]):
+def test_reset_composite(fixture_reset_tester: type[ResetTester]):
"""Test reset method for correct behaviour, on a composite estimator."""
y = fixture_reset_tester(42)
x = fixture_reset_tester(a=y)
@@ -684,20 +682,20 @@ def test_reset_composite(fixture_reset_tester: Type[ResetTester]):
assert not hasattr(x.a, "d")
-def test_get_init_signature(fixture_class_parent: Type[Parent]):
+def test_get_init_signature(fixture_class_parent: type[Parent]):
"""Test error is raised when invalid init signature is used."""
init_sig = fixture_class_parent._get_init_signature()
init_sig_is_list = isinstance(init_sig, list)
init_sig_elements_are_params = all(
isinstance(p, inspect.Parameter) for p in init_sig
)
- assert (
- init_sig_is_list and init_sig_elements_are_params
- ), "`_get_init_signature` is not returning expected result."
+ assert init_sig_is_list and init_sig_elements_are_params, (
+ "`_get_init_signature` is not returning expected result."
+ )
def test_get_init_signature_raises_error_for_invalid_signature(
- fixture_invalid_init: Type[InvalidInitSignatureTester],
+ fixture_invalid_init: type[InvalidInitSignatureTester],
):
"""Test error is raised when invalid init signature is used."""
with pytest.raises(RuntimeError):
@@ -706,9 +704,9 @@ def test_get_init_signature_raises_error_for_invalid_signature(
@pytest.mark.parametrize("sort", [True, False])
def test_get_param_names(
- fixture_object: Type[BaseObject],
- fixture_class_parent: Type[Parent],
- fixture_class_parent_expected_params: Dict[str, Any],
+ fixture_object: type[BaseObject],
+ fixture_class_parent: type[Parent],
+ fixture_class_parent_expected_params: dict[str, Any],
sort: bool,
):
"""Test that get_param_names returns list of string parameter names."""
@@ -723,10 +721,10 @@ def test_get_param_names(
def test_get_params(
- fixture_class_parent: Type[Parent],
- fixture_class_parent_expected_params: Dict[str, Any],
+ fixture_class_parent: type[Parent],
+ fixture_class_parent_expected_params: dict[str, Any],
fixture_class_instance_no_param_interface: NoParamInterface,
- fixture_composition_dummy: Type[CompositionDummy],
+ fixture_composition_dummy: type[CompositionDummy],
):
"""Test get_params returns expected parameters."""
# Simple test of returned params
@@ -750,8 +748,8 @@ def test_get_params(
def test_get_params_invariance(
- fixture_class_parent: Type[Parent],
- fixture_composition_dummy: Type[CompositionDummy],
+ fixture_class_parent: type[Parent],
+ fixture_composition_dummy: type[CompositionDummy],
):
"""Test that get_params(deep=False) is subset of get_params(deep=True)."""
composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
@@ -760,7 +758,7 @@ def test_get_params_invariance(
assert all(item in deep_params.items() for item in shallow_params.items())
-def test_get_params_after_set_params(fixture_class_parent: Type[Parent]):
+def test_get_params_after_set_params(fixture_class_parent: type[Parent]):
"""Test that get_params returns the same thing before and after set_params.
Based on scikit-learn check in check_estimator.
@@ -780,7 +778,7 @@ def test_get_params_after_set_params(fixture_class_parent: Type[Parent]):
test_values = [-np.inf, np.inf, None]
test_params = deepcopy(orig_params)
- for param_name in orig_params.keys():
+ for param_name in orig_params:
default_value = orig_params[param_name]
for value in test_values:
test_params[param_name] = value
@@ -801,9 +799,9 @@ def test_get_params_after_set_params(fixture_class_parent: Type[Parent]):
def test_set_params(
- fixture_class_parent: Type[Parent],
- fixture_class_parent_expected_params: Dict[str, Any],
- fixture_composition_dummy: Type[CompositionDummy],
+ fixture_class_parent: type[Parent],
+ fixture_class_parent_expected_params: dict[str, Any],
+ fixture_composition_dummy: type[CompositionDummy],
):
"""Test set_params works as expected."""
# Simple case of setting a parameter
@@ -826,7 +824,7 @@ def test_set_params(
def test_set_params_raises_error_non_existent_param(
fixture_class_parent_instance: Parent,
- fixture_composition_dummy: Type[CompositionDummy],
+ fixture_composition_dummy: type[CompositionDummy],
):
"""Test set_params raises an error when passed a non-existent parameter name."""
# non-existing parameter in svc
@@ -843,7 +841,7 @@ def test_set_params_raises_error_non_existent_param(
def test_set_params_raises_error_non_interface_composite(
fixture_class_instance_no_param_interface: NoParamInterface,
- fixture_composition_dummy: Type[CompositionDummy],
+ fixture_composition_dummy: type[CompositionDummy],
):
"""Test set_params raises error when setting param of non-conforming composite."""
# When a composite is made up of a class that doesn't have the BaseObject
@@ -870,7 +868,7 @@ def __init__(self, param=5):
def test_set_params_with_no_param_to_set_returns_object(
- fixture_class_parent: Type[Parent],
+ fixture_class_parent: type[Parent],
):
"""Test set_params correctly returns self when no parameters are set."""
base_obj = fixture_class_parent()
@@ -908,19 +906,19 @@ def test_clone_2(fixture_class_parent_instance: Parent):
def test_clone_raises_error_for_nonconforming_objects(
- fixture_invalid_init: Type[InvalidInitSignatureTester],
- fixture_buggy: Type[Buggy],
- fixture_modify_param: Type[ModifyParam],
+ fixture_invalid_init: type[InvalidInitSignatureTester],
+ fixture_buggy: type[Buggy],
+ fixture_modify_param: type[ModifyParam],
):
"""Test that clone raises an error on nonconforming BaseObjects."""
buggy = fixture_buggy()
- buggy.set_config(**{"check_clone": True})
+ buggy.set_config(check_clone=True)
buggy.a = 2
with pytest.raises(RuntimeError):
buggy.clone()
varg_obj = fixture_invalid_init(a=7)
- varg_obj.set_config(**{"check_clone": True})
+ varg_obj.set_config(check_clone=True)
with pytest.raises(RuntimeError):
varg_obj.clone()
@@ -939,17 +937,17 @@ def test_config_after_clone_tags(clone_config):
"""Test clone also clones config works as expected."""
class TestClass(BaseObject):
- _tags = {"some_tag": True, "another_tag": 37}
- _config = {"check_clone": 0}
+ _tags: ClassVar[dict] = {"some_tag": True, "another_tag": 37}
+ _config: ClassVar[dict] = {"check_clone": 0}
test_obj = TestClass()
- test_obj.set_config(**{"check_clone": 42, "foo": "bar"})
+ test_obj.set_config(check_clone=42, foo="bar")
if not clone_config:
# if clone_config config is set to False:
# config key check_clone should be default, 0
# the new config key foo should not be present
- test_obj.set_config(**{"clone_config": False})
+ test_obj.set_config(clone_config=False)
expected = 0
else:
# if clone_config config is set to True:
@@ -959,14 +957,14 @@ class TestClass(BaseObject):
test_obj_clone = test_obj.clone()
- assert "check_clone" in test_obj_clone.get_config().keys()
+ assert "check_clone" in test_obj_clone.get_config()
assert test_obj_clone.get_config()["check_clone"] == expected
if clone_config:
- assert "foo" in test_obj_clone.get_config().keys()
+ assert "foo" in test_obj_clone.get_config()
assert test_obj_clone.get_config()["foo"] == "bar"
else:
- assert "foo" not in test_obj_clone.get_config().keys()
+ assert "foo" not in test_obj_clone.get_config()
@pytest.mark.parametrize("clone_config", [True, False])
@@ -974,7 +972,7 @@ def test_nested_config_after_clone_tags(clone_config):
"""Test clone also clones config of nested objects."""
class TestClass(BaseObject):
- _config = {"check_clone": 0}
+ _config: ClassVar[dict] = {"check_clone": 0}
class TestNestedClass(BaseObject):
def __init__(self, obj, obj_iterable):
@@ -982,16 +980,16 @@ def __init__(self, obj, obj_iterable):
self.obj_iterable = obj_iterable
test_obj = TestNestedClass(
- obj=TestClass().set_config(**{"check_clone": 1, "foo": "bar"}),
- obj_iterable=[TestClass().set_config(**{"check_clone": 2, "foo": "barz"})],
+ obj=TestClass().set_config(check_clone=1, foo="bar"),
+ obj_iterable=[TestClass().set_config(check_clone=2, foo="barz")],
)
if not clone_config:
# if clone_config config is set to False:
# config key check_clone should be default, 0
# the new config key foo should not be present
- test_obj.obj.set_config(**{"clone_config": False})
- test_obj.obj_iterable[0].set_config(**{"clone_config": False})
+ test_obj.obj.set_config(clone_config=False)
+ test_obj.obj_iterable[0].set_config(clone_config=False)
expected_obj = 0
expected_obj_iterable = 0
else:
@@ -1004,19 +1002,19 @@ def __init__(self, obj, obj_iterable):
test_obj_clone = test_obj.clone().obj
test_obj_iterable_clone = test_obj.clone().obj_iterable[0]
- assert "check_clone" in test_obj_clone.get_config().keys()
- assert "check_clone" in test_obj_iterable_clone.get_config().keys()
+ assert "check_clone" in test_obj_clone.get_config()
+ assert "check_clone" in test_obj_iterable_clone.get_config()
assert test_obj_clone.get_config()["check_clone"] == expected_obj
assert test_obj_iterable_clone.get_config()["check_clone"] == expected_obj_iterable
if clone_config:
- assert "foo" in test_obj_clone.get_config().keys()
- assert "foo" in test_obj_iterable_clone.get_config().keys()
+ assert "foo" in test_obj_clone.get_config()
+ assert "foo" in test_obj_iterable_clone.get_config()
assert test_obj_clone.get_config()["foo"] == "bar"
assert test_obj_iterable_clone.get_config()["foo"] == "barz"
else:
- assert "foo" not in test_obj_clone.get_config().keys()
- assert "foo" not in test_obj_iterable_clone.get_config().keys()
+ assert "foo" not in test_obj_clone.get_config()
+ assert "foo" not in test_obj_iterable_clone.get_config()
@pytest.mark.skipif(
@@ -1033,7 +1031,7 @@ def __init__(self, obj, obj_iterable):
],
)
def test_clone_none_and_empty_array_nan_sparse_matrix(
- fixture_class_parent: Type[Parent], c_value
+ fixture_class_parent: type[Parent], c_value
):
from sklearn.base import clone
@@ -1052,7 +1050,7 @@ def test_clone_none_and_empty_array_nan_sparse_matrix(
assert base_obj.c is new_base_obj2.c
-def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
+def test_clone_estimator_types(fixture_class_parent: type[Parent]):
"""Test clone works for parameters that are types rather than instances."""
base_obj = fixture_class_parent(c=fixture_class_parent)
new_base_obj = base_obj.clone()
@@ -1065,7 +1063,7 @@ def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_class_rather_than_instance_raises_error(
- fixture_class_parent: Type[Parent],
+ fixture_class_parent: type[Parent],
):
"""Test clone raises expected error when cloning a class instead of an instance."""
from sklearn.base import clone
@@ -1109,8 +1107,8 @@ def test_clone_sklearn_composite_retains_config():
# Tests of BaseObject pretty printing representation inspired by sklearn
def test_baseobject_repr(
- fixture_class_parent: Type[Parent],
- fixture_composition_dummy: Type[CompositionDummy],
+ fixture_class_parent: type[Parent],
+ fixture_composition_dummy: type[CompositionDummy],
):
"""Test BaseObject repr works as expected."""
# Simple test where all parameters are left at defaults
@@ -1149,16 +1147,16 @@ def test_baseobject_repr(
long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000)
assert len(repr(long_base_obj_repr)) == 535
- named_objs = [(f"Step {i+1}", Child()) for i in range(25)]
+ named_objs = [(f"Step {i + 1}", Child()) for i in range(25)]
base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs)))
assert len(repr(base_comp)) == 1362
def test_baseobject_str(fixture_class_parent_instance: Parent):
"""Test BaseObject string representation works."""
- assert (
- str(fixture_class_parent_instance) == "Parent()"
- ), "String representation of instance not working."
+ assert str(fixture_class_parent_instance) == "Parent()", (
+ "String representation of instance not working."
+ )
# Check that local config works as expected
fixture_class_parent_instance.set_config(print_changed_only=False)
@@ -1200,7 +1198,7 @@ def test_get_test_params(fixture_class_parent_instance: Parent):
def test_get_test_params_raises_error_when_params_required(
- fixture_required_param: Type[RequiredParam],
+ fixture_required_param: type[RequiredParam],
):
"""Test get_test_params raises an error when parameters are required."""
with pytest.raises(ValueError):
@@ -1208,7 +1206,7 @@ def test_get_test_params_raises_error_when_params_required(
def test_create_test_instance(
- fixture_class_parent: Type[Parent], fixture_class_parent_instance: Parent
+ fixture_class_parent: type[Parent], fixture_class_parent_instance: Parent
):
"""Test first that create_test_instance logic works."""
base_obj = fixture_class_parent.create_test_instance()
@@ -1275,9 +1273,9 @@ def test_has_implementation_of(
class ConfigTester(BaseObject):
- _config = {"foo_config": 42, "bar": "a"}
+ _config: ClassVar[dict] = {"foo_config": 42, "bar": "a"}
- clsvar = 210
+ clsvar: ClassVar[int] = 210
def __init__(self, a, b=42):
self.a = a
@@ -1286,9 +1284,9 @@ def __init__(self, a, b=42):
class AnotherConfigTester(BaseObject):
- _config = {"print_changed_only": False, "bar": "a"}
+ _config: ClassVar[dict] = {"print_changed_only": False, "bar": "a"}
- clsvar = 210
+ clsvar: ClassVar[int] = 210
def __init__(self, a, b=42):
self.a = a
@@ -1382,9 +1380,9 @@ def test_get_set_config():
"""Tests get_config and set_config methods."""
class _TestConfig(BaseObject):
- _config = {"foo_config": 42, "bar": "a"}
+ _config: ClassVar[dict] = {"foo_config": 42, "bar": "a"}
- clsvar = 210
+ clsvar: ClassVar[int] = 210
def __init__(self, a, b=42):
self.a = a
diff --git a/skbase/tests/test_baseestimator.py b/skbase/tests/test_baseestimator.py
index 606b71bc..8851978b 100644
--- a/skbase/tests/test_baseestimator.py
+++ b/skbase/tests/test_baseestimator.py
@@ -49,18 +49,18 @@ def test_has_is_fitted(fixture_estimator_instance):
"""Test BaseEstimator has `is_fitted` property."""
has_private_is_fitted = hasattr(fixture_estimator_instance, "_is_fitted")
has_is_fitted = hasattr(fixture_estimator_instance, "is_fitted")
- assert (
- has_private_is_fitted and has_is_fitted
- ), "BaseEstimator does not have `is_fitted` property;"
+ assert has_private_is_fitted and has_is_fitted, (
+ "BaseEstimator does not have `is_fitted` property;"
+ )
def test_has_check_is_fitted(fixture_estimator_instance):
"""Test BaseEstimator has `check_is_fitted` method."""
has_check_is_fitted = hasattr(fixture_estimator_instance, "check_is_fitted")
is_method = inspect.ismethod(fixture_estimator_instance.check_is_fitted)
- assert (
- has_check_is_fitted and is_method
- ), "`BaseEstimator` does not have `check_is_fitted` method."
+ assert has_check_is_fitted and is_method, (
+ "`BaseEstimator` does not have `check_is_fitted` method."
+ )
def test_is_fitted(fixture_estimator_instance):
@@ -68,9 +68,9 @@ def test_is_fitted(fixture_estimator_instance):
expected_value_unfitted = (
fixture_estimator_instance.is_fitted == fixture_estimator_instance._is_fitted
)
- assert (
- expected_value_unfitted
- ), "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value."
+ assert expected_value_unfitted, (
+ "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value."
+ )
def test_check_is_fitted_raises_error_when_unfitted(fixture_estimator_instance):
diff --git a/skbase/tests/test_exceptions.py b/skbase/tests/test_exceptions.py
index fb83a509..db706a1c 100644
--- a/skbase/tests/test_exceptions.py
+++ b/skbase/tests/test_exceptions.py
@@ -6,13 +6,11 @@
test_exceptions_raise_error - Test that skbase exceptions raise expected error.
"""
-from typing import List
-
import pytest
from skbase._exceptions import FixtureGenerationError, NotFittedError
-__author__: List[str] = ["RNKuhns"]
+__author__: list[str] = ["RNKuhns"]
ALL_EXCEPTIONS = [FixtureGenerationError, NotFittedError]
diff --git a/skbase/tests/test_lookup.py b/skbase/tests/test_lookup.py
index bb985b10..c3de3a35 100644
--- a/skbase/tests/test_lookup.py
+++ b/skbase/tests/test_lookup.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
"""Tests for skbase.lookup utilities."""
import importlib
@@ -14,7 +15,7 @@ def test_all_objects_returns_class_name_for_alias(tmp_path, monkeypatch):
# create a tmp module to test all_objects behaviour
(root / "__init__.py").write_text(
- "from .module import AliasName\n" "__all__ = ['AliasName']\n"
+ "from .module import AliasName\n__all__ = ['AliasName']\n"
)
(root / "module.py").write_text(
"from skbase.base import BaseObject\n\n"
diff --git a/skbase/tests/test_lookup_metaclasses.py b/skbase/tests/test_lookup_metaclasses.py
index 8608de48..9d4adb39 100644
--- a/skbase/tests/test_lookup_metaclasses.py
+++ b/skbase/tests/test_lookup_metaclasses.py
@@ -3,13 +3,12 @@
import importlib
import inspect
-from typing import List
from skbase.lookup import get_package_metadata
from skbase.lookup._lookup import _get_members_uw
from skbase.utils.dependencies._import import CommonMagicMeta, MagicAttribute
-__author__: List[str] = ["SimonBlanke"]
+__author__: list[str] = ["SimonBlanke"]
def test_get_members_uw_discovers_metaclass_classes():
diff --git a/skbase/tests/test_meta.py b/skbase/tests/test_meta.py
index e5d1a631..19934807 100644
--- a/skbase/tests/test_meta.py
+++ b/skbase/tests/test_meta.py
@@ -100,18 +100,18 @@ def test_basemetaestimator_has_is_fitted(fixture_metaestimator_instance):
"""Test BaseEstimator has `is_fitted` property."""
has_private_is_fitted = hasattr(fixture_metaestimator_instance, "_is_fitted")
has_is_fitted = hasattr(fixture_metaestimator_instance, "is_fitted")
- assert (
- has_private_is_fitted and has_is_fitted
- ), "`BaseMetaEstimator` does not have `is_fitted` property or `_is_fitted` attr."
+ assert has_private_is_fitted and has_is_fitted, (
+ "`BaseMetaEstimator` does not have `is_fitted` property or `_is_fitted` attr."
+ )
def test_basemetaestimator_has_check_is_fitted(fixture_metaestimator_instance):
"""Test BaseEstimator has `check_is_fitted` method."""
has_check_is_fitted = hasattr(fixture_metaestimator_instance, "check_is_fitted")
is_method = inspect.ismethod(fixture_metaestimator_instance.check_is_fitted)
- assert (
- has_check_is_fitted and is_method
- ), "`BaseMetaEstimator` does not have `check_is_fitted` method."
+ assert has_check_is_fitted and is_method, (
+ "`BaseMetaEstimator` does not have `check_is_fitted` method."
+ )
@pytest.mark.parametrize("is_fitted_value", (True, False))
@@ -122,9 +122,9 @@ def test_basemetaestimator_is_fitted(fixture_metaestimator_instance, is_fitted_v
fixture_metaestimator_instance.is_fitted
== fixture_metaestimator_instance._is_fitted
)
- assert (
- expected_value_unfitted
- ), "`BaseMetaEstimator` property `is_fitted` does not return `_is_fitted` value."
+ assert expected_value_unfitted, (
+ "`BaseMetaEstimator` property `is_fitted` does not return `_is_fitted` value."
+ )
def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted(
@@ -189,12 +189,12 @@ def test_set_params_resets_fitted_state():
meta_obj.set_params(steps=new_steps)
# Fitted state should be gone after set_params
- assert not hasattr(
- meta_obj, "fitted_attr_"
- ), "fitted_attr_ should be removed by reset() during set_params(steps=...)"
- assert not hasattr(
- meta_obj, "another_fitted_"
- ), "another_fitted_ should be removed by reset() during set_params(steps=...)"
+ assert not hasattr(meta_obj, "fitted_attr_"), (
+ "fitted_attr_ should be removed by reset() during set_params(steps=...)"
+ )
+ assert not hasattr(meta_obj, "another_fitted_"), (
+ "another_fitted_ should be removed by reset() during set_params(steps=...)"
+ )
# Test 2: Replacing individual step should also trigger reset
meta_obj = MetaObjectTester(steps=steps)
@@ -202,6 +202,6 @@ def test_set_params_resets_fitted_state():
meta_obj.set_params(foo=ComponentDummy(77))
- assert not hasattr(
- meta_obj, "fitted_attr_"
- ), "fitted_attr_ should be removed by reset() during set_params(foo=...)"
+ assert not hasattr(meta_obj, "fitted_attr_"), (
+ "fitted_attr_ should be removed by reset() during set_params(foo=...)"
+ )
diff --git a/skbase/tests/test_tagaliaser.py b/skbase/tests/test_tagaliaser.py
index ce3378a2..53145c03 100644
--- a/skbase/tests/test_tagaliaser.py
+++ b/skbase/tests/test_tagaliaser.py
@@ -2,6 +2,7 @@
"""Tests the aliasing logic in the Tag Aliaser."""
import re
+from typing import ClassVar
import pytest
@@ -12,19 +13,19 @@
class AliaserTestClass(_TagAliaserMixin, _BaseObject):
"""Class for testing tag aliasing logic."""
- _tags = {
+ _tags: ClassVar[dict] = {
"new_tag_1": "new_tag_1_value",
"old_tag_1": "old_tag_1_value",
"new_tag_2": "new_tag_2_value",
"old_tag_3": "old_tag_3_value",
}
- alias_dict = {
+ alias_dict: ClassVar[dict] = {
"old_tag_1": "new_tag_1",
"old_tag_2": "new_tag_2",
"old_tag_3": "new_tag_3",
}
- deprecate_dict = {
+ deprecate_dict: ClassVar[dict] = {
"old_tag_1": "42.0.0",
"old_tag_2": "84.0.0",
"old_tag_3": "126.0.0",
diff --git a/skbase/utils/__init__.py b/skbase/utils/__init__.py
index 529f53a5..5788deec 100644
--- a/skbase/utils/__init__.py
+++ b/skbase/utils/__init__.py
@@ -3,8 +3,6 @@
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Utility functionality used through `skbase`."""
-from typing import List
-
from skbase.utils._iter import make_strings_unique
from skbase.utils._nested_iter import flatten, is_flat, unflat_len, unflatten
from skbase.utils._utils import subset_dict_keys
@@ -16,8 +14,8 @@
set_random_state,
)
-__author__: List[str] = ["RNKuhns", "fkiraly"]
-__all__: List[str] = [
+__author__: list[str] = ["RNKuhns", "fkiraly"]
+__all__: list[str] = [
"check_random_state",
"deep_equals",
"flatten",
diff --git a/skbase/utils/_iter.py b/skbase/utils/_iter.py
index 5d31f459..23704080 100644
--- a/skbase/utils/_iter.py
+++ b/skbase/utils/_iter.py
@@ -11,9 +11,9 @@
__author__ = ["fkiraly", "RNKuhns"]
__all__ = [
- "_scalar_to_seq",
- "_remove_type_text",
"_format_seq_to_str",
+ "_remove_type_text",
+ "_scalar_to_seq",
"make_strings_unique",
]
@@ -61,15 +61,12 @@ def _scalar_to_seq(scalar, sequence_type=None):
# We'll treat str like regular scalar and not a sequence
if isinstance(scalar, Sequence) and not isinstance(scalar, str):
return scalar
- elif sequence_type is None:
+ if sequence_type is None:
return (scalar,)
- elif issubclass(sequence_type, Sequence) and sequence_type != Sequence:
+ if issubclass(sequence_type, Sequence) and sequence_type != Sequence:
# Note calling (scalar,) is done to avoid str unpacking
return sequence_type((scalar,)) # type: ignore
- else:
- raise ValueError(
- "`sequence_type` must be a subclass of collections.abc.Sequence."
- )
+ raise ValueError("`sequence_type` must be a subclass of collections.abc.Sequence.")
def _remove_type_text(input_):
@@ -80,8 +77,7 @@ def _remove_type_text(input_):
m = re.match("^$", input_)
if m:
return m[1]
- else:
- return input_
+ return input_
def _format_seq_to_str(seq, sep=", ", last_sep=None, remove_type_text=True):
@@ -134,9 +130,9 @@ def _format_seq_to_str(seq, sep=", ", last_sep=None, remove_type_text=True):
if isinstance(seq, str):
return seq
# Allow casting of scalars to strings
- elif isinstance(seq, (int, float, bool, type)):
+ if isinstance(seq, (int, float, bool, type)):
return _remove_type_text(seq)
- elif not isinstance(seq, Sequence):
+ if not isinstance(seq, Sequence):
raise TypeError(
"`seq` must be a sequence or scalar str, int, float, bool or class."
)
diff --git a/skbase/utils/_nested_iter.py b/skbase/utils/_nested_iter.py
index 055e539f..7fb94c6d 100644
--- a/skbase/utils/_nested_iter.py
+++ b/skbase/utils/_nested_iter.py
@@ -4,13 +4,12 @@
"""Functionality for working with nested sequences."""
import collections
-from typing import List
-__author__: List[str] = ["RNKuhns", "fkiraly"]
-__all__: List[str] = [
+__author__: list[str] = ["RNKuhns", "fkiraly"]
+__all__: list[str] = [
+ "_remove_single",
"flatten",
"is_flat",
- "_remove_single",
"unflat_len",
"unflatten",
]
@@ -42,8 +41,7 @@ def _remove_single(x):
"""
if len(x) == 1:
return x[0]
- else:
- return x
+ return x
def flatten(obj):
@@ -73,8 +71,7 @@ def flatten(obj):
obj, (collections.abc.Iterable, collections.abc.Sequence)
) or isinstance(obj, str):
return [obj]
- else:
- return type(obj)([y for x in obj for y in flatten(x)])
+ return type(obj)([y for x in obj for y in flatten(x)])
def unflatten(obj, template):
@@ -109,7 +106,7 @@ def unflatten(obj, template):
ls = [unflat_len(x) for x in template]
for i in range(1, len(ls)):
ls[i] += ls[i - 1]
- ls = [0] + ls
+ ls = [0, *ls]
res = [unflatten(obj[ls[i] : ls[i + 1]], template[i]) for i in range(len(ls) - 1)]
@@ -146,8 +143,7 @@ def unflat_len(obj):
obj, (collections.abc.Iterable, collections.abc.Sequence)
) or isinstance(obj, str):
return 1
- else:
- return sum([unflat_len(x) for x in obj])
+ return sum([unflat_len(x) for x in obj])
def is_flat(obj):
diff --git a/skbase/utils/_utils.py b/skbase/utils/_utils.py
index ba4eca26..9b402bd9 100644
--- a/skbase/utils/_utils.py
+++ b/skbase/utils/_utils.py
@@ -3,16 +3,17 @@
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Functionality for working with sequences."""
-from typing import Any, Iterable, List, MutableMapping, Optional, Union
+from collections.abc import Iterable, MutableMapping
+from typing import Any
-__author__: List[str] = ["RNKuhns"]
-__all__: List[str] = ["subset_dict_keys"]
+__author__: list[str] = ["RNKuhns"]
+__all__: list[str] = ["subset_dict_keys"]
def subset_dict_keys(
input_dict: MutableMapping[Any, Any],
- keys: Union[Iterable, int, float, bool, str, type],
- prefix: Optional[str] = None,
+ keys: Iterable | int | float | bool | str | type,
+ prefix: str | None = None,
remove_prefix: bool = True,
):
"""Subset dictionary so it only contains specified keys.
@@ -76,17 +77,13 @@ def rem_prefix(x):
return x[len(prefix__) :]
# The way this is used below, this else shouldn't really execute
# But its here for completeness in case something goes wrong
- else:
- return x # pragma: no cover
+ return x # pragma: no cover
# Handle passage of certain scalar values
if isinstance(keys, (str, float, int, bool, type)):
keys = [keys]
- if prefix is not None:
- keys = [f"{prefix}__{key}" for key in keys]
- else:
- keys = list(keys)
+ keys = [f"{prefix}__{key}" for key in keys] if prefix is not None else list(keys)
subsetted_dict = {rem_prefix(k): v for k, v in input_dict.items() if k in keys}
return subsetted_dict
diff --git a/skbase/utils/deep_equals/_common.py b/skbase/utils/deep_equals/_common.py
index bc486f23..ba6d1551 100644
--- a/skbase/utils/deep_equals/_common.py
+++ b/skbase/utils/deep_equals/_common.py
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
"""Common utility functions for skbase.utils.deep_equals."""
+from typing import Any
-def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False):
+
+def _ret(is_equal, msg="", string_arguments: list[Any] | None = None, return_msg=False):
"""Return is_equal and msg, formatted with string_arguments if return_msg=True.
Parameters
@@ -28,8 +30,7 @@ def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False):
elif isinstance(string_arguments, (list, tuple)) and len(string_arguments) > 0:
msg = msg.format(*string_arguments)
return is_equal, msg
- else:
- return is_equal
+ return is_equal
def _make_ret(return_msg):
diff --git a/skbase/utils/deep_equals/_deep_equals.py b/skbase/utils/deep_equals/_deep_equals.py
index cf73a5a7..8b98c922 100644
--- a/skbase/utils/deep_equals/_deep_equals.py
+++ b/skbase/utils/deep_equals/_deep_equals.py
@@ -112,8 +112,7 @@ def _is_npnan(x):
return isinstance(x, float) and np.isnan(x)
- else:
- return False
+ return False
def _coerce_list(x):
@@ -131,8 +130,7 @@ def _numpy_equals_plugin(x, y, return_msg=False, deep_equals=None):
if not numpy_available or not _is_npndarray(x):
return None
- else:
- import numpy as np
+ import numpy as np
ret = _make_ret(return_msg)
@@ -144,15 +142,14 @@ def _numpy_equals_plugin(x, y, return_msg=False, deep_equals=None):
return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
if x.dtype == "str":
return ret(np.array_equal(x, y), ".values")
- elif x.dtype == "object":
+ if x.dtype == "object":
x_flat = x.flatten()
y_flat = y.flatten()
for i in range(len(x_flat)):
is_equal, msg = deep_equals(x_flat[i], y_flat[i], return_msg=True)
return ret(is_equal, f"[{i}]" + msg)
return ret(True, "") # catches len(x_flat) == 0
- else:
- return ret(np.array_equal(x, y, equal_nan=True), ".values")
+ return ret(np.array_equal(x, y, equal_nan=True), ".values")
def _pandas_equals_plugin(x, y, return_msg=False, deep_equals=None):
@@ -189,9 +186,8 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
else:
msg = ""
return ret(index_equal and values_equal, msg)
- else:
- return ret(x.equals(y), ".series_equals, x = {} != y = {}", [x, y])
- elif isinstance(x, pd.DataFrame):
+ return ret(x.equals(y), ".series_equals, x = {} != y = {}", [x, y])
+ if isinstance(x, pd.DataFrame):
# check column names for equality
if not x.columns.equals(y.columns):
return ret(
@@ -209,48 +205,56 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
# and would upset the type check, e.g., RangeIndex(2) vs Index([0, 1])
xix = x.index
yix = y.index
- if hasattr(xix, "dtype") and hasattr(xix, "dtype"):
- if not xix.dtype == yix.dtype:
+ if hasattr(xix, "dtype") and hasattr(xix, "dtype") and xix.dtype != yix.dtype:
+ return ret(
+ False,
+ ".index.dtype, x.index.dtype = {} != y.index.dtype = {}",
+ [xix.dtype, yix.dtype],
+ )
+ if (
+ hasattr(xix, "dtypes")
+ and hasattr(yix, "dtypes")
+ and not x.dtypes.equals(y.dtypes)
+ ):
+ return ret(
+ False,
+ ".index.dtypes, x.dtypes = {} != y.index.dtypes = {}",
+ [xix.dtypes, yix.dtypes],
+ )
+ ix_eq = xix.equals(yix)
+ if not ix_eq:
+ if len(xix) != len(yix):
return ret(
False,
- ".index.dtype, x.index.dtype = {} != y.index.dtype = {}",
- [xix.dtype, yix.dtype],
+ ".index.len, x.index.len = {} != y.index.len = {}",
+ [len(xix), len(yix)],
)
- if hasattr(xix, "dtypes") and hasattr(yix, "dtypes"):
- if not x.dtypes.equals(y.dtypes):
+ if hasattr(xix, "name") and hasattr(yix, "name") and xix.name != yix.name:
return ret(
False,
- ".index.dtypes, x.dtypes = {} != y.index.dtypes = {}",
- [xix.dtypes, yix.dtypes],
+ ".index.name, x.index.name = {} != y.index.name = {}",
+ [xix.name, yix.name],
)
- ix_eq = xix.equals(yix)
- if not ix_eq:
- if not len(xix) == len(yix):
+ if (
+ hasattr(xix, "names")
+ and hasattr(yix, "names")
+ and len(xix.names) != len(yix.names)
+ ):
return ret(
False,
- ".index.len, x.index.len = {} != y.index.len = {}",
- [len(xix), len(yix)],
+ ".index.names, x.index.names = {} != y.index.name = {}",
+ [xix.names, yix.names],
+ )
+ if (
+ hasattr(xix, "names")
+ and hasattr(yix, "names")
+ and not np.all(xix.names == yix.names)
+ ):
+ return ret(
+ False,
+ ".index.names, x.index.names = {} != y.index.name = {}",
+ [xix.names, yix.names],
)
- if hasattr(xix, "name") and hasattr(yix, "name"):
- if not xix.name == yix.name:
- return ret(
- False,
- ".index.name, x.index.name = {} != y.index.name = {}",
- [xix.name, yix.name],
- )
- if hasattr(xix, "names") and hasattr(yix, "names"):
- if not len(xix.names) == len(yix.names):
- return ret(
- False,
- ".index.names, x.index.names = {} != y.index.name = {}",
- [xix.names, yix.names],
- )
- if not np.all(xix.names == yix.names):
- return ret(
- False,
- ".index.names, x.index.names = {} != y.index.name = {}",
- [xix.names, yix.names],
- )
elts_eq = np.all(xix == yix)
return ret(elts_eq, ".index.equals, x = {} != y = {}", [xix, yix])
# if columns, dtypes are equal and at least one is object, recurse over Series
@@ -260,24 +264,26 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
if not is_equal:
return ret(False, f"[{c!r}]" + msg)
return ret(True, "")
- else:
- return ret(x.equals(y), ".df_equals, x = {} != y = {}", [x, y])
- elif isinstance(x, pd.Index):
- if hasattr(x, "dtype") and hasattr(y, "dtype"):
- if not x.dtype == y.dtype:
- return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
- if hasattr(x, "dtypes") and hasattr(y, "dtypes"):
- if not x.dtypes.equals(y.dtypes):
- return ret(
- False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
- )
+ return ret(x.equals(y), ".df_equals, x = {} != y = {}", [x, y])
+ if isinstance(x, pd.Index):
+ if hasattr(x, "dtype") and hasattr(y, "dtype") and x.dtype != y.dtype:
+ return ret(
+ False,
+ f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}",
+ )
+ if (
+ hasattr(x, "dtypes")
+ and hasattr(y, "dtypes")
+ and not x.dtypes.equals(y.dtypes)
+ ):
+ msg = f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
+ return ret(False, msg)
return ret(x.equals(y), "index.equals, x = {} != y = {}", [x, y])
- else:
- raise RuntimeError(
- f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)},"
- f" type(y)={type(y)}, both should be one of "
- "pd.Series, pd.DataFrame, pd.Index"
- )
+ raise RuntimeError(
+ f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)},"
+ f" type(y)={type(y)}, both should be one of "
+ "pd.Series, pd.DataFrame, pd.Index"
+ )
def _tuple_equals(x, y, return_msg=False, deep_equals=None):
@@ -522,12 +528,12 @@ def deep_equals_curried(x, y, return_msg=False):
if isinstance(x, (list, tuple)):
dec = deep_equals_curried
return ret(*_tuple_equals(x, y, return_msg=True, deep_equals=dec))
- elif isinstance(x, dict):
+ if isinstance(x, dict):
dec = deep_equals_curried
return ret(*_dict_equals(x, y, return_msg=True, deep_equals=dec))
- elif _is_npnan(x):
+ if _is_npnan(x):
return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}")
- elif isclass(x):
+ if isclass(x):
return ret(x == y, f".class, x={x.__name__} != y={y.__name__}")
if plugins is not None:
@@ -589,8 +595,7 @@ def _safe_any_unequal(x, y):
any_un = any(unequal)
if isinstance(any_un, bool):
return any_un
- else:
- return False
+ return False
except Exception:
return False
@@ -600,8 +605,7 @@ def _safe_any_unequal(x, y):
any_un = np.any(x != y) or np.any(_coerce_list(x != y))
if isinstance(any_un, bool) or any_un.dtype == "bool":
return any_un
- else:
- return False
+ return False
except Exception:
return False
diff --git a/skbase/utils/dependencies/__init__.py b/skbase/utils/dependencies/__init__.py
index 11bafbdc..0cfc3eed 100644
--- a/skbase/utils/dependencies/__init__.py
+++ b/skbase/utils/dependencies/__init__.py
@@ -11,8 +11,8 @@
from skbase.utils.dependencies._import import _safe_import
__all__ = [
+ "_check_estimator_deps",
"_check_python_version",
"_check_soft_dependencies",
- "_check_estimator_deps",
"_safe_import",
]
diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py
index ee29200a..6e85e8e0 100644
--- a/skbase/utils/dependencies/_dependencies.py
+++ b/skbase/utils/dependencies/_dependencies.py
@@ -221,8 +221,7 @@ def _is_version_req_satisfied(pkg_env_version, pkg_version_req):
return False
if pkg_version_req != SpecifierSet(""):
return pkg_env_version in pkg_version_req
- else:
- return True
+ return True
pkg_version_reqs = []
pkg_env_versions = []
@@ -275,7 +274,7 @@ def _is_version_req_satisfied(pkg_env_version, pkg_version_req):
# now we check compatibility with the version specifier if non-empty
if not any(req_sat):
- zp = zip(package_req, pkg_names, pkg_env_versions, req_sat)
+ zp = zip(package_req, pkg_names, pkg_env_versions, req_sat, strict=False)
reqs_not_satisfied = [x for x in zp if x[3] is False]
actual_vers = [f"{x[1]} {x[2]}" for x in reqs_not_satisfied]
pkg_env_version_str = ", ".join(actual_vers)
@@ -465,10 +464,7 @@ def _check_python_version(
return True
# now we know that est_version is not compatible with sys_version
- if isclass(obj):
- class_name = obj.__name__
- else:
- class_name = type(obj).__name__
+ class_name = obj.__name__ if isclass(obj) else type(obj).__name__
if not isinstance(msg, str):
msg = (
@@ -542,10 +538,7 @@ def _check_env_marker(obj, package=None, msg=None, severity="error"):
return True
# now we know that est_marker is not compatible with the environment
- if isclass(obj):
- class_name = obj.__name__
- else:
- class_name = type(obj).__name__
+ class_name = obj.__name__ if isclass(obj) else type(obj).__name__
if not isinstance(msg, str):
msg = (
@@ -671,10 +664,7 @@ def _normalize_version(version):
"""
if version is None:
return None
- if not isinstance(version, Version):
- version_obj = Version(version)
- else:
- version_obj = version
+ version_obj = Version(version) if not isinstance(version, Version) else version
normalized_version = f"{version_obj.major}.{version_obj.minor}.{version_obj.micro}"
return normalized_version
@@ -727,13 +717,13 @@ def _raise_at_severity(
if severity == "error":
raise exception_type(msg)
- elif severity == "warning":
+ if severity == "warning":
warnings.warn(msg, category=warning_type, stacklevel=stacklevel)
elif severity == "none":
- return None
+ return
else:
raise ValueError(
f"Error in calling {caller}, severity "
f'argument must be "error", "warning", or "none", found {severity!r}.'
)
- return None
+ return
diff --git a/skbase/utils/dependencies/_import.py b/skbase/utils/dependencies/_import.py
index 9ba41f2b..7162a449 100644
--- a/skbase/utils/dependencies/_import.py
+++ b/skbase/utils/dependencies/_import.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
"""Import a module/class, return a Mock object or None if import fails."""
import importlib
@@ -118,19 +119,18 @@ def _safe_import(import_path, pkg_name=None, condition=True, return_object="Magi
"other than ImportError or AttributeError:"
f": {e}. ",
ImportWarning,
+ stacklevel=2,
)
- pass
if return_object == "MagicMock":
mock_obj = _create_mock_class(obj_name)
return mock_obj
- elif return_object == "None":
+ if return_object == "None":
return None
- else:
- raise RuntimeError(
- "Error in skbase _safe_import, return_object argument must be "
- f"'MagicMock' or 'None', but found {return_object}"
- )
+ raise RuntimeError(
+ "Error in skbase _safe_import, return_object argument must be "
+ f"'MagicMock' or 'None', but found {return_object}"
+ )
class CommonMagicMeta(type):
diff --git a/skbase/utils/dependencies/tests/test_check_dependencies.py b/skbase/utils/dependencies/tests/test_check_dependencies.py
index fd0d78cf..f384bdd8 100644
--- a/skbase/utils/dependencies/tests/test_check_dependencies.py
+++ b/skbase/utils/dependencies/tests/test_check_dependencies.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""Tests for _check_soft_dependencies utility."""
+from typing import ClassVar
from unittest.mock import patch
import pytest
@@ -57,12 +58,12 @@ def test_check_soft_deps():
def test_check_soft_dependencies_nested():
"""Test check_soft_dependencies with ."""
- ALWAYS_INSTALLED = "pytest" # noqa: N806
- ALWAYS_INSTALLED2 = "numpy" # noqa: N806
- ALWAYS_INSTALLED_W_V = "pytest>=0.5.0" # noqa: N806
- ALWAYS_INSTALLED_W_V2 = "numpy>=0.1.0" # noqa: N806
- NEVER_INSTALLED = "nonexistent__package_foo_bar" # noqa: N806
- NEVER_INSTALLED_W_V = "pytest<0.1.0" # noqa: N806
+ ALWAYS_INSTALLED = "pytest"
+ ALWAYS_INSTALLED2 = "numpy"
+ ALWAYS_INSTALLED_W_V = "pytest>=0.5.0"
+ ALWAYS_INSTALLED_W_V2 = "numpy>=0.1.0"
+ NEVER_INSTALLED = "nonexistent__package_foo_bar"
+ NEVER_INSTALLED_W_V = "pytest<0.1.0"
# Test that the function does not raise an error when all dependencies are installed
_check_soft_dependencies(ALWAYS_INSTALLED)
@@ -155,7 +156,7 @@ def test_check_python_version(
mock_sys.version = "3.8.1"
class DummyObjectClass(BaseObject):
- _tags = {
+ _tags: ClassVar[dict] = {
"python_version": ">=3.7.1", # PEP 440 version specifier, e.g., ">=3.7"
"python_dependencies": None, # PEP 440 dependency strs, e.g., "pandas>=1.0"
"env_marker": None, # PEP 508 environment marker, e.g., "os_name=='posix'"
diff --git a/skbase/utils/dependencies/tests/test_safe_import.py b/skbase/utils/dependencies/tests/test_safe_import.py
index 152fc2d6..7004f4bf 100644
--- a/skbase/utils/dependencies/tests/test_safe_import.py
+++ b/skbase/utils/dependencies/tests/test_safe_import.py
@@ -64,8 +64,8 @@ def test_import_existing_object():
def test_multiple_inheritance_from_mock():
"""Test multiple inheritance from dynamic MagicMock."""
- Class1 = _safe_import("foobar.foo.FooBar") # noqa: N806
- Class2 = _safe_import("barfoobar.BarFooBar") # noqa: N806
+ Class1 = _safe_import("foobar.foo.FooBar")
+ Class2 = _safe_import("barfoobar.BarFooBar")
class NewClass(Class1, Class2):
"""This should not trigger an error.
@@ -75,8 +75,6 @@ class NewClass(Class1, Class2):
identical to MagicMock.
"""
- pass
-
def test_soft_dependency_chains():
"""Test soft dependency chains.
diff --git a/skbase/utils/git_diff.py b/skbase/utils/git_diff.py
index f7ab1538..8b28c7bc 100644
--- a/skbase/utils/git_diff.py
+++ b/skbase/utils/git_diff.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
"""Git related utilities to identify changed modules."""
__author__ = ["fkiraly"]
@@ -6,7 +7,6 @@
import importlib.util
import inspect
from functools import lru_cache
-from typing import List
@lru_cache
@@ -53,7 +53,7 @@ def _get_path_from_module(module_str):
raise ImportError(f"Error finding module {module_str!r}") from e
-def _run_git_diff(cmd: List[str]) -> str:
+def _run_git_diff(cmd: list[str]) -> str:
# Safety note: cmd is always a hard-coded list constructed in this module only.
# No user input is ever injected → safe from shell injection.
result = __import__("subprocess").run( # nosec B404, B603
@@ -165,9 +165,9 @@ def _get_packages_with_changed_specs():
packages = []
for line in changed_lines:
- if line.find("'") > line.find('"') and line.find('"') != -1:
- sep = '"'
- elif line.find("'") == -1:
+ if (line.find("'") > line.find('"') and line.find('"') != -1) or line.find(
+ "'"
+ ) == -1:
sep = '"'
else:
sep = "'"
diff --git a/skbase/utils/random_state.py b/skbase/utils/random_state.py
index c655eb50..3faa7335 100644
--- a/skbase/utils/random_state.py
+++ b/skbase/utils/random_state.py
@@ -58,7 +58,7 @@ def set_random_state(estimator, random_state=None, deep=True, root_policy="copy"
keys.append(key)
seeds = sample_dependent_seed(random_state, n_seeds=len(keys))
- to_set = dict(zip(keys, seeds))
+ to_set = dict(zip(keys, seeds, strict=False))
if root_policy == "copy" and "random_state" in to_set:
to_set["random_state"] = random_state_orig
diff --git a/skbase/utils/stderr_mute.py b/skbase/utils/stderr_mute.py
index bab4b119..5626809d 100644
--- a/skbase/utils/stderr_mute.py
+++ b/skbase/utils/stderr_mute.py
@@ -35,7 +35,7 @@ def __enter__(self):
self._stderr = sys.stderr
sys.stderr = io.StringIO()
- def __exit__(self, type, value, traceback): # noqa: A002
+ def __exit__(self, type, value, traceback):
"""Context manager exit point."""
# restore stderr if active
# if not active, nothing needs to be done, since stderr was not replaced
@@ -49,7 +49,7 @@ def __exit__(self, type, value, traceback): # noqa: A002
# return statement not needed as type was None, but included for clarity
return True
- def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
+ def _handle_exit_exceptions(self, type, value, traceback):
"""Handle exceptions raised during __exit__.
Parameters
diff --git a/skbase/utils/stdout_mute.py b/skbase/utils/stdout_mute.py
index 6783b22e..02254d62 100644
--- a/skbase/utils/stdout_mute.py
+++ b/skbase/utils/stdout_mute.py
@@ -35,7 +35,7 @@ def __enter__(self):
self._stdout = sys.stdout
sys.stdout = io.StringIO()
- def __exit__(self, type, value, traceback): # noqa: A002
+ def __exit__(self, type, value, traceback):
"""Context manager exit point."""
# restore stdout if active
# if not active, nothing needs to be done, since stdout was not replaced
@@ -49,7 +49,7 @@ def __exit__(self, type, value, traceback): # noqa: A002
# return statement not needed as type was None, but included for clarity
return True
- def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
+ def _handle_exit_exceptions(self, type, value, traceback):
"""Handle exceptions raised during __exit__.
Parameters
diff --git a/skbase/utils/tests/test_deep_equals.py b/skbase/utils/tests/test_deep_equals.py
index 9ab3d168..7d3f339c 100644
--- a/skbase/utils/tests/test_deep_equals.py
+++ b/skbase/utils/tests/test_deep_equals.py
@@ -12,7 +12,7 @@
EXAMPLES = [
42,
[],
- ((((())))),
+ (()),
[[[[()]]]],
3.5,
4.2,
@@ -122,13 +122,11 @@ def copy_except_if_sklearn(obj):
"""
if not _check_soft_dependencies("scikit-learn", severity="none"):
return deepcopy(obj)
- else:
- from sklearn.base import BaseEstimator
+ from sklearn.base import BaseEstimator
- if isinstance(obj, BaseEstimator):
- return obj
- else:
- return deepcopy(obj)
+ if isinstance(obj, BaseEstimator):
+ return obj
+ return deepcopy(obj)
# Add JAX examples
diff --git a/skbase/utils/tests/test_iter.py b/skbase/utils/tests/test_iter.py
index 329d914e..850dace7 100644
--- a/skbase/utils/tests/test_iter.py
+++ b/skbase/utils/tests/test_iter.py
@@ -75,8 +75,8 @@ def test_format_seq_to_str():
def test_format_seq_to_str_raises():
"""Test _format_seq_to_str raises error when input is unexpected type."""
- with pytest.raises(TypeError, match="`seq` must be a sequence or scalar.*"):
- _format_seq_to_str((c for c in [1, 2, 3]))
+ with pytest.raises(TypeError, match=r"`seq` must be a sequence or scalar.*"):
+ _format_seq_to_str(c for c in [1, 2, 3])
def test_scalar_to_seq_expected_output():
@@ -97,13 +97,13 @@ def test_scalar_to_seq_raises():
"""Test scalar_to_seq raises error when `sequence_type` is unexpected type."""
with pytest.raises(
ValueError,
- match="`sequence_type` must be a subclass of collections.abc.Sequence.",
+ match=r"`sequence_type` must be a subclass of collections.abc.Sequence.",
):
_scalar_to_seq(7, sequence_type=int)
with pytest.raises(
ValueError,
- match="`sequence_type` must be a subclass of collections.abc.Sequence.",
+ match=r"`sequence_type` must be a subclass of collections.abc.Sequence.",
):
_scalar_to_seq(7, sequence_type=dict)
diff --git a/skbase/utils/tests/test_nested_iter.py b/skbase/utils/tests/test_nested_iter.py
index 4ff89875..91058530 100644
--- a/skbase/utils/tests/test_nested_iter.py
+++ b/skbase/utils/tests/test_nested_iter.py
@@ -70,7 +70,7 @@ def test_unflat_len():
assert unflat_len((1, 2)) == 2
assert unflat_len([1, (2, 3), 4, 5]) == 5
assert unflat_len([1, 2, (c for c in (2, 3, 4))]) == 5
- assert unflat_len((c for c in [1, 2, (c for c in (2, 3, 4))])) == 5
+ assert unflat_len(c for c in [1, 2, (c for c in (2, 3, 4))]) == 5
def test_is_flat():
@@ -78,8 +78,8 @@ def test_is_flat():
assert is_flat([1, 2, 3, 4, 5]) is True
assert is_flat([1, (2, 3), 4, 5]) is False
# Check with flat generator
- assert is_flat((c for c in [1, 2, 3])) is True
+ assert is_flat(c for c in [1, 2, 3]) is True
# Check with nested generator
assert is_flat([1, 2, (c for c in (2, 3, 4))]) is False
# Check with generator nested in a generator
- assert is_flat((c for c in [1, 2, (c for c in (2, 3, 4))])) is False
+ assert is_flat(c for c in [1, 2, (c for c in (2, 3, 4))]) is False
diff --git a/skbase/utils/tests/test_random_state.py b/skbase/utils/tests/test_random_state.py
index 844bc6de..f1adc7be 100644
--- a/skbase/utils/tests/test_random_state.py
+++ b/skbase/utils/tests/test_random_state.py
@@ -36,10 +36,7 @@ def set_seed(obj):
return set_random_state(
obj, random_state=42, deep=deep, root_policy=root_policy
)
- else:
- return obj.set_random_state(
- random_state=42, deep=deep, self_policy=root_policy
- )
+ return obj.set_random_state(random_state=42, deep=deep, self_policy=root_policy)
class DummyDummy(BaseObject):
"""Has no random_state attribute."""
@@ -47,7 +44,7 @@ class DummyDummy(BaseObject):
def __init__(self, foo):
self.foo = foo
- super(DummyDummy, self).__init__()
+ super().__init__()
class SeedCompositionDummy(BaseObject):
"""Potentially composite object, for testing."""
@@ -56,7 +53,7 @@ def __init__(self, foo, random_state=None):
self.foo = foo
self.random_state = random_state
- super(SeedCompositionDummy, self).__init__()
+ super().__init__()
simple = SeedCompositionDummy(foo=1, random_state=41)
seedless = DummyDummy(foo=42)
diff --git a/skbase/utils/tests/test_std_mute.py b/skbase/utils/tests/test_std_mute.py
index b248a35b..5665f7f3 100644
--- a/skbase/utils/tests/test_std_mute.py
+++ b/skbase/utils/tests/test_std_mute.py
@@ -23,10 +23,14 @@ def test_std_mute(mute, expected):
stdout_io = io.StringIO()
try:
- with redirect_stderr(stderr_io), redirect_stdout(stdout_io):
- with StderrMute(mute), StdoutMute(mute):
- sys.stdout.write("test stdout")
- sys.stderr.write("test sterr")
- 1 / 0
+ with (
+ redirect_stderr(stderr_io),
+ redirect_stdout(stdout_io),
+ StderrMute(mute),
+ StdoutMute(mute),
+ ):
+ sys.stdout.write("test stdout")
+ sys.stderr.write("test sterr")
+ _ = 1 / 0
except ZeroDivisionError:
assert expected == [stdout_io.getvalue(), stderr_io.getvalue()]
diff --git a/skbase/validate/__init__.py b/skbase/validate/__init__.py
index 97d5b7a5..d5989610 100644
--- a/skbase/validate/__init__.py
+++ b/skbase/validate/__init__.py
@@ -3,8 +3,6 @@
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Tools for validating and comparing BaseObjects and collections of BaseObjects."""
-from typing import List
-
from skbase.validate._named_objects import (
check_sequence_named_objects,
is_named_object_tuple,
@@ -12,8 +10,8 @@
)
from skbase.validate._types import check_sequence, check_type, is_sequence
-__author__: List[str] = ["RNKuhns", "fkiraly"]
-__all__: List[str] = [
+__author__: list[str] = ["RNKuhns", "fkiraly"]
+__all__: list[str] = [
"check_sequence",
"check_sequence_named_objects",
"check_type",
diff --git a/skbase/validate/_named_objects.py b/skbase/validate/_named_objects.py
index 488a832f..3ed76031 100644
--- a/skbase/validate/_named_objects.py
+++ b/skbase/validate/_named_objects.py
@@ -84,9 +84,7 @@ def is_named_object_tuple(obj, object_type=None):
object_type = BaseObject
if not isinstance(obj, tuple) or len(obj) != 2:
return False
- if not isinstance(obj[0], str) or not isinstance(obj[1], object_type):
- return False
- return True
+ return isinstance(obj[0], str) and isinstance(obj[1], object_type)
def is_sequence_named_objects(
diff --git a/skbase/validate/_types.py b/skbase/validate/_types.py
index 21df51b5..a62568aa 100644
--- a/skbase/validate/_types.py
+++ b/skbase/validate/_types.py
@@ -5,19 +5,20 @@
import collections
import inspect
-from typing import Any, List, Optional, Sequence, Tuple, Union
+from collections.abc import Sequence
+from typing import Any
from skbase.utils._iter import _format_seq_to_str, _remove_type_text, _scalar_to_seq
-__author__: List[str] = ["RNKuhns", "fkiraly"]
-__all__: List[str] = ["check_sequence", "check_type", "is_sequence"]
+__author__: list[str] = ["RNKuhns", "fkiraly"]
+__all__: list[str] = ["check_sequence", "check_type", "is_sequence"]
def check_type(
input_: Any,
expected_type: type,
allow_none: bool = False,
- input_name: Optional[str] = None,
+ input_name: str | None = None,
use_subclass: bool = False,
) -> Any:
"""Check the input is the expected type.
@@ -90,30 +91,26 @@ def check_type(
type_check = issubclass if use_subclass else isinstance
if (allow_none and input_ is None) or type_check(input_, expected_type):
return input_
- else:
- chk_msg = "subclass type" if use_subclass else "be type"
- expected_type_str = _remove_type_text(expected_type)
- input_type_str = _remove_type_text(type(input_))
- if allow_none:
- type_msg = f"{expected_type_str} or None"
- else:
- type_msg = f"{expected_type_str}"
- raise TypeError(
- f"`{input_name}` should {chk_msg} {type_msg}, but found {input_type_str}."
- )
+ chk_msg = "subclass type" if use_subclass else "be type"
+ expected_type_str = _remove_type_text(expected_type)
+ input_type_str = _remove_type_text(type(input_))
+ type_msg = f"{expected_type_str} or None" if allow_none else f"{expected_type_str}"
+ raise TypeError(
+ f"`{input_name}` should {chk_msg} {type_msg}, but found {input_type_str}."
+ )
def _convert_scalar_seq_type_input_to_tuple(
- type_input: Optional[Union[type, Tuple[type, ...]]],
- none_default: Optional[type] = None,
- type_input_subclass: Optional[type] = None,
- input_name: str = None,
-) -> Tuple[type, ...]:
+ type_input: type | tuple[type, ...] | None,
+ none_default: type | None = None,
+ type_input_subclass: type | None = None,
+ input_name: str | None = None,
+) -> tuple[type, ...]:
"""Convert input that is scalar or sequence of types to always be a tuple."""
if none_default is None:
none_default = collections.abc.Sequence
- seq_output: Tuple[type, ...]
+ seq_output: tuple[type, ...]
if type_input is None:
seq_output = (none_default,)
# if a sequence of types received as sequence_type, convert to tuple of types
@@ -134,8 +131,8 @@ def _convert_scalar_seq_type_input_to_tuple(
def is_sequence(
input_seq: Any,
- sequence_type: Optional[Union[type, Tuple[type, ...]]] = None,
- element_type: Optional[Union[type, Tuple[type, ...]]] = None,
+ sequence_type: type | tuple[type, ...] | None = None,
+ element_type: type | tuple[type, ...] | None = None,
) -> bool:
"""Indicate if an object is a sequence with optional check of element types.
@@ -223,11 +220,11 @@ def is_sequence(
def check_sequence(
input_seq: Sequence[Any],
- sequence_type: Optional[Union[type, Tuple[type, ...]]] = None,
- element_type: Optional[Union[type, Tuple[type, ...]]] = None,
- coerce_output_type_to: type = None,
+ sequence_type: type | tuple[type, ...] | None = None,
+ element_type: type | tuple[type, ...] | None = None,
+ coerce_output_type_to: type | None = None,
coerce_scalar_input: bool = False,
- sequence_name: str = None,
+ sequence_name: str | None = None,
) -> Sequence[Any]:
"""Check whether an object is a sequence with optional check of element types.
diff --git a/skbase/validate/tests/test_type_validations.py b/skbase/validate/tests/test_type_validations.py
index 0e35b464..260346d8 100644
--- a/skbase/validate/tests/test_type_validations.py
+++ b/skbase/validate/tests/test_type_validations.py
@@ -65,7 +65,7 @@ def test_check_type_output(fixture_estimator_instance, fixture_object_instance):
with pytest.raises(TypeError, match=r"`input` should be type.*"):
check_type(BaseEstimator, expected_type=BaseObject)
- with pytest.raises(TypeError, match="^`input` should be.*"):
+ with pytest.raises(TypeError, match=r"^`input` should be.*"):
check_type("something", expected_type=int, allow_none=True)
# Verify optional use of issubclass instead of isinstance
@@ -80,13 +80,13 @@ def test_check_type_raises_error_if_expected_type_is_wrong_format():
`expected_type` should be a type or tuple of types.
"""
- with pytest.raises(TypeError, match="^`expected_type` should be.*"):
+ with pytest.raises(TypeError, match=r"^`expected_type` should be.*"):
check_type(7, expected_type=11)
- with pytest.raises(TypeError, match="^`expected_type` should be.*"):
+ with pytest.raises(TypeError, match=r"^`expected_type` should be.*"):
check_type(7, expected_type=[int])
- with pytest.raises(TypeError, match="^`expected_type` should be.*"):
+ with pytest.raises(TypeError, match=r"^`expected_type` should be.*"):
check_type(None, expected_type=[int])
@@ -100,7 +100,7 @@ def test_is_sequence_output():
# True for any sequence
assert is_sequence([1, 2, 3]) is True
# But false for generators, since they are iterable but not sequences
- assert is_sequence((c for c in [1, 2, 3])) is False
+ assert is_sequence(c for c in [1, 2, 3]) is False
# Test use of sequence_type restriction
assert is_sequence([1, 2, 3, 4], sequence_type=list) is True
@@ -171,20 +171,20 @@ def test_check_sequence_output():
# But false for generators, since they are iterable but not sequences
with pytest.raises(
TypeError,
- match="Invalid sequence: Input sequence expected to be a a sequence.",
+ match=r"Invalid sequence: Input sequence expected to be a a sequence.",
):
- assert check_sequence((c for c in [1, 2, 3]))
+ assert check_sequence(c for c in [1, 2, 3])
# Test use of sequence_type restriction
assert check_sequence([1, 2, 3, 4], sequence_type=list) == [1, 2, 3, 4]
with pytest.raises(
TypeError,
- match="Invalid sequence: Input sequence expected to be a tuple.",
+ match=r"Invalid sequence: Input sequence expected to be a tuple.",
):
check_sequence([1, 2, 3, 4], sequence_type=tuple)
with pytest.raises(
TypeError,
- match="Invalid sequence: Input sequence expected to be a list.",
+ match=r"Invalid sequence: Input sequence expected to be a list.",
):
check_sequence((1, 2, 3, 4), sequence_type=list)
assert check_sequence((1, 2, 3, 4), sequence_type=tuple) == (1, 2, 3, 4)
@@ -203,22 +203,22 @@ def test_check_sequence_output():
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence([1, 2, 3], element_type=float)
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence([1, 2, 3, 4], sequence_type=tuple, element_type=int)
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence([1, 2, 3, 4], sequence_type=list, element_type=float)
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence([1, 2, 3, 4], sequence_type=tuple, element_type=float)
@@ -231,7 +231,7 @@ def test_check_sequence_output():
assert check_sequence([1, "something", 4.5]) == [1, "something", 4.5]
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence([1, "something", 4.5], element_type=float)
@@ -242,7 +242,7 @@ def test_check_sequence_output():
# Test with 3rd party types works in default way via exact type
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence([1.2, 4.7], element_type=np.float64)
input_seq = [np.float64(1.2), np.float64(4.7)]
@@ -257,7 +257,7 @@ def test_check_sequence_output():
]
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence([np.nan, 4], element_type=int)
@@ -300,7 +300,7 @@ def test_check_sequence_scalar_input_coercion():
# Still raise an error if element type is not expected
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence(
7,
@@ -323,12 +323,12 @@ def test_check_sequence_with_seq_of_class_and_instance_input(
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
check_sequence(list(input_seq), sequence_type=tuple, element_type=BaseObject)
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
# Verify we detect when list elements are not instances of valid class type
check_sequence([1, 2, 3], element_type=BaseObject)
@@ -341,7 +341,7 @@ def test_check_sequence_with_seq_of_class_and_instance_input(
) == list(input_seq)
with pytest.raises(
TypeError,
- match="Invalid sequence: .*",
+ match=r"Invalid sequence: .*",
):
# Verify we detect when list elements are not instances of valid types
check_sequence([1, 2, 3], element_type=BaseObject)