Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def _set_params(self, attr: str, **params):
Self
Instance of self.
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self

# Ensure strict ordering of parameter setting:
# 1. All steps
if attr in params:
Expand All @@ -245,7 +249,10 @@ def _set_params(self, attr: str, **params):
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_object(attr, name, params.pop(name))
# 3. Step parameters and other initialisation arguments

# 3. Process remaining parameters and apply reset consistently with BaseObject
# Call super().set_params() which will handle reset and nested parameter
# processing
super().set_params(**params) # type: ignore
return self

Expand Down
35 changes: 12 additions & 23 deletions skbase/lookup/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
Parameters
----------
obj : BaseObject, an sktime estimator
tag_filter : dict of (str or list of str), default=None
tag_filter : str, list[str] or dict of (str or list of str), default=None
subsets the returned estimators as follows:
each key/value pair is statement in "and"/conjunction

Expand All @@ -190,34 +190,23 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
if tag_filter is None:
return True

type_msg = (
"filter_tags argument of all_objects must be "
"a dict with str or re.Pattern keys, "
"str, or iterable of str, "
"but found"
)

if not isinstance(tag_filter, (str, Iterable, dict)):
raise TypeError(f"{type_msg} type {type(tag_filter)}")

if not hasattr(obj, "get_class_tag"):
return False

# case: tag_filter is string
# Handle backward compatibility - convert str/list/tuple to dict
if isinstance(tag_filter, str):
tag_filter = {tag_filter: True}

# case: tag_filter is iterable of str but not dict
# If a iterable of strings is provided, check that all are in the returned tag_dict
if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
if not all(isinstance(t, str) for t in tag_filter):
raise ValueError(f"{type_msg} {tag_filter}")
elif isinstance(tag_filter, (list, tuple)):
# Check if all elements are strings (original error handling)
if not all(isinstance(tag, str) for tag in tag_filter):
raise ValueError("filter_tags")
tag_filter = dict.fromkeys(tag_filter, True)
elif not isinstance(tag_filter, dict):
raise TypeError("filter_tags")

if not hasattr(obj, "get_class_tag"):
return False

# case: tag_filter is dict
# check that all keys are str
if not all(isinstance(t, str) for t in tag_filter.keys()):
raise ValueError(f"{type_msg} {tag_filter}")
raise ValueError("filter_tags")

cond_sat = True

Expand Down
213 changes: 213 additions & 0 deletions skbase/lookup/tests/test_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,3 +1071,216 @@ def test_all_object_class_lookup_invalid_object_types_raises(
object_types=class_filter,
class_lookup=class_lookup,
)


# ==============================================================================
# ADDITIONAL TESTS FOR EDGE CASES AND ERROR HANDLING
# ==============================================================================


def test_all_objects_filter_tags_string_preprocessing():
"""Test all_objects converts string filter_tags to dict correctly."""
# Test string input conversion
objs_str = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags="A",
)

objs_dict = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags={"A": True},
)

# Results should be identical
assert objs_str.equals(
objs_dict
), "String and dict filter should return same results"


def test_all_objects_filter_tags_list_preprocessing():
"""Test all_objects converts list filter_tags to dict correctly."""
# Test list of strings input conversion
objs_list = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags=["A", "B"],
)

objs_dict = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags={"A": True, "B": True},
)

# Results should be identical
assert objs_list.equals(
objs_dict
), "List and dict filter should return same results"


def test_all_objects_filter_tags_tuple_preprocessing():
"""Test all_objects converts tuple filter_tags to dict correctly."""
# Test tuple of strings input conversion
objs_tuple = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags=("A", "B"),
)

objs_dict = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags={"A": True, "B": True},
)

# Results should be identical
assert objs_tuple.equals(
objs_dict
), "Tuple and dict filter should return same results"


def test_get_package_metadata_filter_tags_string_preprocessing():
"""Test get_package_metadata converts string tag_filter to dict correctly."""
result_str = get_package_metadata(
"skbase",
modules_to_ignore="skbase",
tag_filter="A",
classes_to_exclude=TagAliaserMixin,
)

result_dict = get_package_metadata(
"skbase",
modules_to_ignore="skbase",
tag_filter={"A": True},
classes_to_exclude=TagAliaserMixin,
)

# Results should be identical
assert result_str.keys() == result_dict.keys()


def test_get_package_metadata_filter_tags_list_preprocessing():
"""Test get_package_metadata converts list tag_filter to dict correctly."""
result_list = get_package_metadata(
"skbase",
modules_to_ignore="skbase",
tag_filter=["A", "B"],
classes_to_exclude=TagAliaserMixin,
)

result_dict = get_package_metadata(
"skbase",
modules_to_ignore="skbase",
tag_filter={"A": True, "B": True},
classes_to_exclude=TagAliaserMixin,
)

# Results should be identical
assert result_list.keys() == result_dict.keys()


