diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst
index 03026bdf..9508d39e 100644
--- a/docs/source/api_reference.rst
+++ b/docs/source/api_reference.rst
@@ -28,6 +28,27 @@ Base Classes
BaseObject
BaseEstimator
+.. _global_config:
+
+Configure ``skbase``
+====================
+
+.. automodule:: skbase.config
+ :no-members:
+ :no-inherited-members:
+
+.. currentmodule:: skbase.config
+
+.. autosummary::
+ :toctree: api_reference/auto_generated/
+ :template: function.rst
+
+ get_config
+ get_default_config
+ set_config
+ reset_config
+ config_context
+
.. _obj_retrieval:
Object Retrieval
diff --git a/docs/source/user_documentation/user_guide.rst b/docs/source/user_documentation/user_guide.rst
index c8febd31..4e15f3ec 100644
--- a/docs/source/user_documentation/user_guide.rst
+++ b/docs/source/user_documentation/user_guide.rst
@@ -29,6 +29,7 @@ that ``skbase`` provides, see the :ref:`api_ref`.
user_guide/lookup
user_guide/validate
user_guide/testing
+ user_guide/configuration
.. grid:: 1 2 2 2
@@ -103,3 +104,17 @@ that ``skbase`` provides, see the :ref:`api_ref`.
:expand:
Testing
+
+ .. grid-item-card:: Configuration
+ :text-align: center
+
+ Configure ``skbase``.
+
+ +++
+
+ .. button-ref:: user_guide/configuration
+ :color: primary
+ :click-parent:
+ :expand:
+
+ Configuration
diff --git a/docs/source/user_documentation/user_guide/configuration.rst b/docs/source/user_documentation/user_guide/configuration.rst
new file mode 100644
index 00000000..05b76843
--- /dev/null
+++ b/docs/source/user_documentation/user_guide/configuration.rst
@@ -0,0 +1,11 @@
+.. _user_guide_global_config:
+
+====================
+Configure ``skbase``
+====================
+
+.. note::
+
+ The user guide is under development. We have created a basic
+ structure and are looking for contributions to develop the user guide
+ further.
diff --git a/skbase/base/_base.py b/skbase/base/_base.py
index 970bf350..1393e267 100644
--- a/skbase/base/_base.py
+++ b/skbase/base/_base.py
@@ -53,6 +53,7 @@ class name: BaseEstimator
fitted state check - check_is_fitted (raises error if not is_fitted)
"""
import inspect
+import re
import warnings
from collections import defaultdict
from copy import deepcopy
@@ -62,7 +63,10 @@ class name: BaseEstimator
from sklearn.base import BaseEstimator as _BaseEstimator
from skbase._exceptions import NotFittedError
+from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager
+from skbase.config import get_config
+from skbase.config._config import _CONFIG_REGISTRY
__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]
@@ -446,7 +450,39 @@ def get_config(self):
class attribute via nested inheritance and then any overrides
and new tags from _onfig_dynamic object attribute.
"""
- return self._get_flags(flag_attr_name="_config")
+ config = get_config().copy()
+
+ # Get any extension configuration interface defined in the class
+ # for example if downstream package wants to extend skbase to retrieve
+ # their own config
+ if hasattr(self, "__skbase_get_config__") and callable(
+ self.__skbase_get_config__
+ ):
+ skbase_get_config_extension_dict = self.__skbase_get_config__()
+ else:
+ skbase_get_config_extension_dict = {}
+ if isinstance(skbase_get_config_extension_dict, dict):
+ config.update(skbase_get_config_extension_dict)
+ else:
+ msg = "Use of `__skbase_get_config__` to extend the interface for local "
+ msg += "overrides of the global configuration must return a dictionary.\n"
+ msg += f"But a {type(skbase_get_config_extension_dict)} was found."
+ warnings.warn(msg, UserWarning, stacklevel=2)
+ local_config = self._get_flags(flag_attr_name="_config").copy()
+ # IF the local config is one of
+ for config_param, config_value in local_config.items():
+ if config_param in _CONFIG_REGISTRY:
+ msg = "Invalid value encountered for global configuration parameter "
+ msg += f"{config_param}. Using global parameter configuration value.\n"
+ config_value = _CONFIG_REGISTRY[
+ config_param
+ ].get_valid_param_or_default(
+ config_value, default_value=config[config_param]
+ )
+ local_config[config_param] = config_value
+ config.update(local_config)
+
+ return config
def set_config(self, **config_dict):
"""Set config flags to given values.
@@ -682,6 +718,98 @@ def _components(self, base_class=None):
return comp_dict
+ def __repr__(self, n_char_max: int = 700):
+ """Represent class as string.
+
+ This follows the scikit-learn implementation for the string representation
+ of parameterized objects.
+
+ Parameters
+ ----------
+ n_char_max : int
+ Maximum (approximate) number of non-blank characters to render. This
+ can be useful in testing.
+ """
+ from skbase.base._pretty_printing._pprint import _BaseObjectPrettyPrinter
+
+ n_max_elements_to_show = 30 # number of elements to show in sequences
+ # use ellipsis for sequences with a lot of elements
+ pp = _BaseObjectPrettyPrinter(
+ compact=True,
+ indent=1,
+ indent_at_name=True,
+ n_max_elements_to_show=n_max_elements_to_show,
+ changed_only=self.get_config()["print_changed_only"],
+ )
+
+ repr_ = pp.pformat(self)
+
+ # Use bruteforce ellipsis when there are a lot of non-blank characters
+ n_nonblank = len("".join(repr_.split()))
+ if n_nonblank > n_char_max:
+ lim = n_char_max // 2 # apprx number of chars to keep on both ends
+ regex = r"^(\s*\S){%d}" % lim
+ # The regex '^(\s*\S){%d}' matches from the start of the string
+ # until the nth non-blank character:
+ # - ^ matches the start of string
+ # - (pattern){n} matches n repetitions of pattern
+ # - \s*\S matches a non-blank char following zero or more blanks
+ left_match = re.match(regex, repr_)
+ right_match = re.match(regex, repr_[::-1])
+ left_lim = left_match.end() if left_match is not None else 0
+ right_lim = right_match.end() if right_match is not None else 0
+
+ if "\n" in repr_[left_lim:-right_lim]:
+ # The left side and right side aren't on the same line.
+ # To avoid weird cuts, e.g.:
+ # categoric...ore',
+ # we need to start the right side with an appropriate newline
+ # character so that it renders properly as:
+ # categoric...
+ # handle_unknown='ignore',
+ # so we add [^\n]*\n which matches until the next \n
+ regex += r"[^\n]*\n"
+ right_match = re.match(regex, repr_[::-1])
+ right_lim = right_match.end() if right_match is not None else 0
+
+ ellipsis = "..."
+ if left_lim + len(ellipsis) < len(repr_) - right_lim:
+ # Only add ellipsis if it results in a shorter repr
+ repr_ = repr_[:left_lim] + "..." + repr_[-right_lim:]
+
+ return repr_
+
+ @property
+ def _repr_html_(self):
+ """HTML representation of BaseObject.
+
+ This is redundant with the logic of `_repr_mimebundle_`. The latter
+ should be favorted in the long term, `_repr_html_` is only
+ implemented for consumers who do not interpret `_repr_mimbundle_`.
+ """
+ if self.get_config()["display"] != "diagram":
+ raise AttributeError(
+ "_repr_html_ is only defined when the "
+ "`display` configuration option is set to 'diagram'."
+ )
+ return self._repr_html_inner
+
+ def _repr_html_inner(self):
+ """Return HTML representation of class.
+
+ This function is returned by the @property `_repr_html_` to make
+ `hasattr(BaseObject, "_repr_html_") return `True` or `False` depending
+ on `self.get_config()["display"]`.
+ """
+ return _object_html_repr(self)
+
+ def _repr_mimebundle_(self, **kwargs):
+ """Mime bundle used by jupyter kernels to display instances of BaseObject."""
+ output = {"text/plain": repr(self)}
+ if self.get_config()["display"] == "diagram":
+ output["text/html"] = _object_html_repr(self)
+ return output
+
class TagAliaserMixin:
"""Mixin class for tag aliasing and deprecation of old tags.
diff --git a/skbase/base/_pretty_printing/__init__.py b/skbase/base/_pretty_printing/__init__.py
new file mode 100644
index 00000000..669c800c
--- /dev/null
+++ b/skbase/base/_pretty_printing/__init__.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python3 -u
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Many elements of this code were developed in scikit-learn. These elements
+# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
+# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+"""Functionality for pretty printing BaseObjects."""
+from typing import List
+
+__author__: List[str] = ["RNKuhns"]
+__all__: List[str] = []
diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py
new file mode 100644
index 00000000..16515b9d
--- /dev/null
+++ b/skbase/base/_pretty_printing/_object_html_repr.py
@@ -0,0 +1,400 @@
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Many elements of this code were developed in scikit-learn. These elements
+# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
+# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+"""Functionality to represent instance of BaseObject as html."""
+
+import html
+import uuid
+from contextlib import closing, suppress
+from io import StringIO
+from string import Template
+
+from skbase.config import config_context # type: ignore
+
+__author__ = ["RNKuhns"]
+
+
+class _VisualBlock:
+ """HTML Representation of BaseObject.
+
+ Parameters
+ ----------
+ kind : {'serial', 'parallel', 'single'}
+ kind of HTML block
+
+ objs : list of BaseObjects or `_VisualBlock`s or a single BaseObject
+ If kind != 'single', then `objs` is a list of
+ BaseObjects. If kind == 'single', then `objs` is a single BaseObject.
+
+ names : list of str, default=None
+ If kind != 'single', then `names` corresponds to BaseObjects.
+ If kind == 'single', then `names` is a single string corresponding to
+ the single BaseObject.
+
+ name_details : list of str, str, or None, default=None
+ If kind != 'single', then `name_details` corresponds to `names`.
+ If kind == 'single', then `name_details` is a single string
+ corresponding to the single BaseObject.
+
+ dash_wrapped : bool, default=True
+ If true, wrapped HTML element will be wrapped with a dashed border.
+ Only active when kind != 'single'.
+ """
+
+ def __init__(self, kind, objs, *, names=None, name_details=None, dash_wrapped=True):
+ self.kind = kind
+ self.objs = objs
+ self.dash_wrapped = dash_wrapped
+
+ if self.kind in ("parallel", "serial"):
+ if names is None:
+ names = (None,) * len(objs)
+ if name_details is None:
+ name_details = (None,) * len(objs)
+
+ self.names = names
+ self.name_details = name_details
+
+ def _sk_visual_block_(self):
+ return self
+
+
+def _write_label_html(
+ out,
+ name,
+ name_details,
+ outer_class="sk-label-container",
+ inner_class="sk-label",
+ checked=False,
+):
+ """Write labeled html with or without a dropdown with named details."""
+ out.write(f'
')
+ name = html.escape(name)
+
+ if name_details is not None:
+ name_details = html.escape(str(name_details))
+ label_class = "sk-toggleable__label sk-toggleable__label-arrow"
+
+ checked_str = "checked" if checked else ""
+ est_id = uuid.uuid4()
+ out.write(
+ ''
+ f""
+ f'
{name_details}'
+ "
"
+ )
+ else:
+ out.write(f"")
+ out.write("
") # outer_class inner_class
+
+
+def _get_visual_block(base_object):
+ """Generate information about how to display a BaseObject."""
+ with suppress(AttributeError):
+ return base_object._sk_visual_block_()
+
+ if isinstance(base_object, str):
+ return _VisualBlock(
+ "single", base_object, names=base_object, name_details=base_object
+ )
+ elif base_object is None:
+ return _VisualBlock("single", base_object, names="None", name_details="None")
+
+ # check if base_object looks like a meta base_object wraps base_object
+ if hasattr(base_object, "get_params"):
+ base_objects = []
+ for key, value in base_object.get_params().items():
+ # Only look at the BaseObjects in the first layer
+ if "__" not in key and hasattr(value, "get_params"):
+ base_objects.append(value)
+ if len(base_objects):
+ return _VisualBlock("parallel", base_objects, names=None)
+
+ return _VisualBlock(
+ "single",
+ base_object,
+ names=base_object.__class__.__name__,
+ name_details=str(base_object),
+ )
+
+
+def _write_base_object_html(
+ out, base_object, base_object_label, base_object_label_details, first_call=False
+):
+ """Write BaseObject to html in serial, parallel, or by itself (single)."""
+ if first_call:
+ est_block = _get_visual_block(base_object)
+ else:
+ # So it is easier to read, always use print_changed_only==True
+ # regardless of configuration
+ with config_context(print_changed_only=True):
+ est_block = _get_visual_block(base_object)
+
+ if est_block.kind in ("serial", "parallel"):
+ dashed_wrapped = first_call or est_block.dash_wrapped
+ dash_cls = " sk-dashed-wrapped" if dashed_wrapped else ""
+ out.write(f'
")
+
+ html_output = out.getvalue()
+ return html_output
diff --git a/skbase/base/_pretty_printing/_pprint.py b/skbase/base/_pretty_printing/_pprint.py
new file mode 100644
index 00000000..2e8a8a80
--- /dev/null
+++ b/skbase/base/_pretty_printing/_pprint.py
@@ -0,0 +1,412 @@
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Many elements of this code were developed in scikit-learn. These elements
+# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
+# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+"""Utility functionality for pretty-printing objects used in BaseObject.__repr__."""
+import inspect
+import pprint
+from collections import OrderedDict
+
+from skbase.base import BaseObject
+
+# from skbase.config import get_config
+from skbase.utils._check import _is_scalar_nan
+
+
+class KeyValTuple(tuple):
+ """Dummy class for correctly rendering key-value tuples from dicts."""
+
+ def __repr__(self):
+ """Represent as string."""
+ # needed for _dispatch[tuple.__repr__] not to be overridden
+ return super().__repr__()
+
+
+class KeyValTupleParam(KeyValTuple):
+ """Dummy class for correctly rendering key-value tuples from parameters."""
+
+ pass
+
+
+def _changed_params(base_object):
+ """Return dict (param_name: value) of parameters with non-default values."""
+ params = base_object.get_params(deep=False)
+ init_func = getattr(
+ base_object.__init__, "deprecated_original", base_object.__init__
+ )
+ init_params = inspect.signature(init_func).parameters
+ init_params = {name: param.default for name, param in init_params.items()}
+
+ def has_changed(k, v):
+ if k not in init_params: # happens if k is part of a **kwargs
+ return True
+ if init_params[k] == inspect._empty: # k has no default value
+ return True
+ # try to avoid calling repr on nested BaseObjects
+ if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__:
+ return True
+ # Use repr as a last resort. It may be expensive.
+ if repr(v) != repr(init_params[k]) and not (
+ _is_scalar_nan(init_params[k]) and _is_scalar_nan(v)
+ ):
+ return True
+ return False
+
+ return {k: v for k, v in params.items() if has_changed(k, v)}
+
+
+class _BaseObjectPrettyPrinter(pprint.PrettyPrinter):
+ """Pretty Printer class for BaseObjects.
+
+ This extends the pprint.PrettyPrinter class similar to scikit-learn's
+ implementation, so that:
+
+ - BaseObjects are printed with their parameters, e.g.
+ BaseObject(param1=value1, ...) which is not supported by default.
+ - the 'compact' parameter of PrettyPrinter is ignored for dicts, which
+ may lead to very long representations that we want to avoid.
+
+ Quick overview of pprint.PrettyPrinter (see also
+ https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers):
+
+ - the entry point is the _format() method which calls format() (overridden
+ here)
+ - format() directly calls _safe_repr() for a first try at rendering the
+ object
+ - _safe_repr formats the whole object recursively, only calling itself,
+ not caring about line length or anything
+ - back to _format(), if the output string is too long, _format() then calls
+ the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on
+ the type of the object. This where the line length and the compact
+ parameters are taken into account.
+ - those _pprint_TYPE() methods will internally use the format() method for
+ rendering the nested objects of an object (e.g. the elements of a list)
+
+ In the end, everything has to be implemented twice: in _safe_repr and in
+ the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not
+ straightforward to extend (especially when we want a compact output), so
+ the code is a bit convoluted.
+
+ This class overrides:
+ - format() to support the changed_only parameter
+ - _safe_repr to support printing of BaseObjects that fit on a single line
+ - _format_dict_items so that dict are correctly 'compacted'
+ - _format_items so that ellipsis is used on long lists and tuples
+
+ When BaseObjects cannot be printed on a single line, the builtin _format()
+ will call _pprint_object() because it was registered to do so (see
+ _dispatch[BaseObject.__repr__] = _pprint_object).
+
+ both _format_dict_items() and _pprint_Object() use the
+ _format_params_or_dict_items() method that will format parameters and
+ key-value pairs respecting the compact parameter. This method needs another
+ subroutine _pprint_key_val_tuple() used when a parameter or a key-value
+ pair is too long to fit on a single line. This subroutine is called in
+ _format() and is registered as well in the _dispatch dict (just like
+ _pprint_object). We had to create the two classes KeyValTuple and
+ KeyValTupleParam for this.
+ """
+
+ def __init__(
+ self,
+ indent=1,
+ width=80,
+ depth=None,
+ stream=None,
+ *,
+ compact=False,
+ indent_at_name=True,
+ n_max_elements_to_show=None,
+ changed_only=True,
+ ):
+ super().__init__(indent, width, depth, stream, compact=compact)
+ self._indent_at_name = indent_at_name
+ if self._indent_at_name:
+ self._indent_per_level = 1 # ignore indent param
+ self.changed_only = changed_only
+ # Max number of elements in a list, dict, tuple until we start using
+ # ellipsis. This also affects the number of arguments of a BaseObject
+ # (they are treated as dicts)
+ self.n_max_elements_to_show = n_max_elements_to_show
+
+ def format(self, obj, context, maxlevels, level): # noqa
+ return _safe_repr(
+ obj, context, maxlevels, level, changed_only=self.changed_only
+ )
+
+ def _pprint_object(self, obj, stream, indent, allowance, context, level):
+ stream.write(obj.__class__.__name__ + "(")
+ if self._indent_at_name:
+ indent += len(obj.__class__.__name__)
+
+ if self.changed_only:
+ params = _changed_params(obj)
+ else:
+ params = obj.get_params(deep=False)
+
+ params = OrderedDict((name, val) for (name, val) in sorted(params.items()))
+
+ self._format_params(
+ params.items(), stream, indent, allowance + 1, context, level
+ )
+ stream.write(")")
+
+ def _format_dict_items(self, items, stream, indent, allowance, context, level):
+ return self._format_params_or_dict_items(
+ items, stream, indent, allowance, context, level, is_dict=True
+ )
+
+ def _format_params(self, items, stream, indent, allowance, context, level):
+ return self._format_params_or_dict_items(
+ items, stream, indent, allowance, context, level, is_dict=False
+ )
+
+ def _format_params_or_dict_items(
+ self, obj, stream, indent, allowance, context, level, is_dict
+ ):
+ """Format dict items or parameters respecting the compact=True parameter.
+
+ For some reason, the builtin rendering of dict items doesn't
+ respect compact=True and will use one line per key-value if all cannot
+ fit in a single line.
+ Dict items will be rendered as <'key': value> while params will be
+ rendered as . The implementation is mostly copy/pasting from
+ the builtin _format_items().
+ This also adds ellipsis if the number of items is greater than
+ self.n_max_elements_to_show.
+ """
+ write = stream.write
+ indent += self._indent_per_level
+ delimnl = ",\n" + " " * indent
+ delim = ""
+ width = max_width = self._width - indent + 1
+ it = iter(obj)
+ try:
+ next_ent = next(it)
+ except StopIteration:
+ return
+ last = False
+ n_items = 0
+ while not last:
+ if n_items == self.n_max_elements_to_show:
+ write(", ...")
+ break
+ n_items += 1
+ ent = next_ent
+ try:
+ next_ent = next(it)
+ except StopIteration:
+ last = True
+ max_width -= allowance
+ width -= allowance
+ if self._compact:
+ k, v = ent
+ krepr = self._repr(k, context, level)
+ vrepr = self._repr(v, context, level)
+ if not is_dict:
+ krepr = krepr.strip("'")
+ middle = ": " if is_dict else "="
+ rep = krepr + middle + vrepr
+ w = len(rep) + 2
+ if width < w:
+ width = max_width
+ if delim:
+ delim = delimnl
+ if width >= w:
+ width -= w
+ write(delim)
+ delim = ", "
+ write(rep)
+ continue
+ write(delim)
+ delim = delimnl
+ class_ = KeyValTuple if is_dict else KeyValTupleParam
+ self._format(
+ class_(ent), stream, indent, allowance if last else 1, context, level
+ )
+
+ def _format_items(self, items, stream, indent, allowance, context, level):
+ """Format the items of an iterable (list, tuple...).
+
+ Same as the built-in _format_items, with support for ellipsis if the
+ number of elements is greater than self.n_max_elements_to_show.
+ """
+ write = stream.write
+ indent += self._indent_per_level
+ if self._indent_per_level > 1:
+ write((self._indent_per_level - 1) * " ")
+ delimnl = ",\n" + " " * indent
+ delim = ""
+ width = max_width = self._width - indent + 1
+ it = iter(items)
+ try:
+ next_ent = next(it)
+ except StopIteration:
+ return
+ last = False
+ n_items = 0
+ while not last:
+ if n_items == self.n_max_elements_to_show:
+ write(", ...")
+ break
+ n_items += 1
+ ent = next_ent
+ try:
+ next_ent = next(it)
+ except StopIteration:
+ last = True
+ max_width -= allowance
+ width -= allowance
+ if self._compact:
+ rep = self._repr(ent, context, level)
+ w = len(rep) + 2
+ if width < w:
+ width = max_width
+ if delim:
+ delim = delimnl
+ if width >= w:
+ width -= w
+ write(delim)
+ delim = ", "
+ write(rep)
+ continue
+ write(delim)
+ delim = delimnl
+ self._format(ent, stream, indent, allowance if last else 1, context, level)
+
+ def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level):
+ """Pretty printing for key-value tuples from dict or parameters."""
+ k, v = obj
+ rep = self._repr(k, context, level)
+ if isinstance(obj, KeyValTupleParam):
+ rep = rep.strip("'")
+ middle = "="
+ else:
+ middle = ": "
+ stream.write(rep)
+ stream.write(middle)
+ self._format(
+ v, stream, indent + len(rep) + len(middle), allowance, context, level
+ )
+
+ # Follow what scikit-learn did here and copy _dispatch to prevent instances
+ # of the builtin PrettyPrinter class to call methods of
+ # _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906)
+ # mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch"
+ _dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore
+ _dispatch[BaseObject.__repr__] = _pprint_object
+ _dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
+
+
+def _safe_repr(obj, context, maxlevels, level, changed_only=False):
+ """Safe string representation logic.
+
+ Same as the builtin _safe_repr, with added support for BaseObjects.
+ """
+ typ = type(obj)
+
+ if typ in pprint._builtin_scalars:
+ return repr(obj), True, False
+
+ r = getattr(typ, "__repr__", None)
+ if issubclass(typ, dict) and r is dict.__repr__:
+ if not obj:
+ return "{}", True, False
+ objid = id(obj)
+ if maxlevels and level >= maxlevels:
+ return "{...}", False, objid in context
+ if objid in context:
+ return pprint._recursion(obj), False, True
+ context[objid] = 1
+ readable = True
+ recursive = False
+ components = []
+ append = components.append
+ level += 1
+ saferepr = _safe_repr
+ items = sorted(obj.items(), key=pprint._safe_tuple)
+ for k, v in items:
+ krepr, kreadable, krecur = saferepr(
+ k, context, maxlevels, level, changed_only=changed_only
+ )
+ vrepr, vreadable, vrecur = saferepr(
+ v, context, maxlevels, level, changed_only=changed_only
+ )
+ append("%s: %s" % (krepr, vrepr))
+ readable = readable and kreadable and vreadable
+ if krecur or vrecur:
+ recursive = True
+ del context[objid]
+ return "{%s}" % ", ".join(components), readable, recursive
+
+ if (issubclass(typ, list) and r is list.__repr__) or (
+ issubclass(typ, tuple) and r is tuple.__repr__
+ ):
+ if issubclass(typ, list):
+ if not obj:
+ return "[]", True, False
+ format_ = "[%s]"
+ elif len(obj) == 1:
+ format_ = "(%s,)"
+ else:
+ if not obj:
+ return "()", True, False
+ format_ = "(%s)"
+ objid = id(obj)
+ if maxlevels and level >= maxlevels:
+ return format_ % "...", False, objid in context
+ if objid in context:
+ return pprint._recursion(obj), False, True
+ context[objid] = 1
+ readable = True
+ recursive = False
+ components = []
+ append = components.append
+ level += 1
+ for o in obj:
+ orepr, oreadable, orecur = _safe_repr(
+ o, context, maxlevels, level, changed_only=changed_only
+ )
+ append(orepr)
+ if not oreadable:
+ readable = False
+ if orecur:
+ recursive = True
+ del context[objid]
+ return format_ % ", ".join(components), readable, recursive
+
+ if issubclass(typ, BaseObject):
+ objid = id(obj)
+ if maxlevels and level >= maxlevels:
+ return "{...}", False, objid in context
+ if objid in context:
+ return pprint._recursion(obj), False, True
+ context[objid] = 1
+ readable = True
+ recursive = False
+ if changed_only:
+ params = _changed_params(obj)
+ else:
+ params = obj.get_params(deep=False)
+ components = []
+ append = components.append
+ level += 1
+ saferepr = _safe_repr
+ items = sorted(params.items(), key=pprint._safe_tuple)
+ for k, v in items:
+ krepr, kreadable, krecur = saferepr(
+ k, context, maxlevels, level, changed_only=changed_only
+ )
+ vrepr, vreadable, vrecur = saferepr(
+ v, context, maxlevels, level, changed_only=changed_only
+ )
+ append("%s=%s" % (krepr.strip("'"), vrepr))
+ readable = readable and kreadable and vreadable
+ if krecur or vrecur:
+ recursive = True
+ del context[objid]
+ return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive)
+
+ rep = repr(obj)
+ return rep, (rep and not rep.startswith("<")), False
diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py
new file mode 100644
index 00000000..fd8a750a
--- /dev/null
+++ b/skbase/config/__init__.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+""":mod:`skbase.config` provides tools for the global configuration of ``skbase``.
+
+For more information on configuration usage patterns see the
+:ref:`user guide `.
+"""
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Includes functionality like get_config, set_config, and config_context
+# that is similar to scikit-learn. These elements are copyrighted by the
+# scikit-learn developers, BSD-3-Clause License. For conditions see
+# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+from typing import List
+
+from skbase.config._config import (
+ GlobalConfigParamSetting,
+ config_context,
+ get_config,
+ get_default_config,
+ reset_config,
+ set_config,
+)
+
+__author__: List[str] = ["RNKuhns"]
+__all__: List[str] = [
+ "GlobalConfigParamSetting",
+ "get_default_config",
+ "get_config",
+ "set_config",
+ "reset_config",
+ "config_context",
+]
diff --git a/skbase/config/_config.py b/skbase/config/_config.py
new file mode 100644
index 00000000..ec2fc13c
--- /dev/null
+++ b/skbase/config/_config.py
@@ -0,0 +1,397 @@
+# -*- coding: utf-8 -*-
+"""Implement logic for global configuration of skbase."""
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Includes functionality like get_config, set_config, and config_context
+# that is similar to scikit-learn. These elements are copyrighted by the
+# scikit-learn developers, BSD-3-Clause License. For conditions see
+# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+import collections
+import sys
+import threading
+import warnings
+from contextlib import contextmanager
+from dataclasses import dataclass
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+
+if sys.version_info >= (3, 8):
+ from typing import Literal
+else:
+ from typing_extensions import Literal # type: ignore
+
+from skbase.utils._iter import _format_seq_to_str
+
+__author__: List[str] = ["RNKuhns"]
+__all__: List[str] = [
+ "GlobalConfigParamSetting",
+ "get_default_config",
+ "get_config",
+ "set_config",
+ "reset_config",
+ "config_context",
+]
+
+
+@dataclass
+class GlobalConfigParamSetting:
+ """Define types of the setting information for a given config parameter."""
+
+ name: str
+ os_environ_name: str
+ expected_type: Union[type, Tuple[type]]
+ allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]]
+ default_value: Any
+
+ def get_allowed_values(self) -> List[Any]:
+ """Get `allowed_values` or empty tuple if `allowed_values` is None.
+
+ Returns
+ -------
+ tuple
+ Allowable values if any.
+ """
+ if self.allowed_values is None:
+ return []
+ elif isinstance(self.allowed_values, list):
+ return self.allowed_values
+ elif isinstance(
+ self.allowed_values, collections.abc.Iterable
+ ) and not isinstance(self.allowed_values, str):
+ return list(self.allowed_values)
+ else:
+ return [self.allowed_values]
+
+ def is_valid_param_value(self, value):
+ """Validate that a global configuration value is valid.
+
+ Verifies that the value set for a global configuration parameter is valid
+ based on the its configuration settings.
+
+ Returns
+ -------
+ bool
+ Whether a parameter value is valid.
+ """
+ allowed_values = self.get_allowed_values()
+
+ valid_param: bool
+ if not isinstance(value, self.expected_type):
+ valid_param = False
+ elif allowed_values is not None and value not in allowed_values:
+ valid_param = False
+ else:
+ valid_param = True
+ return valid_param
+
+ def get_valid_param_or_default(self, value, default_value=None, msg=None):
+ """Validate `value` and return default if it is not valid.
+
+ Parameters
+ ----------
+ value : Any
+ The configuration parameter value to set.
+ default_value : Any, default=None
+ An optional default value to use to set the configuration parameter
+ if `value` is not valid based on defined expected type and allowed
+ values. If None, and `value` is invalid then the classes `default_value`
+ parameter is used.
+ msg : str, default=None
+ An optional message to be used as start of the UserWarning message.
+ """
+ if self.is_valid_param_value(value):
+ return value
+ else:
+ if msg is None:
+ msg = ""
+ msg + f"When setting global config values for `{self.name}`, the values "
+ msg += f"should be of type {self.expected_type}.\n"
+ if self.allowed_values is not None:
+ values_str = _format_seq_to_str(
+ self.get_allowed_values(), last_sep="or", remove_type_text=True
+ )
+ msg += f"Allowed values are be one of {values_str}. "
+ msg += f"But found {value}."
+ warnings.warn(msg, UserWarning, stacklevel=2)
+ return default_value if default_value is not None else self.default_value
+
+
+_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = {
+ "print_changed_only": GlobalConfigParamSetting(
+ name="print_changed_only",
+ os_environ_name="SKBASE_PRINT_CHANGED_ONLY",
+ expected_type=bool,
+ allowed_values=(True, False),
+ default_value=True,
+ ),
+ "display": GlobalConfigParamSetting(
+ name="display",
+ os_environ_name="SKBASE_OBJECT_DISPLAY",
+ expected_type=str,
+ allowed_values=("text", "diagram"),
+ default_value="text",
+ ),
+}
+
+_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = {
+ config_name: config_info.default_value
+ for config_name, config_info in _CONFIG_REGISTRY.items()
+}
+
+global_config = _DEFAULT_GLOBAL_CONFIG.copy()
+_THREAD_LOCAL_DATA = threading.local()
+
+
+def _get_threadlocal_config() -> Dict[str, Any]:
+ """Get a threadlocal **mutable** configuration.
+
+ If the configuration does not exist, copy the default global configuration.
+
+ Returns
+ -------
+ dict
+ Threadlocal global config or copy of default global configuration.
+ """
+ if not hasattr(_THREAD_LOCAL_DATA, "global_config"):
+ _THREAD_LOCAL_DATA.global_config = global_config.copy()
+ return _THREAD_LOCAL_DATA.global_config
+
+
+def get_default_config() -> Dict[str, Any]:
+ """Retrieve the default global configuration.
+
+ This will always return the default ``skbase`` global configuration.
+
+ Returns
+ -------
+ config : dict
+ The default configurable settings (keys) and their default values (values).
+
+ See Also
+ --------
+ config_context :
+ Configuration context manager.
+ get_config :
+ Retrieve current global configuration values.
+ set_config :
+ Set global configuration.
+ reset_config :
+ Reset configuration to ``skbase`` default.
+
+ Examples
+ --------
+ >>> from skbase.config import get_default_config
+ >>> get_default_config()
+ {'print_changed_only': True, 'display': 'text'}
+ """
+ return _DEFAULT_GLOBAL_CONFIG.copy()
+
+
+def get_config() -> Dict[str, Any]:
+ """Retrieve current values for configuration set by :meth:`set_config`.
+
+ Will return the default configuration if know updated configuration has
+ been set by :meth:`set_config`.
+
+ Returns
+ -------
+ config : dict
+ The configurable settings (keys) and their default values (values).
+
+ See Also
+ --------
+ config_context :
+ Configuration context manager.
+ get_default_config :
+ Retrieve ``skbase``'s default configuration.
+ set_config :
+ Set global configuration.
+ reset_config :
+ Reset configuration to ``skbase`` default.
+
+ Examples
+ --------
+ >>> from skbase.config import get_config
+ >>> get_config()
+ {'print_changed_only': True, 'display': 'text'}
+ """
+ return _get_threadlocal_config().copy()
+
+
+def set_config(
+ print_changed_only: Optional[bool] = None,
+ display: Literal["text", "diagram"] = None,
+ local_threadsafe: bool = False,
+) -> None:
+ """Set global configuration.
+
+ Allows the ``skbase`` global configuration to be updated.
+
+ Parameters
+ ----------
+ print_changed_only : bool, default=None
+ If True, only the parameters that were set to non-default
+ values will be printed when printing a BaseObject instance. For example,
+ ``print(SVC())`` while True will only print 'SVC()', but would print
+ 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters
+ when False. If None, the existing value won't change.
+ display : {'text', 'diagram'}, default=None
+ If 'diagram', instances inheritting from BaseOBject will be displayed
+ as a diagram in a Jupyter lab or notebook context. If 'text', instances
+ inheritting from BaseObject will be displayed as text. If None, the
+ existing value won't change.
+ local_threadsafe : bool, default=False
+ If False, set the backend as default for all threads.
+
+ Returns
+ -------
+ None
+ No output returned.
+
+ See Also
+ --------
+ config_context :
+ Configuration context manager.
+ get_default_config :
+ Retrieve ``skbase``'s default configuration.
+ get_config :
+ Retrieve current global configuration values.
+ reset_config :
+ Reset configuration to default.
+
+ Examples
+ --------
+ >>> from skbase.config import get_config, set_config
+ >>> get_config()
+ {'print_changed_only': True, 'display': 'text'}
+ >>> set_config(display='diagram')
+ >>> get_config()
+ {'print_changed_only': True, 'display': 'diagram'}
+ """
+ local_config = _get_threadlocal_config()
+
+ msg = "Attempting to set an invalid value for a global configuration.\n"
+ msg += "Using current configuration value of parameter as a result.\n"
+ if print_changed_only is not None:
+ local_config["print_changed_only"] = _CONFIG_REGISTRY[
+ "print_changed_only"
+ ].get_valid_param_or_default(
+ print_changed_only,
+ default_value=local_config["print_changed_only"],
+ msg=msg,
+ )
+ if display is not None:
+ local_config["display"] = _CONFIG_REGISTRY[
+ "display"
+ ].get_valid_param_or_default(
+ display, default_value=local_config["display"], msg=msg
+ )
+
+ if not local_threadsafe:
+ global_config.update(local_config)
+
+ return None
+
+
+def reset_config() -> None:
+ """Reset the global configuration to the default.
+
+ Will remove any user updates to the global configuration and reset the values
+ back to the ``skbase`` defaults.
+
+ Returns
+ -------
+ None
+ No output returned.
+
+ See Also
+ --------
+ config_context :
+ Configuration context manager.
+ get_default_config :
+ Retrieve ``skbase``'s default configuration.
+ get_config :
+ Retrieve current global configuration values.
+ set_config :
+ Set global configuration.
+
+ Examples
+ --------
+ >>> from skbase.config import get_config, set_config, reset_config
+ >>> get_config()
+ {'print_changed_only': True, 'display': 'text'}
+ >>> set_config(display='diagram')
+ >>> get_config()
+ {'print_changed_only': True, 'display': 'diagram'}
+ >>> reset_config()
+ >>> get_config()
+ {'print_changed_only': True, 'display': 'text'}
+ """
+ default_config = get_default_config()
+ set_config(**default_config)
+ return None
+
+
+@contextmanager
+def config_context(
+ print_changed_only: Optional[bool] = None,
+ display: Literal["text", "diagram"] = None,
+ local_threadsafe: bool = False,
+) -> Iterator[None]:
+ """Context manager for global configuration.
+
+ Provides the ability to run code using different configuration without
+ having to update the global config.
+
+ Parameters
+ ----------
+ print_changed_only : bool, default=None
+ If True, only the parameters that were set to non-default
+ values will be printed when printing a BaseObject instance. For example,
+ ``print(SVC())`` while True will only print 'SVC()', but would print
+ 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters
+ when False. If None, the existing value won't change.
+ display : {'text', 'diagram'}, default=None
+ If 'diagram', instances inheritting from BaseOBject will be displayed
+ as a diagram in a Jupyter lab or notebook context. If 'text', instances
+ inheritting from BaseObject will be displayed as text. If None, the
+ existing value won't change.
+ local_threadsafe : bool, default=False
+ If False, set the config as default for all threads.
+
+ Yields
+ ------
+ None
+
+ See Also
+ --------
+ get_default_config :
+ Retrieve ``skbase``'s default configuration.
+ get_config :
+ Retrieve current values of the global configuration.
+ set_config :
+ Set global configuration.
+ reset_config :
+ Reset configuration to ``skbase`` default.
+
+ Notes
+ -----
+ All settings, not just those presently modified, will be returned to
+ their previous values when the context manager is exited.
+
+ Examples
+ --------
+ >>> from skbase.config import config_context
+ >>> with config_context(display='diagram'):
+ ... pass
+ """
+ old_config = get_config()
+ set_config(
+ print_changed_only=print_changed_only,
+ display=display,
+ local_threadsafe=local_threadsafe,
+ )
+
+ try:
+ yield
+ finally:
+ set_config(**old_config)
diff --git a/skbase/config/tests/__init__.py b/skbase/config/tests/__init__.py
new file mode 100644
index 00000000..ac71ad0c
--- /dev/null
+++ b/skbase/config/tests/__init__.py
@@ -0,0 +1,4 @@
+#!/usr/bin/env python3 -u
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+"""Test functionality of :mod:`skbase.config`."""
diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py
new file mode 100644
index 00000000..1af163e4
--- /dev/null
+++ b/skbase/config/tests/test_config.py
@@ -0,0 +1,143 @@
+# -*- coding: utf-8 -*-
+"""Test configuration functionality."""
+import pytest
+
+from skbase.config import (
+ GlobalConfigParamSetting,
+ config_context,
+ get_config,
+ get_default_config,
+ reset_config,
+ set_config,
+)
+from skbase.config._config import _CONFIG_REGISTRY, _DEFAULT_GLOBAL_CONFIG
+
+PRINT_CHANGE_ONLY_VALUES = _CONFIG_REGISTRY["print_changed_only"].get_allowed_values()
+DISPLAY_VALUES = _CONFIG_REGISTRY["display"].get_allowed_values()
+
+
+@pytest.fixture
+def config_registry():
+ """Config registry fixture."""
+ return _CONFIG_REGISTRY
+
+
+@pytest.fixture
+def global_config_default():
+ """Config registry fixture."""
+ return _DEFAULT_GLOBAL_CONFIG
+
+
+@pytest.mark.parametrize("allowed_values", (None, (), "something", range(1, 8)))
+def test_global_config_param_get_allowed_values(allowed_values):
+ """Test GlobalConfigParamSetting behavior works as expected."""
+ some_config_param = GlobalConfigParamSetting(
+ name="some_param",
+ os_environ_name="SKBASE_OBJECT_DISPLAY",
+ expected_type=str,
+ allowed_values=allowed_values,
+ default_value="text",
+ )
+ # Verify we always coerce output of get_allowed_values to tuple
+ values = some_config_param.get_allowed_values()
+ assert isinstance(values, list)
+
+
+@pytest.mark.parametrize("value", (None, (), "wrong_string", "text", range(1, 8)))
+def test_global_config_param_is_valid_param_value(value):
+ """Test GlobalConfigParamSetting behavior works as expected."""
+ some_config_param = GlobalConfigParamSetting(
+ name="some_param",
+ os_environ_name="SKBASE_OBJECT_DISPLAY",
+ expected_type=str,
+ allowed_values=("text", "diagram"),
+ default_value="text",
+ )
+ # Verify we correctly identify invalid parameters
+ if value in ("text", "diagram"):
+ expected_valid = True
+ else:
+ expected_valid = False
+ assert some_config_param.is_valid_param_value(value) == expected_valid
+
+
+def test_get_default_config(global_config_default):
+ """Test get_default_config alwasy returns the default config."""
+ assert get_default_config() == global_config_default
+ set_config(print_changed_only=False)
+ assert get_default_config() == global_config_default
+
+
+@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES)
+@pytest.mark.parametrize("display", DISPLAY_VALUES)
+def test_set_config_then_get_config_returns_expected_value(print_changed_only, display):
+ """Verify that get_config returns set config values if set_config run."""
+ set_config(print_changed_only=print_changed_only, display=display)
+ retrieved_default = get_config()
+ expected_config = {"print_changed_only": print_changed_only, "display": display}
+ msg = "`get_config` used after `set_config` does not return expected values.\n"
+ msg += "After set_config is run, get_config should return the set values.\n "
+ msg += f"Expected {expected_config}, but returned {retrieved_default}."
+ assert retrieved_default == expected_config, msg
+
+
+@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES)
+@pytest.mark.parametrize("display", DISPLAY_VALUES)
+def test_reset_config_resets_the_config(
+ print_changed_only, display, global_config_default
+):
+ """Verify that get_config returns default config if reset_config run."""
+ set_config(print_changed_only=print_changed_only, display=display)
+ reset_config()
+ retrieved_config = get_config()
+
+ msg = "`get_config` does not return expected values after `reset_config`.\n"
+ msg += "`After reset_config is run, get_config` should return defaults.\n"
+ msg += f"Expected {global_config_default}, but returned {retrieved_config}."
+ assert retrieved_config == global_config_default, msg
+ reset_config()
+
+
+@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES)
+@pytest.mark.parametrize("display", DISPLAY_VALUES)
+def test_config_context(print_changed_only, display):
+ """Verify that config_context affects context but not overall configuration."""
+ # Make sure config is reset to default values then retrieve it
+ reset_config()
+ retrieved_config = get_config()
+ # Now lets make sure the config_context is changing the context of those values
+ # within the scope of the context manager as expected
+ for print_changed_only in (True, False):
+ with config_context(print_changed_only=print_changed_only, display=display):
+ retrieved_context_config = get_config()
+ expected_config = {"print_changed_only": print_changed_only, "display": display}
+ msg = "`get_config` does not return expected values within `config_context`.\n"
+ msg += "`get_config` should return config defined by `config_context`.\n"
+ msg += f"Expected {expected_config}, but returned {retrieved_context_config}."
+ assert retrieved_context_config == expected_config, msg
+
+ # Outside of the config_context we should have not affected the retrieved config
+ # set by call to reset_config()
+ config_post_config_context = get_config()
+ msg = "`get_config` does not return expected values after `config_context`a.\n"
+ msg += "`config_context` should not affect configuration outside its context.\n"
+ msg += f"Expected {config_post_config_context}, but returned {retrieved_config}."
+ assert retrieved_config == config_post_config_context, msg
+ reset_config()
+
+
+def test_set_config_behavior_invalid_value():
+ """Test set_config uses default and raises warning when setting invalid value."""
+ reset_config()
+ original_config = get_config().copy()
+ with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"):
+ set_config(print_changed_only="False")
+
+ assert get_config() == original_config
+
+ original_config = get_config().copy()
+ with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"):
+ set_config(print_changed_only=7)
+
+ assert get_config() == original_config
+ reset_config()
diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py
index 4f36034a..f0cf89b1 100644
--- a/skbase/tests/conftest.py
+++ b/skbase/tests/conftest.py
@@ -22,7 +22,12 @@
"skbase.base",
"skbase.base._base",
"skbase.base._meta",
+ "skbase.base._pretty_printing",
+ "skbase.base._pretty_printing._object_html_repr",
+ "skbase.base._pretty_printing._pprint",
"skbase.base._tagmanager",
+ "skbase.config",
+ "skbase.config._config",
"skbase.lookup",
"skbase.lookup.tests",
"skbase.lookup.tests.test_lookup",
@@ -42,6 +47,7 @@
"skbase.tests.test_baseestimator",
"skbase.tests.mock_package.test_mock_package",
"skbase.utils",
+ "skbase.utils._check",
"skbase.utils._iter",
"skbase.utils._nested_iter",
"skbase.validate",
@@ -51,6 +57,7 @@
SKBASE_PUBLIC_MODULES = (
"skbase",
"skbase.base",
+ "skbase.config",
"skbase.lookup",
"skbase.lookup.tests",
"skbase.lookup.tests.test_lookup",
@@ -74,6 +81,9 @@
"skbase.base": ("BaseEstimator", "BaseMetaEstimator", "BaseObject"),
"skbase.base._base": ("BaseEstimator", "BaseObject"),
"skbase.base._meta": ("BaseMetaEstimator",),
+ "skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
+ "skbase.config": ("GlobalConfigParamSetting",),
+ "skbase.config._config": ("GlobalConfigParamSetting",),
"skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"),
"skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
"skbase.testing.test_all_objects": (
@@ -85,11 +95,30 @@
SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
SKBASE_CLASSES_BY_MODULE.update(
{
- "skbase.base._meta": ("BaseMetaEstimator",),
+ "skbase.base._pretty_printing._object_html_repr": ("_VisualBlock",),
+ "skbase.base._pretty_printing._pprint": (
+ "KeyValTuple",
+ "KeyValTupleParam",
+ "_BaseObjectPrettyPrinter",
+ ),
"skbase.base._tagmanager": ("_FlagManager",),
}
)
SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
+ "skbase.config": (
+ "get_config",
+ "get_default_config",
+ "set_config",
+ "reset_config",
+ "config_context",
+ ),
+ "skbase.config._config": (
+ "get_config",
+ "get_default_config",
+ "set_config",
+ "reset_config",
+ "config_context",
+ ),
"skbase.lookup": ("all_objects", "get_package_metadata"),
"skbase.lookup._lookup": ("all_objects", "get_package_metadata"),
"skbase.testing.utils._conditional_fixtures": (
@@ -124,6 +153,21 @@
SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
SKBASE_FUNCTIONS_BY_MODULE.update(
{
+ "skbase.base._pretty_printing._object_html_repr": (
+ "_get_visual_block",
+ "_object_html_repr",
+ "_write_base_object_html",
+ "_write_label_html",
+ ),
+ "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"),
+ "skbase.config._config": (
+ "_get_threadlocal_config",
+ "get_config",
+ "get_default_config",
+ "set_config",
+ "reset_config",
+ "config_context",
+ ),
"skbase.lookup._lookup": (
"_determine_module_path",
"_get_return_tags",
@@ -155,6 +199,7 @@
"_coerce_list",
),
"skbase.testing.utils.inspect": ("_get_args",),
+ "skbase.utils._check": ("_is_scalar_nan",),
"skbase.utils._iter": (
"_format_seq_to_str",
"_remove_type_text",
diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py
index 38d2d5a9..28d54d5a 100644
--- a/skbase/tests/test_base.py
+++ b/skbase/tests/test_base.py
@@ -69,16 +69,17 @@
import inspect
from copy import deepcopy
+from typing import Any, Dict, Type
import numpy as np
import pytest
import scipy.sparse as sp
-from sklearn import config_context
# TODO: Update with import of skbase clone function once implemented
from sklearn.base import clone
from skbase.base import BaseEstimator, BaseObject
+from skbase.config import config_context, get_config
from skbase.tests.conftest import Child, Parent
from skbase.tests.mock_package.test_mock_package import CompositionDummy
@@ -200,13 +201,13 @@ def fixture_reset_tester():
@pytest.fixture
-def fixture_class_child_tags(fixture_class_child):
+def fixture_class_child_tags(fixture_class_child: Type[Child]):
"""Pytest fixture for tags of Child."""
return fixture_class_child.get_class_tags()
@pytest.fixture
-def fixture_object_instance_set_tags(fixture_tag_class_object):
+def fixture_object_instance_set_tags(fixture_tag_class_object: Child):
"""Fixture class instance to test tag setting."""
fixture_tag_set = {"A": 42424243, "E": 3}
return fixture_tag_class_object.set_tags(**fixture_tag_set)
@@ -266,7 +267,9 @@ def fixture_class_instance_no_param_interface():
return NoParamInterface()
-def test_get_class_tags(fixture_class_child, fixture_class_child_tags):
+def test_get_class_tags(
+ fixture_class_child: Type[Child], fixture_class_child_tags: Any
+):
"""Test get_class_tags class method of BaseObject for correctness.
Raises
@@ -280,7 +283,7 @@ def test_get_class_tags(fixture_class_child, fixture_class_child_tags):
assert child_tags == fixture_class_child_tags, msg
-def test_get_class_tag(fixture_class_child, fixture_class_child_tags):
+def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any):
"""Test get_class_tag class method of BaseObject for correctness.
Raises
@@ -307,7 +310,7 @@ def test_get_class_tag(fixture_class_child, fixture_class_child_tags):
assert child_tag_default_none is None, msg
-def test_get_tags(fixture_tag_class_object, fixture_object_tags):
+def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
"""Test get_tags method of BaseObject for correctness.
Raises
@@ -321,7 +324,7 @@ def test_get_tags(fixture_tag_class_object, fixture_object_tags):
assert object_tags == fixture_object_tags, msg
-def test_get_tag(fixture_tag_class_object, fixture_object_tags):
+def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
"""Test get_tag method of BaseObject for correctness.
Raises
@@ -351,7 +354,7 @@ def test_get_tag(fixture_tag_class_object, fixture_object_tags):
assert object_tag_default_none is None, msg
-def test_get_tag_raises(fixture_tag_class_object):
+def test_get_tag_raises(fixture_tag_class_object: Child):
"""Test that get_tag method raises error for unknown tag.
Raises
@@ -363,9 +366,9 @@ def test_get_tag_raises(fixture_tag_class_object):
def test_set_tags(
- fixture_object_instance_set_tags,
- fixture_object_set_tags,
- fixture_object_dynamic_tags,
+ fixture_object_instance_set_tags: Any,
+ fixture_object_set_tags: Dict[str, Any],
+ fixture_object_dynamic_tags: Dict[str, int],
):
"""Test set_tags method of BaseObject for correctness.
@@ -381,7 +384,9 @@ def test_set_tags(
assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg
-def test_set_tags_works_with_missing_tags_dynamic_attribute(fixture_tag_class_object):
+def test_set_tags_works_with_missing_tags_dynamic_attribute(
+ fixture_tag_class_object: Child,
+):
"""Test set_tags will still work if _tags_dynamic is missing."""
base_obj = deepcopy(fixture_tag_class_object)
delattr(base_obj, "_tags_dynamic")
@@ -460,7 +465,7 @@ class AnotherTestClass(BaseObject):
assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
-def test_is_composite(fixture_composition_dummy):
+def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]):
"""Test is_composite tag for correctness.
Raises
@@ -474,7 +479,11 @@ def test_is_composite(fixture_composition_dummy):
assert composite.is_composite()
-def test_components(fixture_object, fixture_class_parent, fixture_composition_dummy):
+def test_components(
+ fixture_object: Type[BaseObject],
+ fixture_class_parent: Type[Parent],
+ fixture_composition_dummy: Type[CompositionDummy],
+):
"""Test component retrieval.
Raises
@@ -507,7 +516,7 @@ def test_components(fixture_object, fixture_class_parent, fixture_composition_du
def test_components_raises_error_base_class_is_not_class(
- fixture_object, fixture_composition_dummy
+ fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy]
):
"""Test _component method raises error if base_class param is not class."""
non_composite = fixture_composition_dummy(foo=42)
@@ -526,7 +535,7 @@ def test_components_raises_error_base_class_is_not_class(
def test_components_raises_error_base_class_is_not_baseobject_subclass(
- fixture_composition_dummy,
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test _component method raises error if base_class is not BaseObject subclass."""
@@ -540,7 +549,7 @@ class SomeClass:
# Test parameter interface (get_params, set_params, reset and related methods)
# Some tests of get_params and set_params are adapted from sklearn tests
-def test_reset(fixture_reset_tester):
+def test_reset(fixture_reset_tester: Type[ResetTester]):
"""Test reset method for correct behaviour, on a simple estimator.
Raises
@@ -567,7 +576,7 @@ def test_reset(fixture_reset_tester):
assert hasattr(x, "foo")
-def test_reset_composite(fixture_reset_tester):
+def test_reset_composite(fixture_reset_tester: Type[ResetTester]):
"""Test reset method for correct behaviour, on a composite estimator."""
y = fixture_reset_tester(42)
x = fixture_reset_tester(a=y)
@@ -582,7 +591,7 @@ def test_reset_composite(fixture_reset_tester):
assert not hasattr(x.a, "d")
-def test_get_init_signature(fixture_class_parent):
+def test_get_init_signature(fixture_class_parent: Type[Parent]):
"""Test error is raised when invalid init signature is used."""
init_sig = fixture_class_parent._get_init_signature()
init_sig_is_list = isinstance(init_sig, list)
@@ -594,14 +603,18 @@ def test_get_init_signature(fixture_class_parent):
), "`_get_init_signature` is not returning expected result."
-def test_get_init_signature_raises_error_for_invalid_signature(fixture_invalid_init):
+def test_get_init_signature_raises_error_for_invalid_signature(
+ fixture_invalid_init: Type[InvalidInitSignatureTester],
+):
"""Test error is raised when invalid init signature is used."""
with pytest.raises(RuntimeError):
fixture_invalid_init._get_init_signature()
def test_get_param_names(
- fixture_object, fixture_class_parent, fixture_class_parent_expected_params
+ fixture_object: Type[BaseObject],
+ fixture_class_parent: Type[Parent],
+ fixture_class_parent_expected_params: Dict[str, Any],
):
"""Test that get_param_names returns list of string parameter names."""
param_names = fixture_class_parent.get_param_names()
@@ -612,10 +625,10 @@ def test_get_param_names(
def test_get_params(
- fixture_class_parent,
- fixture_class_parent_expected_params,
- fixture_class_instance_no_param_interface,
- fixture_composition_dummy,
+ fixture_class_parent: Type[Parent],
+ fixture_class_parent_expected_params: Dict[str, Any],
+ fixture_class_instance_no_param_interface: NoParamInterface,
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test get_params returns expected parameters."""
# Simple test of returned params
@@ -638,7 +651,10 @@ def test_get_params(
assert "foo" in params and "bar" in params and len(params) == 2
-def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy):
+def test_get_params_invariance(
+ fixture_class_parent: Type[Parent],
+ fixture_composition_dummy: Type[CompositionDummy],
+):
"""Test that get_params(deep=False) is subset of get_params(deep=True)."""
composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
shallow_params = composite.get_params(deep=False)
@@ -646,7 +662,7 @@ def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy):
assert all(item in deep_params.items() for item in shallow_params.items())
-def test_get_params_after_set_params(fixture_class_parent):
+def test_get_params_after_set_params(fixture_class_parent: Type[Parent]):
"""Test that get_params returns the same thing before and after set_params.
Based on scikit-learn check in check_estimator.
@@ -687,9 +703,9 @@ def test_get_params_after_set_params(fixture_class_parent):
def test_set_params(
- fixture_class_parent,
- fixture_class_parent_expected_params,
- fixture_composition_dummy,
+ fixture_class_parent: Type[Parent],
+ fixture_class_parent_expected_params: Dict[str, Any],
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test set_params works as expected."""
# Simple case of setting a parameter
@@ -711,7 +727,8 @@ def test_set_params(
def test_set_params_raises_error_non_existent_param(
- fixture_class_parent_instance, fixture_composition_dummy
+ fixture_class_parent_instance: Parent,
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test set_params raises an error when passed a non-existent parameter name."""
# non-existing parameter in svc
@@ -727,7 +744,8 @@ def test_set_params_raises_error_non_existent_param(
def test_set_params_raises_error_non_interface_composite(
- fixture_class_instance_no_param_interface, fixture_composition_dummy
+ fixture_class_instance_no_param_interface: NoParamInterface,
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test set_params raises error when setting param of non-conforming composite."""
# When a composite is made up of a class that doesn't have the BaseObject
@@ -753,7 +771,9 @@ def __init__(self, param=5):
est.get_params()
-def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent):
+def test_set_params_with_no_param_to_set_returns_object(
+ fixture_class_parent: Type[Parent],
+):
"""Test set_params correctly returns self when no parameters are set."""
base_obj = fixture_class_parent()
orig_params = deepcopy(base_obj.get_params())
@@ -767,7 +787,7 @@ def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent):
# This section tests the clone functionality
# These have been adapted from sklearn's tests of clone to use the clone
# method that is included as part of the BaseObject interface
-def test_clone(fixture_class_parent_instance):
+def test_clone(fixture_class_parent_instance: Parent):
"""Test that clone is making a deep copy as expected."""
# Creates a BaseObject and makes a copy of its original state
# (which, in this case, is the current state of the BaseObject),
@@ -777,7 +797,7 @@ def test_clone(fixture_class_parent_instance):
assert fixture_class_parent_instance.get_params() == new_base_obj.get_params()
-def test_clone_2(fixture_class_parent_instance):
+def test_clone_2(fixture_class_parent_instance: Parent):
"""Test that clone does not copy attributes not set in constructor."""
# We first create an estimator, give it an own attribute, and
# make a copy of its original state. Then we check that the copy doesn't
@@ -790,7 +810,9 @@ def test_clone_2(fixture_class_parent_instance):
def test_clone_raises_error_for_nonconforming_objects(
- fixture_invalid_init, fixture_buggy, fixture_modify_param
+ fixture_invalid_init: Type[InvalidInitSignatureTester],
+ fixture_buggy: Type[Buggy],
+ fixture_modify_param: Type[ModifyParam],
):
"""Test that clone raises an error on nonconforming BaseObjects."""
buggy = fixture_buggy()
@@ -807,7 +829,7 @@ def test_clone_raises_error_for_nonconforming_objects(
obj_that_modifies.clone()
-def test_clone_param_is_none(fixture_class_parent):
+def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter set to None."""
base_obj = fixture_class_parent(c=None)
new_base_obj = clone(base_obj)
@@ -816,7 +838,7 @@ def test_clone_param_is_none(fixture_class_parent):
assert base_obj.c is new_base_obj2.c
-def test_clone_empty_array(fixture_class_parent):
+def test_clone_empty_array(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter is scipy sparse matrix.
This test is based on scikit-learn regression test to make sure clone
@@ -830,7 +852,7 @@ def test_clone_empty_array(fixture_class_parent):
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
-def test_clone_sparse_matrix(fixture_class_parent):
+def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter is scipy sparse matrix.
This test is based on scikit-learn regression test to make sure clone
@@ -843,7 +865,7 @@ def test_clone_sparse_matrix(fixture_class_parent):
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
-def test_clone_nan(fixture_class_parent):
+def test_clone_nan(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter is np.nan.
This test is based on scikit-learn regression test to make sure clone
@@ -858,7 +880,7 @@ def test_clone_nan(fixture_class_parent):
assert base_obj.c is new_base_obj2.c
-def test_clone_estimator_types(fixture_class_parent):
+def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
"""Test clone works for parameters that are types rather than instances."""
base_obj = fixture_class_parent(c=fixture_class_parent)
new_base_obj = base_obj.clone()
@@ -866,7 +888,9 @@ def test_clone_estimator_types(fixture_class_parent):
assert base_obj.c == new_base_obj.c
-def test_clone_class_rather_than_instance_raises_error(fixture_class_parent):
+def test_clone_class_rather_than_instance_raises_error(
+ fixture_class_parent: Type[Parent],
+):
"""Test clone raises expected error when cloning a class instead of an instance."""
msg = "You should provide an instance of scikit-learn estimator"
with pytest.raises(TypeError, match=msg):
@@ -874,31 +898,77 @@ def test_clone_class_rather_than_instance_raises_error(fixture_class_parent):
# Tests of BaseObject pretty printing representation inspired by sklearn
-def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy):
+def test_baseobject_repr(
+ fixture_class_parent: Type[Parent],
+ fixture_composition_dummy: Type[CompositionDummy],
+):
"""Test BaseObject repr works as expected."""
# Simple test where all parameters are left at defaults
# Should not see parameters and values in printed representation
+
base_obj = fixture_class_parent()
assert repr(base_obj) == "Parent()"
# Check that we can alter the detail about params that is printed
# using config_context with ``print_changed_only=False``
- with config_context(print_changed_only=False):
+ with config_context(print_changed_only=False, display="text"):
assert repr(base_obj) == "Parent(a='something', b=7, c=None)"
+ # Check that local config works as expected
+ base_obj.set_config(print_changed_only=False)
+ assert repr(base_obj) == "Parent(a='something', b=7, c=None)"
+
+ # Test with dict parameter (note that dict is sorted by keys when printed)
+ # not printed in order it was created
+ base_obj = fixture_class_parent(c={"c": 1, "a": 2})
+ assert repr(base_obj) == "Parent(c={'a': 2, 'c': 1})"
+
+ # Now test when one params values are named object tuples
+ named_objs = [
+ ("step 1", fixture_class_parent()),
+ ("step 2", fixture_class_parent()),
+ ]
+ base_obj = fixture_class_parent(c=named_objs)
+ assert repr(base_obj) == "Parent(c=[('step 1', Parent()), ('step 2', Parent())])"
+
+ # Or when they are just lists of tuples or just tuples as param
+ base_obj = fixture_class_parent(c=[("one", 1), ("two", 2)])
+ assert repr(base_obj) == "Parent(c=[('one', 1), ('two', 2)])"
+
+ base_obj = fixture_class_parent(c=(1, 2, 3))
+ assert repr(base_obj) == "Parent(c=(1, 2, 3))"
+
simple_composite = fixture_composition_dummy(foo=fixture_class_parent())
assert repr(simple_composite) == "CompositionDummy(foo=Parent())"
long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000)
assert len(repr(long_base_obj_repr)) == 535
+ named_objs = [(f"Step {i+1}", Child()) for i in range(25)]
+ base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs)))
+ assert len(repr(base_comp)) == 1362
+ with config_context(print_changed_only=False):
+ assert "..." in repr(base_comp)
+
-def test_baseobject_str(fixture_class_parent_instance):
+def test_baseobject_str(fixture_class_parent_instance: Parent):
"""Test BaseObject string representation works."""
- str(fixture_class_parent_instance)
+ assert (
+ str(fixture_class_parent_instance) == "Parent()"
+ ), "String representation of instance not working."
+ # Check that we can alter the detail about params that is printed
+ # using config_context with ``print_changed_only=False``
+ with config_context(print_changed_only=False, display="text"):
+ assert (
+ str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)"
+ )
+ # Check that local config works as expected
+ fixture_class_parent_instance.set_config(print_changed_only=False)
+ assert str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)"
-def test_baseobject_repr_mimebundle_(fixture_class_parent_instance):
+
+def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent):
"""Test display configuration controls output."""
# Checks the display configuration flag controls the json output
with config_context(display="diagram"):
@@ -912,7 +982,7 @@ def test_baseobject_repr_mimebundle_(fixture_class_parent_instance):
assert "text/html" not in output
-def test_repr_html_wraps(fixture_class_parent_instance):
+def test_repr_html_wraps(fixture_class_parent_instance: Parent):
"""Test display configuration flag controls the html output."""
with config_context(display="diagram"):
output = fixture_class_parent_instance._repr_html_()
@@ -925,20 +995,24 @@ def test_repr_html_wraps(fixture_class_parent_instance):
# Test BaseObject's ability to generate test instances
-def test_get_test_params(fixture_class_parent_instance):
+def test_get_test_params(fixture_class_parent_instance: Parent):
"""Test get_test_params returns empty dictionary."""
base_obj = fixture_class_parent_instance
test_params = base_obj.get_test_params()
assert isinstance(test_params, dict) and len(test_params) == 0
-def test_get_test_params_raises_error_when_params_required(fixture_required_param):
+def test_get_test_params_raises_error_when_params_required(
+ fixture_required_param: Type[RequiredParam],
+):
"""Test get_test_params raises an error when parameters are required."""
with pytest.raises(ValueError):
fixture_required_param(7).get_test_params()
-def test_create_test_instance(fixture_class_parent, fixture_class_parent_instance):
+def test_create_test_instance(
+ fixture_class_parent: Type[Parent], fixture_class_parent_instance: Parent
+):
"""Test first that create_test_instance logic works."""
base_obj = fixture_class_parent.create_test_instance()
@@ -957,7 +1031,7 @@ def test_create_test_instance(fixture_class_parent, fixture_class_parent_instanc
assert hasattr(base_obj, "_tags_dynamic"), msg
-def test_create_test_instances_and_names(fixture_class_parent_instance):
+def test_create_test_instances_and_names(fixture_class_parent_instance: Parent):
"""Test that create_test_instances_and_names works."""
base_objs, names = fixture_class_parent_instance.create_test_instances_and_names()
@@ -990,7 +1064,7 @@ def test_create_test_instances_and_names(fixture_class_parent_instance):
# Tests _has_implementation_of interface
def test_has_implementation_of(
- fixture_class_parent_instance, fixture_class_child_instance
+ fixture_class_parent_instance: Parent, fixture_class_child_instance: Child
):
"""Test _has_implementation_of detects methods in class with overrides in mro."""
# When the class overrides a parent classes method should return True
@@ -1014,33 +1088,182 @@ def __init__(self, a, b=42):
self.c = 84
-def test_set_get_config():
- """Test logic behind get_config, set_config.
+class AnotherConfigTester(BaseObject):
+ _config = {"print_changed_only": False, "bar": "a"}
- Raises
- ------
- AssertionError if logic behind get_config, set_config is incorrect, logic tested:
- calling get_fitted_params on a non-composite fittable returns the fitted param
- calling get_fitted_params on a composite returns all nested params
+ clsvar = 210
+
+ def __init__(self, a, b=42):
+ self.a = a
+ self.b = b
+ self.c = 84
+
+
+class ConfigExtensionInterfaceTester(BaseObject):
+ _config = {"print_changed_only": False, "bar": "a"}
+
+ clsvar = 210
+
+ def __init__(self, a, b=42):
+ self.a = a
+ self.b = b
+ self.c = 84
+
+ def __skbase_get_config__(self):
+ """Return get_config extension."""
+ return {"print_changed_only": True, "some_other_config": 70}
+
+
+def test_local_config_without_use_of_extension_interface():
+ """Test ``BaseObject().get_config`` and ``BaseObject().set_config``.
+
+ ``BaseObject.get_config()`` should return the global dict updated with any local
+ configs defined in ``BaseObject._config`` or set on the instance with
+ ``BaseObject().set_config()``.
"""
- obj = ConfigTester(4242)
+ # Initially test that We can retrieve the local config with local configs
+ # as defined in BaseObject._config
- config_start = obj.get_config()
- assert isinstance(config_start, dict)
- assert set(config_start.keys()) == {"foo_config", "bar"}
- assert config_start["foo_config"] == 42
- assert config_start["bar"] == "a"
+ # Case 1: local configs are not part of global config param names
+ obj = ConfigTester(4242)
+ obj_config = obj.get_config()
+ current_global_config = get_config().copy()
+ expected_global_config = set(current_global_config.keys())
+ assert isinstance(obj_config, dict)
+ assert set(obj_config.keys()) == expected_global_config | {"foo_config", "bar"}
+ assert obj_config["foo_config"] == 42
+ assert obj_config["bar"] == "a"
+ for param_name in current_global_config:
+ assert obj_config[param_name] == current_global_config[param_name]
+
+ # Case 2: local configs overlap with (will override) global params
+ obj = AnotherConfigTester(4242)
+ obj_config = obj.get_config()
+ current_global_config = get_config().copy()
+ expected_global_config = set(current_global_config.keys())
+ assert isinstance(obj_config, dict)
+ assert set(obj_config.keys()) == expected_global_config | {"bar"}
+ assert obj_config["bar"] == "a"
+ # Should have overrided global config value which is set to True
+ assert obj_config["print_changed_only"] is False
+ for param_name in current_global_config:
+ if param_name != "print_changed_only":
+ assert obj_config[param_name] == current_global_config[param_name]
+
+ # Case 3: local configs are not part of global config param names and we also
+ # make use of dynamic BaseObject.set_config()
+ obj = ConfigTester(4242)
+ current_global_config = get_config().copy()
+ expected_global_config = set(current_global_config.keys())
+ # Verify set config returns the original object
setconfig_return = obj.set_config(foobar=126)
assert obj is setconfig_return
obj.set_config(**{"bar": "b"})
- config_end = obj.get_config()
- assert isinstance(config_end, dict)
- assert set(config_end.keys()) == {"foo_config", "bar", "foobar"}
- assert config_end["foo_config"] == 42
- assert config_end["bar"] == "b"
- assert config_end["foobar"] == 126
+ updated_obj_config = obj.get_config()
+ assert isinstance(updated_obj_config, dict)
+ assert set(updated_obj_config.keys()) == (
+ expected_global_config | {"foo_config", "bar", "foobar"}
+ )
+ assert updated_obj_config["foo_config"] == 42
+ assert updated_obj_config["bar"] == "b"
+ assert updated_obj_config["foobar"] == 126
+
+ # Case 4: local configs are not part of global config param names and we also
+ # make use of dynamic BaseObject.set_config() to update a config that is also
+ # part of global config
+ obj = ConfigTester(4242)
+ current_global_config = get_config().copy()
+ expected_global_config = set(current_global_config.keys())
+
+ # Verify set config returns the original object
+ setconfig_return = obj.set_config(print_changed_only=False)
+ assert obj is setconfig_return
+
+ updated_obj_config = obj.get_config()
+ assert isinstance(updated_obj_config, dict)
+ assert set(updated_obj_config.keys()) == (
+ expected_global_config | {"foo_config", "bar"}
+ )
+ assert updated_obj_config["foo_config"] == 42
+ assert updated_obj_config["bar"] == "a"
+ assert updated_obj_config["print_changed_only"] is False
+ for param_name in current_global_config:
+ if param_name != "print_changed_only":
+ assert updated_obj_config[param_name] == current_global_config[param_name]
+
+ # Case 5: local configs overlap with (will override) global params
+ # Then the local config defined in AnotherConfigTester._config is overrode again
+ # by calling AnotherConfigTester().set_config()
+ obj = AnotherConfigTester(4242)
+ obj.set_config(print_changed_only=True)
+ obj_config = obj.get_config()
+ current_global_config = get_config().copy()
+ expected_global_config = set(current_global_config.keys())
+ assert isinstance(obj_config, dict)
+ assert set(obj_config.keys()) == expected_global_config | {"bar"}
+ assert obj_config["bar"] == "a"
+ # Should have overrided global config value which is set to True
+ assert obj_config["print_changed_only"] is True
+ for param_name in current_global_config:
+ if param_name != "print_changed_only":
+ assert obj_config[param_name] == current_global_config[param_name]
+
+
+def test_local_config_with_use_of_extension_interface():
+ """Test BaseObject local config interface when ``__skbase_get_config__`` defined.
+
+ BaseObject.get_config() should return the global dict updated in the following
+ order:
+
+ - Any config returned by ``BaseObject.__skbase_get_config__``.
+ - Any configs defined in ``BaseObject._config`` or set on the instance with
+ ``BaseObject().set_config()``.
+ """
+ current_global_config = get_config().copy()
+ obj = ConfigExtensionInterfaceTester(4242)
+ obj_config = obj.get_config()
+ assert "some_other_config" in obj_config
+ assert obj_config["print_changed_only"] is False
+
+ expected_global_config = set(current_global_config.keys())
+ assert isinstance(obj_config, dict)
+ assert set(obj_config.keys()) == expected_global_config | {
+ "some_other_config",
+ "bar",
+ }
+ assert obj_config["bar"] == "a"
+ # Should have overrided global config value which is set to True
+
+ for param_name in current_global_config:
+ if param_name != "print_changed_only":
+ assert obj_config[param_name] == current_global_config[param_name]
+
+ # Now lets verify we can override the config items only returned by
+ # __skbase_get_config__ extension interface
+ obj.set_config(some_other_config=22)
+ obj_config_updated = obj.get_config()
+ assert obj_config_updated["some_other_config"] == 22
+ for param_name in obj_config:
+ if param_name != "some_other_config":
+ assert obj_config_updated[param_name] == obj_config[param_name]
+
+
+def local_config_interface_does_not_affect_global_config_interface():
+ """Test that calls to instance config interface doesn't impact global config."""
+ from skbase.config import get_default_config, global_config
+
+ obj = AnotherConfigTester(4242)
+ _global_config_before = global_config.copy()
+ _default_config_before = get_default_config()
+ global_config_before = get_config().copy()
+ obj.set_config(print_changed_only=False, some_other_param=7)
+ global_config_after = get_config().copy()
+ assert global_config_before == global_config_after
+ assert "some_other_param" not in global_config_after
+ assert _global_config_before == global_config
+ assert _default_config_before == get_default_config()
class FittableCompositionDummy(BaseEstimator):
diff --git a/skbase/utils/_check.py b/skbase/utils/_check.py
new file mode 100644
index 00000000..7a6830c4
--- /dev/null
+++ b/skbase/utils/_check.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Elements of _is_scalar_nan re-use code developed in scikit-learn. These elements
+# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
+# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+
+"""Utility functions to perform various types of checks."""
+from __future__ import annotations
+
+import math
+import numbers
+from typing import Any
+
+__all__ = ["_is_scalar_nan"]
+__author__ = ["RNKuhns"]
+
+
+def _is_scalar_nan(x: Any) -> bool:
+ """Test if x is NaN.
+
+ This function is meant to overcome the issue that np.isnan does not allow
+ non-numerical types as input, and that np.nan is not float('nan').
+
+ Parameters
+ ----------
+ x : Any
+ The item to be checked to determine if it is a scalar nan value.
+
+ Returns
+ -------
+ bool
+ True if `x` is a scalar nan value
+
+ Notes
+ -----
+ This code follows scikit-learn's implementation.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> from skbase.utils._check import _is_scalar_nan
+ >>> _is_scalar_nan(np.nan)
+ True
+ >>> _is_scalar_nan(float("nan"))
+ True
+ >>> _is_scalar_nan(None)
+ False
+ >>> _is_scalar_nan("")
+ False
+ >>> _is_scalar_nan([np.nan])
+ False
+ """
+ return isinstance(x, numbers.Real) and math.isnan(x)
diff --git a/skbase/utils/tests/test_check.py b/skbase/utils/tests/test_check.py
new file mode 100644
index 00000000..f998bae4
--- /dev/null
+++ b/skbase/utils/tests/test_check.py
@@ -0,0 +1,24 @@
+#!/usr/bin/env python3 -u
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+"""Tests of the utility functionality for performing various checks.
+
+tests in this module include:
+
+- test_is_scalar_nan_output to verify _is_scalar_nan outputs expected value for
+ different inputs.
+"""
+import numpy as np
+
+from skbase.utils._check import _is_scalar_nan
+
+__author__ = ["RNKuhns"]
+
+
+def test_is_scalar_nan_output():
+ """Test that _is_scalar_nan outputs expected value for different inputs."""
+ assert _is_scalar_nan(np.nan) is True
+ assert _is_scalar_nan(float("nan")) is True
+ assert _is_scalar_nan(None) is False
+ assert _is_scalar_nan("") is False
+ assert _is_scalar_nan([np.nan]) is False