Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,21 @@ def _set_params(self, attr: str, **params):
Self
Instance of self.
"""
if not params:
Comment thread
fkiraly marked this conversation as resolved.
return self

# Track whether we handle any params locally.
# This is needed because we pop params before calling super(),
# and if we pop ALL params, super() will return early without
# calling reset(), which would be inconsistent with BaseObject behavior.
# See issue #412.
params_handled_locally = False

# Ensure strict ordering of parameter setting:
# 1. All steps
if attr in params:
setattr(self, attr, params.pop(attr))
params_handled_locally = True
# 2. Step replacement
items = getattr(self, attr)
names = []
Expand All @@ -245,8 +256,16 @@ def _set_params(self, attr: str, **params):
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_object(attr, name, params.pop(name))
params_handled_locally = True
# 3. Step parameters and other initialisation arguments
super().set_params(**params) # type: ignore

# If we handled params locally and super() got nothing (empty dict),
# it would have returned early without calling reset().
# To maintain consistency with BaseObject, we must call reset() ourselves.
if params_handled_locally and not params:
self.reset()

return self

def _replace_object(self, attr, name, new_val) -> None:
Expand Down
37 changes: 37 additions & 0 deletions skbase/tests/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,40 @@ def test_metaestimator_composite(long_steps):

meta_est.set_params(bar__b="something else")
assert meta_est.get_params()["bar__b"] == "something else"


def test_set_params_resets_fitted_state():
"""Test that set_params calls reset, removing fitted state.

Regression test for issue #412.
BaseMetaObject should call reset() during set_params() to clear fitted state,
maintaining consistency with BaseObject behavior.
"""
steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))]
meta_obj = MetaObjectTester(steps=steps)

# Add some fitted state (simulating a fit operation)
meta_obj.fitted_attr_ = "should be removed after set_params"
meta_obj.another_fitted_ = 123

# Test 1: Setting the named object parameter should trigger reset
new_steps = [("foo", ComponentDummy(99))]
meta_obj.set_params(steps=new_steps)

# Fitted state should be gone after set_params
assert not hasattr(
meta_obj, "fitted_attr_"
), "fitted_attr_ should be removed by reset() during set_params(steps=...)"
assert not hasattr(
meta_obj, "another_fitted_"
), "another_fitted_ should be removed by reset() during set_params(steps=...)"

# Test 2: Replacing individual step should also trigger reset
meta_obj = MetaObjectTester(steps=steps)
meta_obj.fitted_attr_ = "should be removed"

meta_obj.set_params(foo=ComponentDummy(77))

assert not hasattr(
meta_obj, "fitted_attr_"
), "fitted_attr_ should be removed by reset() during set_params(foo=...)"