diff --git a/skbase/tests/test_deep_equals.py b/skbase/tests/test_deep_equals.py new file mode 100644 index 00000000..b361c709 --- /dev/null +++ b/skbase/tests/test_deep_equals.py @@ -0,0 +1,13 @@ +def test_deep_equals_numpy_string_array(): + """Test that deep_equals works on string-dtype numpy arrays. GitHub issue #517.""" + import numpy as np + + from skbase.utils.deep_equals import deep_equals + + s = np.array(["1", "2", "1"]) + # Same arrays should be equal + assert deep_equals(s, s) is True + + # Different arrays should not be equal + t = np.array(["1", "2", "3"]) + assert deep_equals(s, t) is False diff --git a/skbase/utils/deep_equals/_deep_equals.py b/skbase/utils/deep_equals/_deep_equals.py index cf73a5a7..ab72dfb9 100644 --- a/skbase/utils/deep_equals/_deep_equals.py +++ b/skbase/utils/deep_equals/_deep_equals.py @@ -142,7 +142,7 @@ def _numpy_equals_plugin(x, y, return_msg=False, deep_equals=None): return ret(False, f".shape, x.shape = {x.shape} != y.shape = {y.shape}") if x.dtype != y.dtype: return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}") - if x.dtype == "str": + if x.dtype == "str" or np.issubdtype(x.dtype, np.character): return ret(np.array_equal(x, y), ".values") elif x.dtype == "object": x_flat = x.flatten()