Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class name: BaseEstimator
from skbase.base._clone_base import _check_clone, _clone
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager
from skbase.config import get_config as get_global_config

__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]
Expand Down Expand Up @@ -751,7 +752,23 @@ def get_config(self):
class attribute via nested inheritance and then any overrides
and new tags from _onfig_dynamic object attribute.
"""
return self._get_flags(flag_attr_name="_config")
config = self._get_class_flags(flag_attr_name="_config")

# Update with global config
global_config = get_global_config()
config.update(global_config)

# Update with extension config if available
if hasattr(self, "__skbase_get_config__"):
extension_config = self.__skbase_get_config__()
if isinstance(extension_config, dict):
config.update(extension_config)

# Update with local config overrides (highest priority)
if hasattr(self, "_config_dynamic"):
config.update(self._config_dynamic)

return config

def set_config(self, **config_dict):
"""Set config flags to given values.
Expand Down
18 changes: 18 additions & 0 deletions skbase/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""Global configuration management for skbase."""

from skbase.config._config import (
config_context,
get_config,
get_default_config,
reset_config,
set_config,
)

__all__ = [
"config_context",
"get_config",
"get_default_config",
"reset_config",
"set_config",
]
88 changes: 88 additions & 0 deletions skbase/config/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
"""Global configuration management for skbase."""

from contextlib import contextmanager
from copy import deepcopy

__author__ = ["RNKuhns"]
__all__ = [
"config_context",
"get_config",
"get_default_config",
"reset_config",
"set_config",
]

# Global configuration defaults
_DEFAULT_CONFIG = {
"display": "diagram",
"print_changed_only": True,
"check_clone": False,
"clone_config": True,
}

# Global config storage
_global_config = deepcopy(_DEFAULT_CONFIG)


def _get_global_config():
"""Get the global config dict."""
return _global_config


def get_default_config():
"""Get the default global config.

Returns
-------
dict
Default global config.
"""
return deepcopy(_DEFAULT_CONFIG)


def get_config():
"""Get current global config.

Returns
-------
dict
Current global config.
"""
return deepcopy(_get_global_config())


def set_config(**config_dict):
"""Set global config values.

Parameters
----------
**config_dict : dict
Config key-value pairs to set globally.
"""
global_config = _get_global_config()
global_config.update(config_dict)


def reset_config():
"""Reset global config to defaults."""
global _global_config
_global_config = deepcopy(_DEFAULT_CONFIG)


@contextmanager
def config_context(**config_dict):
"""Context manager for temporary config changes.

Parameters
----------
**config_dict : dict
Config key-value pairs to set temporarily.
"""
old_config = get_config()
set_config(**config_dict)
try:
yield
finally:
global _global_config
_global_config = old_config
2 changes: 2 additions & 0 deletions skbase/config/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""Tests for skbase global config functionality."""
77 changes: 77 additions & 0 deletions skbase/config/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
"""Tests for global config functionality."""

from skbase.config import (
config_context,
get_config,
get_default_config,
reset_config,
set_config,
)


def test_get_default_config():
"""Test get_default_config returns correct defaults."""
defaults = get_default_config()
expected = {
"display": "diagram",
"print_changed_only": True,
"check_clone": False,
"clone_config": True,
}
assert defaults == expected


def test_get_config():
"""Test get_config returns current global config."""
reset_config()
config = get_config()
assert config == get_default_config()


def test_set_config():
"""Test set_config updates global config."""
reset_config()
set_config(display="text", check_clone=True)
config = get_config()
assert config["display"] == "text"
assert config["check_clone"] is True
assert config["print_changed_only"] is True # unchanged


def test_reset_config():
"""Test reset_config resets to defaults."""
set_config(display="text")
reset_config()
config = get_config()
assert config == get_default_config()


def test_config_context():
"""Test config_context temporarily changes config."""
reset_config()
original_config = get_config()

with config_context(display="text", check_clone=True):
inner_config = get_config()
assert inner_config["display"] == "text"
assert inner_config["check_clone"] is True

# Should be back to original
final_config = get_config()
assert final_config == original_config


def test_config_context_nested():
"""Test nested config_context."""
reset_config()

with config_context(display="text"):
assert get_config()["display"] == "text"

