diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 03026bdf..9508d39e 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -28,6 +28,27 @@ Base Classes BaseObject BaseEstimator +.. _global_config: + +Configure ``skbase`` +==================== + +.. automodule:: skbase.config + :no-members: + :no-inherited-members: + +.. currentmodule:: skbase.config + +.. autosummary:: + :toctree: api_reference/auto_generated/ + :template: function.rst + + get_config + get_default_config + set_config + reset_config + config_context + .. _obj_retrieval: Object Retrieval diff --git a/docs/source/user_documentation/user_guide.rst b/docs/source/user_documentation/user_guide.rst index c8febd31..4e15f3ec 100644 --- a/docs/source/user_documentation/user_guide.rst +++ b/docs/source/user_documentation/user_guide.rst @@ -29,6 +29,7 @@ that ``skbase`` provides, see the :ref:`api_ref`. user_guide/lookup user_guide/validate user_guide/testing + user_guide/configuration .. grid:: 1 2 2 2 @@ -103,3 +104,17 @@ that ``skbase`` provides, see the :ref:`api_ref`. :expand: Testing + + .. grid-item-card:: Configuration + :text-align: center + + Configure ``skbase``. + + +++ + + .. button-ref:: user_guide/configuration + :color: primary + :click-parent: + :expand: + + Configuration diff --git a/docs/source/user_documentation/user_guide/configuration.rst b/docs/source/user_documentation/user_guide/configuration.rst new file mode 100644 index 00000000..05b76843 --- /dev/null +++ b/docs/source/user_documentation/user_guide/configuration.rst @@ -0,0 +1,11 @@ +.. _user_guide_global_config: + +==================== +Configure ``skbase`` +==================== + +.. note:: + + The user guide is under development. We have created a basic + structure and are looking for contributions to develop the user guide + further. diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 970bf350..1393e267 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -53,6 +53,7 @@ class name: BaseEstimator fitted state check - check_is_fitted (raises error if not is_fitted) """ import inspect +import re import warnings from collections import defaultdict from copy import deepcopy @@ -62,7 +63,10 @@ class name: BaseEstimator from sklearn.base import BaseEstimator as _BaseEstimator from skbase._exceptions import NotFittedError +from skbase.base._pretty_printing._object_html_repr import _object_html_repr from skbase.base._tagmanager import _FlagManager +from skbase.config import get_config +from skbase.config._config import _CONFIG_REGISTRY __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -446,7 +450,39 @@ def get_config(self): class attribute via nested inheritance and then any overrides and new tags from _onfig_dynamic object attribute. """ - return self._get_flags(flag_attr_name="_config") + config = get_config().copy() + + # Get any extension configuration interface defined in the class + # for example if downstream package wants to extend skbase to retrieve + # their own config + if hasattr(self, "__skbase_get_config__") and callable( + self.__skbase_get_config__ + ): + skbase_get_config_extension_dict = self.__skbase_get_config__() + else: + skbase_get_config_extension_dict = {} + if isinstance(skbase_get_config_extension_dict, dict): + config.update(skbase_get_config_extension_dict) + else: + msg = "Use of `__skbase_get_config__` to extend the interface for local " + msg += "overrides of the global configuration must return a dictionary.\n" + msg += f"But a {type(skbase_get_config_extension_dict)} was found." + warnings.warn(msg, UserWarning, stacklevel=2) + local_config = self._get_flags(flag_attr_name="_config").copy() + # IF the local config is one of + for config_param, config_value in local_config.items(): + if config_param in _CONFIG_REGISTRY: + msg = "Invalid value encountered for global configuration parameter " + msg += f"{config_param}. Using global parameter configuration value.\n" + config_value = _CONFIG_REGISTRY[ + config_param + ].get_valid_param_or_default( + config_value, default_value=config[config_param] + ) + local_config[config_param] = config_value + config.update(local_config) + + return config def set_config(self, **config_dict): """Set config flags to given values. @@ -682,6 +718,98 @@ def _components(self, base_class=None): return comp_dict + def __repr__(self, n_char_max: int = 700): + """Represent class as string. + + This follows the scikit-learn implementation for the string representation + of parameterized objects. + + Parameters + ---------- + n_char_max : int + Maximum (approximate) number of non-blank characters to render. This + can be useful in testing. + """ + from skbase.base._pretty_printing._pprint import _BaseObjectPrettyPrinter + + n_max_elements_to_show = 30 # number of elements to show in sequences + # use ellipsis for sequences with a lot of elements + pp = _BaseObjectPrettyPrinter( + compact=True, + indent=1, + indent_at_name=True, + n_max_elements_to_show=n_max_elements_to_show, + changed_only=self.get_config()["print_changed_only"], + ) + + repr_ = pp.pformat(self) + + # Use bruteforce ellipsis when there are a lot of non-blank characters + n_nonblank = len("".join(repr_.split())) + if n_nonblank > n_char_max: + lim = n_char_max // 2 # apprx number of chars to keep on both ends + regex = r"^(\s*\S){%d}" % lim + # The regex '^(\s*\S){%d}' matches from the start of the string + # until the nth non-blank character: + # - ^ matches the start of string + # - (pattern){n} matches n repetitions of pattern + # - \s*\S matches a non-blank char following zero or more blanks + left_match = re.match(regex, repr_) + right_match = re.match(regex, repr_[::-1]) + left_lim = left_match.end() if left_match is not None else 0 + right_lim = right_match.end() if right_match is not None else 0 + + if "\n" in repr_[left_lim:-right_lim]: + # The left side and right side aren't on the same line. + # To avoid weird cuts, e.g.: + # categoric...ore', + # we need to start the right side with an appropriate newline + # character so that it renders properly as: + # categoric... + # handle_unknown='ignore', + # so we add [^\n]*\n which matches until the next \n + regex += r"[^\n]*\n" + right_match = re.match(regex, repr_[::-1]) + right_lim = right_match.end() if right_match is not None else 0 + + ellipsis = "..." + if left_lim + len(ellipsis) < len(repr_) - right_lim: + # Only add ellipsis if it results in a shorter repr + repr_ = repr_[:left_lim] + "..." + repr_[-right_lim:] + + return repr_ + + @property + def _repr_html_(self): + """HTML representation of BaseObject. + + This is redundant with the logic of `_repr_mimebundle_`. The latter + should be favorted in the long term, `_repr_html_` is only + implemented for consumers who do not interpret `_repr_mimbundle_`. + """ + if self.get_config()["display"] != "diagram": + raise AttributeError( + "_repr_html_ is only defined when the " + "`display` configuration option is set to 'diagram'." + ) + return self._repr_html_inner + + def _repr_html_inner(self): + """Return HTML representation of class. + + This function is returned by the @property `_repr_html_` to make + `hasattr(BaseObject, "_repr_html_") return `True` or `False` depending + on `self.get_config()["display"]`. + """ + return _object_html_repr(self) + + def _repr_mimebundle_(self, **kwargs): + """Mime bundle used by jupyter kernels to display instances of BaseObject.""" + output = {"text/plain": repr(self)} + if self.get_config()["display"] == "diagram": + output["text/html"] = _object_html_repr(self) + return output + class TagAliaserMixin: """Mixin class for tag aliasing and deprecation of old tags. diff --git a/skbase/base/_pretty_printing/__init__.py b/skbase/base/_pretty_printing/__init__.py new file mode 100644 index 00000000..669c800c --- /dev/null +++ b/skbase/base/_pretty_printing/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Functionality for pretty printing BaseObjects.""" +from typing import List + +__author__: List[str] = ["RNKuhns"] +__all__: List[str] = [] diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py new file mode 100644 index 00000000..16515b9d --- /dev/null +++ b/skbase/base/_pretty_printing/_object_html_repr.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Functionality to represent instance of BaseObject as html.""" + +import html +import uuid +from contextlib import closing, suppress +from io import StringIO +from string import Template + +from skbase.config import config_context # type: ignore + +__author__ = ["RNKuhns"] + + +class _VisualBlock: + """HTML Representation of BaseObject. + + Parameters + ---------- + kind : {'serial', 'parallel', 'single'} + kind of HTML block + + objs : list of BaseObjects or `_VisualBlock`s or a single BaseObject + If kind != 'single', then `objs` is a list of + BaseObjects. If kind == 'single', then `objs` is a single BaseObject. + + names : list of str, default=None + If kind != 'single', then `names` corresponds to BaseObjects. + If kind == 'single', then `names` is a single string corresponding to + the single BaseObject. + + name_details : list of str, str, or None, default=None + If kind != 'single', then `name_details` corresponds to `names`. + If kind == 'single', then `name_details` is a single string + corresponding to the single BaseObject. + + dash_wrapped : bool, default=True + If true, wrapped HTML element will be wrapped with a dashed border. + Only active when kind != 'single'. + """ + + def __init__(self, kind, objs, *, names=None, name_details=None, dash_wrapped=True): + self.kind = kind + self.objs = objs + self.dash_wrapped = dash_wrapped + + if self.kind in ("parallel", "serial"): + if names is None: + names = (None,) * len(objs) + if name_details is None: + name_details = (None,) * len(objs) + + self.names = names + self.name_details = name_details + + def _sk_visual_block_(self): + return self + + +def _write_label_html( + out, + name, + name_details, + outer_class="sk-label-container", + inner_class="sk-label", + checked=False, +): + """Write labeled html with or without a dropdown with named details.""" + out.write(f'
') + name = html.escape(name) + + if name_details is not None: + name_details = html.escape(str(name_details)) + label_class = "sk-toggleable__label sk-toggleable__label-arrow" + + checked_str = "checked" if checked else "" + est_id = uuid.uuid4() + out.write( + '' + f"" + f'
{name_details}'
+            "
" + ) + else: + out.write(f"") + out.write("
") # outer_class inner_class + + +def _get_visual_block(base_object): + """Generate information about how to display a BaseObject.""" + with suppress(AttributeError): + return base_object._sk_visual_block_() + + if isinstance(base_object, str): + return _VisualBlock( + "single", base_object, names=base_object, name_details=base_object + ) + elif base_object is None: + return _VisualBlock("single", base_object, names="None", name_details="None") + + # check if base_object looks like a meta base_object wraps base_object + if hasattr(base_object, "get_params"): + base_objects = [] + for key, value in base_object.get_params().items(): + # Only look at the BaseObjects in the first layer + if "__" not in key and hasattr(value, "get_params"): + base_objects.append(value) + if len(base_objects): + return _VisualBlock("parallel", base_objects, names=None) + + return _VisualBlock( + "single", + base_object, + names=base_object.__class__.__name__, + name_details=str(base_object), + ) + + +def _write_base_object_html( + out, base_object, base_object_label, base_object_label_details, first_call=False +): + """Write BaseObject to html in serial, parallel, or by itself (single).""" + if first_call: + est_block = _get_visual_block(base_object) + else: + # So it is easier to read, always use print_changed_only==True + # regardless of configuration + with config_context(print_changed_only=True): + est_block = _get_visual_block(base_object) + + if est_block.kind in ("serial", "parallel"): + dashed_wrapped = first_call or est_block.dash_wrapped + dash_cls = " sk-dashed-wrapped" if dashed_wrapped else "" + out.write(f'
') + + if base_object_label: + _write_label_html(out, base_object_label, base_object_label_details) + + kind = est_block.kind + out.write(f'
') + est_infos = zip(est_block.objs, est_block.names, est_block.name_details) + + for est, name, name_details in est_infos: + if kind == "serial": + _write_base_object_html(out, est, name, name_details) + else: # parallel + out.write('
') + # wrap element in a serial visualblock + serial_block = _VisualBlock("serial", [est], dash_wrapped=False) + _write_base_object_html(out, serial_block, name, name_details) + out.write("
") # sk-parallel-item + + out.write("
") + elif est_block.kind == "single": + _write_label_html( + out, + est_block.names, + est_block.name_details, + outer_class="sk-item", + inner_class="sk-estimator", + checked=first_call, + ) + + +_STYLE = """ +#$id { + color: black; + background-color: white; +} +#$id pre{ + padding: 0; +} +#$id div.sk-toggleable { + background-color: white; +} +#$id label.sk-toggleable__label { + cursor: pointer; + display: block; + width: 100%; + margin-bottom: 0; + padding: 0.3em; + box-sizing: border-box; + text-align: center; +} +#$id label.sk-toggleable__label-arrow:before { + content: "▸"; + float: left; + margin-right: 0.25em; + color: #696969; +} +#$id label.sk-toggleable__label-arrow:hover:before { + color: black; +} +#$id div.sk-estimator:hover label.sk-toggleable__label-arrow:before { + color: black; +} +#$id div.sk-toggleable__content { + max-height: 0; + max-width: 0; + overflow: hidden; + text-align: left; + background-color: #f0f8ff; +} +#$id div.sk-toggleable__content pre { + margin: 0.2em; + color: black; + border-radius: 0.25em; + background-color: #f0f8ff; +} +#$id input.sk-toggleable__control:checked~div.sk-toggleable__content { + max-height: 200px; + max-width: 100%; + overflow: auto; +} +#$id input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before { + content: "▾"; +} +#$id div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id input.sk-hidden--visually { + border: 0; + clip: rect(1px 1px 1px 1px); + clip: rect(1px, 1px, 1px, 1px); + height: 1px; + margin: -1px; + overflow: hidden; + padding: 0; + position: absolute; + width: 1px; +} +#$id div.sk-estimator { + font-family: monospace; + background-color: #f0f8ff; + border: 1px dotted black; + border-radius: 0.25em; + box-sizing: border-box; + margin-bottom: 0.5em; +} +#$id div.sk-estimator:hover { + background-color: #d4ebff; +} +#$id div.sk-parallel-item::after { + content: ""; + width: 100%; + border-bottom: 1px solid gray; + flex-grow: 1; +} +#$id div.sk-label:hover label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id div.sk-serial::before { + content: ""; + position: absolute; + border-left: 1px solid gray; + box-sizing: border-box; + top: 2em; + bottom: 0; + left: 50%; +} +#$id div.sk-serial { + display: flex; + flex-direction: column; + align-items: center; + background-color: white; + padding-right: 0.2em; + padding-left: 0.2em; +} +#$id div.sk-item { + z-index: 1; +} +#$id div.sk-parallel { + display: flex; + align-items: stretch; + justify-content: center; + background-color: white; +} +#$id div.sk-parallel::before { + content: ""; + position: absolute; + border-left: 1px solid gray; + box-sizing: border-box; + top: 2em; + bottom: 0; + left: 50%; +} +#$id div.sk-parallel-item { + display: flex; + flex-direction: column; + position: relative; + background-color: white; +} +#$id div.sk-parallel-item:first-child::after { + align-self: flex-end; + width: 50%; +} +#$id div.sk-parallel-item:last-child::after { + align-self: flex-start; + width: 50%; +} +#$id div.sk-parallel-item:only-child::after { + width: 0; +} +#$id div.sk-dashed-wrapped { + border: 1px dashed gray; + margin: 0 0.4em 0.5em 0.4em; + box-sizing: border-box; + padding-bottom: 0.4em; + background-color: white; + position: relative; +} +#$id div.sk-label label { + font-family: monospace; + font-weight: bold; + background-color: white; + display: inline-block; + line-height: 1.2em; +} +#$id div.sk-label-container { + position: relative; + z-index: 2; + text-align: center; +} +#$id div.sk-container { + /* jupyter's `normalize.less` sets `[hidden] { display: none; }` + but bootstrap.min.css set `[hidden] { display: none !important; }` + so we also need the `!important` here to be able to override the + default hidden behavior on the sphinx rendered scikit-learn.org. + See: https://github.com/scikit-learn/scikit-learn/issues/21755 */ + display: inline-block !important; + position: relative; +} +#$id div.sk-text-repr-fallback { + display: none; +} +""".replace( + " ", "" +).replace( + "\n", "" +) # noqa + + +def _object_html_repr(base_object): + """Build a HTML representation of a BaseObject. + + Parameters + ---------- + base_object : base object + The BaseObject or inheritting class to visualize. + + Returns + ------- + html: str + HTML representation of BaseObject. + """ + with closing(StringIO()) as out: + container_id = "sk-" + str(uuid.uuid4()) + style_template = Template(_STYLE) + style_with_id = style_template.substitute(id=container_id) + base_object_str = str(base_object) + + # The fallback message is shown by default and loading the CSS sets + # div.sk-text-repr-fallback to display: none to hide the fallback message. + # + # If the notebook is trusted, the CSS is loaded which hides the fallback + # message. If the notebook is not trusted, then the CSS is not loaded and the + # fallback message is shown by default. + # + # The reverse logic applies to HTML repr div.sk-container. + # div.sk-container is hidden by default and the loading the CSS displays it. + fallback_msg = ( + "Please rerun this cell to show the HTML repr or trust the notebook." + ) + out.write( + f"" + f'
' + '
' + f"
{html.escape(base_object_str)}
{fallback_msg}" + "
" + '
") + + html_output = out.getvalue() + return html_output diff --git a/skbase/base/_pretty_printing/_pprint.py b/skbase/base/_pretty_printing/_pprint.py new file mode 100644 index 00000000..2e8a8a80 --- /dev/null +++ b/skbase/base/_pretty_printing/_pprint.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Utility functionality for pretty-printing objects used in BaseObject.__repr__.""" +import inspect +import pprint +from collections import OrderedDict + +from skbase.base import BaseObject + +# from skbase.config import get_config +from skbase.utils._check import _is_scalar_nan + + +class KeyValTuple(tuple): + """Dummy class for correctly rendering key-value tuples from dicts.""" + + def __repr__(self): + """Represent as string.""" + # needed for _dispatch[tuple.__repr__] not to be overridden + return super().__repr__() + + +class KeyValTupleParam(KeyValTuple): + """Dummy class for correctly rendering key-value tuples from parameters.""" + + pass + + +def _changed_params(base_object): + """Return dict (param_name: value) of parameters with non-default values.""" + params = base_object.get_params(deep=False) + init_func = getattr( + base_object.__init__, "deprecated_original", base_object.__init__ + ) + init_params = inspect.signature(init_func).parameters + init_params = {name: param.default for name, param in init_params.items()} + + def has_changed(k, v): + if k not in init_params: # happens if k is part of a **kwargs + return True + if init_params[k] == inspect._empty: # k has no default value + return True + # try to avoid calling repr on nested BaseObjects + if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__: + return True + # Use repr as a last resort. It may be expensive. + if repr(v) != repr(init_params[k]) and not ( + _is_scalar_nan(init_params[k]) and _is_scalar_nan(v) + ): + return True + return False + + return {k: v for k, v in params.items() if has_changed(k, v)} + + +class _BaseObjectPrettyPrinter(pprint.PrettyPrinter): + """Pretty Printer class for BaseObjects. + + This extends the pprint.PrettyPrinter class similar to scikit-learn's + implementation, so that: + + - BaseObjects are printed with their parameters, e.g. + BaseObject(param1=value1, ...) which is not supported by default. + - the 'compact' parameter of PrettyPrinter is ignored for dicts, which + may lead to very long representations that we want to avoid. + + Quick overview of pprint.PrettyPrinter (see also + https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers): + + - the entry point is the _format() method which calls format() (overridden + here) + - format() directly calls _safe_repr() for a first try at rendering the + object + - _safe_repr formats the whole object recursively, only calling itself, + not caring about line length or anything + - back to _format(), if the output string is too long, _format() then calls + the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on + the type of the object. This where the line length and the compact + parameters are taken into account. + - those _pprint_TYPE() methods will internally use the format() method for + rendering the nested objects of an object (e.g. the elements of a list) + + In the end, everything has to be implemented twice: in _safe_repr and in + the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not + straightforward to extend (especially when we want a compact output), so + the code is a bit convoluted. + + This class overrides: + - format() to support the changed_only parameter + - _safe_repr to support printing of BaseObjects that fit on a single line + - _format_dict_items so that dict are correctly 'compacted' + - _format_items so that ellipsis is used on long lists and tuples + + When BaseObjects cannot be printed on a single line, the builtin _format() + will call _pprint_object() because it was registered to do so (see + _dispatch[BaseObject.__repr__] = _pprint_object). + + both _format_dict_items() and _pprint_Object() use the + _format_params_or_dict_items() method that will format parameters and + key-value pairs respecting the compact parameter. This method needs another + subroutine _pprint_key_val_tuple() used when a parameter or a key-value + pair is too long to fit on a single line. This subroutine is called in + _format() and is registered as well in the _dispatch dict (just like + _pprint_object). We had to create the two classes KeyValTuple and + KeyValTupleParam for this. + """ + + def __init__( + self, + indent=1, + width=80, + depth=None, + stream=None, + *, + compact=False, + indent_at_name=True, + n_max_elements_to_show=None, + changed_only=True, + ): + super().__init__(indent, width, depth, stream, compact=compact) + self._indent_at_name = indent_at_name + if self._indent_at_name: + self._indent_per_level = 1 # ignore indent param + self.changed_only = changed_only + # Max number of elements in a list, dict, tuple until we start using + # ellipsis. This also affects the number of arguments of a BaseObject + # (they are treated as dicts) + self.n_max_elements_to_show = n_max_elements_to_show + + def format(self, obj, context, maxlevels, level): # noqa + return _safe_repr( + obj, context, maxlevels, level, changed_only=self.changed_only + ) + + def _pprint_object(self, obj, stream, indent, allowance, context, level): + stream.write(obj.__class__.__name__ + "(") + if self._indent_at_name: + indent += len(obj.__class__.__name__) + + if self.changed_only: + params = _changed_params(obj) + else: + params = obj.get_params(deep=False) + + params = OrderedDict((name, val) for (name, val) in sorted(params.items())) + + self._format_params( + params.items(), stream, indent, allowance + 1, context, level + ) + stream.write(")") + + def _format_dict_items(self, items, stream, indent, allowance, context, level): + return self._format_params_or_dict_items( + items, stream, indent, allowance, context, level, is_dict=True + ) + + def _format_params(self, items, stream, indent, allowance, context, level): + return self._format_params_or_dict_items( + items, stream, indent, allowance, context, level, is_dict=False + ) + + def _format_params_or_dict_items( + self, obj, stream, indent, allowance, context, level, is_dict + ): + """Format dict items or parameters respecting the compact=True parameter. + + For some reason, the builtin rendering of dict items doesn't + respect compact=True and will use one line per key-value if all cannot + fit in a single line. + Dict items will be rendered as <'key': value> while params will be + rendered as . The implementation is mostly copy/pasting from + the builtin _format_items(). + This also adds ellipsis if the number of items is greater than + self.n_max_elements_to_show. + """ + write = stream.write + indent += self._indent_per_level + delimnl = ",\n" + " " * indent + delim = "" + width = max_width = self._width - indent + 1 + it = iter(obj) + try: + next_ent = next(it) + except StopIteration: + return + last = False + n_items = 0 + while not last: + if n_items == self.n_max_elements_to_show: + write(", ...") + break + n_items += 1 + ent = next_ent + try: + next_ent = next(it) + except StopIteration: + last = True + max_width -= allowance + width -= allowance + if self._compact: + k, v = ent + krepr = self._repr(k, context, level) + vrepr = self._repr(v, context, level) + if not is_dict: + krepr = krepr.strip("'") + middle = ": " if is_dict else "=" + rep = krepr + middle + vrepr + w = len(rep) + 2 + if width < w: + width = max_width + if delim: + delim = delimnl + if width >= w: + width -= w + write(delim) + delim = ", " + write(rep) + continue + write(delim) + delim = delimnl + class_ = KeyValTuple if is_dict else KeyValTupleParam + self._format( + class_(ent), stream, indent, allowance if last else 1, context, level + ) + + def _format_items(self, items, stream, indent, allowance, context, level): + """Format the items of an iterable (list, tuple...). + + Same as the built-in _format_items, with support for ellipsis if the + number of elements is greater than self.n_max_elements_to_show. + """ + write = stream.write + indent += self._indent_per_level + if self._indent_per_level > 1: + write((self._indent_per_level - 1) * " ") + delimnl = ",\n" + " " * indent + delim = "" + width = max_width = self._width - indent + 1 + it = iter(items) + try: + next_ent = next(it) + except StopIteration: + return + last = False + n_items = 0 + while not last: + if n_items == self.n_max_elements_to_show: + write(", ...") + break + n_items += 1 + ent = next_ent + try: + next_ent = next(it) + except StopIteration: + last = True + max_width -= allowance + width -= allowance + if self._compact: + rep = self._repr(ent, context, level) + w = len(rep) + 2 + if width < w: + width = max_width + if delim: + delim = delimnl + if width >= w: + width -= w + write(delim) + delim = ", " + write(rep) + continue + write(delim) + delim = delimnl + self._format(ent, stream, indent, allowance if last else 1, context, level) + + def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level): + """Pretty printing for key-value tuples from dict or parameters.""" + k, v = obj + rep = self._repr(k, context, level) + if isinstance(obj, KeyValTupleParam): + rep = rep.strip("'") + middle = "=" + else: + middle = ": " + stream.write(rep) + stream.write(middle) + self._format( + v, stream, indent + len(rep) + len(middle), allowance, context, level + ) + + # Follow what scikit-learn did here and copy _dispatch to prevent instances + # of the builtin PrettyPrinter class to call methods of + # _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906) + # mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch" + _dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore + _dispatch[BaseObject.__repr__] = _pprint_object + _dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple + + +def _safe_repr(obj, context, maxlevels, level, changed_only=False): + """Safe string representation logic. + + Same as the builtin _safe_repr, with added support for BaseObjects. + """ + typ = type(obj) + + if typ in pprint._builtin_scalars: + return repr(obj), True, False + + r = getattr(typ, "__repr__", None) + if issubclass(typ, dict) and r is dict.__repr__: + if not obj: + return "{}", True, False + objid = id(obj) + if maxlevels and level >= maxlevels: + return "{...}", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + components = [] + append = components.append + level += 1 + saferepr = _safe_repr + items = sorted(obj.items(), key=pprint._safe_tuple) + for k, v in items: + krepr, kreadable, krecur = saferepr( + k, context, maxlevels, level, changed_only=changed_only + ) + vrepr, vreadable, vrecur = saferepr( + v, context, maxlevels, level, changed_only=changed_only + ) + append("%s: %s" % (krepr, vrepr)) + readable = readable and kreadable and vreadable + if krecur or vrecur: + recursive = True + del context[objid] + return "{%s}" % ", ".join(components), readable, recursive + + if (issubclass(typ, list) and r is list.__repr__) or ( + issubclass(typ, tuple) and r is tuple.__repr__ + ): + if issubclass(typ, list): + if not obj: + return "[]", True, False + format_ = "[%s]" + elif len(obj) == 1: + format_ = "(%s,)" + else: + if not obj: + return "()", True, False + format_ = "(%s)" + objid = id(obj) + if maxlevels and level >= maxlevels: + return format_ % "...", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + components = [] + append = components.append + level += 1 + for o in obj: + orepr, oreadable, orecur = _safe_repr( + o, context, maxlevels, level, changed_only=changed_only + ) + append(orepr) + if not oreadable: + readable = False + if orecur: + recursive = True + del context[objid] + return format_ % ", ".join(components), readable, recursive + + if issubclass(typ, BaseObject): + objid = id(obj) + if maxlevels and level >= maxlevels: + return "{...}", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + if changed_only: + params = _changed_params(obj) + else: + params = obj.get_params(deep=False) + components = [] + append = components.append + level += 1 + saferepr = _safe_repr + items = sorted(params.items(), key=pprint._safe_tuple) + for k, v in items: + krepr, kreadable, krecur = saferepr( + k, context, maxlevels, level, changed_only=changed_only + ) + vrepr, vreadable, vrecur = saferepr( + v, context, maxlevels, level, changed_only=changed_only + ) + append("%s=%s" % (krepr.strip("'"), vrepr)) + readable = readable and kreadable and vreadable + if krecur or vrecur: + recursive = True + del context[objid] + return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive) + + rep = repr(obj) + return rep, (rep and not rep.startswith("<")), False diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py new file mode 100644 index 00000000..fd8a750a --- /dev/null +++ b/skbase/config/__init__.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +""":mod:`skbase.config` provides tools for the global configuration of ``skbase``. + +For more information on configuration usage patterns see the +:ref:`user guide `. +""" +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Includes functionality like get_config, set_config, and config_context +# that is similar to scikit-learn. These elements are copyrighted by the +# scikit-learn developers, BSD-3-Clause License. For conditions see +# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +from typing import List + +from skbase.config._config import ( + GlobalConfigParamSetting, + config_context, + get_config, + get_default_config, + reset_config, + set_config, +) + +__author__: List[str] = ["RNKuhns"] +__all__: List[str] = [ + "GlobalConfigParamSetting", + "get_default_config", + "get_config", + "set_config", + "reset_config", + "config_context", +] diff --git a/skbase/config/_config.py b/skbase/config/_config.py new file mode 100644 index 00000000..ec2fc13c --- /dev/null +++ b/skbase/config/_config.py @@ -0,0 +1,397 @@ +# -*- coding: utf-8 -*- +"""Implement logic for global configuration of skbase.""" +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Includes functionality like get_config, set_config, and config_context +# that is similar to scikit-learn. These elements are copyrighted by the +# scikit-learn developers, BSD-3-Clause License. For conditions see +# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +import collections +import sys +import threading +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # type: ignore + +from skbase.utils._iter import _format_seq_to_str + +__author__: List[str] = ["RNKuhns"] +__all__: List[str] = [ + "GlobalConfigParamSetting", + "get_default_config", + "get_config", + "set_config", + "reset_config", + "config_context", +] + + +@dataclass +class GlobalConfigParamSetting: + """Define types of the setting information for a given config parameter.""" + + name: str + os_environ_name: str + expected_type: Union[type, Tuple[type]] + allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] + default_value: Any + + def get_allowed_values(self) -> List[Any]: + """Get `allowed_values` or empty tuple if `allowed_values` is None. + + Returns + ------- + tuple + Allowable values if any. + """ + if self.allowed_values is None: + return [] + elif isinstance(self.allowed_values, list): + return self.allowed_values + elif isinstance( + self.allowed_values, collections.abc.Iterable + ) and not isinstance(self.allowed_values, str): + return list(self.allowed_values) + else: + return [self.allowed_values] + + def is_valid_param_value(self, value): + """Validate that a global configuration value is valid. + + Verifies that the value set for a global configuration parameter is valid + based on the its configuration settings. + + Returns + ------- + bool + Whether a parameter value is valid. + """ + allowed_values = self.get_allowed_values() + + valid_param: bool + if not isinstance(value, self.expected_type): + valid_param = False + elif allowed_values is not None and value not in allowed_values: + valid_param = False + else: + valid_param = True + return valid_param + + def get_valid_param_or_default(self, value, default_value=None, msg=None): + """Validate `value` and return default if it is not valid. + + Parameters + ---------- + value : Any + The configuration parameter value to set. + default_value : Any, default=None + An optional default value to use to set the configuration parameter + if `value` is not valid based on defined expected type and allowed + values. If None, and `value` is invalid then the classes `default_value` + parameter is used. + msg : str, default=None + An optional message to be used as start of the UserWarning message. + """ + if self.is_valid_param_value(value): + return value + else: + if msg is None: + msg = "" + msg + f"When setting global config values for `{self.name}`, the values " + msg += f"should be of type {self.expected_type}.\n" + if self.allowed_values is not None: + values_str = _format_seq_to_str( + self.get_allowed_values(), last_sep="or", remove_type_text=True + ) + msg += f"Allowed values are be one of {values_str}. " + msg += f"But found {value}." + warnings.warn(msg, UserWarning, stacklevel=2) + return default_value if default_value is not None else self.default_value + + +_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { + "print_changed_only": GlobalConfigParamSetting( + name="print_changed_only", + os_environ_name="SKBASE_PRINT_CHANGED_ONLY", + expected_type=bool, + allowed_values=(True, False), + default_value=True, + ), + "display": GlobalConfigParamSetting( + name="display", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ), +} + +_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { + config_name: config_info.default_value + for config_name, config_info in _CONFIG_REGISTRY.items() +} + +global_config = _DEFAULT_GLOBAL_CONFIG.copy() +_THREAD_LOCAL_DATA = threading.local() + + +def _get_threadlocal_config() -> Dict[str, Any]: + """Get a threadlocal **mutable** configuration. + + If the configuration does not exist, copy the default global configuration. + + Returns + ------- + dict + Threadlocal global config or copy of default global configuration. + """ + if not hasattr(_THREAD_LOCAL_DATA, "global_config"): + _THREAD_LOCAL_DATA.global_config = global_config.copy() + return _THREAD_LOCAL_DATA.global_config + + +def get_default_config() -> Dict[str, Any]: + """Retrieve the default global configuration. + + This will always return the default ``skbase`` global configuration. + + Returns + ------- + config : dict + The default configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_default_config + >>> get_default_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _DEFAULT_GLOBAL_CONFIG.copy() + + +def get_config() -> Dict[str, Any]: + """Retrieve current values for configuration set by :meth:`set_config`. + + Will return the default configuration if know updated configuration has + been set by :meth:`set_config`. + + Returns + ------- + config : dict + The configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _get_threadlocal_config().copy() + + +def set_config( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> None: + """Set global configuration. + + Allows the ``skbase`` global configuration to be updated. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the backend as default for all threads. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + reset_config : + Reset configuration to default. + + Examples + -------- + >>> from skbase.config import get_config, set_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + """ + local_config = _get_threadlocal_config() + + msg = "Attempting to set an invalid value for a global configuration.\n" + msg += "Using current configuration value of parameter as a result.\n" + if print_changed_only is not None: + local_config["print_changed_only"] = _CONFIG_REGISTRY[ + "print_changed_only" + ].get_valid_param_or_default( + print_changed_only, + default_value=local_config["print_changed_only"], + msg=msg, + ) + if display is not None: + local_config["display"] = _CONFIG_REGISTRY[ + "display" + ].get_valid_param_or_default( + display, default_value=local_config["display"], msg=msg + ) + + if not local_threadsafe: + global_config.update(local_config) + + return None + + +def reset_config() -> None: + """Reset the global configuration to the default. + + Will remove any user updates to the global configuration and reset the values + back to the ``skbase`` defaults. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + + Examples + -------- + >>> from skbase.config import get_config, set_config, reset_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + >>> reset_config() + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + default_config = get_default_config() + set_config(**default_config) + return None + + +@contextmanager +def config_context( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> Iterator[None]: + """Context manager for global configuration. + + Provides the ability to run code using different configuration without + having to update the global config. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the config as default for all threads. + + Yields + ------ + None + + See Also + -------- + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current values of the global configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Notes + ----- + All settings, not just those presently modified, will be returned to + their previous values when the context manager is exited. + + Examples + -------- + >>> from skbase.config import config_context + >>> with config_context(display='diagram'): + ... pass + """ + old_config = get_config() + set_config( + print_changed_only=print_changed_only, + display=display, + local_threadsafe=local_threadsafe, + ) + + try: + yield + finally: + set_config(**old_config) diff --git a/skbase/config/tests/__init__.py b/skbase/config/tests/__init__.py new file mode 100644 index 00000000..ac71ad0c --- /dev/null +++ b/skbase/config/tests/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Test functionality of :mod:`skbase.config`.""" diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py new file mode 100644 index 00000000..1af163e4 --- /dev/null +++ b/skbase/config/tests/test_config.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +"""Test configuration functionality.""" +import pytest + +from skbase.config import ( + GlobalConfigParamSetting, + config_context, + get_config, + get_default_config, + reset_config, + set_config, +) +from skbase.config._config import _CONFIG_REGISTRY, _DEFAULT_GLOBAL_CONFIG + +PRINT_CHANGE_ONLY_VALUES = _CONFIG_REGISTRY["print_changed_only"].get_allowed_values() +DISPLAY_VALUES = _CONFIG_REGISTRY["display"].get_allowed_values() + + +@pytest.fixture +def config_registry(): + """Config registry fixture.""" + return _CONFIG_REGISTRY + + +@pytest.fixture +def global_config_default(): + """Config registry fixture.""" + return _DEFAULT_GLOBAL_CONFIG + + +@pytest.mark.parametrize("allowed_values", (None, (), "something", range(1, 8))) +def test_global_config_param_get_allowed_values(allowed_values): + """Test GlobalConfigParamSetting behavior works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=allowed_values, + default_value="text", + ) + # Verify we always coerce output of get_allowed_values to tuple + values = some_config_param.get_allowed_values() + assert isinstance(values, list) + + +@pytest.mark.parametrize("value", (None, (), "wrong_string", "text", range(1, 8))) +def test_global_config_param_is_valid_param_value(value): + """Test GlobalConfigParamSetting behavior works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ) + # Verify we correctly identify invalid parameters + if value in ("text", "diagram"): + expected_valid = True + else: + expected_valid = False + assert some_config_param.is_valid_param_value(value) == expected_valid + + +def test_get_default_config(global_config_default): + """Test get_default_config alwasy returns the default config.""" + assert get_default_config() == global_config_default + set_config(print_changed_only=False) + assert get_default_config() == global_config_default + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_set_config_then_get_config_returns_expected_value(print_changed_only, display): + """Verify that get_config returns set config values if set_config run.""" + set_config(print_changed_only=print_changed_only, display=display) + retrieved_default = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` used after `set_config` does not return expected values.\n" + msg += "After set_config is run, get_config should return the set values.\n " + msg += f"Expected {expected_config}, but returned {retrieved_default}." + assert retrieved_default == expected_config, msg + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_reset_config_resets_the_config( + print_changed_only, display, global_config_default +): + """Verify that get_config returns default config if reset_config run.""" + set_config(print_changed_only=print_changed_only, display=display) + reset_config() + retrieved_config = get_config() + + msg = "`get_config` does not return expected values after `reset_config`.\n" + msg += "`After reset_config is run, get_config` should return defaults.\n" + msg += f"Expected {global_config_default}, but returned {retrieved_config}." + assert retrieved_config == global_config_default, msg + reset_config() + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_config_context(print_changed_only, display): + """Verify that config_context affects context but not overall configuration.""" + # Make sure config is reset to default values then retrieve it + reset_config() + retrieved_config = get_config() + # Now lets make sure the config_context is changing the context of those values + # within the scope of the context manager as expected + for print_changed_only in (True, False): + with config_context(print_changed_only=print_changed_only, display=display): + retrieved_context_config = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` does not return expected values within `config_context`.\n" + msg += "`get_config` should return config defined by `config_context`.\n" + msg += f"Expected {expected_config}, but returned {retrieved_context_config}." + assert retrieved_context_config == expected_config, msg + + # Outside of the config_context we should have not affected the retrieved config + # set by call to reset_config() + config_post_config_context = get_config() + msg = "`get_config` does not return expected values after `config_context`a.\n" + msg += "`config_context` should not affect configuration outside its context.\n" + msg += f"Expected {config_post_config_context}, but returned {retrieved_config}." + assert retrieved_config == config_post_config_context, msg + reset_config() + + +def test_set_config_behavior_invalid_value(): + """Test set_config uses default and raises warning when setting invalid value.""" + reset_config() + original_config = get_config().copy() + with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): + set_config(print_changed_only="False") + + assert get_config() == original_config + + original_config = get_config().copy() + with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): + set_config(print_changed_only=7) + + assert get_config() == original_config + reset_config() diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 4f36034a..f0cf89b1 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -22,7 +22,12 @@ "skbase.base", "skbase.base._base", "skbase.base._meta", + "skbase.base._pretty_printing", + "skbase.base._pretty_printing._object_html_repr", + "skbase.base._pretty_printing._pprint", "skbase.base._tagmanager", + "skbase.config", + "skbase.config._config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -42,6 +47,7 @@ "skbase.tests.test_baseestimator", "skbase.tests.mock_package.test_mock_package", "skbase.utils", + "skbase.utils._check", "skbase.utils._iter", "skbase.utils._nested_iter", "skbase.validate", @@ -51,6 +57,7 @@ SKBASE_PUBLIC_MODULES = ( "skbase", "skbase.base", + "skbase.config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -74,6 +81,9 @@ "skbase.base": ("BaseEstimator", "BaseMetaEstimator", "BaseObject"), "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaEstimator",), + "skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"), + "skbase.config": ("GlobalConfigParamSetting",), + "skbase.config._config": ("GlobalConfigParamSetting",), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"), "skbase.testing.test_all_objects": ( @@ -85,11 +95,30 @@ SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy() SKBASE_CLASSES_BY_MODULE.update( { - "skbase.base._meta": ("BaseMetaEstimator",), + "skbase.base._pretty_printing._object_html_repr": ("_VisualBlock",), + "skbase.base._pretty_printing._pprint": ( + "KeyValTuple", + "KeyValTupleParam", + "_BaseObjectPrettyPrinter", + ), "skbase.base._tagmanager": ("_FlagManager",), } ) SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = { + "skbase.config": ( + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), + "skbase.config._config": ( + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), "skbase.lookup": ("all_objects", "get_package_metadata"), "skbase.lookup._lookup": ("all_objects", "get_package_metadata"), "skbase.testing.utils._conditional_fixtures": ( @@ -124,6 +153,21 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { + "skbase.base._pretty_printing._object_html_repr": ( + "_get_visual_block", + "_object_html_repr", + "_write_base_object_html", + "_write_label_html", + ), + "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"), + "skbase.config._config": ( + "_get_threadlocal_config", + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), "skbase.lookup._lookup": ( "_determine_module_path", "_get_return_tags", @@ -155,6 +199,7 @@ "_coerce_list", ), "skbase.testing.utils.inspect": ("_get_args",), + "skbase.utils._check": ("_is_scalar_nan",), "skbase.utils._iter": ( "_format_seq_to_str", "_remove_type_text", diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 38d2d5a9..28d54d5a 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -69,16 +69,17 @@ import inspect from copy import deepcopy +from typing import Any, Dict, Type import numpy as np import pytest import scipy.sparse as sp -from sklearn import config_context # TODO: Update with import of skbase clone function once implemented from sklearn.base import clone from skbase.base import BaseEstimator, BaseObject +from skbase.config import config_context, get_config from skbase.tests.conftest import Child, Parent from skbase.tests.mock_package.test_mock_package import CompositionDummy @@ -200,13 +201,13 @@ def fixture_reset_tester(): @pytest.fixture -def fixture_class_child_tags(fixture_class_child): +def fixture_class_child_tags(fixture_class_child: Type[Child]): """Pytest fixture for tags of Child.""" return fixture_class_child.get_class_tags() @pytest.fixture -def fixture_object_instance_set_tags(fixture_tag_class_object): +def fixture_object_instance_set_tags(fixture_tag_class_object: Child): """Fixture class instance to test tag setting.""" fixture_tag_set = {"A": 42424243, "E": 3} return fixture_tag_class_object.set_tags(**fixture_tag_set) @@ -266,7 +267,9 @@ def fixture_class_instance_no_param_interface(): return NoParamInterface() -def test_get_class_tags(fixture_class_child, fixture_class_child_tags): +def test_get_class_tags( + fixture_class_child: Type[Child], fixture_class_child_tags: Any +): """Test get_class_tags class method of BaseObject for correctness. Raises @@ -280,7 +283,7 @@ def test_get_class_tags(fixture_class_child, fixture_class_child_tags): assert child_tags == fixture_class_child_tags, msg -def test_get_class_tag(fixture_class_child, fixture_class_child_tags): +def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any): """Test get_class_tag class method of BaseObject for correctness. Raises @@ -307,7 +310,7 @@ def test_get_class_tag(fixture_class_child, fixture_class_child_tags): assert child_tag_default_none is None, msg -def test_get_tags(fixture_tag_class_object, fixture_object_tags): +def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): """Test get_tags method of BaseObject for correctness. Raises @@ -321,7 +324,7 @@ def test_get_tags(fixture_tag_class_object, fixture_object_tags): assert object_tags == fixture_object_tags, msg -def test_get_tag(fixture_tag_class_object, fixture_object_tags): +def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): """Test get_tag method of BaseObject for correctness. Raises @@ -351,7 +354,7 @@ def test_get_tag(fixture_tag_class_object, fixture_object_tags): assert object_tag_default_none is None, msg -def test_get_tag_raises(fixture_tag_class_object): +def test_get_tag_raises(fixture_tag_class_object: Child): """Test that get_tag method raises error for unknown tag. Raises @@ -363,9 +366,9 @@ def test_get_tag_raises(fixture_tag_class_object): def test_set_tags( - fixture_object_instance_set_tags, - fixture_object_set_tags, - fixture_object_dynamic_tags, + fixture_object_instance_set_tags: Any, + fixture_object_set_tags: Dict[str, Any], + fixture_object_dynamic_tags: Dict[str, int], ): """Test set_tags method of BaseObject for correctness. @@ -381,7 +384,9 @@ def test_set_tags( assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg -def test_set_tags_works_with_missing_tags_dynamic_attribute(fixture_tag_class_object): +def test_set_tags_works_with_missing_tags_dynamic_attribute( + fixture_tag_class_object: Child, +): """Test set_tags will still work if _tags_dynamic is missing.""" base_obj = deepcopy(fixture_tag_class_object) delattr(base_obj, "_tags_dynamic") @@ -460,7 +465,7 @@ class AnotherTestClass(BaseObject): assert test_obj_tags.get(tag) == another_base_obj_tags[tag] -def test_is_composite(fixture_composition_dummy): +def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]): """Test is_composite tag for correctness. Raises @@ -474,7 +479,11 @@ def test_is_composite(fixture_composition_dummy): assert composite.is_composite() -def test_components(fixture_object, fixture_class_parent, fixture_composition_dummy): +def test_components( + fixture_object: Type[BaseObject], + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test component retrieval. Raises @@ -507,7 +516,7 @@ def test_components(fixture_object, fixture_class_parent, fixture_composition_du def test_components_raises_error_base_class_is_not_class( - fixture_object, fixture_composition_dummy + fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy] ): """Test _component method raises error if base_class param is not class.""" non_composite = fixture_composition_dummy(foo=42) @@ -526,7 +535,7 @@ def test_components_raises_error_base_class_is_not_class( def test_components_raises_error_base_class_is_not_baseobject_subclass( - fixture_composition_dummy, + fixture_composition_dummy: Type[CompositionDummy], ): """Test _component method raises error if base_class is not BaseObject subclass.""" @@ -540,7 +549,7 @@ class SomeClass: # Test parameter interface (get_params, set_params, reset and related methods) # Some tests of get_params and set_params are adapted from sklearn tests -def test_reset(fixture_reset_tester): +def test_reset(fixture_reset_tester: Type[ResetTester]): """Test reset method for correct behaviour, on a simple estimator. Raises @@ -567,7 +576,7 @@ def test_reset(fixture_reset_tester): assert hasattr(x, "foo") -def test_reset_composite(fixture_reset_tester): +def test_reset_composite(fixture_reset_tester: Type[ResetTester]): """Test reset method for correct behaviour, on a composite estimator.""" y = fixture_reset_tester(42) x = fixture_reset_tester(a=y) @@ -582,7 +591,7 @@ def test_reset_composite(fixture_reset_tester): assert not hasattr(x.a, "d") -def test_get_init_signature(fixture_class_parent): +def test_get_init_signature(fixture_class_parent: Type[Parent]): """Test error is raised when invalid init signature is used.""" init_sig = fixture_class_parent._get_init_signature() init_sig_is_list = isinstance(init_sig, list) @@ -594,14 +603,18 @@ def test_get_init_signature(fixture_class_parent): ), "`_get_init_signature` is not returning expected result." -def test_get_init_signature_raises_error_for_invalid_signature(fixture_invalid_init): +def test_get_init_signature_raises_error_for_invalid_signature( + fixture_invalid_init: Type[InvalidInitSignatureTester], +): """Test error is raised when invalid init signature is used.""" with pytest.raises(RuntimeError): fixture_invalid_init._get_init_signature() def test_get_param_names( - fixture_object, fixture_class_parent, fixture_class_parent_expected_params + fixture_object: Type[BaseObject], + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], ): """Test that get_param_names returns list of string parameter names.""" param_names = fixture_class_parent.get_param_names() @@ -612,10 +625,10 @@ def test_get_param_names( def test_get_params( - fixture_class_parent, - fixture_class_parent_expected_params, - fixture_class_instance_no_param_interface, - fixture_composition_dummy, + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], + fixture_class_instance_no_param_interface: NoParamInterface, + fixture_composition_dummy: Type[CompositionDummy], ): """Test get_params returns expected parameters.""" # Simple test of returned params @@ -638,7 +651,10 @@ def test_get_params( assert "foo" in params and "bar" in params and len(params) == 2 -def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy): +def test_get_params_invariance( + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test that get_params(deep=False) is subset of get_params(deep=True).""" composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84) shallow_params = composite.get_params(deep=False) @@ -646,7 +662,7 @@ def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy): assert all(item in deep_params.items() for item in shallow_params.items()) -def test_get_params_after_set_params(fixture_class_parent): +def test_get_params_after_set_params(fixture_class_parent: Type[Parent]): """Test that get_params returns the same thing before and after set_params. Based on scikit-learn check in check_estimator. @@ -687,9 +703,9 @@ def test_get_params_after_set_params(fixture_class_parent): def test_set_params( - fixture_class_parent, - fixture_class_parent_expected_params, - fixture_composition_dummy, + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params works as expected.""" # Simple case of setting a parameter @@ -711,7 +727,8 @@ def test_set_params( def test_set_params_raises_error_non_existent_param( - fixture_class_parent_instance, fixture_composition_dummy + fixture_class_parent_instance: Parent, + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params raises an error when passed a non-existent parameter name.""" # non-existing parameter in svc @@ -727,7 +744,8 @@ def test_set_params_raises_error_non_existent_param( def test_set_params_raises_error_non_interface_composite( - fixture_class_instance_no_param_interface, fixture_composition_dummy + fixture_class_instance_no_param_interface: NoParamInterface, + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params raises error when setting param of non-conforming composite.""" # When a composite is made up of a class that doesn't have the BaseObject @@ -753,7 +771,9 @@ def __init__(self, param=5): est.get_params() -def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent): +def test_set_params_with_no_param_to_set_returns_object( + fixture_class_parent: Type[Parent], +): """Test set_params correctly returns self when no parameters are set.""" base_obj = fixture_class_parent() orig_params = deepcopy(base_obj.get_params()) @@ -767,7 +787,7 @@ def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent): # This section tests the clone functionality # These have been adapted from sklearn's tests of clone to use the clone # method that is included as part of the BaseObject interface -def test_clone(fixture_class_parent_instance): +def test_clone(fixture_class_parent_instance: Parent): """Test that clone is making a deep copy as expected.""" # Creates a BaseObject and makes a copy of its original state # (which, in this case, is the current state of the BaseObject), @@ -777,7 +797,7 @@ def test_clone(fixture_class_parent_instance): assert fixture_class_parent_instance.get_params() == new_base_obj.get_params() -def test_clone_2(fixture_class_parent_instance): +def test_clone_2(fixture_class_parent_instance: Parent): """Test that clone does not copy attributes not set in constructor.""" # We first create an estimator, give it an own attribute, and # make a copy of its original state. Then we check that the copy doesn't @@ -790,7 +810,9 @@ def test_clone_2(fixture_class_parent_instance): def test_clone_raises_error_for_nonconforming_objects( - fixture_invalid_init, fixture_buggy, fixture_modify_param + fixture_invalid_init: Type[InvalidInitSignatureTester], + fixture_buggy: Type[Buggy], + fixture_modify_param: Type[ModifyParam], ): """Test that clone raises an error on nonconforming BaseObjects.""" buggy = fixture_buggy() @@ -807,7 +829,7 @@ def test_clone_raises_error_for_nonconforming_objects( obj_that_modifies.clone() -def test_clone_param_is_none(fixture_class_parent): +def test_clone_param_is_none(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter set to None.""" base_obj = fixture_class_parent(c=None) new_base_obj = clone(base_obj) @@ -816,7 +838,7 @@ def test_clone_param_is_none(fixture_class_parent): assert base_obj.c is new_base_obj2.c -def test_clone_empty_array(fixture_class_parent): +def test_clone_empty_array(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is scipy sparse matrix. This test is based on scikit-learn regression test to make sure clone @@ -830,7 +852,7 @@ def test_clone_empty_array(fixture_class_parent): np.testing.assert_array_equal(base_obj.c, new_base_obj2.c) -def test_clone_sparse_matrix(fixture_class_parent): +def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is scipy sparse matrix. This test is based on scikit-learn regression test to make sure clone @@ -843,7 +865,7 @@ def test_clone_sparse_matrix(fixture_class_parent): np.testing.assert_array_equal(base_obj.c, new_base_obj2.c) -def test_clone_nan(fixture_class_parent): +def test_clone_nan(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is np.nan. This test is based on scikit-learn regression test to make sure clone @@ -858,7 +880,7 @@ def test_clone_nan(fixture_class_parent): assert base_obj.c is new_base_obj2.c -def test_clone_estimator_types(fixture_class_parent): +def test_clone_estimator_types(fixture_class_parent: Type[Parent]): """Test clone works for parameters that are types rather than instances.""" base_obj = fixture_class_parent(c=fixture_class_parent) new_base_obj = base_obj.clone() @@ -866,7 +888,9 @@ def test_clone_estimator_types(fixture_class_parent): assert base_obj.c == new_base_obj.c -def test_clone_class_rather_than_instance_raises_error(fixture_class_parent): +def test_clone_class_rather_than_instance_raises_error( + fixture_class_parent: Type[Parent], +): """Test clone raises expected error when cloning a class instead of an instance.""" msg = "You should provide an instance of scikit-learn estimator" with pytest.raises(TypeError, match=msg): @@ -874,31 +898,77 @@ def test_clone_class_rather_than_instance_raises_error(fixture_class_parent): # Tests of BaseObject pretty printing representation inspired by sklearn -def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy): +def test_baseobject_repr( + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test BaseObject repr works as expected.""" # Simple test where all parameters are left at defaults # Should not see parameters and values in printed representation + base_obj = fixture_class_parent() assert repr(base_obj) == "Parent()" # Check that we can alter the detail about params that is printed # using config_context with ``print_changed_only=False`` - with config_context(print_changed_only=False): + with config_context(print_changed_only=False, display="text"): assert repr(base_obj) == "Parent(a='something', b=7, c=None)" + # Check that local config works as expected + base_obj.set_config(print_changed_only=False) + assert repr(base_obj) == "Parent(a='something', b=7, c=None)" + + # Test with dict parameter (note that dict is sorted by keys when printed) + # not printed in order it was created + base_obj = fixture_class_parent(c={"c": 1, "a": 2}) + assert repr(base_obj) == "Parent(c={'a': 2, 'c': 1})" + + # Now test when one params values are named object tuples + named_objs = [ + ("step 1", fixture_class_parent()), + ("step 2", fixture_class_parent()), + ] + base_obj = fixture_class_parent(c=named_objs) + assert repr(base_obj) == "Parent(c=[('step 1', Parent()), ('step 2', Parent())])" + + # Or when they are just lists of tuples or just tuples as param + base_obj = fixture_class_parent(c=[("one", 1), ("two", 2)]) + assert repr(base_obj) == "Parent(c=[('one', 1), ('two', 2)])" + + base_obj = fixture_class_parent(c=(1, 2, 3)) + assert repr(base_obj) == "Parent(c=(1, 2, 3))" + simple_composite = fixture_composition_dummy(foo=fixture_class_parent()) assert repr(simple_composite) == "CompositionDummy(foo=Parent())" long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000) assert len(repr(long_base_obj_repr)) == 535 + named_objs = [(f"Step {i+1}", Child()) for i in range(25)] + base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs))) + assert len(repr(base_comp)) == 1362 + with config_context(print_changed_only=False): + assert "..." in repr(base_comp) + -def test_baseobject_str(fixture_class_parent_instance): +def test_baseobject_str(fixture_class_parent_instance: Parent): """Test BaseObject string representation works.""" - str(fixture_class_parent_instance) + assert ( + str(fixture_class_parent_instance) == "Parent()" + ), "String representation of instance not working." + # Check that we can alter the detail about params that is printed + # using config_context with ``print_changed_only=False`` + with config_context(print_changed_only=False, display="text"): + assert ( + str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)" + ) + # Check that local config works as expected + fixture_class_parent_instance.set_config(print_changed_only=False) + assert str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)" -def test_baseobject_repr_mimebundle_(fixture_class_parent_instance): + +def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent): """Test display configuration controls output.""" # Checks the display configuration flag controls the json output with config_context(display="diagram"): @@ -912,7 +982,7 @@ def test_baseobject_repr_mimebundle_(fixture_class_parent_instance): assert "text/html" not in output -def test_repr_html_wraps(fixture_class_parent_instance): +def test_repr_html_wraps(fixture_class_parent_instance: Parent): """Test display configuration flag controls the html output.""" with config_context(display="diagram"): output = fixture_class_parent_instance._repr_html_() @@ -925,20 +995,24 @@ def test_repr_html_wraps(fixture_class_parent_instance): # Test BaseObject's ability to generate test instances -def test_get_test_params(fixture_class_parent_instance): +def test_get_test_params(fixture_class_parent_instance: Parent): """Test get_test_params returns empty dictionary.""" base_obj = fixture_class_parent_instance test_params = base_obj.get_test_params() assert isinstance(test_params, dict) and len(test_params) == 0 -def test_get_test_params_raises_error_when_params_required(fixture_required_param): +def test_get_test_params_raises_error_when_params_required( + fixture_required_param: Type[RequiredParam], +): """Test get_test_params raises an error when parameters are required.""" with pytest.raises(ValueError): fixture_required_param(7).get_test_params() -def test_create_test_instance(fixture_class_parent, fixture_class_parent_instance): +def test_create_test_instance( + fixture_class_parent: Type[Parent], fixture_class_parent_instance: Parent +): """Test first that create_test_instance logic works.""" base_obj = fixture_class_parent.create_test_instance() @@ -957,7 +1031,7 @@ def test_create_test_instance(fixture_class_parent, fixture_class_parent_instanc assert hasattr(base_obj, "_tags_dynamic"), msg -def test_create_test_instances_and_names(fixture_class_parent_instance): +def test_create_test_instances_and_names(fixture_class_parent_instance: Parent): """Test that create_test_instances_and_names works.""" base_objs, names = fixture_class_parent_instance.create_test_instances_and_names() @@ -990,7 +1064,7 @@ def test_create_test_instances_and_names(fixture_class_parent_instance): # Tests _has_implementation_of interface def test_has_implementation_of( - fixture_class_parent_instance, fixture_class_child_instance + fixture_class_parent_instance: Parent, fixture_class_child_instance: Child ): """Test _has_implementation_of detects methods in class with overrides in mro.""" # When the class overrides a parent classes method should return True @@ -1014,33 +1088,182 @@ def __init__(self, a, b=42): self.c = 84 -def test_set_get_config(): - """Test logic behind get_config, set_config. +class AnotherConfigTester(BaseObject): + _config = {"print_changed_only": False, "bar": "a"} - Raises - ------ - AssertionError if logic behind get_config, set_config is incorrect, logic tested: - calling get_fitted_params on a non-composite fittable returns the fitted param - calling get_fitted_params on a composite returns all nested params + clsvar = 210 + + def __init__(self, a, b=42): + self.a = a + self.b = b + self.c = 84 + + +class ConfigExtensionInterfaceTester(BaseObject): + _config = {"print_changed_only": False, "bar": "a"} + + clsvar = 210 + + def __init__(self, a, b=42): + self.a = a + self.b = b + self.c = 84 + + def __skbase_get_config__(self): + """Return get_config extension.""" + return {"print_changed_only": True, "some_other_config": 70} + + +def test_local_config_without_use_of_extension_interface(): + """Test ``BaseObject().get_config`` and ``BaseObject().set_config``. + + ``BaseObject.get_config()`` should return the global dict updated with any local + configs defined in ``BaseObject._config`` or set on the instance with + ``BaseObject().set_config()``. """ - obj = ConfigTester(4242) + # Initially test that We can retrieve the local config with local configs + # as defined in BaseObject._config - config_start = obj.get_config() - assert isinstance(config_start, dict) - assert set(config_start.keys()) == {"foo_config", "bar"} - assert config_start["foo_config"] == 42 - assert config_start["bar"] == "a" + # Case 1: local configs are not part of global config param names + obj = ConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"foo_config", "bar"} + assert obj_config["foo_config"] == 42 + assert obj_config["bar"] == "a" + for param_name in current_global_config: + assert obj_config[param_name] == current_global_config[param_name] + + # Case 2: local configs overlap with (will override) global params + obj = AnotherConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"bar"} + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + assert obj_config["print_changed_only"] is False + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + # Case 3: local configs are not part of global config param names and we also + # make use of dynamic BaseObject.set_config() + obj = ConfigTester(4242) + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + # Verify set config returns the original object setconfig_return = obj.set_config(foobar=126) assert obj is setconfig_return obj.set_config(**{"bar": "b"}) - config_end = obj.get_config() - assert isinstance(config_end, dict) - assert set(config_end.keys()) == {"foo_config", "bar", "foobar"} - assert config_end["foo_config"] == 42 - assert config_end["bar"] == "b" - assert config_end["foobar"] == 126 + updated_obj_config = obj.get_config() + assert isinstance(updated_obj_config, dict) + assert set(updated_obj_config.keys()) == ( + expected_global_config | {"foo_config", "bar", "foobar"} + ) + assert updated_obj_config["foo_config"] == 42 + assert updated_obj_config["bar"] == "b" + assert updated_obj_config["foobar"] == 126 + + # Case 4: local configs are not part of global config param names and we also + # make use of dynamic BaseObject.set_config() to update a config that is also + # part of global config + obj = ConfigTester(4242) + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + + # Verify set config returns the original object + setconfig_return = obj.set_config(print_changed_only=False) + assert obj is setconfig_return + + updated_obj_config = obj.get_config() + assert isinstance(updated_obj_config, dict) + assert set(updated_obj_config.keys()) == ( + expected_global_config | {"foo_config", "bar"} + ) + assert updated_obj_config["foo_config"] == 42 + assert updated_obj_config["bar"] == "a" + assert updated_obj_config["print_changed_only"] is False + for param_name in current_global_config: + if param_name != "print_changed_only": + assert updated_obj_config[param_name] == current_global_config[param_name] + + # Case 5: local configs overlap with (will override) global params + # Then the local config defined in AnotherConfigTester._config is overrode again + # by calling AnotherConfigTester().set_config() + obj = AnotherConfigTester(4242) + obj.set_config(print_changed_only=True) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"bar"} + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + assert obj_config["print_changed_only"] is True + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + +def test_local_config_with_use_of_extension_interface(): + """Test BaseObject local config interface when ``__skbase_get_config__`` defined. + + BaseObject.get_config() should return the global dict updated in the following + order: + + - Any config returned by ``BaseObject.__skbase_get_config__``. + - Any configs defined in ``BaseObject._config`` or set on the instance with + ``BaseObject().set_config()``. + """ + current_global_config = get_config().copy() + obj = ConfigExtensionInterfaceTester(4242) + obj_config = obj.get_config() + assert "some_other_config" in obj_config + assert obj_config["print_changed_only"] is False + + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | { + "some_other_config", + "bar", + } + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + # Now lets verify we can override the config items only returned by + # __skbase_get_config__ extension interface + obj.set_config(some_other_config=22) + obj_config_updated = obj.get_config() + assert obj_config_updated["some_other_config"] == 22 + for param_name in obj_config: + if param_name != "some_other_config": + assert obj_config_updated[param_name] == obj_config[param_name] + + +def local_config_interface_does_not_affect_global_config_interface(): + """Test that calls to instance config interface doesn't impact global config.""" + from skbase.config import get_default_config, global_config + + obj = AnotherConfigTester(4242) + _global_config_before = global_config.copy() + _default_config_before = get_default_config() + global_config_before = get_config().copy() + obj.set_config(print_changed_only=False, some_other_param=7) + global_config_after = get_config().copy() + assert global_config_before == global_config_after + assert "some_other_param" not in global_config_after + assert _global_config_before == global_config + assert _default_config_before == get_default_config() class FittableCompositionDummy(BaseEstimator): diff --git a/skbase/utils/_check.py b/skbase/utils/_check.py new file mode 100644 index 00000000..7a6830c4 --- /dev/null +++ b/skbase/utils/_check.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Elements of _is_scalar_nan re-use code developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING + +"""Utility functions to perform various types of checks.""" +from __future__ import annotations + +import math +import numbers +from typing import Any + +__all__ = ["_is_scalar_nan"] +__author__ = ["RNKuhns"] + + +def _is_scalar_nan(x: Any) -> bool: + """Test if x is NaN. + + This function is meant to overcome the issue that np.isnan does not allow + non-numerical types as input, and that np.nan is not float('nan'). + + Parameters + ---------- + x : Any + The item to be checked to determine if it is a scalar nan value. + + Returns + ------- + bool + True if `x` is a scalar nan value + + Notes + ----- + This code follows scikit-learn's implementation. + + Examples + -------- + >>> import numpy as np + >>> from skbase.utils._check import _is_scalar_nan + >>> _is_scalar_nan(np.nan) + True + >>> _is_scalar_nan(float("nan")) + True + >>> _is_scalar_nan(None) + False + >>> _is_scalar_nan("") + False + >>> _is_scalar_nan([np.nan]) + False + """ + return isinstance(x, numbers.Real) and math.isnan(x) diff --git a/skbase/utils/tests/test_check.py b/skbase/utils/tests/test_check.py new file mode 100644 index 00000000..f998bae4 --- /dev/null +++ b/skbase/utils/tests/test_check.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Tests of the utility functionality for performing various checks. + +tests in this module include: + +- test_is_scalar_nan_output to verify _is_scalar_nan outputs expected value for + different inputs. +""" +import numpy as np + +from skbase.utils._check import _is_scalar_nan + +__author__ = ["RNKuhns"] + + +def test_is_scalar_nan_output(): + """Test that _is_scalar_nan outputs expected value for different inputs.""" + assert _is_scalar_nan(np.nan) is True + assert _is_scalar_nan(float("nan")) is True + assert _is_scalar_nan(None) is False + assert _is_scalar_nan("") is False + assert _is_scalar_nan([np.nan]) is False