Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -977,15 +977,15 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
actualFmt = "(Missing)"
} else {
actual = objects[i]
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
actualFmt = safeFormatArg(actual)
}

if len(args) <= i {
expected = "(Missing)"
expectedFmt = "(Missing)"
} else {
expected = args[i]
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
expectedFmt = safeFormatArg(expected)
}

if matcher, ok := expected.(argumentMatcher); ok {
Expand Down Expand Up @@ -1310,3 +1310,27 @@ func isFuncSame(f1, f2 *runtime.Func) bool {

return f1File == f2File && f1Loc == f2Loc
}

// safeFormatArg formats an argument for display in diff output, handling
// concurrent modification safely. For types that contain inherent references
// (pointers, maps, slices, channels), fmt.Sprintf("%v") would traverse the
// underlying data structure, which races with concurrent writers. For maps
// this is a non-recoverable fatal error; for pointers and slices it triggers
// data races detectable by -race.
//
// To avoid this, reference types are formatted with %p (address only) instead
// of %v. All other types (structs, primitives, strings) are value types that
// cannot race, so they use the full %v representation.
//
// See https://github.com/stretchr/testify/issues/1597
func safeFormatArg(v interface{}) string {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a testsuite for this new function.

Also, in its current implementation it is safer but no yet "safe". So the name is misleading.

if v == nil {
return fmt.Sprintf("(%[1]T=%[1]v)", v)
}
switch reflect.TypeOf(v).Kind() {
case reflect.Ptr, reflect.Map, reflect.Slice, reflect.Chan:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace reflect.Ptr (deprecated) with reflect.Pointer.

Other types must also be handled. In particular reflect.Array, reflect.Interface. We definitely need a comprehensive testsuite.

return fmt.Sprintf("(%T=%p)", v, v)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked the output of safeFormatArg([]int(nil))?
We need a comprehensive test suite.

Also use [1] like in other paths.

default:
return fmt.Sprintf("(%[1]T=%[1]v)", v)
}
}
81 changes: 78 additions & 3 deletions mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,81 @@ func Test_Arguments_Diff_WithArgMatcher(t *testing.T) {
assert.Contains(t, diff, `No differences.`)
}

// Test_Arguments_Diff_ConcurrentPointerModification verifies that
// Arguments.Diff does not race when a pointer argument is concurrently
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 3 tests have much common code. Please refactor to extract a func(t *testing.T, arg interface{}, accessArg func()).

Also group them as a top level test with 3 subtests.

// modified. This is a regression test for https://github.com/stretchr/testify/issues/1597.
// Adapted from https://github.com/stretchr/testify/pull/1598.
func Test_Arguments_Diff_ConcurrentPointerModification(t *testing.T) {
t.Parallel()

type data struct {
Value string
}

arg := &data{Value: "original"}
args := Arguments([]interface{}{Anything})

done := make(chan struct{})
go func() {
defer close(done)
for i := 0; i < 1000; i++ {
arg.Value = fmt.Sprintf("modified-%d", i)
}
}()

// Without the fix, this races with the goroutine above because
// fmt.Sprintf("%v") traverses *data while the goroutine writes to it.
for i := 0; i < 100; i++ {
args.Diff([]interface{}{arg})
}
<-done
}

// Test_Arguments_Diff_ConcurrentMapModification verifies that Arguments.Diff
// does not race when a map argument is concurrently modified.
// Raised by @brackendawson in https://github.com/stretchr/testify/pull/1598#discussion_r1869482623.
func Test_Arguments_Diff_ConcurrentMapModification(t *testing.T) {
t.Parallel()

arg := map[string]string{"key": "original"}
args := Arguments([]interface{}{Anything})

done := make(chan struct{})
go func() {
defer close(done)
for i := 0; i < 1000; i++ {
arg["key"] = fmt.Sprintf("modified-%d", i)
}
}()

for i := 0; i < 100; i++ {
args.Diff([]interface{}{arg})
}
<-done
}

// Test_Arguments_Diff_ConcurrentSliceModification verifies that Arguments.Diff
// does not race when a slice argument is concurrently modified.
func Test_Arguments_Diff_ConcurrentSliceModification(t *testing.T) {
t.Parallel()

arg := []string{"original"}
args := Arguments([]interface{}{Anything})

done := make(chan struct{})
go func() {
defer close(done)
for i := 0; i < 1000; i++ {
arg[0] = fmt.Sprintf("modified-%d", i)
}
}()

for i := 0; i < 100; i++ {
args.Diff([]interface{}{arg})
}
<-done
}

func Test_Arguments_Assert(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -2271,7 +2346,7 @@ func TestArgumentMatcherToPrintMismatchWithReferenceType(t *testing.T) {
defer func() {
if r := recover(); r != nil {
matchingExp := regexp.MustCompile(
`\s+mock: Unexpected Method Call\s+-*\s+GetTimes\(\[\]int\)\s+0: \[\]int\{1\}\s+The closest call I have is:\s+GetTimes\(mock.argumentMatcher\)\s+0: mock.argumentMatcher\{.*?\}\s+Diff:.*\(\[\]int=\[1\]\) not matched by func\(\[\]int\) bool\nat: \[[^\]]+mock\/mock_test.go`)
`\s+mock: Unexpected Method Call\s+-*\s+GetTimes\(\[\]int\)\s+0: \[\]int\{1\}\s+The closest call I have is:\s+GetTimes\(mock.argumentMatcher\)\s+0: mock.argumentMatcher\{.*?\}\s+Diff:.*\(\[\]int=0x[0-9a-f]+\) not matched by func\(\[\]int\) bool\nat: \[[^\]]+mock\/mock_test.go`)
assert.Regexp(t, matchingExp, r)
}
}()
Expand Down Expand Up @@ -2306,7 +2381,7 @@ func TestClosestCallFavorsFirstMock(t *testing.T) {

defer func() {
if r := recover(); r != nil {
diffRegExp := `Difference found in argument 0:\s+--- Expected\s+\+\+\+ Actual\s+@@ -2,4 \+2,4 @@\s+\(bool\) true,\s+- \(bool\) true,\s+- \(bool\) true\s+\+ \(bool\) false,\s+\+ \(bool\) false\s+}\s+Diff: 0: FAIL: \(\[\]bool=\[(true\s?|false\s?){3}]\) != \(\[\]bool=\[(true\s?|false\s?){3}\]\)`
diffRegExp := `Difference found in argument 0:\s+--- Expected\s+\+\+\+ Actual\s+@@ -2,4 \+2,4 @@\s+\(bool\) true,\s+- \(bool\) true,\s+- \(bool\) true\s+\+ \(bool\) false,\s+\+ \(bool\) false\s+}\s+Diff: 0: FAIL: \(\[\]bool=0x[0-9a-f]+\) != \(\[\]bool=0x[0-9a-f]+\)`
matchingExp := regexp.MustCompile(unexpectedCallRegex(`TheExampleMethod7([]bool)`, `0: \[\]bool{true, false, false}`, `0: \[\]bool{true, true, true}`, diffRegExp))
assert.Regexp(t, matchingExp, r)
}
Expand All @@ -2324,7 +2399,7 @@ func TestClosestCallUsesRepeatabilityToFindClosest(t *testing.T) {

defer func() {
if r := recover(); r != nil {
diffRegExp := `Difference found in argument 0:\s+--- Expected\s+\+\+\+ Actual\s+@@ -1,4 \+1,4 @@\s+\(\[\]bool\) \(len=3\) {\s+- \(bool\) false,\s+- \(bool\) false,\s+\+ \(bool\) true,\s+\+ \(bool\) true,\s+\(bool\) false\s+Diff: 0: FAIL: \(\[\]bool=\[(true\s?|false\s?){3}]\) != \(\[\]bool=\[(true\s?|false\s?){3}\]\)`
diffRegExp := `Difference found in argument 0:\s+--- Expected\s+\+\+\+ Actual\s+@@ -1,4 \+1,4 @@\s+\(\[\]bool\) \(len=3\) {\s+- \(bool\) false,\s+- \(bool\) false,\s+\+ \(bool\) true,\s+\+ \(bool\) true,\s+\(bool\) false\s+Diff: 0: FAIL: \(\[\]bool=0x[0-9a-f]+\) != \(\[\]bool=0x[0-9a-f]+\)`
matchingExp := regexp.MustCompile(unexpectedCallRegex(`TheExampleMethod7([]bool)`, `0: \[\]bool{true, true, false}`, `0: \[\]bool{false, false, false}`, diffRegExp))
assert.Regexp(t, matchingExp, r)
}
Expand Down