diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 970bf350..fb2cd2ed 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -53,6 +53,7 @@ class name: BaseEstimator fitted state check - check_is_fitted (raises error if not is_fitted) """ import inspect +import re import warnings from collections import defaultdict from copy import deepcopy @@ -62,6 +63,7 @@ class name: BaseEstimator from sklearn.base import BaseEstimator as _BaseEstimator from skbase._exceptions import NotFittedError +from skbase.base._pretty_printing._object_html_repr import _object_html_repr from skbase.base._tagmanager import _FlagManager __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] @@ -74,6 +76,11 @@ class BaseObject(_FlagManager, _BaseEstimator): Extends scikit-learn's BaseEstimator to include sktime style interface for tags. """ + _config = { + "display": "diagram", + "print_changed_only": True, + } + def __init__(self): """Construct BaseObject.""" self._init_flags(flag_attr_name="_tags") @@ -682,6 +689,98 @@ def _components(self, base_class=None): return comp_dict + def __repr__(self, n_char_max: int = 700): + """Represent class as string. + + This follows the scikit-learn implementation for the string representation + of parameterized objects. + + Parameters + ---------- + n_char_max : int + Maximum (approximate) number of non-blank characters to render. This + can be useful in testing. + """ + from skbase.base._pretty_printing._pprint import _BaseObjectPrettyPrinter + + n_max_elements_to_show = 30 # number of elements to show in sequences + # use ellipsis for sequences with a lot of elements + pp = _BaseObjectPrettyPrinter( + compact=True, + indent=1, + indent_at_name=True, + n_max_elements_to_show=n_max_elements_to_show, + changed_only=self.get_config()["print_changed_only"], + ) + + repr_ = pp.pformat(self) + + # Use bruteforce ellipsis when there are a lot of non-blank characters + n_nonblank = len("".join(repr_.split())) + if n_nonblank > n_char_max: + lim = n_char_max // 2 # apprx number of chars to keep on both ends + regex = r"^(\s*\S){%d}" % lim + # The regex '^(\s*\S){%d}' matches from the start of the string + # until the nth non-blank character: + # - ^ matches the start of string + # - (pattern){n} matches n repetitions of pattern + # - \s*\S matches a non-blank char following zero or more blanks + left_match = re.match(regex, repr_) + right_match = re.match(regex, repr_[::-1]) + left_lim = left_match.end() if left_match is not None else 0 + right_lim = right_match.end() if right_match is not None else 0 + + if "\n" in repr_[left_lim:-right_lim]: + # The left side and right side aren't on the same line. + # To avoid weird cuts, e.g.: + # categoric...ore', + # we need to start the right side with an appropriate newline + # character so that it renders properly as: + # categoric... + # handle_unknown='ignore', + # so we add [^\n]*\n which matches until the next \n + regex += r"[^\n]*\n" + right_match = re.match(regex, repr_[::-1]) + right_lim = right_match.end() if right_match is not None else 0 + + ellipsis = "..." + if left_lim + len(ellipsis) < len(repr_) - right_lim: + # Only add ellipsis if it results in a shorter repr + repr_ = repr_[:left_lim] + "..." + repr_[-right_lim:] + + return repr_ + + @property + def _repr_html_(self): + """HTML representation of BaseObject. + + This is redundant with the logic of `_repr_mimebundle_`. The latter + should be favorted in the long term, `_repr_html_` is only + implemented for consumers who do not interpret `_repr_mimbundle_`. + """ + if self.get_config()["display"] != "diagram": + raise AttributeError( + "_repr_html_ is only defined when the " + "`display` configuration option is set to 'diagram'." + ) + return self._repr_html_inner + + def _repr_html_inner(self): + """Return HTML representation of class. + + This function is returned by the @property `_repr_html_` to make + `hasattr(BaseObject, "_repr_html_") return `True` or `False` depending + on `self.get_config()["display"]`. + """ + return _object_html_repr(self) + + def _repr_mimebundle_(self, **kwargs): + """Mime bundle used by jupyter kernels to display instances of BaseObject.""" + output = {"text/plain": repr(self)} + if self.get_config()["display"] == "diagram": + output["text/html"] = _object_html_repr(self) + return output + class TagAliaserMixin: """Mixin class for tag aliasing and deprecation of old tags. diff --git a/skbase/base/_pretty_printing/__init__.py b/skbase/base/_pretty_printing/__init__.py new file mode 100644 index 00000000..669c800c --- /dev/null +++ b/skbase/base/_pretty_printing/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Functionality for pretty printing BaseObjects.""" +from typing import List + +__author__: List[str] = ["RNKuhns"] +__all__: List[str] = [] diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py new file mode 100644 index 00000000..397b289c --- /dev/null +++ b/skbase/base/_pretty_printing/_object_html_repr.py @@ -0,0 +1,392 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Functionality to represent instance of BaseObject as html.""" + +import html +import uuid +from contextlib import closing, suppress +from io import StringIO +from string import Template + +__author__ = ["RNKuhns"] + + +class _VisualBlock: + """HTML Representation of BaseObject. + + Parameters + ---------- + kind : {'serial', 'parallel', 'single'} + kind of HTML block + + objs : list of BaseObjects or `_VisualBlock`s or a single BaseObject + If kind != 'single', then `objs` is a list of + BaseObjects. If kind == 'single', then `objs` is a single BaseObject. + + names : list of str, default=None + If kind != 'single', then `names` corresponds to BaseObjects. + If kind == 'single', then `names` is a single string corresponding to + the single BaseObject. + + name_details : list of str, str, or None, default=None + If kind != 'single', then `name_details` corresponds to `names`. + If kind == 'single', then `name_details` is a single string + corresponding to the single BaseObject. + + dash_wrapped : bool, default=True + If true, wrapped HTML element will be wrapped with a dashed border. + Only active when kind != 'single'. + """ + + def __init__(self, kind, objs, *, names=None, name_details=None, dash_wrapped=True): + self.kind = kind + self.objs = objs + self.dash_wrapped = dash_wrapped + + if self.kind in ("parallel", "serial"): + if names is None: + names = (None,) * len(objs) + if name_details is None: + name_details = (None,) * len(objs) + + self.names = names + self.name_details = name_details + + def _sk_visual_block_(self): + return self + + +def _write_label_html( + out, + name, + name_details, + outer_class="sk-label-container", + inner_class="sk-label", + checked=False, +): + """Write labeled html with or without a dropdown with named details.""" + out.write(f'
') + name = html.escape(name) + + if name_details is not None: + name_details = html.escape(str(name_details)) + label_class = "sk-toggleable__label sk-toggleable__label-arrow" + + checked_str = "checked" if checked else "" + est_id = uuid.uuid4() + out.write( + '' + f"" + f'
{name_details}'
+            "
" + ) + else: + out.write(f"") + out.write("
") # outer_class inner_class + + +def _get_visual_block(base_object): + """Generate information about how to display a BaseObject.""" + with suppress(AttributeError): + return base_object._sk_visual_block_() + + if isinstance(base_object, str): + return _VisualBlock( + "single", base_object, names=base_object, name_details=base_object + ) + elif base_object is None: + return _VisualBlock("single", base_object, names="None", name_details="None") + + # check if base_object looks like a meta base_object wraps base_object + if hasattr(base_object, "get_params"): + base_objects = [] + for key, value in base_object.get_params().items(): + # Only look at the BaseObjects in the first layer + if "__" not in key and hasattr(value, "get_params"): + base_objects.append(value) + if len(base_objects): + return _VisualBlock("parallel", base_objects, names=None) + + return _VisualBlock( + "single", + base_object, + names=base_object.__class__.__name__, + name_details=str(base_object), + ) + + +def _write_base_object_html( + out, base_object, base_object_label, base_object_label_details, first_call=False +): + """Write BaseObject to html in serial, parallel, or by itself (single).""" + est_block = _get_visual_block(base_object) + + if est_block.kind in ("serial", "parallel"): + dashed_wrapped = first_call or est_block.dash_wrapped + dash_cls = " sk-dashed-wrapped" if dashed_wrapped else "" + out.write(f'
') + + if base_object_label: + _write_label_html(out, base_object_label, base_object_label_details) + + kind = est_block.kind + out.write(f'
') + est_infos = zip(est_block.objs, est_block.names, est_block.name_details) + + for est, name, name_details in est_infos: + if kind == "serial": + _write_base_object_html(out, est, name, name_details) + else: # parallel + out.write('
') + # wrap element in a serial visualblock + serial_block = _VisualBlock("serial", [est], dash_wrapped=False) + _write_base_object_html(out, serial_block, name, name_details) + out.write("
") # sk-parallel-item + + out.write("
") + elif est_block.kind == "single": + _write_label_html( + out, + est_block.names, + est_block.name_details, + outer_class="sk-item", + inner_class="sk-estimator", + checked=first_call, + ) + + +_STYLE = """ +#$id { + color: black; + background-color: white; +} +#$id pre{ + padding: 0; +} +#$id div.sk-toggleable { + background-color: white; +} +#$id label.sk-toggleable__label { + cursor: pointer; + display: block; + width: 100%; + margin-bottom: 0; + padding: 0.3em; + box-sizing: border-box; + text-align: center; +} +#$id label.sk-toggleable__label-arrow:before { + content: "▸"; + float: left; + margin-right: 0.25em; + color: #696969; +} +#$id label.sk-toggleable__label-arrow:hover:before { + color: black; +} +#$id div.sk-estimator:hover label.sk-toggleable__label-arrow:before { + color: black; +} +#$id div.sk-toggleable__content { + max-height: 0; + max-width: 0; + overflow: hidden; + text-align: left; + background-color: #f0f8ff; +} +#$id div.sk-toggleable__content pre { + margin: 0.2em; + color: black; + border-radius: 0.25em; + background-color: #f0f8ff; +} +#$id input.sk-toggleable__control:checked~div.sk-toggleable__content { + max-height: 200px; + max-width: 100%; + overflow: auto; +} +#$id input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before { + content: "▾"; +} +#$id div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id input.sk-hidden--visually { + border: 0; + clip: rect(1px 1px 1px 1px); + clip: rect(1px, 1px, 1px, 1px); + height: 1px; + margin: -1px; + overflow: hidden; + padding: 0; + position: absolute; + width: 1px; +} +#$id div.sk-estimator { + font-family: monospace; + background-color: #f0f8ff; + border: 1px dotted black; + border-radius: 0.25em; + box-sizing: border-box; + margin-bottom: 0.5em; +} +#$id div.sk-estimator:hover { + background-color: #d4ebff; +} +#$id div.sk-parallel-item::after { + content: ""; + width: 100%; + border-bottom: 1px solid gray; + flex-grow: 1; +} +#$id div.sk-label:hover label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id div.sk-serial::before { + content: ""; + position: absolute; + border-left: 1px solid gray; + box-sizing: border-box; + top: 2em; + bottom: 0; + left: 50%; +} +#$id div.sk-serial { + display: flex; + flex-direction: column; + align-items: center; + background-color: white; + padding-right: 0.2em; + padding-left: 0.2em; +} +#$id div.sk-item { + z-index: 1; +} +#$id div.sk-parallel { + display: flex; + align-items: stretch; + justify-content: center; + background-color: white; +} +#$id div.sk-parallel::before { + content: ""; + position: absolute; + border-left: 1px solid gray; + box-sizing: border-box; + top: 2em; + bottom: 0; + left: 50%; +} +#$id div.sk-parallel-item { + display: flex; + flex-direction: column; + position: relative; + background-color: white; +} +#$id div.sk-parallel-item:first-child::after { + align-self: flex-end; + width: 50%; +} +#$id div.sk-parallel-item:last-child::after { + align-self: flex-start; + width: 50%; +} +#$id div.sk-parallel-item:only-child::after { + width: 0; +} +#$id div.sk-dashed-wrapped { + border: 1px dashed gray; + margin: 0 0.4em 0.5em 0.4em; + box-sizing: border-box; + padding-bottom: 0.4em; + background-color: white; + position: relative; +} +#$id div.sk-label label { + font-family: monospace; + font-weight: bold; + background-color: white; + display: inline-block; + line-height: 1.2em; +} +#$id div.sk-label-container { + position: relative; + z-index: 2; + text-align: center; +} +#$id div.sk-container { + /* jupyter's `normalize.less` sets `[hidden] { display: none; }` + but bootstrap.min.css set `[hidden] { display: none !important; }` + so we also need the `!important` here to be able to override the + default hidden behavior on the sphinx rendered scikit-learn.org. + See: https://github.com/scikit-learn/scikit-learn/issues/21755 */ + display: inline-block !important; + position: relative; +} +#$id div.sk-text-repr-fallback { + display: none; +} +""".replace( + " ", "" +).replace( + "\n", "" +) # noqa + + +def _object_html_repr(base_object): + """Build a HTML representation of a BaseObject. + + Parameters + ---------- + base_object : base object + The BaseObject or inheritting class to visualize. + + Returns + ------- + html: str + HTML representation of BaseObject. + """ + with closing(StringIO()) as out: + container_id = "sk-" + str(uuid.uuid4()) + style_template = Template(_STYLE) + style_with_id = style_template.substitute(id=container_id) + base_object_str = str(base_object) + + # The fallback message is shown by default and loading the CSS sets + # div.sk-text-repr-fallback to display: none to hide the fallback message. + # + # If the notebook is trusted, the CSS is loaded which hides the fallback + # message. If the notebook is not trusted, then the CSS is not loaded and the + # fallback message is shown by default. + # + # The reverse logic applies to HTML repr div.sk-container. + # div.sk-container is hidden by default and the loading the CSS displays it. + fallback_msg = ( + "Please rerun this cell to show the HTML repr or trust the notebook." + ) + out.write( + f"" + f'
' + '
' + f"
{html.escape(base_object_str)}
{fallback_msg}" + "
" + '
") + + html_output = out.getvalue() + return html_output diff --git a/skbase/base/_pretty_printing/_pprint.py b/skbase/base/_pretty_printing/_pprint.py new file mode 100644 index 00000000..2e8a8a80 --- /dev/null +++ b/skbase/base/_pretty_printing/_pprint.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Utility functionality for pretty-printing objects used in BaseObject.__repr__.""" +import inspect +import pprint +from collections import OrderedDict + +from skbase.base import BaseObject + +# from skbase.config import get_config +from skbase.utils._check import _is_scalar_nan + + +class KeyValTuple(tuple): + """Dummy class for correctly rendering key-value tuples from dicts.""" + + def __repr__(self): + """Represent as string.""" + # needed for _dispatch[tuple.__repr__] not to be overridden + return super().__repr__() + + +class KeyValTupleParam(KeyValTuple): + """Dummy class for correctly rendering key-value tuples from parameters.""" + + pass + + +def _changed_params(base_object): + """Return dict (param_name: value) of parameters with non-default values.""" + params = base_object.get_params(deep=False) + init_func = getattr( + base_object.__init__, "deprecated_original", base_object.__init__ + ) + init_params = inspect.signature(init_func).parameters + init_params = {name: param.default for name, param in init_params.items()} + + def has_changed(k, v): + if k not in init_params: # happens if k is part of a **kwargs + return True + if init_params[k] == inspect._empty: # k has no default value + return True + # try to avoid calling repr on nested BaseObjects + if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__: + return True + # Use repr as a last resort. It may be expensive. + if repr(v) != repr(init_params[k]) and not ( + _is_scalar_nan(init_params[k]) and _is_scalar_nan(v) + ): + return True + return False + + return {k: v for k, v in params.items() if has_changed(k, v)} + + +class _BaseObjectPrettyPrinter(pprint.PrettyPrinter): + """Pretty Printer class for BaseObjects. + + This extends the pprint.PrettyPrinter class similar to scikit-learn's + implementation, so that: + + - BaseObjects are printed with their parameters, e.g. + BaseObject(param1=value1, ...) which is not supported by default. + - the 'compact' parameter of PrettyPrinter is ignored for dicts, which + may lead to very long representations that we want to avoid. + + Quick overview of pprint.PrettyPrinter (see also + https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers): + + - the entry point is the _format() method which calls format() (overridden + here) + - format() directly calls _safe_repr() for a first try at rendering the + object + - _safe_repr formats the whole object recursively, only calling itself, + not caring about line length or anything + - back to _format(), if the output string is too long, _format() then calls + the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on + the type of the object. This where the line length and the compact + parameters are taken into account. + - those _pprint_TYPE() methods will internally use the format() method for + rendering the nested objects of an object (e.g. the elements of a list) + + In the end, everything has to be implemented twice: in _safe_repr and in + the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not + straightforward to extend (especially when we want a compact output), so + the code is a bit convoluted. + + This class overrides: + - format() to support the changed_only parameter + - _safe_repr to support printing of BaseObjects that fit on a single line + - _format_dict_items so that dict are correctly 'compacted' + - _format_items so that ellipsis is used on long lists and tuples + + When BaseObjects cannot be printed on a single line, the builtin _format() + will call _pprint_object() because it was registered to do so (see + _dispatch[BaseObject.__repr__] = _pprint_object). + + both _format_dict_items() and _pprint_Object() use the + _format_params_or_dict_items() method that will format parameters and + key-value pairs respecting the compact parameter. This method needs another + subroutine _pprint_key_val_tuple() used when a parameter or a key-value + pair is too long to fit on a single line. This subroutine is called in + _format() and is registered as well in the _dispatch dict (just like + _pprint_object). We had to create the two classes KeyValTuple and + KeyValTupleParam for this. + """ + + def __init__( + self, + indent=1, + width=80, + depth=None, + stream=None, + *, + compact=False, + indent_at_name=True, + n_max_elements_to_show=None, + changed_only=True, + ): + super().__init__(indent, width, depth, stream, compact=compact) + self._indent_at_name = indent_at_name + if self._indent_at_name: + self._indent_per_level = 1 # ignore indent param + self.changed_only = changed_only + # Max number of elements in a list, dict, tuple until we start using + # ellipsis. This also affects the number of arguments of a BaseObject + # (they are treated as dicts) + self.n_max_elements_to_show = n_max_elements_to_show + + def format(self, obj, context, maxlevels, level): # noqa + return _safe_repr( + obj, context, maxlevels, level, changed_only=self.changed_only + ) + + def _pprint_object(self, obj, stream, indent, allowance, context, level): + stream.write(obj.__class__.__name__ + "(") + if self._indent_at_name: + indent += len(obj.__class__.__name__) + + if self.changed_only: + params = _changed_params(obj) + else: + params = obj.get_params(deep=False) + + params = OrderedDict((name, val) for (name, val) in sorted(params.items())) + + self._format_params( + params.items(), stream, indent, allowance + 1, context, level + ) + stream.write(")") + + def _format_dict_items(self, items, stream, indent, allowance, context, level): + return self._format_params_or_dict_items( + items, stream, indent, allowance, context, level, is_dict=True + ) + + def _format_params(self, items, stream, indent, allowance, context, level): + return self._format_params_or_dict_items( + items, stream, indent, allowance, context, level, is_dict=False + ) + + def _format_params_or_dict_items( + self, obj, stream, indent, allowance, context, level, is_dict + ): + """Format dict items or parameters respecting the compact=True parameter. + + For some reason, the builtin rendering of dict items doesn't + respect compact=True and will use one line per key-value if all cannot + fit in a single line. + Dict items will be rendered as <'key': value> while params will be + rendered as . The implementation is mostly copy/pasting from + the builtin _format_items(). + This also adds ellipsis if the number of items is greater than + self.n_max_elements_to_show. + """ + write = stream.write + indent += self._indent_per_level + delimnl = ",\n" + " " * indent + delim = "" + width = max_width = self._width - indent + 1 + it = iter(obj) + try: + next_ent = next(it) + except StopIteration: + return + last = False + n_items = 0 + while not last: + if n_items == self.n_max_elements_to_show: + write(", ...") + break + n_items += 1 + ent = next_ent + try: + next_ent = next(it) + except StopIteration: + last = True + max_width -= allowance + width -= allowance + if self._compact: + k, v = ent + krepr = self._repr(k, context, level) + vrepr = self._repr(v, context, level) + if not is_dict: + krepr = krepr.strip("'") + middle = ": " if is_dict else "=" + rep = krepr + middle + vrepr + w = len(rep) + 2 + if width < w: + width = max_width + if delim: + delim = delimnl + if width >= w: + width -= w + write(delim) + delim = ", " + write(rep) + continue + write(delim) + delim = delimnl + class_ = KeyValTuple if is_dict else KeyValTupleParam + self._format( + class_(ent), stream, indent, allowance if last else 1, context, level + ) + + def _format_items(self, items, stream, indent, allowance, context, level): + """Format the items of an iterable (list, tuple...). + + Same as the built-in _format_items, with support for ellipsis if the + number of elements is greater than self.n_max_elements_to_show. + """ + write = stream.write + indent += self._indent_per_level + if self._indent_per_level > 1: + write((self._indent_per_level - 1) * " ") + delimnl = ",\n" + " " * indent + delim = "" + width = max_width = self._width - indent + 1 + it = iter(items) + try: + next_ent = next(it) + except StopIteration: + return + last = False + n_items = 0 + while not last: + if n_items == self.n_max_elements_to_show: + write(", ...") + break + n_items += 1 + ent = next_ent + try: + next_ent = next(it) + except StopIteration: + last = True + max_width -= allowance + width -= allowance + if self._compact: + rep = self._repr(ent, context, level) + w = len(rep) + 2 + if width < w: + width = max_width + if delim: + delim = delimnl + if width >= w: + width -= w + write(delim) + delim = ", " + write(rep) + continue + write(delim) + delim = delimnl + self._format(ent, stream, indent, allowance if last else 1, context, level) + + def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level): + """Pretty printing for key-value tuples from dict or parameters.""" + k, v = obj + rep = self._repr(k, context, level) + if isinstance(obj, KeyValTupleParam): + rep = rep.strip("'") + middle = "=" + else: + middle = ": " + stream.write(rep) + stream.write(middle) + self._format( + v, stream, indent + len(rep) + len(middle), allowance, context, level + ) + + # Follow what scikit-learn did here and copy _dispatch to prevent instances + # of the builtin PrettyPrinter class to call methods of + # _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906) + # mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch" + _dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore + _dispatch[BaseObject.__repr__] = _pprint_object + _dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple + + +def _safe_repr(obj, context, maxlevels, level, changed_only=False): + """Safe string representation logic. + + Same as the builtin _safe_repr, with added support for BaseObjects. + """ + typ = type(obj) + + if typ in pprint._builtin_scalars: + return repr(obj), True, False + + r = getattr(typ, "__repr__", None) + if issubclass(typ, dict) and r is dict.__repr__: + if not obj: + return "{}", True, False + objid = id(obj) + if maxlevels and level >= maxlevels: + return "{...}", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + components = [] + append = components.append + level += 1 + saferepr = _safe_repr + items = sorted(obj.items(), key=pprint._safe_tuple) + for k, v in items: + krepr, kreadable, krecur = saferepr( + k, context, maxlevels, level, changed_only=changed_only + ) + vrepr, vreadable, vrecur = saferepr( + v, context, maxlevels, level, changed_only=changed_only + ) + append("%s: %s" % (krepr, vrepr)) + readable = readable and kreadable and vreadable + if krecur or vrecur: + recursive = True + del context[objid] + return "{%s}" % ", ".join(components), readable, recursive + + if (issubclass(typ, list) and r is list.__repr__) or ( + issubclass(typ, tuple) and r is tuple.__repr__ + ): + if issubclass(typ, list): + if not obj: + return "[]", True, False + format_ = "[%s]" + elif len(obj) == 1: + format_ = "(%s,)" + else: + if not obj: + return "()", True, False + format_ = "(%s)" + objid = id(obj) + if maxlevels and level >= maxlevels: + return format_ % "...", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + components = [] + append = components.append + level += 1 + for o in obj: + orepr, oreadable, orecur = _safe_repr( + o, context, maxlevels, level, changed_only=changed_only + ) + append(orepr) + if not oreadable: + readable = False + if orecur: + recursive = True + del context[objid] + return format_ % ", ".join(components), readable, recursive + + if issubclass(typ, BaseObject): + objid = id(obj) + if maxlevels and level >= maxlevels: + return "{...}", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + if changed_only: + params = _changed_params(obj) + else: + params = obj.get_params(deep=False) + components = [] + append = components.append + level += 1 + saferepr = _safe_repr + items = sorted(params.items(), key=pprint._safe_tuple) + for k, v in items: + krepr, kreadable, krecur = saferepr( + k, context, maxlevels, level, changed_only=changed_only + ) + vrepr, vreadable, vrecur = saferepr( + v, context, maxlevels, level, changed_only=changed_only + ) + append("%s=%s" % (krepr.strip("'"), vrepr)) + readable = readable and kreadable and vreadable + if krecur or vrecur: + recursive = True + del context[objid] + return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive) + + rep = repr(obj) + return rep, (rep and not rep.startswith("<")), False diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 0b9bd15f..7a0508ef 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -22,6 +22,9 @@ "skbase.base", "skbase.base._base", "skbase.base._meta", + "skbase.base._pretty_printing", + "skbase.base._pretty_printing._object_html_repr", + "skbase.base._pretty_printing._pprint", "skbase.base._tagmanager", "skbase.lookup", "skbase.lookup.tests", @@ -42,6 +45,7 @@ "skbase.tests.test_baseestimator", "skbase.tests.mock_package.test_mock_package", "skbase.utils", + "skbase.utils._check", "skbase.utils._iter", "skbase.utils._nested_iter", "skbase.utils._utils", @@ -80,6 +84,7 @@ ), "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaObject", "BaseMetaEstimator"), + "skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"), "skbase.testing.test_all_objects": ( @@ -96,6 +101,12 @@ "BaseMetaEstimator", "_MetaObjectMixin", ), + "skbase.base._pretty_printing._object_html_repr": ("_VisualBlock",), + "skbase.base._pretty_printing._pprint": ( + "KeyValTuple", + "KeyValTupleParam", + "_BaseObjectPrettyPrinter", + ), "skbase.base._tagmanager": ("_FlagManager",), } ) @@ -140,6 +151,13 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { + "skbase.base._pretty_printing._object_html_repr": ( + "_get_visual_block", + "_object_html_repr", + "_write_base_object_html", + "_write_label_html", + ), + "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"), "skbase.lookup._lookup": ( "_determine_module_path", "_get_return_tags", @@ -171,6 +189,7 @@ "_coerce_list", ), "skbase.testing.utils.inspect": ("_get_args",), + "skbase.utils._check": ("_is_scalar_nan",), "skbase.utils._iter": ( "_format_seq_to_str", "_remove_type_text", diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 38d2d5a9..a5dc5341 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -69,11 +69,11 @@ import inspect from copy import deepcopy +from typing import Any, Dict, Type import numpy as np import pytest import scipy.sparse as sp -from sklearn import config_context # TODO: Update with import of skbase clone function once implemented from sklearn.base import clone @@ -200,13 +200,13 @@ def fixture_reset_tester(): @pytest.fixture -def fixture_class_child_tags(fixture_class_child): +def fixture_class_child_tags(fixture_class_child: Type[Child]): """Pytest fixture for tags of Child.""" return fixture_class_child.get_class_tags() @pytest.fixture -def fixture_object_instance_set_tags(fixture_tag_class_object): +def fixture_object_instance_set_tags(fixture_tag_class_object: Child): """Fixture class instance to test tag setting.""" fixture_tag_set = {"A": 42424243, "E": 3} return fixture_tag_class_object.set_tags(**fixture_tag_set) @@ -266,7 +266,9 @@ def fixture_class_instance_no_param_interface(): return NoParamInterface() -def test_get_class_tags(fixture_class_child, fixture_class_child_tags): +def test_get_class_tags( + fixture_class_child: Type[Child], fixture_class_child_tags: Any +): """Test get_class_tags class method of BaseObject for correctness. Raises @@ -280,7 +282,7 @@ def test_get_class_tags(fixture_class_child, fixture_class_child_tags): assert child_tags == fixture_class_child_tags, msg -def test_get_class_tag(fixture_class_child, fixture_class_child_tags): +def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any): """Test get_class_tag class method of BaseObject for correctness. Raises @@ -307,7 +309,7 @@ def test_get_class_tag(fixture_class_child, fixture_class_child_tags): assert child_tag_default_none is None, msg -def test_get_tags(fixture_tag_class_object, fixture_object_tags): +def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): """Test get_tags method of BaseObject for correctness. Raises @@ -321,7 +323,7 @@ def test_get_tags(fixture_tag_class_object, fixture_object_tags): assert object_tags == fixture_object_tags, msg -def test_get_tag(fixture_tag_class_object, fixture_object_tags): +def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): """Test get_tag method of BaseObject for correctness. Raises @@ -351,7 +353,7 @@ def test_get_tag(fixture_tag_class_object, fixture_object_tags): assert object_tag_default_none is None, msg -def test_get_tag_raises(fixture_tag_class_object): +def test_get_tag_raises(fixture_tag_class_object: Child): """Test that get_tag method raises error for unknown tag. Raises @@ -363,9 +365,9 @@ def test_get_tag_raises(fixture_tag_class_object): def test_set_tags( - fixture_object_instance_set_tags, - fixture_object_set_tags, - fixture_object_dynamic_tags, + fixture_object_instance_set_tags: Any, + fixture_object_set_tags: Dict[str, Any], + fixture_object_dynamic_tags: Dict[str, int], ): """Test set_tags method of BaseObject for correctness. @@ -381,7 +383,9 @@ def test_set_tags( assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg -def test_set_tags_works_with_missing_tags_dynamic_attribute(fixture_tag_class_object): +def test_set_tags_works_with_missing_tags_dynamic_attribute( + fixture_tag_class_object: Child, +): """Test set_tags will still work if _tags_dynamic is missing.""" base_obj = deepcopy(fixture_tag_class_object) delattr(base_obj, "_tags_dynamic") @@ -460,7 +464,7 @@ class AnotherTestClass(BaseObject): assert test_obj_tags.get(tag) == another_base_obj_tags[tag] -def test_is_composite(fixture_composition_dummy): +def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]): """Test is_composite tag for correctness. Raises @@ -474,7 +478,11 @@ def test_is_composite(fixture_composition_dummy): assert composite.is_composite() -def test_components(fixture_object, fixture_class_parent, fixture_composition_dummy): +def test_components( + fixture_object: Type[BaseObject], + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test component retrieval. Raises @@ -507,7 +515,7 @@ def test_components(fixture_object, fixture_class_parent, fixture_composition_du def test_components_raises_error_base_class_is_not_class( - fixture_object, fixture_composition_dummy + fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy] ): """Test _component method raises error if base_class param is not class.""" non_composite = fixture_composition_dummy(foo=42) @@ -526,7 +534,7 @@ def test_components_raises_error_base_class_is_not_class( def test_components_raises_error_base_class_is_not_baseobject_subclass( - fixture_composition_dummy, + fixture_composition_dummy: Type[CompositionDummy], ): """Test _component method raises error if base_class is not BaseObject subclass.""" @@ -540,7 +548,7 @@ class SomeClass: # Test parameter interface (get_params, set_params, reset and related methods) # Some tests of get_params and set_params are adapted from sklearn tests -def test_reset(fixture_reset_tester): +def test_reset(fixture_reset_tester: Type[ResetTester]): """Test reset method for correct behaviour, on a simple estimator. Raises @@ -567,7 +575,7 @@ def test_reset(fixture_reset_tester): assert hasattr(x, "foo") -def test_reset_composite(fixture_reset_tester): +def test_reset_composite(fixture_reset_tester: Type[ResetTester]): """Test reset method for correct behaviour, on a composite estimator.""" y = fixture_reset_tester(42) x = fixture_reset_tester(a=y) @@ -582,7 +590,7 @@ def test_reset_composite(fixture_reset_tester): assert not hasattr(x.a, "d") -def test_get_init_signature(fixture_class_parent): +def test_get_init_signature(fixture_class_parent: Type[Parent]): """Test error is raised when invalid init signature is used.""" init_sig = fixture_class_parent._get_init_signature() init_sig_is_list = isinstance(init_sig, list) @@ -594,14 +602,18 @@ def test_get_init_signature(fixture_class_parent): ), "`_get_init_signature` is not returning expected result." -def test_get_init_signature_raises_error_for_invalid_signature(fixture_invalid_init): +def test_get_init_signature_raises_error_for_invalid_signature( + fixture_invalid_init: Type[InvalidInitSignatureTester], +): """Test error is raised when invalid init signature is used.""" with pytest.raises(RuntimeError): fixture_invalid_init._get_init_signature() def test_get_param_names( - fixture_object, fixture_class_parent, fixture_class_parent_expected_params + fixture_object: Type[BaseObject], + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], ): """Test that get_param_names returns list of string parameter names.""" param_names = fixture_class_parent.get_param_names() @@ -612,10 +624,10 @@ def test_get_param_names( def test_get_params( - fixture_class_parent, - fixture_class_parent_expected_params, - fixture_class_instance_no_param_interface, - fixture_composition_dummy, + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], + fixture_class_instance_no_param_interface: NoParamInterface, + fixture_composition_dummy: Type[CompositionDummy], ): """Test get_params returns expected parameters.""" # Simple test of returned params @@ -638,7 +650,10 @@ def test_get_params( assert "foo" in params and "bar" in params and len(params) == 2 -def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy): +def test_get_params_invariance( + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test that get_params(deep=False) is subset of get_params(deep=True).""" composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84) shallow_params = composite.get_params(deep=False) @@ -646,7 +661,7 @@ def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy): assert all(item in deep_params.items() for item in shallow_params.items()) -def test_get_params_after_set_params(fixture_class_parent): +def test_get_params_after_set_params(fixture_class_parent: Type[Parent]): """Test that get_params returns the same thing before and after set_params. Based on scikit-learn check in check_estimator. @@ -687,9 +702,9 @@ def test_get_params_after_set_params(fixture_class_parent): def test_set_params( - fixture_class_parent, - fixture_class_parent_expected_params, - fixture_composition_dummy, + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params works as expected.""" # Simple case of setting a parameter @@ -711,7 +726,8 @@ def test_set_params( def test_set_params_raises_error_non_existent_param( - fixture_class_parent_instance, fixture_composition_dummy + fixture_class_parent_instance: Parent, + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params raises an error when passed a non-existent parameter name.""" # non-existing parameter in svc @@ -727,7 +743,8 @@ def test_set_params_raises_error_non_existent_param( def test_set_params_raises_error_non_interface_composite( - fixture_class_instance_no_param_interface, fixture_composition_dummy + fixture_class_instance_no_param_interface: NoParamInterface, + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params raises error when setting param of non-conforming composite.""" # When a composite is made up of a class that doesn't have the BaseObject @@ -753,7 +770,9 @@ def __init__(self, param=5): est.get_params() -def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent): +def test_set_params_with_no_param_to_set_returns_object( + fixture_class_parent: Type[Parent], +): """Test set_params correctly returns self when no parameters are set.""" base_obj = fixture_class_parent() orig_params = deepcopy(base_obj.get_params()) @@ -767,7 +786,7 @@ def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent): # This section tests the clone functionality # These have been adapted from sklearn's tests of clone to use the clone # method that is included as part of the BaseObject interface -def test_clone(fixture_class_parent_instance): +def test_clone(fixture_class_parent_instance: Parent): """Test that clone is making a deep copy as expected.""" # Creates a BaseObject and makes a copy of its original state # (which, in this case, is the current state of the BaseObject), @@ -777,7 +796,7 @@ def test_clone(fixture_class_parent_instance): assert fixture_class_parent_instance.get_params() == new_base_obj.get_params() -def test_clone_2(fixture_class_parent_instance): +def test_clone_2(fixture_class_parent_instance: Parent): """Test that clone does not copy attributes not set in constructor.""" # We first create an estimator, give it an own attribute, and # make a copy of its original state. Then we check that the copy doesn't @@ -790,7 +809,9 @@ def test_clone_2(fixture_class_parent_instance): def test_clone_raises_error_for_nonconforming_objects( - fixture_invalid_init, fixture_buggy, fixture_modify_param + fixture_invalid_init: Type[InvalidInitSignatureTester], + fixture_buggy: Type[Buggy], + fixture_modify_param: Type[ModifyParam], ): """Test that clone raises an error on nonconforming BaseObjects.""" buggy = fixture_buggy() @@ -807,7 +828,7 @@ def test_clone_raises_error_for_nonconforming_objects( obj_that_modifies.clone() -def test_clone_param_is_none(fixture_class_parent): +def test_clone_param_is_none(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter set to None.""" base_obj = fixture_class_parent(c=None) new_base_obj = clone(base_obj) @@ -816,7 +837,7 @@ def test_clone_param_is_none(fixture_class_parent): assert base_obj.c is new_base_obj2.c -def test_clone_empty_array(fixture_class_parent): +def test_clone_empty_array(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is scipy sparse matrix. This test is based on scikit-learn regression test to make sure clone @@ -830,7 +851,7 @@ def test_clone_empty_array(fixture_class_parent): np.testing.assert_array_equal(base_obj.c, new_base_obj2.c) -def test_clone_sparse_matrix(fixture_class_parent): +def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is scipy sparse matrix. This test is based on scikit-learn regression test to make sure clone @@ -843,7 +864,7 @@ def test_clone_sparse_matrix(fixture_class_parent): np.testing.assert_array_equal(base_obj.c, new_base_obj2.c) -def test_clone_nan(fixture_class_parent): +def test_clone_nan(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is np.nan. This test is based on scikit-learn regression test to make sure clone @@ -858,7 +879,7 @@ def test_clone_nan(fixture_class_parent): assert base_obj.c is new_base_obj2.c -def test_clone_estimator_types(fixture_class_parent): +def test_clone_estimator_types(fixture_class_parent: Type[Parent]): """Test clone works for parameters that are types rather than instances.""" base_obj = fixture_class_parent(c=fixture_class_parent) new_base_obj = base_obj.clone() @@ -866,7 +887,9 @@ def test_clone_estimator_types(fixture_class_parent): assert base_obj.c == new_base_obj.c -def test_clone_class_rather_than_instance_raises_error(fixture_class_parent): +def test_clone_class_rather_than_instance_raises_error( + fixture_class_parent: Type[Parent], +): """Test clone raises expected error when cloning a class instead of an instance.""" msg = "You should provide an instance of scikit-learn estimator" with pytest.raises(TypeError, match=msg): @@ -874,17 +897,40 @@ def test_clone_class_rather_than_instance_raises_error(fixture_class_parent): # Tests of BaseObject pretty printing representation inspired by sklearn -def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy): +def test_baseobject_repr( + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test BaseObject repr works as expected.""" # Simple test where all parameters are left at defaults # Should not see parameters and values in printed representation + base_obj = fixture_class_parent() assert repr(base_obj) == "Parent()" - # Check that we can alter the detail about params that is printed - # using config_context with ``print_changed_only=False`` - with config_context(print_changed_only=False): - assert repr(base_obj) == "Parent(a='something', b=7, c=None)" + # Check that local config works as expected + base_obj.set_config(print_changed_only=False) + assert repr(base_obj) == "Parent(a='something', b=7, c=None)" + + # Test with dict parameter (note that dict is sorted by keys when printed) + # not printed in order it was created + base_obj = fixture_class_parent(c={"c": 1, "a": 2}) + assert repr(base_obj) == "Parent(c={'a': 2, 'c': 1})" + + # Now test when one params values are named object tuples + named_objs = [ + ("step 1", fixture_class_parent()), + ("step 2", fixture_class_parent()), + ] + base_obj = fixture_class_parent(c=named_objs) + assert repr(base_obj) == "Parent(c=[('step 1', Parent()), ('step 2', Parent())])" + + # Or when they are just lists of tuples or just tuples as param + base_obj = fixture_class_parent(c=[("one", 1), ("two", 2)]) + assert repr(base_obj) == "Parent(c=[('one', 1), ('two', 2)])" + + base_obj = fixture_class_parent(c=(1, 2, 3)) + assert repr(base_obj) == "Parent(c=(1, 2, 3))" simple_composite = fixture_composition_dummy(foo=fixture_class_parent()) assert repr(simple_composite) == "CompositionDummy(foo=Parent())" @@ -892,53 +938,67 @@ def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy): long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000) assert len(repr(long_base_obj_repr)) == 535 + named_objs = [(f"Step {i+1}", Child()) for i in range(25)] + base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs))) + assert len(repr(base_comp)) == 1362 + -def test_baseobject_str(fixture_class_parent_instance): +def test_baseobject_str(fixture_class_parent_instance: Parent): """Test BaseObject string representation works.""" - str(fixture_class_parent_instance) + assert ( + str(fixture_class_parent_instance) == "Parent()" + ), "String representation of instance not working." + + # Check that local config works as expected + fixture_class_parent_instance.set_config(print_changed_only=False) + assert str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)" -def test_baseobject_repr_mimebundle_(fixture_class_parent_instance): +def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent): """Test display configuration controls output.""" # Checks the display configuration flag controls the json output - with config_context(display="diagram"): - output = fixture_class_parent_instance._repr_mimebundle_() - assert "text/plain" in output - assert "text/html" in output + fixture_class_parent_instance.set_config(display="diagram") + output = fixture_class_parent_instance._repr_mimebundle_() + assert "text/plain" in output + assert "text/html" in output - with config_context(display="text"): - output = fixture_class_parent_instance._repr_mimebundle_() - assert "text/plain" in output - assert "text/html" not in output + fixture_class_parent_instance.set_config(display="text") + output = fixture_class_parent_instance._repr_mimebundle_() + assert "text/plain" in output + assert "text/html" not in output -def test_repr_html_wraps(fixture_class_parent_instance): +def test_repr_html_wraps(fixture_class_parent_instance: Parent): """Test display configuration flag controls the html output.""" - with config_context(display="diagram"): - output = fixture_class_parent_instance._repr_html_() - assert "