with config_context(print_changed_only=False):
assert get_config()["display"] == "text"
assert get_config()["print_changed_only"] is False

assert get_config()["display"] == "text"
assert get_config()["print_changed_only"] is True # back to default
29 changes: 29 additions & 0 deletions skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
"skbase.base._pretty_printing._object_html_repr",
"skbase.base._pretty_printing._pprint",
"skbase.base._tagmanager",
"skbase.config",
"skbase.config._config",
"skbase.config.tests",
"skbase.config.tests.test_config",
"skbase.lookup",
"skbase.lookup.tests",
"skbase.lookup.tests.test_lookup",
Expand Down Expand Up @@ -72,6 +76,9 @@
SKBASE_PUBLIC_MODULES = (
"skbase",
"skbase.base",
"skbase.config",
"skbase.config.tests",
"skbase.config.tests.test_config",
"skbase.lookup",
"skbase.lookup.tests",
"skbase.lookup.tests.test_lookup",
Expand Down Expand Up @@ -157,6 +164,20 @@
}
)
SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
"skbase.config": (
"config_context",
"get_config",
"get_default_config",
"reset_config",
"set_config",
),
"skbase.config._config": (
"config_context",
"get_config",
"get_default_config",
"reset_config",
"set_config",
),
"skbase.lookup": ("all_objects", "get_package_metadata"),
"skbase.lookup._lookup": ("all_objects", "get_package_metadata"),
"skbase.testing.utils._conditional_fixtures": (
Expand Down Expand Up @@ -223,6 +244,14 @@
"_write_label_html",
),
"skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"),
"skbase.config._config": (
"_get_global_config",
"config_context",
"get_config",
"get_default_config",
"reset_config",
"set_config",
),
"skbase.lookup._lookup": (
"all_objects",
"get_package_metadata",
Expand Down
38 changes: 36 additions & 2 deletions skbase/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
import scipy.sparse as sp

from skbase.base import BaseEstimator, BaseObject
from skbase.config import get_default_config
from skbase.config import reset_config as reset_global_config
from skbase.tests.conftest import Child, Parent
from skbase.tests.mock_package.test_mock_package import CompositionDummy
from skbase.utils.dependencies import _check_soft_dependencies
Expand Down Expand Up @@ -1380,6 +1382,7 @@ def __init__(self, foo, bar=84):

def test_get_set_config():
"""Tests get_config and set_config methods."""
reset_global_config() # Reset global config to defaults

class _TestConfig(BaseObject):
_config = {"foo_config": 42, "bar": "a"}
Expand All @@ -1394,13 +1397,15 @@ def __init__(self, a, b=42):
test_obj = _TestConfig(7)

expected_config_orig = BaseObject._config.copy()
expected_config_orig.update({"foo_config": 42, "bar": "a"})
expected_config_orig.update(get_default_config()) # global defaults
expected_config_orig.update({"foo_config": 42, "bar": "a"}) # class

# Test get_config
assert test_obj.get_config() == expected_config_orig

expected_config = BaseObject._config.copy()
expected_config.update({"foo_config": 37, "bar": "a"})
expected_config.update(get_default_config()) # global
expected_config.update({"foo_config": 37, "bar": "a"}) # local override

# Test set_config
test_obj.set_config(foo_config=37)
Expand All @@ -1412,6 +1417,35 @@ def __init__(self, a, b=42):
assert test_obj.get_config() == expected_config


def test_global_config_integration():
"""Test that global config is integrated into BaseObject.get_config."""
from skbase.config import reset_config as reset_global_config
from skbase.config import set_config as set_global_config

reset_global_config()

class _TestGlobalConfig(BaseObject):
_config = {"local_config": "class_value"}

test_obj = _TestGlobalConfig()

# Initially, should have class + global defaults
config = test_obj.get_config()
assert config["local_config"] == "class_value"
assert config["display"] == "diagram" # global default

# Set global config
set_global_config(display="text")

config = test_obj.get_config()
assert config["display"] == "text" # global override

# Local override should take precedence
test_obj.set_config(display="diagram")
config = test_obj.get_config()
assert config["display"] == "diagram" # local override


def test_clone_with_custom_plugins():
"""Test that cloning works when custom clone_plugins are provided.

Expand Down