Skip to content
13 changes: 13 additions & 0 deletions skbase/tests/test_deep_equals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
def test_deep_equals_numpy_string_array():
"""Test that deep_equals works on string-dtype numpy arrays. GitHub issue #517."""
import numpy as np

from skbase.utils.deep_equals import deep_equals

s = np.array(["1", "2", "1"])
# Same arrays should be equal
assert deep_equals(s, s) is True

# Different arrays should not be equal
t = np.array(["1", "2", "3"])
assert deep_equals(s, t) is False
2 changes: 1 addition & 1 deletion skbase/utils/deep_equals/_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _numpy_equals_plugin(x, y, return_msg=False, deep_equals=None):
return ret(False, f".shape, x.shape = {x.shape} != y.shape = {y.shape}")
if x.dtype != y.dtype:
return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
if x.dtype == "str":
if x.dtype == "str" or np.issubdtype(x.dtype, np.character):
return ret(np.array_equal(x, y), ".values")
elif x.dtype == "object":
x_flat = x.flatten()
Expand Down
Loading