diff --git a/skbase/base/_base.py b/skbase/base/_base.py
index 970bf350..fb2cd2ed 100644
--- a/skbase/base/_base.py
+++ b/skbase/base/_base.py
@@ -53,6 +53,7 @@ class name: BaseEstimator
fitted state check - check_is_fitted (raises error if not is_fitted)
"""
import inspect
+import re
import warnings
from collections import defaultdict
from copy import deepcopy
@@ -62,6 +63,7 @@ class name: BaseEstimator
from sklearn.base import BaseEstimator as _BaseEstimator
from skbase._exceptions import NotFittedError
+from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager
__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
@@ -74,6 +76,11 @@ class BaseObject(_FlagManager, _BaseEstimator):
Extends scikit-learn's BaseEstimator to include sktime style interface for tags.
"""
+ _config = {
+ "display": "diagram",
+ "print_changed_only": True,
+ }
+
def __init__(self):
"""Construct BaseObject."""
self._init_flags(flag_attr_name="_tags")
@@ -682,6 +689,98 @@ def _components(self, base_class=None):
return comp_dict
+ def __repr__(self, n_char_max: int = 700):
+ """Represent class as string.
+
+ This follows the scikit-learn implementation for the string representation
+ of parameterized objects.
+
+ Parameters
+ ----------
+ n_char_max : int
+ Maximum (approximate) number of non-blank characters to render. This
+ can be useful in testing.
+ """
+ from skbase.base._pretty_printing._pprint import _BaseObjectPrettyPrinter
+
+ n_max_elements_to_show = 30 # number of elements to show in sequences
+ # use ellipsis for sequences with a lot of elements
+ pp = _BaseObjectPrettyPrinter(
+ compact=True,
+ indent=1,
+ indent_at_name=True,
+ n_max_elements_to_show=n_max_elements_to_show,
+ changed_only=self.get_config()["print_changed_only"],
+ )
+
+ repr_ = pp.pformat(self)
+
+ # Use bruteforce ellipsis when there are a lot of non-blank characters
+ n_nonblank = len("".join(repr_.split()))
+ if n_nonblank > n_char_max:
+ lim = n_char_max // 2 # apprx number of chars to keep on both ends
+ regex = r"^(\s*\S){%d}" % lim
+ # The regex '^(\s*\S){%d}' matches from the start of the string
+ # until the nth non-blank character:
+ # - ^ matches the start of string
+ # - (pattern){n} matches n repetitions of pattern
+ # - \s*\S matches a non-blank char following zero or more blanks
+ left_match = re.match(regex, repr_)
+ right_match = re.match(regex, repr_[::-1])
+ left_lim = left_match.end() if left_match is not None else 0
+ right_lim = right_match.end() if right_match is not None else 0
+
+ if "\n" in repr_[left_lim:-right_lim]:
+ # The left side and right side aren't on the same line.
+ # To avoid weird cuts, e.g.:
+ # categoric...ore',
+ # we need to start the right side with an appropriate newline
+ # character so that it renders properly as:
+ # categoric...
+ # handle_unknown='ignore',
+ # so we add [^\n]*\n which matches until the next \n
+ regex += r"[^\n]*\n"
+ right_match = re.match(regex, repr_[::-1])
+ right_lim = right_match.end() if right_match is not None else 0
+
+ ellipsis = "..."
+ if left_lim + len(ellipsis) < len(repr_) - right_lim:
+ # Only add ellipsis if it results in a shorter repr
+ repr_ = repr_[:left_lim] + "..." + repr_[-right_lim:]
+
+ return repr_
+
+ @property
+ def _repr_html_(self):
+ """HTML representation of BaseObject.
+
+ This is redundant with the logic of `_repr_mimebundle_`. The latter
+ should be favorted in the long term, `_repr_html_` is only
+ implemented for consumers who do not interpret `_repr_mimbundle_`.
+ """
+ if self.get_config()["display"] != "diagram":
+ raise AttributeError(
+ "_repr_html_ is only defined when the "
+ "`display` configuration option is set to 'diagram'."
+ )
+ return self._repr_html_inner
+
+ def _repr_html_inner(self):
+ """Return HTML representation of class.
+
+ This function is returned by the @property `_repr_html_` to make
+ `hasattr(BaseObject, "_repr_html_") return `True` or `False` depending
+ on `self.get_config()["display"]`.
+ """
+ return _object_html_repr(self)
+
+ def _repr_mimebundle_(self, **kwargs):
+ """Mime bundle used by jupyter kernels to display instances of BaseObject."""
+ output = {"text/plain": repr(self)}
+ if self.get_config()["display"] == "diagram":
+ output["text/html"] = _object_html_repr(self)
+ return output
+
class TagAliaserMixin:
"""Mixin class for tag aliasing and deprecation of old tags.
diff --git a/skbase/base/_pretty_printing/__init__.py b/skbase/base/_pretty_printing/__init__.py
new file mode 100644
index 00000000..669c800c
--- /dev/null
+++ b/skbase/base/_pretty_printing/__init__.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python3 -u
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Many elements of this code were developed in scikit-learn. These elements
+# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
+# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+"""Functionality for pretty printing BaseObjects."""
+from typing import List
+
+__author__: List[str] = ["RNKuhns"]
+__all__: List[str] = []
diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py
new file mode 100644
index 00000000..397b289c
--- /dev/null
+++ b/skbase/base/_pretty_printing/_object_html_repr.py
@@ -0,0 +1,392 @@
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Many elements of this code were developed in scikit-learn. These elements
+# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
+# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+"""Functionality to represent instance of BaseObject as html."""
+
+import html
+import uuid
+from contextlib import closing, suppress
+from io import StringIO
+from string import Template
+
+__author__ = ["RNKuhns"]
+
+
+class _VisualBlock:
+ """HTML Representation of BaseObject.
+
+ Parameters
+ ----------
+ kind : {'serial', 'parallel', 'single'}
+ kind of HTML block
+
+ objs : list of BaseObjects or `_VisualBlock`s or a single BaseObject
+ If kind != 'single', then `objs` is a list of
+ BaseObjects. If kind == 'single', then `objs` is a single BaseObject.
+
+ names : list of str, default=None
+ If kind != 'single', then `names` corresponds to BaseObjects.
+ If kind == 'single', then `names` is a single string corresponding to
+ the single BaseObject.
+
+ name_details : list of str, str, or None, default=None
+ If kind != 'single', then `name_details` corresponds to `names`.
+ If kind == 'single', then `name_details` is a single string
+ corresponding to the single BaseObject.
+
+ dash_wrapped : bool, default=True
+ If true, wrapped HTML element will be wrapped with a dashed border.
+ Only active when kind != 'single'.
+ """
+
+ def __init__(self, kind, objs, *, names=None, name_details=None, dash_wrapped=True):
+ self.kind = kind
+ self.objs = objs
+ self.dash_wrapped = dash_wrapped
+
+ if self.kind in ("parallel", "serial"):
+ if names is None:
+ names = (None,) * len(objs)
+ if name_details is None:
+ name_details = (None,) * len(objs)
+
+ self.names = names
+ self.name_details = name_details
+
+ def _sk_visual_block_(self):
+ return self
+
+
+def _write_label_html(
+ out,
+ name,
+ name_details,
+ outer_class="sk-label-container",
+ inner_class="sk-label",
+ checked=False,
+):
+ """Write labeled html with or without a dropdown with named details."""
+ out.write(f'
')
+ name = html.escape(name)
+
+ if name_details is not None:
+ name_details = html.escape(str(name_details))
+ label_class = "sk-toggleable__label sk-toggleable__label-arrow"
+
+ checked_str = "checked" if checked else ""
+ est_id = uuid.uuid4()
+ out.write(
+ ''
+ f""
+ f'
{name_details}'
+ "
"
+ )
+ else:
+ out.write(f"")
+ out.write("
") # outer_class inner_class
+
+
+def _get_visual_block(base_object):
+ """Generate information about how to display a BaseObject."""
+ with suppress(AttributeError):
+ return base_object._sk_visual_block_()
+
+ if isinstance(base_object, str):
+ return _VisualBlock(
+ "single", base_object, names=base_object, name_details=base_object
+ )
+ elif base_object is None:
+ return _VisualBlock("single", base_object, names="None", name_details="None")
+
+ # check if base_object looks like a meta base_object wraps base_object
+ if hasattr(base_object, "get_params"):
+ base_objects = []
+ for key, value in base_object.get_params().items():
+ # Only look at the BaseObjects in the first layer
+ if "__" not in key and hasattr(value, "get_params"):
+ base_objects.append(value)
+ if len(base_objects):
+ return _VisualBlock("parallel", base_objects, names=None)
+
+ return _VisualBlock(
+ "single",
+ base_object,
+ names=base_object.__class__.__name__,
+ name_details=str(base_object),
+ )
+
+
+def _write_base_object_html(
+ out, base_object, base_object_label, base_object_label_details, first_call=False
+):
+ """Write BaseObject to html in serial, parallel, or by itself (single)."""
+ est_block = _get_visual_block(base_object)
+
+ if est_block.kind in ("serial", "parallel"):
+ dashed_wrapped = first_call or est_block.dash_wrapped
+ dash_cls = " sk-dashed-wrapped" if dashed_wrapped else ""
+ out.write(f'
")
+
+ html_output = out.getvalue()
+ return html_output
diff --git a/skbase/base/_pretty_printing/_pprint.py b/skbase/base/_pretty_printing/_pprint.py
new file mode 100644
index 00000000..2e8a8a80
--- /dev/null
+++ b/skbase/base/_pretty_printing/_pprint.py
@@ -0,0 +1,412 @@
+# -*- coding: utf-8 -*-
+# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
+# Many elements of this code were developed in scikit-learn. These elements
+# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
+# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
+"""Utility functionality for pretty-printing objects used in BaseObject.__repr__."""
+import inspect
+import pprint
+from collections import OrderedDict
+
+from skbase.base import BaseObject
+
+# from skbase.config import get_config
+from skbase.utils._check import _is_scalar_nan
+
+
+class KeyValTuple(tuple):
+ """Dummy class for correctly rendering key-value tuples from dicts."""
+
+ def __repr__(self):
+ """Represent as string."""
+ # needed for _dispatch[tuple.__repr__] not to be overridden
+ return super().__repr__()
+
+
+class KeyValTupleParam(KeyValTuple):
+ """Dummy class for correctly rendering key-value tuples from parameters."""
+
+ pass
+
+
+def _changed_params(base_object):
+ """Return dict (param_name: value) of parameters with non-default values."""
+ params = base_object.get_params(deep=False)
+ init_func = getattr(
+ base_object.__init__, "deprecated_original", base_object.__init__
+ )
+ init_params = inspect.signature(init_func).parameters
+ init_params = {name: param.default for name, param in init_params.items()}
+
+ def has_changed(k, v):
+ if k not in init_params: # happens if k is part of a **kwargs
+ return True
+ if init_params[k] == inspect._empty: # k has no default value
+ return True
+ # try to avoid calling repr on nested BaseObjects
+ if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__:
+ return True
+ # Use repr as a last resort. It may be expensive.
+ if repr(v) != repr(init_params[k]) and not (
+ _is_scalar_nan(init_params[k]) and _is_scalar_nan(v)
+ ):
+ return True
+ return False
+
+ return {k: v for k, v in params.items() if has_changed(k, v)}
+
+
+class _BaseObjectPrettyPrinter(pprint.PrettyPrinter):
+ """Pretty Printer class for BaseObjects.
+
+ This extends the pprint.PrettyPrinter class similar to scikit-learn's
+ implementation, so that:
+
+ - BaseObjects are printed with their parameters, e.g.
+ BaseObject(param1=value1, ...) which is not supported by default.
+ - the 'compact' parameter of PrettyPrinter is ignored for dicts, which
+ may lead to very long representations that we want to avoid.
+
+ Quick overview of pprint.PrettyPrinter (see also
+ https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers):
+
+ - the entry point is the _format() method which calls format() (overridden
+ here)
+ - format() directly calls _safe_repr() for a first try at rendering the
+ object
+ - _safe_repr formats the whole object recursively, only calling itself,
+ not caring about line length or anything
+ - back to _format(), if the output string is too long, _format() then calls
+ the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on
+ the type of the object. This where the line length and the compact
+ parameters are taken into account.
+ - those _pprint_TYPE() methods will internally use the format() method for
+ rendering the nested objects of an object (e.g. the elements of a list)
+
+ In the end, everything has to be implemented twice: in _safe_repr and in
+ the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not
+ straightforward to extend (especially when we want a compact output), so
+ the code is a bit convoluted.
+
+ This class overrides:
+ - format() to support the changed_only parameter
+ - _safe_repr to support printing of BaseObjects that fit on a single line
+ - _format_dict_items so that dict are correctly 'compacted'
+ - _format_items so that ellipsis is used on long lists and tuples
+
+ When BaseObjects cannot be printed on a single line, the builtin _format()
+ will call _pprint_object() because it was registered to do so (see
+ _dispatch[BaseObject.__repr__] = _pprint_object).
+
+ both _format_dict_items() and _pprint_Object() use the
+ _format_params_or_dict_items() method that will format parameters and
+ key-value pairs respecting the compact parameter. This method needs another
+ subroutine _pprint_key_val_tuple() used when a parameter or a key-value
+ pair is too long to fit on a single line. This subroutine is called in
+ _format() and is registered as well in the _dispatch dict (just like
+ _pprint_object). We had to create the two classes KeyValTuple and
+ KeyValTupleParam for this.
+ """
+
+ def __init__(
+ self,
+ indent=1,
+ width=80,
+ depth=None,
+ stream=None,
+ *,
+ compact=False,
+ indent_at_name=True,
+ n_max_elements_to_show=None,
+ changed_only=True,
+ ):
+ super().__init__(indent, width, depth, stream, compact=compact)
+ self._indent_at_name = indent_at_name
+ if self._indent_at_name:
+ self._indent_per_level = 1 # ignore indent param
+ self.changed_only = changed_only
+ # Max number of elements in a list, dict, tuple until we start using
+ # ellipsis. This also affects the number of arguments of a BaseObject
+ # (they are treated as dicts)
+ self.n_max_elements_to_show = n_max_elements_to_show
+
+ def format(self, obj, context, maxlevels, level): # noqa
+ return _safe_repr(
+ obj, context, maxlevels, level, changed_only=self.changed_only
+ )
+
+ def _pprint_object(self, obj, stream, indent, allowance, context, level):
+ stream.write(obj.__class__.__name__ + "(")
+ if self._indent_at_name:
+ indent += len(obj.__class__.__name__)
+
+ if self.changed_only:
+ params = _changed_params(obj)
+ else:
+ params = obj.get_params(deep=False)
+
+ params = OrderedDict((name, val) for (name, val) in sorted(params.items()))
+
+ self._format_params(
+ params.items(), stream, indent, allowance + 1, context, level
+ )
+ stream.write(")")
+
+ def _format_dict_items(self, items, stream, indent, allowance, context, level):
+ return self._format_params_or_dict_items(
+ items, stream, indent, allowance, context, level, is_dict=True
+ )
+
+ def _format_params(self, items, stream, indent, allowance, context, level):
+ return self._format_params_or_dict_items(
+ items, stream, indent, allowance, context, level, is_dict=False
+ )
+
+ def _format_params_or_dict_items(
+ self, obj, stream, indent, allowance, context, level, is_dict
+ ):
+ """Format dict items or parameters respecting the compact=True parameter.
+
+ For some reason, the builtin rendering of dict items doesn't
+ respect compact=True and will use one line per key-value if all cannot
+ fit in a single line.
+ Dict items will be rendered as <'key': value> while params will be
+ rendered as . The implementation is mostly copy/pasting from
+ the builtin _format_items().
+ This also adds ellipsis if the number of items is greater than
+ self.n_max_elements_to_show.
+ """
+ write = stream.write
+ indent += self._indent_per_level
+ delimnl = ",\n" + " " * indent
+ delim = ""
+ width = max_width = self._width - indent + 1
+ it = iter(obj)
+ try:
+ next_ent = next(it)
+ except StopIteration:
+ return
+ last = False
+ n_items = 0
+ while not last:
+ if n_items == self.n_max_elements_to_show:
+ write(", ...")
+ break
+ n_items += 1
+ ent = next_ent
+ try:
+ next_ent = next(it)
+ except StopIteration:
+ last = True
+ max_width -= allowance
+ width -= allowance
+ if self._compact:
+ k, v = ent
+ krepr = self._repr(k, context, level)
+ vrepr = self._repr(v, context, level)
+ if not is_dict:
+ krepr = krepr.strip("'")
+ middle = ": " if is_dict else "="
+ rep = krepr + middle + vrepr
+ w = len(rep) + 2
+ if width < w:
+ width = max_width
+ if delim:
+ delim = delimnl
+ if width >= w:
+ width -= w
+ write(delim)
+ delim = ", "
+ write(rep)
+ continue
+ write(delim)
+ delim = delimnl
+ class_ = KeyValTuple if is_dict else KeyValTupleParam
+ self._format(
+ class_(ent), stream, indent, allowance if last else 1, context, level
+ )
+
+ def _format_items(self, items, stream, indent, allowance, context, level):
+ """Format the items of an iterable (list, tuple...).
+
+ Same as the built-in _format_items, with support for ellipsis if the
+ number of elements is greater than self.n_max_elements_to_show.
+ """
+ write = stream.write
+ indent += self._indent_per_level
+ if self._indent_per_level > 1:
+ write((self._indent_per_level - 1) * " ")
+ delimnl = ",\n" + " " * indent
+ delim = ""
+ width = max_width = self._width - indent + 1
+ it = iter(items)
+ try:
+ next_ent = next(it)
+ except StopIteration:
+ return
+ last = False
+ n_items = 0
+ while not last:
+ if n_items == self.n_max_elements_to_show:
+ write(", ...")
+ break
+ n_items += 1
+ ent = next_ent
+ try:
+ next_ent = next(it)
+ except StopIteration:
+ last = True
+ max_width -= allowance
+ width -= allowance
+ if self._compact:
+ rep = self._repr(ent, context, level)
+ w = len(rep) + 2
+ if width < w:
+ width = max_width
+ if delim:
+ delim = delimnl
+ if width >= w:
+ width -= w
+ write(delim)
+ delim = ", "
+ write(rep)
+ continue
+ write(delim)
+ delim = delimnl
+ self._format(ent, stream, indent, allowance if last else 1, context, level)
+
+ def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level):
+ """Pretty printing for key-value tuples from dict or parameters."""
+ k, v = obj
+ rep = self._repr(k, context, level)
+ if isinstance(obj, KeyValTupleParam):
+ rep = rep.strip("'")
+ middle = "="
+ else:
+ middle = ": "
+ stream.write(rep)
+ stream.write(middle)
+ self._format(
+ v, stream, indent + len(rep) + len(middle), allowance, context, level
+ )
+
+ # Follow what scikit-learn did here and copy _dispatch to prevent instances
+ # of the builtin PrettyPrinter class to call methods of
+ # _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906)
+ # mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch"
+ _dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore
+ _dispatch[BaseObject.__repr__] = _pprint_object
+ _dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
+
+
+def _safe_repr(obj, context, maxlevels, level, changed_only=False):
+ """Safe string representation logic.
+
+ Same as the builtin _safe_repr, with added support for BaseObjects.
+ """
+ typ = type(obj)
+
+ if typ in pprint._builtin_scalars:
+ return repr(obj), True, False
+
+ r = getattr(typ, "__repr__", None)
+ if issubclass(typ, dict) and r is dict.__repr__:
+ if not obj:
+ return "{}", True, False
+ objid = id(obj)
+ if maxlevels and level >= maxlevels:
+ return "{...}", False, objid in context
+ if objid in context:
+ return pprint._recursion(obj), False, True
+ context[objid] = 1
+ readable = True
+ recursive = False
+ components = []
+ append = components.append
+ level += 1
+ saferepr = _safe_repr
+ items = sorted(obj.items(), key=pprint._safe_tuple)
+ for k, v in items:
+ krepr, kreadable, krecur = saferepr(
+ k, context, maxlevels, level, changed_only=changed_only
+ )
+ vrepr, vreadable, vrecur = saferepr(
+ v, context, maxlevels, level, changed_only=changed_only
+ )
+ append("%s: %s" % (krepr, vrepr))
+ readable = readable and kreadable and vreadable
+ if krecur or vrecur:
+ recursive = True
+ del context[objid]
+ return "{%s}" % ", ".join(components), readable, recursive
+
+ if (issubclass(typ, list) and r is list.__repr__) or (
+ issubclass(typ, tuple) and r is tuple.__repr__
+ ):
+ if issubclass(typ, list):
+ if not obj:
+ return "[]", True, False
+ format_ = "[%s]"
+ elif len(obj) == 1:
+ format_ = "(%s,)"
+ else:
+ if not obj:
+ return "()", True, False
+ format_ = "(%s)"
+ objid = id(obj)
+ if maxlevels and level >= maxlevels:
+ return format_ % "...", False, objid in context
+ if objid in context:
+ return pprint._recursion(obj), False, True
+ context[objid] = 1
+ readable = True
+ recursive = False
+ components = []
+ append = components.append
+ level += 1
+ for o in obj:
+ orepr, oreadable, orecur = _safe_repr(
+ o, context, maxlevels, level, changed_only=changed_only
+ )
+ append(orepr)
+ if not oreadable:
+ readable = False
+ if orecur:
+ recursive = True
+ del context[objid]
+ return format_ % ", ".join(components), readable, recursive
+
+ if issubclass(typ, BaseObject):
+ objid = id(obj)
+ if maxlevels and level >= maxlevels:
+ return "{...}", False, objid in context
+ if objid in context:
+ return pprint._recursion(obj), False, True
+ context[objid] = 1
+ readable = True
+ recursive = False
+ if changed_only:
+ params = _changed_params(obj)
+ else:
+ params = obj.get_params(deep=False)
+ components = []
+ append = components.append
+ level += 1
+ saferepr = _safe_repr
+ items = sorted(params.items(), key=pprint._safe_tuple)
+ for k, v in items:
+ krepr, kreadable, krecur = saferepr(
+ k, context, maxlevels, level, changed_only=changed_only
+ )
+ vrepr, vreadable, vrecur = saferepr(
+ v, context, maxlevels, level, changed_only=changed_only
+ )
+ append("%s=%s" % (krepr.strip("'"), vrepr))
+ readable = readable and kreadable and vreadable
+ if krecur or vrecur:
+ recursive = True
+ del context[objid]
+ return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive)
+
+ rep = repr(obj)
+ return rep, (rep and not rep.startswith("<")), False
diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py
index 0b9bd15f..7a0508ef 100644
--- a/skbase/tests/conftest.py
+++ b/skbase/tests/conftest.py
@@ -22,6 +22,9 @@
"skbase.base",
"skbase.base._base",
"skbase.base._meta",
+ "skbase.base._pretty_printing",
+ "skbase.base._pretty_printing._object_html_repr",
+ "skbase.base._pretty_printing._pprint",
"skbase.base._tagmanager",
"skbase.lookup",
"skbase.lookup.tests",
@@ -42,6 +45,7 @@
"skbase.tests.test_baseestimator",
"skbase.tests.mock_package.test_mock_package",
"skbase.utils",
+ "skbase.utils._check",
"skbase.utils._iter",
"skbase.utils._nested_iter",
"skbase.utils._utils",
@@ -80,6 +84,7 @@
),
"skbase.base._base": ("BaseEstimator", "BaseObject"),
"skbase.base._meta": ("BaseMetaObject", "BaseMetaEstimator"),
+ "skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
"skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"),
"skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
"skbase.testing.test_all_objects": (
@@ -96,6 +101,12 @@
"BaseMetaEstimator",
"_MetaObjectMixin",
),
+ "skbase.base._pretty_printing._object_html_repr": ("_VisualBlock",),
+ "skbase.base._pretty_printing._pprint": (
+ "KeyValTuple",
+ "KeyValTupleParam",
+ "_BaseObjectPrettyPrinter",
+ ),
"skbase.base._tagmanager": ("_FlagManager",),
}
)
@@ -140,6 +151,13 @@
SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
SKBASE_FUNCTIONS_BY_MODULE.update(
{
+ "skbase.base._pretty_printing._object_html_repr": (
+ "_get_visual_block",
+ "_object_html_repr",
+ "_write_base_object_html",
+ "_write_label_html",
+ ),
+ "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"),
"skbase.lookup._lookup": (
"_determine_module_path",
"_get_return_tags",
@@ -171,6 +189,7 @@
"_coerce_list",
),
"skbase.testing.utils.inspect": ("_get_args",),
+ "skbase.utils._check": ("_is_scalar_nan",),
"skbase.utils._iter": (
"_format_seq_to_str",
"_remove_type_text",
diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py
index 38d2d5a9..a5dc5341 100644
--- a/skbase/tests/test_base.py
+++ b/skbase/tests/test_base.py
@@ -69,11 +69,11 @@
import inspect
from copy import deepcopy
+from typing import Any, Dict, Type
import numpy as np
import pytest
import scipy.sparse as sp
-from sklearn import config_context
# TODO: Update with import of skbase clone function once implemented
from sklearn.base import clone
@@ -200,13 +200,13 @@ def fixture_reset_tester():
@pytest.fixture
-def fixture_class_child_tags(fixture_class_child):
+def fixture_class_child_tags(fixture_class_child: Type[Child]):
"""Pytest fixture for tags of Child."""
return fixture_class_child.get_class_tags()
@pytest.fixture
-def fixture_object_instance_set_tags(fixture_tag_class_object):
+def fixture_object_instance_set_tags(fixture_tag_class_object: Child):
"""Fixture class instance to test tag setting."""
fixture_tag_set = {"A": 42424243, "E": 3}
return fixture_tag_class_object.set_tags(**fixture_tag_set)
@@ -266,7 +266,9 @@ def fixture_class_instance_no_param_interface():
return NoParamInterface()
-def test_get_class_tags(fixture_class_child, fixture_class_child_tags):
+def test_get_class_tags(
+ fixture_class_child: Type[Child], fixture_class_child_tags: Any
+):
"""Test get_class_tags class method of BaseObject for correctness.
Raises
@@ -280,7 +282,7 @@ def test_get_class_tags(fixture_class_child, fixture_class_child_tags):
assert child_tags == fixture_class_child_tags, msg
-def test_get_class_tag(fixture_class_child, fixture_class_child_tags):
+def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any):
"""Test get_class_tag class method of BaseObject for correctness.
Raises
@@ -307,7 +309,7 @@ def test_get_class_tag(fixture_class_child, fixture_class_child_tags):
assert child_tag_default_none is None, msg
-def test_get_tags(fixture_tag_class_object, fixture_object_tags):
+def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
"""Test get_tags method of BaseObject for correctness.
Raises
@@ -321,7 +323,7 @@ def test_get_tags(fixture_tag_class_object, fixture_object_tags):
assert object_tags == fixture_object_tags, msg
-def test_get_tag(fixture_tag_class_object, fixture_object_tags):
+def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
"""Test get_tag method of BaseObject for correctness.
Raises
@@ -351,7 +353,7 @@ def test_get_tag(fixture_tag_class_object, fixture_object_tags):
assert object_tag_default_none is None, msg
-def test_get_tag_raises(fixture_tag_class_object):
+def test_get_tag_raises(fixture_tag_class_object: Child):
"""Test that get_tag method raises error for unknown tag.
Raises
@@ -363,9 +365,9 @@ def test_get_tag_raises(fixture_tag_class_object):
def test_set_tags(
- fixture_object_instance_set_tags,
- fixture_object_set_tags,
- fixture_object_dynamic_tags,
+ fixture_object_instance_set_tags: Any,
+ fixture_object_set_tags: Dict[str, Any],
+ fixture_object_dynamic_tags: Dict[str, int],
):
"""Test set_tags method of BaseObject for correctness.
@@ -381,7 +383,9 @@ def test_set_tags(
assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg
-def test_set_tags_works_with_missing_tags_dynamic_attribute(fixture_tag_class_object):
+def test_set_tags_works_with_missing_tags_dynamic_attribute(
+ fixture_tag_class_object: Child,
+):
"""Test set_tags will still work if _tags_dynamic is missing."""
base_obj = deepcopy(fixture_tag_class_object)
delattr(base_obj, "_tags_dynamic")
@@ -460,7 +464,7 @@ class AnotherTestClass(BaseObject):
assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
-def test_is_composite(fixture_composition_dummy):
+def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]):
"""Test is_composite tag for correctness.
Raises
@@ -474,7 +478,11 @@ def test_is_composite(fixture_composition_dummy):
assert composite.is_composite()
-def test_components(fixture_object, fixture_class_parent, fixture_composition_dummy):
+def test_components(
+ fixture_object: Type[BaseObject],
+ fixture_class_parent: Type[Parent],
+ fixture_composition_dummy: Type[CompositionDummy],
+):
"""Test component retrieval.
Raises
@@ -507,7 +515,7 @@ def test_components(fixture_object, fixture_class_parent, fixture_composition_du
def test_components_raises_error_base_class_is_not_class(
- fixture_object, fixture_composition_dummy
+ fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy]
):
"""Test _component method raises error if base_class param is not class."""
non_composite = fixture_composition_dummy(foo=42)
@@ -526,7 +534,7 @@ def test_components_raises_error_base_class_is_not_class(
def test_components_raises_error_base_class_is_not_baseobject_subclass(
- fixture_composition_dummy,
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test _component method raises error if base_class is not BaseObject subclass."""
@@ -540,7 +548,7 @@ class SomeClass:
# Test parameter interface (get_params, set_params, reset and related methods)
# Some tests of get_params and set_params are adapted from sklearn tests
-def test_reset(fixture_reset_tester):
+def test_reset(fixture_reset_tester: Type[ResetTester]):
"""Test reset method for correct behaviour, on a simple estimator.
Raises
@@ -567,7 +575,7 @@ def test_reset(fixture_reset_tester):
assert hasattr(x, "foo")
-def test_reset_composite(fixture_reset_tester):
+def test_reset_composite(fixture_reset_tester: Type[ResetTester]):
"""Test reset method for correct behaviour, on a composite estimator."""
y = fixture_reset_tester(42)
x = fixture_reset_tester(a=y)
@@ -582,7 +590,7 @@ def test_reset_composite(fixture_reset_tester):
assert not hasattr(x.a, "d")
-def test_get_init_signature(fixture_class_parent):
+def test_get_init_signature(fixture_class_parent: Type[Parent]):
"""Test error is raised when invalid init signature is used."""
init_sig = fixture_class_parent._get_init_signature()
init_sig_is_list = isinstance(init_sig, list)
@@ -594,14 +602,18 @@ def test_get_init_signature(fixture_class_parent):
), "`_get_init_signature` is not returning expected result."
-def test_get_init_signature_raises_error_for_invalid_signature(fixture_invalid_init):
+def test_get_init_signature_raises_error_for_invalid_signature(
+ fixture_invalid_init: Type[InvalidInitSignatureTester],
+):
"""Test error is raised when invalid init signature is used."""
with pytest.raises(RuntimeError):
fixture_invalid_init._get_init_signature()
def test_get_param_names(
- fixture_object, fixture_class_parent, fixture_class_parent_expected_params
+ fixture_object: Type[BaseObject],
+ fixture_class_parent: Type[Parent],
+ fixture_class_parent_expected_params: Dict[str, Any],
):
"""Test that get_param_names returns list of string parameter names."""
param_names = fixture_class_parent.get_param_names()
@@ -612,10 +624,10 @@ def test_get_param_names(
def test_get_params(
- fixture_class_parent,
- fixture_class_parent_expected_params,
- fixture_class_instance_no_param_interface,
- fixture_composition_dummy,
+ fixture_class_parent: Type[Parent],
+ fixture_class_parent_expected_params: Dict[str, Any],
+ fixture_class_instance_no_param_interface: NoParamInterface,
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test get_params returns expected parameters."""
# Simple test of returned params
@@ -638,7 +650,10 @@ def test_get_params(
assert "foo" in params and "bar" in params and len(params) == 2
-def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy):
+def test_get_params_invariance(
+ fixture_class_parent: Type[Parent],
+ fixture_composition_dummy: Type[CompositionDummy],
+):
"""Test that get_params(deep=False) is subset of get_params(deep=True)."""
composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
shallow_params = composite.get_params(deep=False)
@@ -646,7 +661,7 @@ def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy):
assert all(item in deep_params.items() for item in shallow_params.items())
-def test_get_params_after_set_params(fixture_class_parent):
+def test_get_params_after_set_params(fixture_class_parent: Type[Parent]):
"""Test that get_params returns the same thing before and after set_params.
Based on scikit-learn check in check_estimator.
@@ -687,9 +702,9 @@ def test_get_params_after_set_params(fixture_class_parent):
def test_set_params(
- fixture_class_parent,
- fixture_class_parent_expected_params,
- fixture_composition_dummy,
+ fixture_class_parent: Type[Parent],
+ fixture_class_parent_expected_params: Dict[str, Any],
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test set_params works as expected."""
# Simple case of setting a parameter
@@ -711,7 +726,8 @@ def test_set_params(
def test_set_params_raises_error_non_existent_param(
- fixture_class_parent_instance, fixture_composition_dummy
+ fixture_class_parent_instance: Parent,
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test set_params raises an error when passed a non-existent parameter name."""
# non-existing parameter in svc
@@ -727,7 +743,8 @@ def test_set_params_raises_error_non_existent_param(
def test_set_params_raises_error_non_interface_composite(
- fixture_class_instance_no_param_interface, fixture_composition_dummy
+ fixture_class_instance_no_param_interface: NoParamInterface,
+ fixture_composition_dummy: Type[CompositionDummy],
):
"""Test set_params raises error when setting param of non-conforming composite."""
# When a composite is made up of a class that doesn't have the BaseObject
@@ -753,7 +770,9 @@ def __init__(self, param=5):
est.get_params()
-def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent):
+def test_set_params_with_no_param_to_set_returns_object(
+ fixture_class_parent: Type[Parent],
+):
"""Test set_params correctly returns self when no parameters are set."""
base_obj = fixture_class_parent()
orig_params = deepcopy(base_obj.get_params())
@@ -767,7 +786,7 @@ def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent):
# This section tests the clone functionality
# These have been adapted from sklearn's tests of clone to use the clone
# method that is included as part of the BaseObject interface
-def test_clone(fixture_class_parent_instance):
+def test_clone(fixture_class_parent_instance: Parent):
"""Test that clone is making a deep copy as expected."""
# Creates a BaseObject and makes a copy of its original state
# (which, in this case, is the current state of the BaseObject),
@@ -777,7 +796,7 @@ def test_clone(fixture_class_parent_instance):
assert fixture_class_parent_instance.get_params() == new_base_obj.get_params()
-def test_clone_2(fixture_class_parent_instance):
+def test_clone_2(fixture_class_parent_instance: Parent):
"""Test that clone does not copy attributes not set in constructor."""
# We first create an estimator, give it an own attribute, and
# make a copy of its original state. Then we check that the copy doesn't
@@ -790,7 +809,9 @@ def test_clone_2(fixture_class_parent_instance):
def test_clone_raises_error_for_nonconforming_objects(
- fixture_invalid_init, fixture_buggy, fixture_modify_param
+ fixture_invalid_init: Type[InvalidInitSignatureTester],
+ fixture_buggy: Type[Buggy],
+ fixture_modify_param: Type[ModifyParam],
):
"""Test that clone raises an error on nonconforming BaseObjects."""
buggy = fixture_buggy()
@@ -807,7 +828,7 @@ def test_clone_raises_error_for_nonconforming_objects(
obj_that_modifies.clone()
-def test_clone_param_is_none(fixture_class_parent):
+def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter set to None."""
base_obj = fixture_class_parent(c=None)
new_base_obj = clone(base_obj)
@@ -816,7 +837,7 @@ def test_clone_param_is_none(fixture_class_parent):
assert base_obj.c is new_base_obj2.c
-def test_clone_empty_array(fixture_class_parent):
+def test_clone_empty_array(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter is scipy sparse matrix.
This test is based on scikit-learn regression test to make sure clone
@@ -830,7 +851,7 @@ def test_clone_empty_array(fixture_class_parent):
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
-def test_clone_sparse_matrix(fixture_class_parent):
+def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter is scipy sparse matrix.
This test is based on scikit-learn regression test to make sure clone
@@ -843,7 +864,7 @@ def test_clone_sparse_matrix(fixture_class_parent):
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
-def test_clone_nan(fixture_class_parent):
+def test_clone_nan(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter is np.nan.
This test is based on scikit-learn regression test to make sure clone
@@ -858,7 +879,7 @@ def test_clone_nan(fixture_class_parent):
assert base_obj.c is new_base_obj2.c
-def test_clone_estimator_types(fixture_class_parent):
+def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
"""Test clone works for parameters that are types rather than instances."""
base_obj = fixture_class_parent(c=fixture_class_parent)
new_base_obj = base_obj.clone()
@@ -866,7 +887,9 @@ def test_clone_estimator_types(fixture_class_parent):
assert base_obj.c == new_base_obj.c
-def test_clone_class_rather_than_instance_raises_error(fixture_class_parent):
+def test_clone_class_rather_than_instance_raises_error(
+ fixture_class_parent: Type[Parent],
+):
"""Test clone raises expected error when cloning a class instead of an instance."""
msg = "You should provide an instance of scikit-learn estimator"
with pytest.raises(TypeError, match=msg):
@@ -874,17 +897,40 @@ def test_clone_class_rather_than_instance_raises_error(fixture_class_parent):
# Tests of BaseObject pretty printing representation inspired by sklearn
-def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy):
+def test_baseobject_repr(
+ fixture_class_parent: Type[Parent],
+ fixture_composition_dummy: Type[CompositionDummy],
+):
"""Test BaseObject repr works as expected."""
# Simple test where all parameters are left at defaults
# Should not see parameters and values in printed representation
+
base_obj = fixture_class_parent()
assert repr(base_obj) == "Parent()"
- # Check that we can alter the detail about params that is printed
- # using config_context with ``print_changed_only=False``
- with config_context(print_changed_only=False):
- assert repr(base_obj) == "Parent(a='something', b=7, c=None)"
+ # Check that local config works as expected
+ base_obj.set_config(print_changed_only=False)
+ assert repr(base_obj) == "Parent(a='something', b=7, c=None)"
+
+ # Test with dict parameter (note that dict is sorted by keys when printed)
+ # not printed in order it was created
+ base_obj = fixture_class_parent(c={"c": 1, "a": 2})
+ assert repr(base_obj) == "Parent(c={'a': 2, 'c': 1})"
+
+ # Now test when one params values are named object tuples
+ named_objs = [
+ ("step 1", fixture_class_parent()),
+ ("step 2", fixture_class_parent()),
+ ]
+ base_obj = fixture_class_parent(c=named_objs)
+ assert repr(base_obj) == "Parent(c=[('step 1', Parent()), ('step 2', Parent())])"
+
+ # Or when they are just lists of tuples or just tuples as param
+ base_obj = fixture_class_parent(c=[("one", 1), ("two", 2)])
+ assert repr(base_obj) == "Parent(c=[('one', 1), ('two', 2)])"
+
+ base_obj = fixture_class_parent(c=(1, 2, 3))
+ assert repr(base_obj) == "Parent(c=(1, 2, 3))"
simple_composite = fixture_composition_dummy(foo=fixture_class_parent())
assert repr(simple_composite) == "CompositionDummy(foo=Parent())"
@@ -892,53 +938,67 @@ def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy):
long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000)
assert len(repr(long_base_obj_repr)) == 535
+ named_objs = [(f"Step {i+1}", Child()) for i in range(25)]
+ base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs)))
+ assert len(repr(base_comp)) == 1362
+
-def test_baseobject_str(fixture_class_parent_instance):
+def test_baseobject_str(fixture_class_parent_instance: Parent):
"""Test BaseObject string representation works."""
- str(fixture_class_parent_instance)
+ assert (
+ str(fixture_class_parent_instance) == "Parent()"
+ ), "String representation of instance not working."
+
+ # Check that local config works as expected
+ fixture_class_parent_instance.set_config(print_changed_only=False)
+ assert str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)"
-def test_baseobject_repr_mimebundle_(fixture_class_parent_instance):
+def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent):
"""Test display configuration controls output."""
# Checks the display configuration flag controls the json output
- with config_context(display="diagram"):
- output = fixture_class_parent_instance._repr_mimebundle_()
- assert "text/plain" in output
- assert "text/html" in output
+ fixture_class_parent_instance.set_config(display="diagram")
+ output = fixture_class_parent_instance._repr_mimebundle_()
+ assert "text/plain" in output
+ assert "text/html" in output
- with config_context(display="text"):
- output = fixture_class_parent_instance._repr_mimebundle_()
- assert "text/plain" in output
- assert "text/html" not in output
+ fixture_class_parent_instance.set_config(display="text")
+ output = fixture_class_parent_instance._repr_mimebundle_()
+ assert "text/plain" in output
+ assert "text/html" not in output
-def test_repr_html_wraps(fixture_class_parent_instance):
+def test_repr_html_wraps(fixture_class_parent_instance: Parent):
"""Test display configuration flag controls the html output."""
- with config_context(display="diagram"):
- output = fixture_class_parent_instance._repr_html_()
- assert "