diff --git a/skbase/base/_meta.py b/skbase/base/_meta.py index c87f557e..5d449e21 100644 --- a/skbase/base/_meta.py +++ b/skbase/base/_meta.py @@ -233,10 +233,21 @@ def _set_params(self, attr: str, **params): Self Instance of self. """ + if not params: + return self + + # Track whether we handle any params locally. + # This is needed because we pop params before calling super(), + # and if we pop ALL params, super() will return early without + # calling reset(), which would be inconsistent with BaseObject behavior. + # See issue #412. + params_handled_locally = False + # Ensure strict ordering of parameter setting: # 1. All steps if attr in params: setattr(self, attr, params.pop(attr)) + params_handled_locally = True # 2. Step replacement items = getattr(self, attr) names = [] @@ -245,8 +256,16 @@ def _set_params(self, attr: str, **params): for name in list(params.keys()): if "__" not in name and name in names: self._replace_object(attr, name, params.pop(name)) + params_handled_locally = True # 3. Step parameters and other initialisation arguments super().set_params(**params) # type: ignore + + # If we handled params locally and super() got nothing (empty dict), + # it would have returned early without calling reset(). + # To maintain consistency with BaseObject, we must call reset() ourselves. + if params_handled_locally and not params: + self.reset() + return self def _replace_object(self, attr, name, new_val) -> None: diff --git a/skbase/tests/test_meta.py b/skbase/tests/test_meta.py index df672027..e5d1a631 100644 --- a/skbase/tests/test_meta.py +++ b/skbase/tests/test_meta.py @@ -168,3 +168,40 @@ def test_metaestimator_composite(long_steps): meta_est.set_params(bar__b="something else") assert meta_est.get_params()["bar__b"] == "something else" + + +def test_set_params_resets_fitted_state(): + """Test that set_params calls reset, removing fitted state. + + Regression test for issue #412. + BaseMetaObject should call reset() during set_params() to clear fitted state, + maintaining consistency with BaseObject behavior. + """ + steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))] + meta_obj = MetaObjectTester(steps=steps) + + # Add some fitted state (simulating a fit operation) + meta_obj.fitted_attr_ = "should be removed after set_params" + meta_obj.another_fitted_ = 123 + + # Test 1: Setting the named object parameter should trigger reset + new_steps = [("foo", ComponentDummy(99))] + meta_obj.set_params(steps=new_steps) + + # Fitted state should be gone after set_params + assert not hasattr( + meta_obj, "fitted_attr_" + ), "fitted_attr_ should be removed by reset() during set_params(steps=...)" + assert not hasattr( + meta_obj, "another_fitted_" + ), "another_fitted_ should be removed by reset() during set_params(steps=...)" + + # Test 2: Replacing individual step should also trigger reset + meta_obj = MetaObjectTester(steps=steps) + meta_obj.fitted_attr_ = "should be removed" + + meta_obj.set_params(foo=ComponentDummy(77)) + + assert not hasattr( + meta_obj, "fitted_attr_" + ), "fitted_attr_ should be removed by reset() during set_params(foo=...)"