diff --git a/README.md b/README.md index 293e889..e7f9207 100644 --- a/README.md +++ b/README.md @@ -205,9 +205,10 @@ expect(mock()).toBe(undefined) import type { WhenOptions } from 'vitest-when' ``` -| option | default | type | description | -| ------- | ------- | ------- | -------------------------------------------------- | -| `times` | N/A | integer | Only trigger configured behavior a number of times | +| option | default | type | description | +| ----------------- | ------- | ------- | -------------------------------------------------- | +| `ignoreExtraArgs` | `false` | boolean | Ignore extra arguments when matching arguments | +| `times` | N/A | integer | Only trigger configured behavior a number of times | ### `.calledWith(...args: Parameters): Stub` @@ -306,6 +307,17 @@ expect(mock('hello')).toEqual('sup?') expect(mock('hello')).toEqual('sup?') ``` +You can also ignore extra arguments when matching arguments. + +```ts +const mock = when(vi.fn(), { ignoreExtraArgs: true }) + .calledWith('hello') + .thenReturn('world') + +expect(mock('hello')).toEqual('world') +expect(mock('hello', 'jello')).toEqual('world') +``` + ### `.thenResolve(value: TReturn) -> Mock` When the stubbing is satisfied, resolve a `Promise` with `value` diff --git a/src/behaviors.ts b/src/behaviors.ts index e6526a7..49b6ec9 100644 --- a/src/behaviors.ts +++ b/src/behaviors.ts @@ -6,10 +6,10 @@ import type { AsFunction, ParametersOf, ReturnTypeOf, - WithMatchers, } from './types.ts' export interface WhenOptions { + ignoreExtraArgs?: boolean times?: number } @@ -22,10 +22,7 @@ export interface BehaviorStack { getUnmatchedCalls: () => readonly ParametersOf[] - bindArgs: ( - args: WithMatchers>, - options: WhenOptions, - ) => BoundBehaviorStack + bindArgs: (args: unknown[], options: WhenOptions) => BoundBehaviorStack } export interface BoundBehaviorStack { @@ -37,9 +34,10 @@ export interface BoundBehaviorStack { } export interface BehaviorEntry { - args: WithMatchers + args: unknown[] behavior: Behavior calls: TArgs[] + ignoreExtraArgs: boolean maxCallCount?: number | undefined } @@ -60,6 +58,7 @@ export type Behavior = export interface BehaviorOptions { value: TValue + ignoreExtraArgs: boolean maxCallCount: number | undefined } @@ -76,7 +75,7 @@ export const createBehaviorStack = < use: (args) => { const behavior = behaviors - .filter((b) => behaviorAvailable(b)) + .filter(behaviorAvailable) .find(behaviorMatches(args)) if (!behavior) { @@ -92,8 +91,9 @@ export const createBehaviorStack = < addReturn: (values) => { behaviors.unshift( ...getBehaviorOptions(values, options).map( - ({ value, maxCallCount }) => ({ + ({ value, ignoreExtraArgs, maxCallCount }) => ({ args, + ignoreExtraArgs, maxCallCount, behavior: { type: BehaviorType.RETURN, value }, calls: [], @@ -104,8 +104,9 @@ export const createBehaviorStack = < addResolve: (values) => { behaviors.unshift( ...getBehaviorOptions(values, options).map( - ({ value, maxCallCount }) => ({ + ({ value, ignoreExtraArgs, maxCallCount }) => ({ args, + ignoreExtraArgs, maxCallCount, behavior: { type: BehaviorType.RESOLVE, value }, calls: [], @@ -116,8 +117,9 @@ export const createBehaviorStack = < addThrow: (values) => { behaviors.unshift( ...getBehaviorOptions(values, options).map( - ({ value, maxCallCount }) => ({ + ({ value, ignoreExtraArgs, maxCallCount }) => ({ args, + ignoreExtraArgs, maxCallCount, behavior: { type: BehaviorType.THROW, error: value }, calls: [], @@ -128,8 +130,9 @@ export const createBehaviorStack = < addReject: (values) => { behaviors.unshift( ...getBehaviorOptions(values, options).map( - ({ value, maxCallCount }) => ({ + ({ value, ignoreExtraArgs, maxCallCount }) => ({ args, + ignoreExtraArgs, maxCallCount, behavior: { type: BehaviorType.REJECT, error: value }, calls: [], @@ -140,8 +143,9 @@ export const createBehaviorStack = < addDo: (values) => { behaviors.unshift( ...getBehaviorOptions(values, options).map( - ({ value, maxCallCount }) => ({ + ({ value, ignoreExtraArgs, maxCallCount }) => ({ args, + ignoreExtraArgs, maxCallCount, behavior: { type: BehaviorType.DO, @@ -158,7 +162,7 @@ export const createBehaviorStack = < const getBehaviorOptions = ( values: TValue[], - { times }: WhenOptions, + { ignoreExtraArgs, times }: WhenOptions, ): BehaviorOptions[] => { if (values.length === 0) { values = [undefined as TValue] @@ -166,6 +170,7 @@ const getBehaviorOptions = ( return values.map((value, index) => ({ value, + ignoreExtraArgs: ignoreExtraArgs ?? false, maxCallCount: times ?? (index < values.length - 1 ? 1 : undefined), })) } @@ -179,18 +184,19 @@ const behaviorAvailable = ( ) } -const behaviorMatches = (args: TArgs) => { - return (behavior: BehaviorEntry): boolean => { - let index = 0 +const behaviorMatches = (actualArgs: TArgs) => { + return (behaviorEntry: BehaviorEntry): boolean => { + const { args: expectedArgs, ignoreExtraArgs } = behaviorEntry + const isArgsLengthMatch = ignoreExtraArgs + ? expectedArgs.length <= actualArgs.length + : expectedArgs.length === actualArgs.length - while (index < args.length || index < behavior.args.length) { - if (!equals(args[index], behavior.args[index])) { - return false - } - - index += 1 + if (!isArgsLengthMatch) { + return false } - return true + return expectedArgs.every((expected, index) => + equals(actualArgs[index], expected), + ) } } diff --git a/src/types.ts b/src/types.ts index 916c8dc..4f679d0 100644 --- a/src/types.ts +++ b/src/types.ts @@ -49,6 +49,16 @@ export type ParametersOf = ? Parameters : never +/** An arguments list, optionally without every argument specified */ +export type ArgumentsSpec< + TArgs extends any[], + TOptions extends { ignoreExtraArgs?: boolean } | undefined, +> = TOptions extends { ignoreExtraArgs: true } + ? TArgs extends [infer Head, ...infer Tail] + ? [] | [Head] | [Head, ...ArgumentsSpec] + : TArgs + : TArgs + /** Extract return type from either a function or constructor */ export type ReturnTypeOf = TFunc extends AnyConstructor diff --git a/src/vitest-when.ts b/src/vitest-when.ts index ad35033..0a6f549 100644 --- a/src/vitest-when.ts +++ b/src/vitest-when.ts @@ -3,6 +3,7 @@ import { type DebugResult, getDebug } from './debug.ts' import { asMock, configureMock, validateMock } from './stubs.ts' import type { AnyMockable, + ArgumentsSpec, AsFunction, Mock, MockInstance, @@ -16,8 +17,11 @@ export { type Behavior, BehaviorType, type WhenOptions } from './behaviors.ts' export type { DebugResult, Stubbing } from './debug.ts' export * from './errors.ts' -export interface StubWrapper { - calledWith>( +export interface StubWrapper< + TFunc extends AnyMockable, + TOptions extends WhenOptions | undefined, +> { + calledWith, TOptions>>( ...args: WithMatchers ): Stub } @@ -30,17 +34,20 @@ export interface Stub { thenDo: (...callbacks: AsFunction[]) => Mock } -export const when = ( +export const when = < + TFunc extends AnyMockable, + TOptions extends WhenOptions | undefined = undefined, +>( mock: TFunc | MockInstance, - options: WhenOptions = {}, -): StubWrapper> => { + options?: TOptions, +): StubWrapper, TOptions> => { const validatedMock = validateMock(mock) const behaviorStack = configureMock(validatedMock) const result = asMock(validatedMock) return { calledWith: (...args) => { - const behaviors = behaviorStack.bindArgs(args, options) + const behaviors = behaviorStack.bindArgs(args as unknown[], options ?? {}) return { thenReturn: (...values) => { diff --git a/test/fixtures.ts b/test/fixtures.ts index d5e7cbd..84f44ae 100644 --- a/test/fixtures.ts +++ b/test/fixtures.ts @@ -17,6 +17,10 @@ export async function simpleAsync(input: number): Promise { throw new Error(`simpleAsync(${input})`) } +export function multipleArgs(a: number, b: string, c: boolean): string { + throw new Error(`multipleArgs(${a}, ${b}, ${c})`) +} + export function complex(input: { a: number; b: string }): string { throw new Error(`simple({ a: ${input.a}, b: ${input.b} })`) } diff --git a/test/typing.test-d.ts b/test/typing.test-d.ts index 493b893..bbc94ae 100644 --- a/test/typing.test-d.ts +++ b/test/typing.test-d.ts @@ -17,6 +17,7 @@ import * as subject from '../src/vitest-when.ts' import { complex, generic, + multipleArgs, overloaded, simple, simpleAsync, @@ -45,6 +46,37 @@ describe('vitest-when type signatures', () => { >() }) + it('should handle fewer than required arguments', () => { + subject.when(multipleArgs, { ignoreExtraArgs: true }).calledWith(42) + + subject + .when(multipleArgs, { ignoreExtraArgs: true }) + .calledWith(42, 'hello') + + subject + .when(multipleArgs, { ignoreExtraArgs: true }) + .calledWith(42, 'hello', true) + + subject + .when(multipleArgs, { ignoreExtraArgs: true }) + // @ts-expect-error: too many arguments + .calledWith(42, 'hello', true, 'oh no') + }) + + it('supports using matchers with ignoreExtraArgs', () => { + subject + .when(multipleArgs, { ignoreExtraArgs: true }) + .calledWith(expect.any(Number)) + + subject + .when(multipleArgs, { ignoreExtraArgs: true }) + .calledWith(expect.any(Number), expect.any(String)) + + subject + .when(multipleArgs, { ignoreExtraArgs: true }) + .calledWith(expect.any(Number), expect.any(String), expect.any(Boolean)) + }) + it('returns mock type for then resolve', () => { const result = subject.when(simpleAsync).calledWith(1).thenResolve('hello') diff --git a/test/vitest-when.test.ts b/test/vitest-when.test.ts index c04758d..92e49ac 100644 --- a/test/vitest-when.test.ts +++ b/test/vitest-when.test.ts @@ -279,4 +279,38 @@ describe('vitest-when', () => { // intentionally do not call the spy expect(true).toBe(true) }) + + it.each([ + { stubArgs: [] as unknown[], callArgs: [] as unknown[] }, + { stubArgs: [], callArgs: ['a'] }, + { stubArgs: [], callArgs: ['a', 'b'] }, + { stubArgs: ['a'], callArgs: ['a'] }, + { stubArgs: ['a'], callArgs: ['a', 'b'] }, + { stubArgs: ['a', 'b'], callArgs: ['a', 'b'] }, + ])( + 'matches call $callArgs against stub $stubArgs args with ignoreExtraArgs', + ({ stubArgs, callArgs }) => { + const spy = subject + .when(vi.fn().mockReturnValue('failure'), { ignoreExtraArgs: true }) + .calledWith(...stubArgs) + .thenReturn('success') + + expect(spy(...callArgs)).toEqual('success') + }, + ) + + it.each([ + { stubArgs: ['a'] as unknown[], callArgs: ['b'] as unknown[] }, + { stubArgs: [undefined], callArgs: [] }, + ])( + 'does not match call $callArgs against stub $stubArgs with ignoreExtraArgs', + ({ stubArgs, callArgs }) => { + const spy = subject + .when(vi.fn().mockReturnValue('success'), { ignoreExtraArgs: true }) + .calledWith(...stubArgs) + .thenReturn('failure') + + expect(spy(...callArgs)).toBe('success') + }, + ) })