Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,30 @@ func AnythingOfType(t string) AnythingOfTypeArgument {
return anythingOfTypeArgument(t)
}

type anythingImplementing struct {
interfaceType reflect.Type
}

func (a *anythingImplementing) isImplementedBy(val interface{}) bool {
t2 := reflect.TypeOf(val)

return t2.Implements(a.interfaceType)
}

// AnythingImplementing is just like AnythingOfType, but instead of checking against a concrete type, it is used to check if a value is of a type that implements a given interface
//
// For example, for checking if a value implements the context.Context interface:
//
// var args = Arguments([]interface{}{AnythingImplementing((*context.Context)(nil))})
// args.Assert(t, AnythingImplementing(context.Background())
func AnythingImplementing(val interface{}) anythingImplementing {
// Get the dynamic type
t := reflect.TypeOf(val)
interfaceType := t.Elem()

return anythingImplementing{interfaceType: interfaceType}
}

// IsTypeArgument is a struct that contains the type of an argument
// for use when type checking. This is an alternative to [AnythingOfType].
// Used in [Arguments.Diff] and [Arguments.Assert].
Expand Down Expand Up @@ -1013,6 +1037,12 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
}
case anythingImplementing:
expectedToImplement := expected.interfaceType
if !expected.isImplementedBy(actual) {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: value of type %T does not implement interface %s\n", output, i, actual, expectedToImplement)
}
case *IsTypeArgument:
actualT := reflect.TypeOf(actual)
if actualT != expected.t {
Expand Down
41 changes: 41 additions & 0 deletions mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,18 @@ func Test_Mock_AssertCalled_WithAnythingOfTypeArgument(t *testing.T) {

}

func Test_Mock_AssertCalled_WithAnythingImplementingArgument(t *testing.T) {
t.Parallel()

var mockedService = new(TestExampleImplementation)

mockedService.
On("TheExampleMethod4", AnythingImplementing((*ExampleInterface)(nil))).
Return(nil)

mockedService.TheExampleMethod4(mockedService)
}

func Test_Mock_AssertCalled_WithArguments(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1978,6 +1990,35 @@ func Test_Arguments_Diff_WithAnythingOfTypeArgument_Failing(t *testing.T) {

}

func Test_Arguments_Diff_WithAnythingImplementingArgument(t *testing.T) {
t.Parallel()

var args = Arguments([]interface{}{AnythingImplementing((*ExampleInterface)(nil))})

var mockedService = new(TestExampleImplementation)
var count int
_, count = args.Diff([]interface{}{mockedService})

assert.True(t, args.Assert(t, mockedService))
assert.Equal(t, 0, count)
}

func Test_Arguments_Diff_WithAnythingImplementingArgument_Failing(t *testing.T) {
t.Parallel()

var args = Arguments([]interface{}{
AnythingImplementing((*ExampleInterface)(nil)),
})
var count int
var diff string
intVal := 123
diff, count = args.Diff([]interface{}{intVal})

assert.Equal(t, 1, count)
assert.Contains(t, diff, `value of type int does not implement interface mock.ExampleInterface`)

}

func Test_Arguments_Diff_WithIsTypeArgument(t *testing.T) {
t.Parallel()

Expand Down