diff --git a/skbase/base/_meta.py b/skbase/base/_meta.py index c87f557e..dcc9504b 100644 --- a/skbase/base/_meta.py +++ b/skbase/base/_meta.py @@ -221,7 +221,8 @@ def _get_params(self, attr, deep=True, fitted=False): def _set_params(self, attr: str, **params): """Logic for setting parameters on meta objects/estimators. - Separates out logic for parameter setting on meta objects from public API point. + Optimized for performance, memory, maintainability, and robustness. + Uses single-pass processing with lazy evaluation and minimal allocations. Parameters ---------- @@ -233,22 +234,108 @@ def _set_params(self, attr: str, **params): Self Instance of self. """ - # Ensure strict ordering of parameter setting: - # 1. All steps - if attr in params: - setattr(self, attr, params.pop(attr)) - # 2. Step replacement - items = getattr(self, attr) - names = [] - if items and isinstance(items, (list, tuple)): - names = list(zip(*items))[0] - for name in list(params.keys()): - if "__" not in name and name in names: - self._replace_object(attr, name, params.pop(name)) - # 3. Step parameters and other initialisation arguments - super().set_params(**params) # type: ignore + if not params: + return self + + items = getattr(self, attr, None) + current_names = None + reset_params = {} + deferred_ops = [] + + for param_name, param_value in params.items(): + if param_name == attr: + deferred_ops.append( + (1, param_name, param_value) + ) # op_type=1: container + elif "__" in param_name: + deferred_ops.append((3, param_name, param_value)) # op_type=3: nested + else: + if current_names is None and items and isinstance(items, (list, tuple)): + current_names = [ + ( + item[0] + if isinstance(item, tuple) and len(item) >= 1 + else str(item) + ) + for item in items + ] + + if current_names and param_name in current_names: + deferred_ops.append( + (2, param_name, param_value) + ) # op_type=2: replacement + else: + reset_params[param_name] = param_value + + if reset_params: + super().set_params(**reset_params) # type: ignore + + if deferred_ops: + self._execute_deferred_ops_optimized(attr, deferred_ops) + return self + def _execute_deferred_ops_optimized(self, attr: str, ops): + """Execute deferred operations in optimal order with minimal overhead. + + Parameters + ---------- + attr : str + Named object attribute name + ops : list + List of (op_type, name, value) tuples where: + - op_type=1: container operations + - op_type=2: replacement operations + - op_type=3: nested operations + """ + ops.sort(key=lambda x: x[0]) + + cached_names = None + nested_ops = {} + + for op_type, name, value in ops: + if op_type == 1: + setattr(self, name, value) + cached_names = None + + elif op_type == 2: # Replacement operations + if cached_names is None: + items = getattr(self, attr, None) + if items and isinstance(items, (list, tuple)): + cached_names = [ + ( + item[0] + if isinstance(item, tuple) and len(item) >= 1 + else str(item) + ) + for item in items + ] + else: + cached_names = [] + + if name in cached_names: + self._replace_object(attr, name, value) + + elif op_type == 3: # Nested operations + component_name, _, sub_key = name.partition("__") + if sub_key: # Valid nested parameter + if component_name not in nested_ops: + nested_ops[component_name] = {} + nested_ops[component_name][sub_key] = value + + if nested_ops: + items = getattr(self, attr, None) + if items and isinstance(items, (list, tuple)): + component_lookup = {} + for _i, item in enumerate(items): + if isinstance(item, tuple) and len(item) >= 2: + component_lookup[item[0]] = item[1] + + for component_name, component_params in nested_ops.items(): + component = component_lookup.get(component_name) + if component and hasattr(component, "set_params"): + component.set_params(**component_params) + def _replace_object(self, attr, name, new_val) -> None: """Replace an object in attribute that contains named objects. diff --git a/skbase/tests/test_meta.py b/skbase/tests/test_meta.py index df672027..aff5c1cc 100644 --- a/skbase/tests/test_meta.py +++ b/skbase/tests/test_meta.py @@ -168,3 +168,81 @@ def test_metaestimator_composite(long_steps): meta_est.set_params(bar__b="something else") assert meta_est.get_params()["bar__b"] == "something else" + + +def test_meta_object_reset_consistency(): + """Test that BaseMetaObject resets + consistently with BaseObject during set_params.""" + # Test that BaseMetaObject resets on set_params call like BaseObject + meta_obj = MetaObjectTester(a=1, b="test", steps=[]) + + # Add attributes that should be removed by reset + meta_obj.some_attribute = "test" + meta_obj.fitted_attribute_ = "fitted" + + assert hasattr(meta_obj, "some_attribute") + assert hasattr(meta_obj, "fitted_attribute_") + + # Set parameters - this should trigger reset + meta_obj.set_params(a=3) + + # Check that reset occurred + assert not hasattr(meta_obj, "some_attribute") + assert not hasattr(meta_obj, "fitted_attribute_") + assert meta_obj.a == 3 + + +def test_meta_object_reset_with_steps(): + """Test that BaseMetaObject resets correctly when setting + steps and step parameters.""" + step1 = ComponentDummy(a=100, b="step1") + step2 = ComponentDummy(a=300, b="step2") + + meta_obj = MetaObjectTester( + a=1, b="main", steps=[("old_step", ComponentDummy(a=999, b="old"))] + ) + + # Add attributes that should be removed by reset + meta_obj.some_attribute = "test" + meta_obj.fitted_attribute_ = "fitted" + + assert hasattr(meta_obj, "some_attribute") + assert hasattr(meta_obj, "fitted_attribute_") + + # Set both regular parameter, steps parameter and step-specific parameters + new_steps = [("new1", step1), ("new2", step2)] + meta_obj.set_params(a=42, steps=new_steps, new1__a=500) + + # Check that reset occurred (because 'a' is a regular parameter) + assert not hasattr(meta_obj, "some_attribute") + assert not hasattr(meta_obj, "fitted_attribute_") + + # Check that parameters were set correctly + assert meta_obj.a == 42 # Should be changed + assert len(meta_obj.steps) == 2 + assert meta_obj.steps[0][0] == "new1" + assert meta_obj.steps[0][1].a == 500 # Should be modified by new1__a parameter + assert meta_obj.steps[1][0] == "new2" + assert meta_obj.steps[1][1].a == 300 # Should remain unchanged + + +def test_meta_estimator_reset_consistency(): + """Test that BaseMetaEstimator resets + consistently with BaseObject during set_params.""" + # Test that BaseMetaEstimator resets on set_params call like BaseObject + meta_est = MetaEstimatorTester(a=1, b="test", steps=[]) + + # Add attributes that should be removed by reset + meta_est.some_attribute = "test" + meta_est.fitted_attribute_ = "fitted" + + assert hasattr(meta_est, "some_attribute") + assert hasattr(meta_est, "fitted_attribute_") + + # Set parameters - this should trigger reset + meta_est.set_params(a=3) + + # Check that reset occurred + assert not hasattr(meta_est, "some_attribute") + assert not hasattr(meta_est, "fitted_attribute_") + assert meta_est.a == 3