diff --git a/gui/src/forms/AddModelForm.dynamicFetchRace.test.tsx b/gui/src/forms/AddModelForm.dynamicFetchRace.test.tsx new file mode 100644 index 00000000000..2dc7bc70d8f --- /dev/null +++ b/gui/src/forms/AddModelForm.dynamicFetchRace.test.tsx @@ -0,0 +1,207 @@ +import { act, screen, waitFor } from "@testing-library/react"; +import { describe, expect, it, vi, beforeEach } from "vitest"; +import { AddModelForm } from "./AddModelForm"; +import { renderWithProviders } from "../util/test/render"; + +const fetchProviderModelsMock = vi.hoisted(() => vi.fn()); +const initializeDynamicModelsMock = vi.hoisted(() => vi.fn(async () => {})); + +vi.mock("../pages/AddNewModel/configs/fetchProviderModels", () => ({ + fetchProviderModels: fetchProviderModelsMock, + initializeDynamicModels: initializeDynamicModelsMock, +})); + +vi.mock("../components/modelSelection/ModelSelectionListbox", () => ({ + default: ({ + selectedProvider, + setSelectedProvider, + topOptions = [], + otherOptions = [], + searchPlaceholder, + }: any) => ( +
+
+ {selectedProvider.title} +
+
+ {topOptions.map((option: any) => ( + + ))} +
+
+ {otherOptions.map((option: any) => ( +
+ {option.title} +
+ ))} +
+
+ ), +})); + +function deferred() { + let resolve!: (value: T) => void; + const promise = new Promise((res) => { + resolve = res; + }); + return { promise, resolve }; +} + +function makeFetchedModel(title: string) { + return { + title, + description: title, + params: { + title, + model: title.toLowerCase().replace(/\s+/g, "-"), + }, + isOpenSource: false, + providerOptions: ["openai"], + }; +} + +describe("AddModelForm dynamic fetch race", () => { + beforeEach(() => { + fetchProviderModelsMock.mockReset(); + initializeDynamicModelsMock.mockClear(); + }); + + it("does not leak stale models from the previous provider if the earlier fetch resolves late", async () => { + const pendingOpenAiFetch = deferred(); + + fetchProviderModelsMock.mockImplementation( + (_messenger: unknown, provider: string) => { + if (provider === "openai") { + return pendingOpenAiFetch.promise; + } + return Promise.resolve([]); + }, + ); + + const { user } = await renderWithProviders( + , + ); + + await user.type( + screen.getByPlaceholderText(/Enter your OpenAI API key/i), + "sk-test-key", + ); + await user.click(screen.getByTitle(/fetch available models/i)); + + await user.click(screen.getByRole("button", { name: "Anthropic" })); + expect(screen.getByTestId("provider-current")).toHaveTextContent( + "Anthropic", + ); + + await act(async () => { + pendingOpenAiFetch.resolve([makeFetchedModel("OpenAI Dynamic Model")]); + await pendingOpenAiFetch.promise; + }); + + await waitFor(() => { + expect(screen.getByTestId("provider-current")).toHaveTextContent( + "Anthropic", + ); + expect(screen.getByTestId("model-listbox")).not.toHaveTextContent( + "OpenAI Dynamic Model", + ); + }); + }); + + it("clears the previous provider API key before a newly selected keyed provider can fetch models", async () => { + fetchProviderModelsMock.mockResolvedValue([]); + + const { user } = await renderWithProviders( + , + ); + + const apiKeyInput = screen.getByPlaceholderText( + /Enter your OpenAI API key/i, + ); + await user.type(apiKeyInput, "sk-openai-secret"); + expect(apiKeyInput).toHaveValue("sk-openai-secret"); + + await user.click(screen.getByRole("button", { name: "Anthropic" })); + + const anthropicApiKeyInput = screen.getByPlaceholderText( + /Enter your Anthropic API key/i, + ); + expect(anthropicApiKeyInput).toHaveValue(""); + + const fetchButton = screen.getByTitle(/fetch available models/i); + await user.click(fetchButton); + + expect(fetchProviderModelsMock).not.toHaveBeenCalled(); + }); + + it("releases the previous provider fetch lock so the newly selected provider can fetch immediately", async () => { + const pendingOpenAiFetch = deferred(); + + fetchProviderModelsMock.mockImplementation( + (_messenger: unknown, provider: string) => { + if (provider === "openai") { + return pendingOpenAiFetch.promise; + } + if (provider === "anthropic") { + return Promise.resolve([makeFetchedModel("Anthropic Dynamic Model")]); + } + return Promise.resolve([]); + }, + ); + + const { user } = await renderWithProviders( + , + ); + + await user.type( + screen.getByPlaceholderText(/Enter your OpenAI API key/i), + "sk-openai-secret", + ); + await user.click(screen.getByTitle(/fetch available models/i)); + + await user.click(screen.getByRole("button", { name: "Anthropic" })); + + const anthropicApiKeyInput = screen.getByPlaceholderText( + /Enter your Anthropic API key/i, + ); + await user.type(anthropicApiKeyInput, "sk-anthropic-secret"); + + const fetchButton = screen.getByTitle(/fetch available models/i); + expect(fetchButton).not.toBeDisabled(); + await user.click(fetchButton); + + await waitFor(() => { + expect(fetchProviderModelsMock).toHaveBeenCalledWith( + expect.anything(), + "anthropic", + "sk-anthropic-secret", + undefined, + ); + expect(screen.getByTestId("model-listbox")).toHaveTextContent( + "Anthropic Dynamic Model", + ); + }); + + await act(async () => { + pendingOpenAiFetch.resolve([makeFetchedModel("OpenAI Dynamic Model")]); + await pendingOpenAiFetch.promise; + }); + + await waitFor(() => { + expect(screen.getByTestId("model-listbox")).toHaveTextContent( + "Anthropic Dynamic Model", + ); + expect(screen.getByTestId("model-listbox")).not.toHaveTextContent( + "OpenAI Dynamic Model", + ); + }); + }); +}); diff --git a/gui/src/forms/AddModelForm.tsx b/gui/src/forms/AddModelForm.tsx index 7041d2c0814..6007cb43a6d 100644 --- a/gui/src/forms/AddModelForm.tsx +++ b/gui/src/forms/AddModelForm.tsx @@ -2,7 +2,7 @@ import { ArrowPathIcon, ArrowTopRightOnSquareIcon, } from "@heroicons/react/24/outline"; -import { useCallback, useContext, useEffect, useState } from "react"; +import { useCallback, useContext, useEffect, useRef, useState } from "react"; import { FormProvider, useForm } from "react-hook-form"; import { Button, Input, StyledActionButton } from "../components"; import Alert from "../components/gui/Alert"; @@ -49,6 +49,8 @@ export function AddModelForm({ [], ); const [isFetchingModels, setIsFetchingModels] = useState(false); + const selectedProviderRef = useRef(selectedProvider.provider); + const fetchGenerationRef = useRef(0); useEffect(() => { void initializeDynamicModels(ideMessenger); @@ -56,6 +58,8 @@ export function AddModelForm({ useEffect(() => { setFetchedModelsList([]); + fetchGenerationRef.current += 1; + setIsFetchingModels(false); }, [selectedProvider]); const handleFetchModels = useCallback(async () => { @@ -64,6 +68,8 @@ export function AddModelForm({ if (!apiKey) return; const providerAtFetchTime = selectedProvider.provider; + const fetchGeneration = fetchGenerationRef.current + 1; + fetchGenerationRef.current = fetchGeneration; setIsFetchingModels(true); try { const models = await fetchProviderModels( @@ -72,13 +78,15 @@ export function AddModelForm({ apiKey, apiBase, ); - setFetchedModelsList((prev) => - selectedProvider.provider === providerAtFetchTime ? models : prev, - ); + if (selectedProviderRef.current === providerAtFetchTime) { + setFetchedModelsList(models); + } } catch (error) { console.error("Failed to fetch models:", error); } finally { - setIsFetchingModels(false); + if (fetchGenerationRef.current === fetchGeneration) { + setIsFetchingModels(false); + } } }, [ideMessenger, selectedProvider, formMethods]); @@ -128,9 +136,7 @@ export function AddModelForm({ useEffect(() => { setSelectedModel(selectedProvider.packages[0]); - if (!selectedProvider.tags?.includes(ModelProviderTags.RequiresApiKey)) { - formMethods.setValue("apiKey", ""); - } + formMethods.setValue("apiKey", ""); }, [selectedProvider]); const requiresSkPrefix = @@ -203,6 +209,7 @@ export function AddModelForm({ (provider) => provider.title === val.title, ); if (match) { + selectedProviderRef.current = match.provider; setSelectedProvider(match); } }}