diff --git a/skbase/base/_base.py b/skbase/base/_base.py index c6a0d16a..9ac9840b 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -64,6 +64,7 @@ class name: BaseEstimator from skbase.base._clone_base import _check_clone, _clone from skbase.base._pretty_printing._object_html_repr import _object_html_repr from skbase.base._tagmanager import _FlagManager +from skbase.config import get_config as get_global_config __author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -751,7 +752,23 @@ def get_config(self): class attribute via nested inheritance and then any overrides and new tags from _onfig_dynamic object attribute. """ - return self._get_flags(flag_attr_name="_config") + config = self._get_class_flags(flag_attr_name="_config") + + # Update with global config + global_config = get_global_config() + config.update(global_config) + + # Update with extension config if available + if hasattr(self, "__skbase_get_config__"): + extension_config = self.__skbase_get_config__() + if isinstance(extension_config, dict): + config.update(extension_config) + + # Update with local config overrides (highest priority) + if hasattr(self, "_config_dynamic"): + config.update(self._config_dynamic) + + return config def set_config(self, **config_dict): """Set config flags to given values. diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py new file mode 100644 index 00000000..10d397d5 --- /dev/null +++ b/skbase/config/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +"""Global configuration management for skbase.""" + +from skbase.config._config import ( + config_context, + get_config, + get_default_config, + reset_config, + set_config, +) + +__all__ = [ + "config_context", + "get_config", + "get_default_config", + "reset_config", + "set_config", +] diff --git a/skbase/config/_config.py b/skbase/config/_config.py new file mode 100644 index 00000000..c4a7bb87 --- /dev/null +++ b/skbase/config/_config.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +"""Global configuration management for skbase.""" + +from contextlib import contextmanager +from copy import deepcopy + +__author__ = ["RNKuhns"] +__all__ = [ + "config_context", + "get_config", + "get_default_config", + "reset_config", + "set_config", +] + +# Global configuration defaults +_DEFAULT_CONFIG = { + "display": "diagram", + "print_changed_only": True, + "check_clone": False, + "clone_config": True, +} + +# Global config storage +_global_config = deepcopy(_DEFAULT_CONFIG) + + +def _get_global_config(): + """Get the global config dict.""" + return _global_config + + +def get_default_config(): + """Get the default global config. + + Returns + ------- + dict + Default global config. + """ + return deepcopy(_DEFAULT_CONFIG) + + +def get_config(): + """Get current global config. + + Returns + ------- + dict + Current global config. + """ + return deepcopy(_get_global_config()) + + +def set_config(**config_dict): + """Set global config values. + + Parameters + ---------- + **config_dict : dict + Config key-value pairs to set globally. + """ + global_config = _get_global_config() + global_config.update(config_dict) + + +def reset_config(): + """Reset global config to defaults.""" + global _global_config + _global_config = deepcopy(_DEFAULT_CONFIG) + + +@contextmanager +def config_context(**config_dict): + """Context manager for temporary config changes. + + Parameters + ---------- + **config_dict : dict + Config key-value pairs to set temporarily. + """ + old_config = get_config() + set_config(**config_dict) + try: + yield + finally: + global _global_config + _global_config = old_config diff --git a/skbase/config/tests/__init__.py b/skbase/config/tests/__init__.py new file mode 100644 index 00000000..ea4f9ee6 --- /dev/null +++ b/skbase/config/tests/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +"""Tests for skbase global config functionality.""" diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py new file mode 100644 index 00000000..06e8f0ca --- /dev/null +++ b/skbase/config/tests/test_config.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +"""Tests for global config functionality.""" + +from skbase.config import ( + config_context, + get_config, + get_default_config, + reset_config, + set_config, +) + + +def test_get_default_config(): + """Test get_default_config returns correct defaults.""" + defaults = get_default_config() + expected = { + "display": "diagram", + "print_changed_only": True, + "check_clone": False, + "clone_config": True, + } + assert defaults == expected + + +def test_get_config(): + """Test get_config returns current global config.""" + reset_config() + config = get_config() + assert config == get_default_config() + + +def test_set_config(): + """Test set_config updates global config.""" + reset_config() + set_config(display="text", check_clone=True) + config = get_config() + assert config["display"] == "text" + assert config["check_clone"] is True + assert config["print_changed_only"] is True # unchanged + + +def test_reset_config(): + """Test reset_config resets to defaults.""" + set_config(display="text") + reset_config() + config = get_config() + assert config == get_default_config() + + +def test_config_context(): + """Test config_context temporarily changes config.""" + reset_config() + original_config = get_config() + + with config_context(display="text", check_clone=True): + inner_config = get_config() + assert inner_config["display"] == "text" + assert inner_config["check_clone"] is True + + # Should be back to original + final_config = get_config() + assert final_config == original_config + + +def test_config_context_nested(): + """Test nested config_context.""" + reset_config() + + with config_context(display="text"): + assert get_config()["display"] == "text" + + with config_context(print_changed_only=False): + assert get_config()["display"] == "text" + assert get_config()["print_changed_only"] is False + + assert get_config()["display"] == "text" + assert get_config()["print_changed_only"] is True # back to default diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index dec1b79b..56b3deef 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -33,6 +33,10 @@ "skbase.base._pretty_printing._object_html_repr", "skbase.base._pretty_printing._pprint", "skbase.base._tagmanager", + "skbase.config", + "skbase.config._config", + "skbase.config.tests", + "skbase.config.tests.test_config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -72,6 +76,9 @@ SKBASE_PUBLIC_MODULES = ( "skbase", "skbase.base", + "skbase.config", + "skbase.config.tests", + "skbase.config.tests.test_config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -157,6 +164,20 @@ } ) SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = { + "skbase.config": ( + "config_context", + "get_config", + "get_default_config", + "reset_config", + "set_config", + ), + "skbase.config._config": ( + "config_context", + "get_config", + "get_default_config", + "reset_config", + "set_config", + ), "skbase.lookup": ("all_objects", "get_package_metadata"), "skbase.lookup._lookup": ("all_objects", "get_package_metadata"), "skbase.testing.utils._conditional_fixtures": ( @@ -223,6 +244,14 @@ "_write_label_html", ), "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"), + "skbase.config._config": ( + "_get_global_config", + "config_context", + "get_config", + "get_default_config", + "reset_config", + "set_config", + ), "skbase.lookup._lookup": ( "all_objects", "get_package_metadata", diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 23b4582c..d6b37d05 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -76,6 +76,8 @@ import scipy.sparse as sp from skbase.base import BaseEstimator, BaseObject +from skbase.config import get_default_config +from skbase.config import reset_config as reset_global_config from skbase.tests.conftest import Child, Parent from skbase.tests.mock_package.test_mock_package import CompositionDummy from skbase.utils.dependencies import _check_soft_dependencies @@ -1380,6 +1382,7 @@ def __init__(self, foo, bar=84): def test_get_set_config(): """Tests get_config and set_config methods.""" + reset_global_config() # Reset global config to defaults class _TestConfig(BaseObject): _config = {"foo_config": 42, "bar": "a"} @@ -1394,13 +1397,15 @@ def __init__(self, a, b=42): test_obj = _TestConfig(7) expected_config_orig = BaseObject._config.copy() - expected_config_orig.update({"foo_config": 42, "bar": "a"}) + expected_config_orig.update(get_default_config()) # global defaults + expected_config_orig.update({"foo_config": 42, "bar": "a"}) # class # Test get_config assert test_obj.get_config() == expected_config_orig expected_config = BaseObject._config.copy() - expected_config.update({"foo_config": 37, "bar": "a"}) + expected_config.update(get_default_config()) # global + expected_config.update({"foo_config": 37, "bar": "a"}) # local override # Test set_config test_obj.set_config(foo_config=37) @@ -1412,6 +1417,35 @@ def __init__(self, a, b=42): assert test_obj.get_config() == expected_config +def test_global_config_integration(): + """Test that global config is integrated into BaseObject.get_config.""" + from skbase.config import reset_config as reset_global_config + from skbase.config import set_config as set_global_config + + reset_global_config() + + class _TestGlobalConfig(BaseObject): + _config = {"local_config": "class_value"} + + test_obj = _TestGlobalConfig() + + # Initially, should have class + global defaults + config = test_obj.get_config() + assert config["local_config"] == "class_value" + assert config["display"] == "diagram" # global default + + # Set global config + set_global_config(display="text") + + config = test_obj.get_config() + assert config["display"] == "text" # global override + + # Local override should take precedence + test_obj.set_config(display="diagram") + config = test_obj.get_config() + assert config["display"] == "diagram" # local override + + def test_clone_with_custom_plugins(): """Test that cloning works when custom clone_plugins are provided.