diff --git a/gui/src/forms/AddModelForm.dynamicFetchRace.test.tsx b/gui/src/forms/AddModelForm.dynamicFetchRace.test.tsx
new file mode 100644
index 00000000000..2dc7bc70d8f
--- /dev/null
+++ b/gui/src/forms/AddModelForm.dynamicFetchRace.test.tsx
@@ -0,0 +1,207 @@
+import { act, screen, waitFor } from "@testing-library/react";
+import { describe, expect, it, vi, beforeEach } from "vitest";
+import { AddModelForm } from "./AddModelForm";
+import { renderWithProviders } from "../util/test/render";
+
+const fetchProviderModelsMock = vi.hoisted(() => vi.fn());
+const initializeDynamicModelsMock = vi.hoisted(() => vi.fn(async () => {}));
+
+vi.mock("../pages/AddNewModel/configs/fetchProviderModels", () => ({
+ fetchProviderModels: fetchProviderModelsMock,
+ initializeDynamicModels: initializeDynamicModelsMock,
+}));
+
+vi.mock("../components/modelSelection/ModelSelectionListbox", () => ({
+ default: ({
+ selectedProvider,
+ setSelectedProvider,
+ topOptions = [],
+ otherOptions = [],
+ searchPlaceholder,
+ }: any) => (
+
+
+ {selectedProvider.title}
+
+
+ {topOptions.map((option: any) => (
+
+ ))}
+
+
+ {otherOptions.map((option: any) => (
+
+ {option.title}
+
+ ))}
+
+
+ ),
+}));
+
+function deferred() {
+ let resolve!: (value: T) => void;
+ const promise = new Promise((res) => {
+ resolve = res;
+ });
+ return { promise, resolve };
+}
+
+function makeFetchedModel(title: string) {
+ return {
+ title,
+ description: title,
+ params: {
+ title,
+ model: title.toLowerCase().replace(/\s+/g, "-"),
+ },
+ isOpenSource: false,
+ providerOptions: ["openai"],
+ };
+}
+
+describe("AddModelForm dynamic fetch race", () => {
+ beforeEach(() => {
+ fetchProviderModelsMock.mockReset();
+ initializeDynamicModelsMock.mockClear();
+ });
+
+ it("does not leak stale models from the previous provider if the earlier fetch resolves late", async () => {
+ const pendingOpenAiFetch = deferred();
+
+ fetchProviderModelsMock.mockImplementation(
+ (_messenger: unknown, provider: string) => {
+ if (provider === "openai") {
+ return pendingOpenAiFetch.promise;
+ }
+ return Promise.resolve([]);
+ },
+ );
+
+ const { user } = await renderWithProviders(
+ ,
+ );
+
+ await user.type(
+ screen.getByPlaceholderText(/Enter your OpenAI API key/i),
+ "sk-test-key",
+ );
+ await user.click(screen.getByTitle(/fetch available models/i));
+
+ await user.click(screen.getByRole("button", { name: "Anthropic" }));
+ expect(screen.getByTestId("provider-current")).toHaveTextContent(
+ "Anthropic",
+ );
+
+ await act(async () => {
+ pendingOpenAiFetch.resolve([makeFetchedModel("OpenAI Dynamic Model")]);
+ await pendingOpenAiFetch.promise;
+ });
+
+ await waitFor(() => {
+ expect(screen.getByTestId("provider-current")).toHaveTextContent(
+ "Anthropic",
+ );
+ expect(screen.getByTestId("model-listbox")).not.toHaveTextContent(
+ "OpenAI Dynamic Model",
+ );
+ });
+ });
+
+ it("clears the previous provider API key before a newly selected keyed provider can fetch models", async () => {
+ fetchProviderModelsMock.mockResolvedValue([]);
+
+ const { user } = await renderWithProviders(
+ ,
+ );
+
+ const apiKeyInput = screen.getByPlaceholderText(
+ /Enter your OpenAI API key/i,
+ );
+ await user.type(apiKeyInput, "sk-openai-secret");
+ expect(apiKeyInput).toHaveValue("sk-openai-secret");
+
+ await user.click(screen.getByRole("button", { name: "Anthropic" }));
+
+ const anthropicApiKeyInput = screen.getByPlaceholderText(
+ /Enter your Anthropic API key/i,
+ );
+ expect(anthropicApiKeyInput).toHaveValue("");
+
+ const fetchButton = screen.getByTitle(/fetch available models/i);
+ await user.click(fetchButton);
+
+ expect(fetchProviderModelsMock).not.toHaveBeenCalled();
+ });
+
+ it("releases the previous provider fetch lock so the newly selected provider can fetch immediately", async () => {
+ const pendingOpenAiFetch = deferred();
+
+ fetchProviderModelsMock.mockImplementation(
+ (_messenger: unknown, provider: string) => {
+ if (provider === "openai") {
+ return pendingOpenAiFetch.promise;
+ }
+ if (provider === "anthropic") {
+ return Promise.resolve([makeFetchedModel("Anthropic Dynamic Model")]);
+ }
+ return Promise.resolve([]);
+ },
+ );
+
+ const { user } = await renderWithProviders(
+ ,
+ );
+
+ await user.type(
+ screen.getByPlaceholderText(/Enter your OpenAI API key/i),
+ "sk-openai-secret",
+ );
+ await user.click(screen.getByTitle(/fetch available models/i));
+
+ await user.click(screen.getByRole("button", { name: "Anthropic" }));
+
+ const anthropicApiKeyInput = screen.getByPlaceholderText(
+ /Enter your Anthropic API key/i,
+ );
+ await user.type(anthropicApiKeyInput, "sk-anthropic-secret");
+
+ const fetchButton = screen.getByTitle(/fetch available models/i);
+ expect(fetchButton).not.toBeDisabled();
+ await user.click(fetchButton);
+
+ await waitFor(() => {
+ expect(fetchProviderModelsMock).toHaveBeenCalledWith(
+ expect.anything(),
+ "anthropic",
+ "sk-anthropic-secret",
+ undefined,
+ );
+ expect(screen.getByTestId("model-listbox")).toHaveTextContent(
+ "Anthropic Dynamic Model",
+ );
+ });
+
+ await act(async () => {
+ pendingOpenAiFetch.resolve([makeFetchedModel("OpenAI Dynamic Model")]);
+ await pendingOpenAiFetch.promise;
+ });
+
+ await waitFor(() => {
+ expect(screen.getByTestId("model-listbox")).toHaveTextContent(
+ "Anthropic Dynamic Model",
+ );
+ expect(screen.getByTestId("model-listbox")).not.toHaveTextContent(
+ "OpenAI Dynamic Model",
+ );
+ });
+ });
+});
diff --git a/gui/src/forms/AddModelForm.tsx b/gui/src/forms/AddModelForm.tsx
index 7041d2c0814..6007cb43a6d 100644
--- a/gui/src/forms/AddModelForm.tsx
+++ b/gui/src/forms/AddModelForm.tsx
@@ -2,7 +2,7 @@ import {
ArrowPathIcon,
ArrowTopRightOnSquareIcon,
} from "@heroicons/react/24/outline";
-import { useCallback, useContext, useEffect, useState } from "react";
+import { useCallback, useContext, useEffect, useRef, useState } from "react";
import { FormProvider, useForm } from "react-hook-form";
import { Button, Input, StyledActionButton } from "../components";
import Alert from "../components/gui/Alert";
@@ -49,6 +49,8 @@ export function AddModelForm({
[],
);
const [isFetchingModels, setIsFetchingModels] = useState(false);
+ const selectedProviderRef = useRef(selectedProvider.provider);
+ const fetchGenerationRef = useRef(0);
useEffect(() => {
void initializeDynamicModels(ideMessenger);
@@ -56,6 +58,8 @@ export function AddModelForm({
useEffect(() => {
setFetchedModelsList([]);
+ fetchGenerationRef.current += 1;
+ setIsFetchingModels(false);
}, [selectedProvider]);
const handleFetchModels = useCallback(async () => {
@@ -64,6 +68,8 @@ export function AddModelForm({
if (!apiKey) return;
const providerAtFetchTime = selectedProvider.provider;
+ const fetchGeneration = fetchGenerationRef.current + 1;
+ fetchGenerationRef.current = fetchGeneration;
setIsFetchingModels(true);
try {
const models = await fetchProviderModels(
@@ -72,13 +78,15 @@ export function AddModelForm({
apiKey,
apiBase,
);
- setFetchedModelsList((prev) =>
- selectedProvider.provider === providerAtFetchTime ? models : prev,
- );
+ if (selectedProviderRef.current === providerAtFetchTime) {
+ setFetchedModelsList(models);
+ }
} catch (error) {
console.error("Failed to fetch models:", error);
} finally {
- setIsFetchingModels(false);
+ if (fetchGenerationRef.current === fetchGeneration) {
+ setIsFetchingModels(false);
+ }
}
}, [ideMessenger, selectedProvider, formMethods]);
@@ -128,9 +136,7 @@ export function AddModelForm({
useEffect(() => {
setSelectedModel(selectedProvider.packages[0]);
- if (!selectedProvider.tags?.includes(ModelProviderTags.RequiresApiKey)) {
- formMethods.setValue("apiKey", "");
- }
+ formMethods.setValue("apiKey", "");
}, [selectedProvider]);
const requiresSkPrefix =
@@ -203,6 +209,7 @@ export function AddModelForm({
(provider) => provider.title === val.title,
);
if (match) {
+ selectedProviderRef.current = match.provider;
setSelectedProvider(match);
}
}}