Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 102 additions & 15 deletions skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def _get_params(self, attr, deep=True, fitted=False):
def _set_params(self, attr: str, **params):
"""Logic for setting parameters on meta objects/estimators.

Separates out logic for parameter setting on meta objects from public API point.
Optimized for performance, memory, maintainability, and robustness.
Uses single-pass processing with lazy evaluation and minimal allocations.

Parameters
----------
Expand All @@ -233,22 +234,108 @@ def _set_params(self, attr: str, **params):
Self
Instance of self.
"""
# Ensure strict ordering of parameter setting:
# 1. All steps
if attr in params:
setattr(self, attr, params.pop(attr))
# 2. Step replacement
items = getattr(self, attr)
names = []
if items and isinstance(items, (list, tuple)):
names = list(zip(*items))[0]
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_object(attr, name, params.pop(name))
# 3. Step parameters and other initialisation arguments
super().set_params(**params) # type: ignore
if not params:
return self

items = getattr(self, attr, None)
current_names = None
reset_params = {}
deferred_ops = []

for param_name, param_value in params.items():
if param_name == attr:
deferred_ops.append(
(1, param_name, param_value)
) # op_type=1: container
elif "__" in param_name:
deferred_ops.append((3, param_name, param_value)) # op_type=3: nested
else:
if current_names is None and items and isinstance(items, (list, tuple)):
current_names = [
(
item[0]
if isinstance(item, tuple) and len(item) >= 1
else str(item)
)
for item in items
]

if current_names and param_name in current_names:
deferred_ops.append(
(2, param_name, param_value)
) # op_type=2: replacement
else:
reset_params[param_name] = param_value

if reset_params:
super().set_params(**reset_params) # type: ignore

if deferred_ops:
self._execute_deferred_ops_optimized(attr, deferred_ops)

return self

def _execute_deferred_ops_optimized(self, attr: str, ops):
"""Execute deferred operations in optimal order with minimal overhead.

Parameters
----------
attr : str
Named object attribute name
ops : list
List of (op_type, name, value) tuples where:
- op_type=1: container operations
- op_type=2: replacement operations
- op_type=3: nested operations
"""
ops.sort(key=lambda x: x[0])

cached_names = None
nested_ops = {}

for op_type, name, value in ops:
if op_type == 1:
setattr(self, name, value)
cached_names = None

elif op_type == 2: # Replacement operations
if cached_names is None:
items = getattr(self, attr, None)
if items and isinstance(items, (list, tuple)):
cached_names = [
(
item[0]
if isinstance(item, tuple) and len(item) >= 1
else str(item)
)
for item in items
]
else:
cached_names = []

if name in cached_names:
self._replace_object(attr, name, value)

elif op_type == 3: # Nested operations
component_name, _, sub_key = name.partition("__")
if sub_key: # Valid nested parameter
if component_name not in nested_ops:
nested_ops[component_name] = {}
nested_ops[component_name][sub_key] = value

if nested_ops:
items = getattr(self, attr, None)
if items and isinstance(items, (list, tuple)):
component_lookup = {}
for _i, item in enumerate(items):
if isinstance(item, tuple) and len(item) >= 2:
component_lookup[item[0]] = item[1]

for component_name, component_params in nested_ops.items():
component = component_lookup.get(component_name)
if component and hasattr(component, "set_params"):
component.set_params(**component_params)

def _replace_object(self, attr, name, new_val) -> None:
"""Replace an object in attribute that contains named objects.

Expand Down
78 changes: 78 additions & 0 deletions skbase/tests/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,81 @@ def test_metaestimator_composite(long_steps):

meta_est.set_params(bar__b="something else")
assert meta_est.get_params()["bar__b"] == "something else"


def test_meta_object_reset_consistency():
"""Test that BaseMetaObject resets
consistently with BaseObject during set_params."""
# Test that BaseMetaObject resets on set_params call like BaseObject
meta_obj = MetaObjectTester(a=1, b="test", steps=[])

# Add attributes that should be removed by reset
meta_obj.some_attribute = "test"
meta_obj.fitted_attribute_ = "fitted"

assert hasattr(meta_obj, "some_attribute")
assert hasattr(meta_obj, "fitted_attribute_")

# Set parameters - this should trigger reset
meta_obj.set_params(a=3)

# Check that reset occurred
assert not hasattr(meta_obj, "some_attribute")
assert not hasattr(meta_obj, "fitted_attribute_")
assert meta_obj.a == 3


def test_meta_object_reset_with_steps():
"""Test that BaseMetaObject resets correctly when setting
steps and step parameters."""
step1 = ComponentDummy(a=100, b="step1")
step2 = ComponentDummy(a=300, b="step2")

meta_obj = MetaObjectTester(
a=1, b="main", steps=[("old_step", ComponentDummy(a=999, b="old"))]
)

# Add attributes that should be removed by reset
meta_obj.some_attribute = "test"
meta_obj.fitted_attribute_ = "fitted"

assert hasattr(meta_obj, "some_attribute")
assert hasattr(meta_obj, "fitted_attribute_")

# Set both regular parameter, steps parameter and step-specific parameters
new_steps = [("new1", step1), ("new2", step2)]
meta_obj.set_params(a=42, steps=new_steps, new1__a=500)

# Check that reset occurred (because 'a' is a regular parameter)
assert not hasattr(meta_obj, "some_attribute")
assert not hasattr(meta_obj, "fitted_attribute_")

# Check that parameters were set correctly
assert meta_obj.a == 42 # Should be changed
assert len(meta_obj.steps) == 2
assert meta_obj.steps[0][0] == "new1"
assert meta_obj.steps[0][1].a == 500 # Should be modified by new1__a parameter
assert meta_obj.steps[1][0] == "new2"
assert meta_obj.steps[1][1].a == 300 # Should remain unchanged


def test_meta_estimator_reset_consistency():
"""Test that BaseMetaEstimator resets
consistently with BaseObject during set_params."""
# Test that BaseMetaEstimator resets on set_params call like BaseObject
meta_est = MetaEstimatorTester(a=1, b="test", steps=[])

# Add attributes that should be removed by reset
meta_est.some_attribute = "test"
meta_est.fitted_attribute_ = "fitted"

assert hasattr(meta_est, "some_attribute")
assert hasattr(meta_est, "fitted_attribute_")

# Set parameters - this should trigger reset
meta_est.set_params(a=3)

# Check that reset occurred
assert not hasattr(meta_est, "some_attribute")
assert not hasattr(meta_est, "fitted_attribute_")
assert meta_est.a == 3