diff --git a/skbase/base/_base.py b/skbase/base/_base.py index dd385372..e0fc635d 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -361,9 +361,6 @@ def set_params(self, **params): ------- self : reference to self (after parameters have been set) """ - if not params: - # Simple optimization to gain speed (inspect is slow) - return self valid_params = self.get_params(deep=True) unmatched_keys = [] diff --git a/skbase/base/_meta.py b/skbase/base/_meta.py index 3a4e70db..3f4f4778 100644 --- a/skbase/base/_meta.py +++ b/skbase/base/_meta.py @@ -234,21 +234,10 @@ def _set_params(self, attr: str, **params): Self Instance of self. """ - if not params: - return self - - # Track whether we handle any params locally. - # This is needed because we pop params before calling super(), - # and if we pop ALL params, super() will return early without - # calling reset(), which would be inconsistent with BaseObject behavior. - # See issue #412. - params_handled_locally = False - # Ensure strict ordering of parameter setting: # 1. All steps if attr in params: setattr(self, attr, params.pop(attr)) - params_handled_locally = True # 2. Step replacement items = getattr(self, attr) names = [] @@ -257,15 +246,11 @@ def _set_params(self, attr: str, **params): for name in list(params.keys()): if "__" not in name and name in names: self._replace_object(attr, name, params.pop(name)) - params_handled_locally = True # 3. Step parameters and other initialisation arguments super().set_params(**params) # type: ignore - # If we handled params locally and super() got nothing (empty dict), - # it would have returned early without calling reset(). - # To maintain consistency with BaseObject, we must call reset() ourselves. - if params_handled_locally and not params: - self.reset() + # After the change to BaseObject.set_params, super().set_params will + # always perform reset when appropriate. return self diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index a924ff07..ce17a387 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -882,6 +882,18 @@ def test_set_params_with_no_param_to_set_returns_object( ) +def test_set_params_with_no_param_resets_fitted_state( + fixture_class_parent: Type[Parent], +): + """Test that calling set_params() with no args resets fitted attributes.""" + base_obj = fixture_class_parent() + base_obj.fitted_attr_ = "should be removed" + + base_obj.set_params() + + assert not hasattr(base_obj, "fitted_attr_") + + # This section tests the clone functionality # These have been adapted from sklearn's tests of clone to use the clone # method that is included as part of the BaseObject interface