@pytest.mark.parametrize(
"invalid_filter",
[
123, # int
12.5, # float
object(), # object
["A", 123], # list with non-string
("A", 123), # tuple with non-string
],
)
def test_all_objects_filter_tags_invalid_types_preprocessing(invalid_filter):
"""Test that invalid filter_tags types raise TypeError in all_objects."""
with pytest.raises(
TypeError, match="filter_tags must be a str, list of str, or dict"
):
all_objects(
package_name="skbase",
filter_tags=invalid_filter,
)


@pytest.mark.parametrize(
"invalid_filter",
[
123, # int
12.5, # float
object(), # object
["A", 123], # list with non-string
("A", 123), # tuple with non-string
],
)
def test_get_package_metadata_filter_tags_invalid_types_preprocessing(invalid_filter):
"""Test that invalid tag_filter types raise TypeError in get_package_metadata."""
with pytest.raises(
TypeError, match="tag_filter must be a str, list of str, or dict"
):
get_package_metadata(
"skbase",
tag_filter=invalid_filter,
)


def test_all_objects_filter_tags_empty_list():
"""Test all_objects handles empty list filter_tags correctly."""
objs_empty_list = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags=[],
)

objs_empty_dict = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags={},
)

# Results should be identical
assert objs_empty_list.equals(
objs_empty_dict
), "Empty list and empty dict should return same results"


def test_filter_by_tags_dict_not_modified():
"""Test that _filter_by_tags doesn't modify the original dict in place."""
original_filter = {"A": "1"}
original_copy = original_filter.copy()

# Call all_objects with the filter
all_objects(
package_name="skbase",
filter_tags=original_filter,
)

# Original dict should be unchanged
assert (
original_filter == original_copy
), "Original filter_tags dict should not be modified"


def test_get_package_metadata_filter_tags_dict_copy_behavior():
"""Test that tag_filter dict is copied and not modified in place."""
original_filter = {"A": "1"}
original_copy = original_filter.copy()

# Call get_package_metadata with the filter
get_package_metadata(
"skbase",
tag_filter=original_filter,
classes_to_exclude=TagAliaserMixin,
)

# Original dict should be unchanged
assert (
original_filter == original_copy
), "Original tag_filter dict should not be modified"
78 changes: 78 additions & 0 deletions skbase/tests/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,81 @@ def test_metaestimator_composite(long_steps):

meta_est.set_params(bar__b="something else")
assert meta_est.get_params()["bar__b"] == "something else"


def test_meta_object_reset_consistency():
"""Test that BaseMetaObject resets
consistently with BaseObject during set_params."""
# Test that BaseMetaObject resets on set_params call like BaseObject
meta_obj = MetaObjectTester(a=1, b="test", steps=[])

# Add attributes that should be removed by reset
meta_obj.some_attribute = "test"
meta_obj.fitted_attribute_ = "fitted"

assert hasattr(meta_obj, "some_attribute")
assert hasattr(meta_obj, "fitted_attribute_")

# Set parameters - this should trigger reset
meta_obj.set_params(a=3)

# Check that reset occurred
assert not hasattr(meta_obj, "some_attribute")
assert not hasattr(meta_obj, "fitted_attribute_")
assert meta_obj.a == 3


def test_meta_object_reset_with_steps():
"""Test that BaseMetaObject resets correctly
when setting steps and step parameters."""
step1 = ComponentDummy(a=100, b="step1")
step2 = ComponentDummy(a=300, b="step2")

meta_obj = MetaObjectTester(
a=1, b="main", steps=[("old_step", ComponentDummy(a=999, b="old"))]
)

# Add attributes that should be removed by reset
meta_obj.some_attribute = "test"
meta_obj.fitted_attribute_ = "fitted"

assert hasattr(meta_obj, "some_attribute")
assert hasattr(meta_obj, "fitted_attribute_")

# Set both steps parameter and step-specific parameters
new_steps = [("step1", step1), ("step2", step2)]
meta_obj.set_params(steps=new_steps, step1__a=500)

# Check that reset occurred
assert not hasattr(meta_obj, "some_attribute")
assert not hasattr(meta_obj, "fitted_attribute_")

# Check that parameters were set correctly
assert meta_obj.a == 1 # Should remain unchanged
assert len(meta_obj.steps) == 2
assert meta_obj.steps[0][0] == "step1"
assert meta_obj.steps[0][1].a == 500 # Should be modified by step1__a parameter
assert meta_obj.steps[1][0] == "step2"
assert meta_obj.steps[1][1].a == 300 # Should remain unchanged


def test_meta_estimator_reset_consistency():
"""Test that BaseMetaEstimator resets
consistently with BaseObject during set_params."""
# Test that BaseMetaEstimator resets on set_params call like BaseObject
meta_est = MetaEstimatorTester(a=1, b="test", steps=[])

# Add attributes that should be removed by reset
meta_est.some_attribute = "test"
meta_est.fitted_attribute_ = "fitted"

assert hasattr(meta_est, "some_attribute")
assert hasattr(meta_est, "fitted_attribute_")

# Set parameters - this should trigger reset
meta_est.set_params(a=3)

# Check that reset occurred
assert not hasattr(meta_est, "some_attribute")
assert not hasattr(meta_est, "fitted_attribute_")
assert meta_est.a == 3