Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,6 @@ def set_params(self, **params):
-------
self : reference to self (after parameters have been set)
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)

unmatched_keys = []
Expand Down
19 changes: 2 additions & 17 deletions skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,10 @@ def _set_params(self, attr: str, **params):
Self
Instance of self.
"""
if not params:
return self

# Track whether we handle any params locally.
# This is needed because we pop params before calling super(),
# and if we pop ALL params, super() will return early without
# calling reset(), which would be inconsistent with BaseObject behavior.
# See issue #412.
params_handled_locally = False

# Ensure strict ordering of parameter setting:
# 1. All steps
if attr in params:
setattr(self, attr, params.pop(attr))
params_handled_locally = True
# 2. Step replacement
items = getattr(self, attr)
names = []
Expand All @@ -257,15 +246,11 @@ def _set_params(self, attr: str, **params):
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_object(attr, name, params.pop(name))
params_handled_locally = True
# 3. Step parameters and other initialisation arguments
super().set_params(**params) # type: ignore

# If we handled params locally and super() got nothing (empty dict),
# it would have returned early without calling reset().
# To maintain consistency with BaseObject, we must call reset() ourselves.
if params_handled_locally and not params:
self.reset()
# After the change to BaseObject.set_params, super().set_params will
# always perform reset when appropriate.

return self

Expand Down
12 changes: 12 additions & 0 deletions skbase/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,18 @@ def test_set_params_with_no_param_to_set_returns_object(
)


def test_set_params_with_no_param_resets_fitted_state(
fixture_class_parent: Type[Parent],
):
"""Test that calling set_params() with no args resets fitted attributes."""
base_obj = fixture_class_parent()
base_obj.fitted_attr_ = "should be removed"

base_obj.set_params()

assert not hasattr(base_obj, "fitted_attr_")


# This section tests the clone functionality
# These have been adapted from sklearn's tests of clone to use the clone
# method that is included as part of the BaseObject interface
Expand Down