diff --git a/app/components/Assistant/ChatMessage/chatMessage.tsx b/app/components/Assistant/ChatMessage/chatMessage.tsx index f2006e185..2c3e0382e 100644 --- a/app/components/Assistant/ChatMessage/chatMessage.tsx +++ b/app/components/Assistant/ChatMessage/chatMessage.tsx @@ -26,22 +26,16 @@ export const ChatMessage = ({ {!isUser && ( - AI - + /> )} {isUser ? ( diff --git a/app/components/Assistant/ChatPanel/chatPanel.tsx b/app/components/Assistant/ChatPanel/chatPanel.tsx index d4dcf3d30..3c487b544 100644 --- a/app/components/Assistant/ChatPanel/chatPanel.tsx +++ b/app/components/Assistant/ChatPanel/chatPanel.tsx @@ -19,26 +19,20 @@ interface ChatMessageDisplay { interface ChatPanelProps { error: string | null; + isRestoring?: boolean; loading: boolean; messages: ChatMessageDisplay[]; + onRetry?: () => Promise; onSend: (message: string) => void; suggestions: SuggestionChip[]; } -/** - * The main chat interface with message list, input, and suggestion chips. - * @param props - Component props - * @param props.error - Error message to display - * @param props.loading - Whether the assistant is processing - * @param props.messages - Chat message history - * @param props.onSend - Callback to send a message - * @param props.suggestions - Suggestion chips to display - * @returns Chat panel element - */ export const ChatPanel = ({ error, + isRestoring, loading, messages, + onRetry, onSend, suggestions, }: ChatPanelProps): JSX.Element => { @@ -70,10 +64,21 @@ export const ChatPanel = ({ } }; + const inputDisabled = loading || !!isRestoring; + return ( - {messages.length === 0 && ( + {isRestoring && ( + + + + Restoring conversation... + + + )} + + {!isRestoring && messages.length === 0 && ( Welcome! I can help you explore BRC Analytics data and set up @@ -96,7 +101,17 @@ export const ChatPanel = ({ )} {error && ( - + + Retry + + ) + } + severity="error" + sx={{ mx: 1 }} + > {error} )} @@ -106,13 +121,13 @@ export const ChatPanel = ({ diff --git a/app/components/Assistant/SuggestionChips/suggestionChips.styles.ts b/app/components/Assistant/SuggestionChips/suggestionChips.styles.ts index f31251c3c..8ddc0dc68 100644 --- a/app/components/Assistant/SuggestionChips/suggestionChips.styles.ts +++ b/app/components/Assistant/SuggestionChips/suggestionChips.styles.ts @@ -1,9 +1,22 @@ -import { Box } from "@mui/material"; +import { Box, Chip } from "@mui/material"; import styled from "@emotion/styled"; export const ChipsContainer = styled(Box)({ display: "flex", flexWrap: "wrap", gap: "8px", - padding: "4px 0", + padding: "8px 12px", +}); + +export const StyledSuggestionChip = styled(Chip)({ + "&:hover": { + backgroundColor: "#e3f2fd", + borderColor: "#90caf9", + }, + borderColor: "#bbdefb", + borderRadius: "16px", + fontSize: "0.8125rem", + height: "auto", + padding: "4px 2px", + transition: "background-color 0.15s, border-color 0.15s", }); diff --git a/app/components/Assistant/SuggestionChips/suggestionChips.tsx b/app/components/Assistant/SuggestionChips/suggestionChips.tsx index 5bf6c60ad..a7f1a64b4 100644 --- a/app/components/Assistant/SuggestionChips/suggestionChips.tsx +++ b/app/components/Assistant/SuggestionChips/suggestionChips.tsx @@ -1,7 +1,6 @@ import { JSX } from "react"; -import { Chip } from "@mui/material"; import { SuggestionChip } from "../../../types/api"; -import { ChipsContainer } from "./suggestionChips.styles"; +import { ChipsContainer, StyledSuggestionChip } from "./suggestionChips.styles"; interface SuggestionChipsProps { chips: SuggestionChip[]; @@ -9,14 +8,6 @@ interface SuggestionChipsProps { onSelect: (message: string) => void; } -/** - * Quick-tap suggestion chips shown below the chat input. - * @param props - Component props - * @param props.chips - Available suggestion chips - * @param props.disabled - Whether chips are disabled - * @param props.onSelect - Callback when a chip is selected - * @returns Suggestion chips row - */ export const SuggestionChips = ({ chips, disabled, @@ -27,7 +18,7 @@ export const SuggestionChips = ({ return ( {chips.map((chip) => ( - { + if (didInitRef.current || initialDataSourceView !== VIEW.UPLOAD_MY_DATA) + return; + didInitRef.current = true; + onConfigure(clearSequencingData()); + onConfigure(getUploadMyOwnSequencingData(stepKey)); + }, [initialDataSourceView, onConfigure, stepKey]); return ( {entryLabel} diff --git a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/components/Step/types.ts b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/components/Step/types.ts index 77eeea250..9c8c682e4 100644 --- a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/components/Step/types.ts +++ b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/components/Step/types.ts @@ -6,6 +6,7 @@ import { OnConfigure, } from "../../../../../../../../../../views/WorkflowInputsView/hooks/UseConfigureInputs/types"; import { Status, OnLaunchGalaxy } from "./hooks/UseLaunchGalaxy/types"; +import { VIEW } from "./SequencingStep/components/ToggleButtonGroup/types"; import { OnContinue, OnEdit } from "../../hooks/UseStepper/types"; import { Assembly } from "../../../../../../../../../../views/WorkflowInputsView/types"; @@ -27,6 +28,7 @@ export interface StepProps configuredInput: ConfiguredInput; entryLabel: string; genome?: Assembly; + initialDataSourceView?: VIEW; onConfigure: OnConfigure; onContinue: OnContinue; onEdit: OnEdit; diff --git a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/hooks/UseStepper/hook.ts b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/hooks/UseStepper/hook.ts index 7c9df09cc..37280aefd 100644 --- a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/hooks/UseStepper/hook.ts +++ b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/hooks/UseStepper/hook.ts @@ -3,9 +3,12 @@ import { UseStepper } from "./types"; import { getInitialActiveStep, getNextActiveStep } from "./utils"; import { StepConfig } from "../../components/Step/types"; -export const useStepper = (steps: StepConfig[]): UseStepper => { +export const useStepper = ( + steps: StepConfig[], + initialActiveStepOverride?: number +): UseStepper => { const [activeStep, setActiveStep] = useState( - getInitialActiveStep(steps) + initialActiveStepOverride ?? getInitialActiveStep(steps) ); const onContinue = useCallback((): void => { diff --git a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/types.ts b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/types.ts index ede79232d..37f340666 100644 --- a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/types.ts +++ b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/types.ts @@ -5,6 +5,7 @@ import { OnLaunchGalaxy, } from "./components/Step/hooks/UseLaunchGalaxy/types"; import { StepConfig } from "./components/Step/types"; +import { VIEW } from "./components/Step/SequencingStep/components/ToggleButtonGroup/types"; import { ConfiguredInput } from "../../../../../../../../views/WorkflowInputsView/hooks/UseConfigureInputs/types"; import { Assembly } from "../../../../../../../../views/WorkflowInputsView/types"; import { OnContinue, OnEdit } from "./hooks/UseStepper/types"; @@ -14,6 +15,7 @@ export interface Props { configuredInput: ConfiguredInput; configuredSteps: StepConfig[]; genome?: Assembly; + initialDataSourceView?: VIEW; onConfigure: OnConfigure; onContinue: OnContinue; onEdit: OnEdit; diff --git a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/main.tsx b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/main.tsx index f73b04828..3acfd935e 100644 --- a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/main.tsx +++ b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/main.tsx @@ -14,6 +14,7 @@ export const Main = ({ configuredInput, configuredSteps, genome, + initialDataSourceView, onConfigure, onContinue, onEdit, @@ -35,6 +36,7 @@ export const Main = ({ configuredInput={configuredInput} configuredSteps={configuredSteps} genome={genome} + initialDataSourceView={initialDataSourceView} onConfigure={onConfigure} onContinue={onContinue} onEdit={onEdit} diff --git a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/types.ts b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/types.ts index ac962c19d..f0e45e100 100644 --- a/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/types.ts +++ b/app/components/Entity/components/ConfigureWorkflowInputs/components/Main/types.ts @@ -4,6 +4,7 @@ import { OnConfigure, } from "../../../../../../views/WorkflowInputsView/hooks/UseConfigureInputs/types"; import { StepConfig } from "./components/Stepper/components/Step/types"; +import { VIEW } from "./components/Stepper/components/Step/SequencingStep/components/ToggleButtonGroup/types"; import { Assembly } from "../../../../../../views/WorkflowInputsView/types"; import { OnContinue, @@ -15,6 +16,7 @@ export interface Props { configuredInput: ConfiguredInput; configuredSteps: StepConfig[]; genome?: Assembly; + initialDataSourceView?: VIEW; onConfigure: OnConfigure; onContinue: OnContinue; onEdit: OnEdit; diff --git a/app/hooks/useAssistantChat.ts b/app/hooks/useAssistantChat.ts index 7631be4ac..a94febe95 100644 --- a/app/hooks/useAssistantChat.ts +++ b/app/hooks/useAssistantChat.ts @@ -1,4 +1,4 @@ -import { useCallback, useRef, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; import { llmAPIClient } from "../services/llm-api-client"; import { AnalysisSchema, @@ -6,6 +6,8 @@ import { SuggestionChip, } from "../types/api"; +const SESSION_KEY = "brc-assistant-session-id"; + interface ChatMessageDisplay { content: string; role: "user" | "assistant"; @@ -15,8 +17,11 @@ interface UseAssistantChatReturn { error: string | null; handoffUrl: string | null; isComplete: boolean; + isRestoring: boolean; loading: boolean; messages: ChatMessageDisplay[]; + onRetry?: () => Promise; + resetSession: () => void; schema: AnalysisSchema | null; sendMessage: (message: string) => Promise; suggestions: SuggestionChip[]; @@ -24,7 +29,8 @@ interface UseAssistantChatReturn { /** * Manages assistant chat state: messages, session, schema, and suggestions. - * @returns Chat state and a sendMessage function + * Persists session_id to localStorage and restores on mount. + * @returns Chat state, sendMessage, resetSession, and retry functions */ export const useAssistantChat = (): UseAssistantChatReturn => { const [messages, setMessages] = useState([]); @@ -33,14 +39,53 @@ export const useAssistantChat = (): UseAssistantChatReturn => { const [isComplete, setIsComplete] = useState(false); const [handoffUrl, setHandoffUrl] = useState(null); const [loading, setLoading] = useState(false); + const [isRestoring, setIsRestoring] = useState(false); const [error, setError] = useState(null); + const [lastFailedMessage, setLastFailedMessage] = useState( + null + ); const sessionIdRef = useRef(null); + const sendingRef = useRef(false); + + // Restore session from localStorage on mount + useEffect(() => { + const storedId = localStorage.getItem(SESSION_KEY); + if (!storedId) return; + + let cancelled = false; + setIsRestoring(true); + + llmAPIClient + .assistantRestore(storedId) + .then((restored) => { + if (cancelled) return; + sessionIdRef.current = restored.session_id; + setMessages(restored.messages); + setSchema(restored.schema_state); + setSuggestions(restored.suggestions); + setIsComplete(restored.is_complete); + setHandoffUrl(restored.handoff_url); + }) + .catch(() => { + if (cancelled) return; + localStorage.removeItem(SESSION_KEY); + }) + .finally(() => { + if (!cancelled) setIsRestoring(false); + }); + + return (): void => { + cancelled = true; + }; + }, []); const sendMessage = useCallback(async (message: string): Promise => { - if (!message.trim()) return; + if (!message.trim() || sendingRef.current) return; + sendingRef.current = true; setLoading(true); setError(null); + setLastFailedMessage(null); // Add user message immediately for responsiveness setMessages((prev) => [...prev, { content: message, role: "user" }]); @@ -52,6 +97,7 @@ export const useAssistantChat = (): UseAssistantChatReturn => { }); sessionIdRef.current = response.session_id; + localStorage.setItem(SESSION_KEY, response.session_id); // Add assistant reply setMessages((prev) => [ @@ -66,19 +112,47 @@ export const useAssistantChat = (): UseAssistantChatReturn => { } catch (err) { const errorMessage = handleChatError(err); setError(errorMessage); - // Remove the user message if the request failed entirely - setMessages((prev) => prev.slice(0, -1)); + setLastFailedMessage(message); } finally { setLoading(false); + sendingRef.current = false; } }, []); + const retry = useCallback(async (): Promise => { + if (!lastFailedMessage) return; + const msg = lastFailedMessage; + setLastFailedMessage(null); + setError(null); + setMessages((prev) => prev.slice(0, -1)); + await sendMessage(msg); + }, [lastFailedMessage, sendMessage]); + + const resetSession = useCallback((): void => { + const oldId = sessionIdRef.current; + if (oldId) { + llmAPIClient.assistantDeleteSession(oldId).catch(() => {}); + } + sessionIdRef.current = null; + localStorage.removeItem(SESSION_KEY); + setMessages([]); + setSchema(null); + setSuggestions([]); + setIsComplete(false); + setHandoffUrl(null); + setError(null); + setLastFailedMessage(null); + }, []); + return { error, handoffUrl, isComplete, + isRestoring, loading, messages, + onRetry: lastFailedMessage ? retry : undefined, + resetSession, schema, sendMessage, suggestions, diff --git a/app/services/llm-api-client.ts b/app/services/llm-api-client.ts index 89c859974..d902282eb 100644 --- a/app/services/llm-api-client.ts +++ b/app/services/llm-api-client.ts @@ -5,6 +5,7 @@ import { AssistantChatResponse, DatasetSearchRequest, DatasetSearchResponse, + SessionRestoreResponse, UnifiedSearchRequest, UnifiedSearchResponse, WorkflowSuggestionRequest, @@ -47,6 +48,25 @@ export const llmAPIClient = { .json(); }, + /** + * Delete an assistant session + * @param sessionId - Session to delete + */ + assistantDeleteSession: async (sessionId: string): Promise => { + await apiClient.delete(`assistant/session/${sessionId}`); + }, + + /** + * Restore a previous assistant session + * @param sessionId - Session to restore + * @returns Promise resolving to session state (messages, schema, suggestions) + */ + assistantRestore: async ( + sessionId: string + ): Promise => { + return apiClient.get(`assistant/session/${sessionId}`).json(); + }, + /** * Check health status of LLM service * @returns Promise resolving to health status diff --git a/app/types/api.ts b/app/types/api.ts index 09f60b683..d1c9c686c 100644 --- a/app/types/api.ts +++ b/app/types/api.ts @@ -145,6 +145,15 @@ export interface AssistantChatResponse { token_usage?: TokenUsage; } +export interface SessionRestoreResponse { + handoff_url: string | null; + is_complete: boolean; + messages: { content: string; role: "user" | "assistant" }[]; + schema_state: AnalysisSchema; + session_id: string; + suggestions: SuggestionChip[]; +} + export interface UnifiedSearchResponse { datasets?: { cached: boolean; diff --git a/app/views/AssistantView/assistantView.styles.ts b/app/views/AssistantView/assistantView.styles.ts index fecaf05f4..0bcd17a9c 100644 --- a/app/views/AssistantView/assistantView.styles.ts +++ b/app/views/AssistantView/assistantView.styles.ts @@ -1,11 +1,17 @@ import { Box } from "@mui/material"; import styled from "@emotion/styled"; import { GridPaperSection } from "@databiosphere/findable-ui/lib/components/common/Section/section.styles"; +import { sectionLayout } from "../../components/Layout/components/AppLayout/components/Section/section.styles"; export const AssistantSection = styled(GridPaperSection)` padding: 24px 0; `; +export const SectionContent = styled(Box)` + ${sectionLayout}; + padding: 0 16px; +`; + export const TwoPanelLayout = styled(Box)({ "@media (min-width: 960px)": { flexDirection: "row", @@ -15,9 +21,6 @@ export const TwoPanelLayout = styled(Box)({ display: "flex", flexDirection: "column", gap: "24px", - margin: "0 auto", - maxWidth: "1400px", - padding: "0 16px", }); export const ChatColumn = styled(Box)({ diff --git a/app/views/AssistantView/assistantView.tsx b/app/views/AssistantView/assistantView.tsx index 69074eb45..b6a8c86f7 100644 --- a/app/views/AssistantView/assistantView.tsx +++ b/app/views/AssistantView/assistantView.tsx @@ -1,5 +1,7 @@ import { Fragment, JSX } from "react"; import { useFeatureFlag } from "@databiosphere/findable-ui/lib/hooks/useFeatureFlag/useFeatureFlag"; +import { Box, Button } from "@mui/material"; +import RestartAltIcon from "@mui/icons-material/RestartAlt"; import Error from "next/error"; import { SectionHero } from "../../components/Layout/components/AppLayout/components/Section/components/SectionHero/sectionHero"; import { ChatPanel, SchemaPanel } from "../../components/Assistant"; @@ -9,6 +11,7 @@ import { AssistantSection, ChatColumn, SchemaColumn, + SectionContent, TwoPanelLayout, } from "./assistantView.styles"; @@ -17,8 +20,11 @@ export const AssistantView = (): JSX.Element => { const { error, handoffUrl, + isRestoring, loading, messages, + onRetry, + resetSession, schema, sendMessage, suggestions, @@ -26,6 +32,8 @@ export const AssistantView = (): JSX.Element => { if (!isAssistantEnabled) return ; + const showReset = messages.length > 0 || schema !== null; + return ( { subHead="Explore data and configure analyses with AI guidance" /> - - - - - - - - + + {showReset && ( + + + + )} + + + + + + + + + ); diff --git a/app/views/WorkflowInputsView/hooks/UseAssistantHandoff/types.ts b/app/views/WorkflowInputsView/hooks/UseAssistantHandoff/types.ts new file mode 100644 index 000000000..a4404968a --- /dev/null +++ b/app/views/WorkflowInputsView/hooks/UseAssistantHandoff/types.ts @@ -0,0 +1,8 @@ +export const ASSISTANT_HANDOFF_KEY = "brc-assistant-handoff"; + +export type HandoffDataSource = "ena" | "upload"; + +export interface AssistantHandoff { + dataSource: HandoffDataSource; + timestamp: number; +} diff --git a/app/views/WorkflowInputsView/hooks/UseAssistantHandoff/useAssistantHandoff.ts b/app/views/WorkflowInputsView/hooks/UseAssistantHandoff/useAssistantHandoff.ts new file mode 100644 index 000000000..7088e4496 --- /dev/null +++ b/app/views/WorkflowInputsView/hooks/UseAssistantHandoff/useAssistantHandoff.ts @@ -0,0 +1,27 @@ +import { useState } from "react"; +import { ASSISTANT_HANDOFF_KEY, AssistantHandoff } from "./types"; +const MAX_AGE_MS = 30 * 60 * 1000; // 30 minutes + +function readAndConsumeHandoff(): AssistantHandoff | null { + try { + const raw = localStorage.getItem(ASSISTANT_HANDOFF_KEY); + if (!raw) return null; + localStorage.removeItem(ASSISTANT_HANDOFF_KEY); + const parsed = JSON.parse(raw); + if ( + typeof parsed?.dataSource !== "string" || + typeof parsed?.timestamp !== "number" + ) { + return null; + } + if (Date.now() - parsed.timestamp > MAX_AGE_MS) return null; + return parsed as AssistantHandoff; + } catch { + return null; + } +} + +export const useAssistantHandoff = (): { handoff: AssistantHandoff | null } => { + const [handoff] = useState(readAndConsumeHandoff); + return { handoff }; +}; diff --git a/app/views/WorkflowInputsView/workflowInputsView.tsx b/app/views/WorkflowInputsView/workflowInputsView.tsx index 175c7ca56..1cb82b7fe 100644 --- a/app/views/WorkflowInputsView/workflowInputsView.tsx +++ b/app/views/WorkflowInputsView/workflowInputsView.tsx @@ -6,6 +6,8 @@ import { } from "@databiosphere/findable-ui/lib/components/Layout/components/BackPage/backPageView.styles"; import { JSX } from "react"; import { useStepper } from "../../components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/hooks/UseStepper/hook"; +import { StepConfig } from "../../components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/components/Step/types"; +import { VIEW } from "../../components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/components/Step/SequencingStep/components/ToggleButtonGroup/types"; import { SEQUENCING_STEPS } from "../../components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/steps/constants"; import { useConfiguredSteps } from "../../components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/steps/hook"; import { augmentConfiguredSteps } from "../../components/Entity/components/ConfigureWorkflowInputs/components/Main/components/Stepper/steps/utils"; @@ -13,17 +15,41 @@ import { Main } from "../../components/Entity/components/ConfigureWorkflowInputs import { SideColumn } from "../../components/Entity/components/ConfigureWorkflowInputs/components/SideColumn/sideColumn"; import { Top } from "../../components/Entity/components/ConfigureWorkflowInputs/components/Top/top"; import { getAssembly, getWorkflow } from "../../services/workflows/entities"; +import { useAssistantHandoff } from "./hooks/UseAssistantHandoff/useAssistantHandoff"; import { useConfigureInputs } from "./hooks/UseConfigureInputs/useConfigureInputs"; import { Assembly, Props } from "./types"; import { StyledBackPageContentMainColumn } from "./workflowInputsView.styles"; +const DATA_STEP_KEYS = new Set([ + "readRunsPaired", + "readRunsSingle", + "readRunsAny", + "sampleSheet", +]); + +function findFirstDataStep(steps: StepConfig[]): number | undefined { + const idx = steps.findIndex((s) => DATA_STEP_KEYS.has(s.key)); + return idx >= 0 ? idx : undefined; +} + export const WorkflowInputsView = ({ entityId, trsId }: Props): JSX.Element => { const genome = getAssembly(entityId); const workflow = getWorkflow(trsId); + const { handoff } = useAssistantHandoff(); const { configuredInput, onConfigure } = useConfigureInputs(); const { configuredSteps } = useConfiguredSteps(workflow); - const { activeStep, onContinue, onEdit } = useStepper(configuredSteps); + + const handoffTargetStep = handoff + ? findFirstDataStep(configuredSteps) + : undefined; + const initialDataSourceView = + handoff?.dataSource === "upload" ? VIEW.UPLOAD_MY_DATA : undefined; + + const { activeStep, onContinue, onEdit } = useStepper( + configuredSteps, + handoffTargetStep + ); const { hasSidePanel } = configuredSteps[activeStep] || {}; return ( @@ -38,6 +64,7 @@ export const WorkflowInputsView = ({ entityId, trsId }: Props): JSX.Element => { configuredInput={configuredInput} configuredSteps={configuredSteps} genome={genome} + initialDataSourceView={initialDataSourceView} onConfigure={onConfigure} onContinue={onContinue} onEdit={onEdit} diff --git a/backend/api/app/api/v1/assistant.py b/backend/api/app/api/v1/assistant.py index 2af321f8f..5b68b5bb3 100644 --- a/backend/api/app/api/v1/assistant.py +++ b/backend/api/app/api/v1/assistant.py @@ -1,9 +1,9 @@ import logging -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Response from app.core.dependencies import check_rate_limit, get_assistant_agent -from app.models.assistant import ChatRequest, ChatResponse +from app.models.assistant import ChatRequest, ChatResponse, SessionRestoreResponse logger = logging.getLogger(__name__) @@ -35,3 +35,41 @@ async def assistant_chat( except Exception: logger.exception("Assistant chat error") raise HTTPException(status_code=500, detail="Internal assistant error") + + +@router.get("/session/{session_id}", response_model=SessionRestoreResponse) +async def restore_session( + session_id: str, + agent=Depends(get_assistant_agent), +): + """Restore a previous assistant session (messages, schema, suggestions).""" + try: + state = await agent.session_service.get_session(session_id) + except Exception: + logger.exception("Failed to restore session %s", session_id) + raise HTTPException(status_code=500, detail="Failed to restore session") + if state is None: + raise HTTPException(status_code=404, detail="Session not found or expired") + + is_complete, handoff_url = agent.compute_handoff(state.schema_state) + return SessionRestoreResponse( + session_id=state.session_id, + messages=state.messages, + schema_state=state.schema_state, + suggestions=state.suggestions, + is_complete=is_complete, + handoff_url=handoff_url, + ) + + +@router.delete("/session/{session_id}", status_code=204) +async def delete_session( + session_id: str, + agent=Depends(get_assistant_agent), +): + """Delete an assistant session.""" + try: + await agent.session_service.delete_session(session_id) + except Exception: + logger.exception("Failed to delete session %s", session_id) + return Response(status_code=204) diff --git a/backend/api/app/models/assistant.py b/backend/api/app/models/assistant.py index b153a5cba..b95cd32a7 100644 --- a/backend/api/app/models/assistant.py +++ b/backend/api/app/models/assistant.py @@ -122,7 +122,19 @@ class SessionState(BaseModel): session_id: str schema_state: AnalysisSchema = Field(default_factory=AnalysisSchema) messages: List[ChatMessage] = Field(default_factory=list) + suggestions: List[SuggestionChip] = Field(default_factory=list) # Raw pydantic-ai message history, serialised as JSON-safe dicts agent_message_history: List[Dict[str, Any]] = Field(default_factory=list) # Metadata accumulated during conversation (taxonomy_id, accession, etc.) metadata: Dict[str, Any] = Field(default_factory=dict) + + +class SessionRestoreResponse(BaseModel): + """Response from the session restore endpoint.""" + + session_id: str + messages: List[ChatMessage] + schema_state: AnalysisSchema + suggestions: List[SuggestionChip] = Field(default_factory=list) + is_complete: bool = False + handoff_url: Optional[str] = None diff --git a/backend/api/app/services/assistant_agent.py b/backend/api/app/services/assistant_agent.py index 109b374b9..ce3ea32ae 100644 --- a/backend/api/app/services/assistant_agent.py +++ b/backend/api/app/services/assistant_agent.py @@ -100,6 +100,9 @@ - Don't hallucinate data — if a tool returns no results, say so. - If the user asks about something outside bioinformatics/BRC Analytics, \ politely redirect. +- Each message includes the current analysis state in brackets. Use this to \ + know what's already filled and what still needs to be decided. Don't re-ask \ + about filled fields unless the user wants to change something. ## Schema updates @@ -111,13 +114,24 @@ Format: a JSON object on its own line prefixed with "SCHEMA_UPDATE:". \ Valid keys: organism, assembly, analysis_type, workflow, data_source, \ data_characteristics, gene_annotation. Each value is a string (the display \ -label). For assembly, include the accession. For workflow, include the IWC ID. +label) or null to clear a field. For assembly, include the accession. \ +For workflow, include the IWC ID. + +To clear a field (e.g., the user changed their mind), set its value to null. \ +When a user changes a high-level choice like organism or analysis_type, also \ +clear dependent downstream fields that may no longer be valid. The dependency \ +chain is: organism -> assembly -> analysis_type -> workflow -> \ +data_characteristics, gene_annotation. The data_source field is independent. Example — user said "I want to work with yeast RNA-seq": SCHEMA_UPDATE: {"organism": "Saccharomyces cerevisiae", "analysis_type": "Transcriptomics"} -Only emit this when the user has actually chosen something. If the conversation \ -is purely exploratory (listing options, answering questions), do NOT emit it. +Example — user switches from RNA-seq to variant calling: +SCHEMA_UPDATE: {"analysis_type": "Variant Calling", "workflow": null, "data_characteristics": null, "gene_annotation": null} + +Only emit this when the user has actually chosen or changed something. If the \ +conversation is purely exploratory (listing options, answering questions), do \ +NOT emit it. ## Suggestion chips @@ -243,6 +257,39 @@ def _build_model(self, model_name: str): def is_available(self) -> bool: return self.agent is not None + @staticmethod + def compute_handoff(schema_state: AnalysisSchema) -> tuple[bool, Optional[str]]: + """Compute is_complete and handoff_url from schema state.""" + handoff_url = None + if schema_state.is_complete(): + accession = schema_state.assembly.detail or "" + trs_id = schema_state.workflow.detail or "" + if accession and trs_id: + handoff_url = f"/data/assemblies/{accession}/{trs_id}" + return handoff_url is not None, handoff_url + + @staticmethod + def _build_context_prefix(schema: AnalysisSchema) -> str: + """Serialize current schema state so the LLM knows what's been decided.""" + parts = [] + for name in ( + "organism", + "assembly", + "analysis_type", + "workflow", + "data_source", + "data_characteristics", + "gene_annotation", + ): + field = getattr(schema, name) + if field.status == FieldStatus.FILLED: + parts.append(f"{name}={field.value} (filled)") + elif field.status == FieldStatus.NEEDS_ATTENTION: + parts.append(f"{name}={field.value} (needs attention)") + else: + parts.append(f"{name}=pending") + return f"[Analysis progress: {', '.join(parts)}]" + @staticmethod def _truncate_history(messages: list[ModelMessage]) -> list[ModelMessage]: """Keep history within MAX_HISTORY_MESSAGES. @@ -336,10 +383,14 @@ async def chat( ) agent_history = None + # Prepend current schema state so the LLM knows what's been decided + context_prefix = self._build_context_prefix(state.schema_state) + augmented_message = f"{context_prefix}\n\n{message}" + # Run the agent (with timeout + retry) deps = AssistantDeps(catalog=self.catalog) result = await self._run_agent_with_retry( - message, deps=deps, message_history=agent_history + augmented_message, deps=deps, message_history=agent_history ) raw_reply = result.output @@ -377,17 +428,11 @@ async def chat( # Persist updated state state.schema_state = schema_state + state.suggestions = suggestions state.agent_message_history = to_jsonable_python(result.all_messages()) await self.session_service.save_session(state) - # Only mark complete when we can actually build a handoff URL - handoff_url = None - if schema_state.is_complete(): - accession = schema_state.assembly.detail or "" - trs_id = schema_state.workflow.detail or "" - if accession and trs_id: - handoff_url = f"/data/assemblies/{accession}/{trs_id}" - is_complete = handoff_url is not None + is_complete, handoff_url = self.compute_handoff(schema_state) return ChatResponse( session_id=state.session_id, @@ -401,14 +446,14 @@ async def chat( def _parse_structured_output( self, raw_reply: str - ) -> tuple[str, List[SuggestionChip], Dict[str, str]]: + ) -> tuple[str, List[SuggestionChip], Dict[str, Optional[str]]]: """Extract SCHEMA_UPDATE and SUGGESTIONS lines from the reply. Handles common LLM formatting variations: bold markdown wrapping, mixed case, extra whitespace around the prefix. """ suggestions: List[SuggestionChip] = [] - schema_updates: Dict[str, str] = {} + schema_updates: Dict[str, Optional[str]] = {} reply_lines = [] # Match SUGGESTIONS or SCHEMA_UPDATE with optional markdown bold/italic. @@ -441,7 +486,7 @@ def _parse_structured_output( json_str = line[u_match.end() :].strip() data = json.loads(json_str) if isinstance(data, dict) and all( - isinstance(k, str) and isinstance(v, str) + isinstance(k, str) and (isinstance(v, str) or v is None) for k, v in data.items() ): schema_updates = data @@ -457,9 +502,13 @@ def _parse_structured_output( def _apply_schema_updates( self, current: AnalysisSchema, - updates: Dict[str, str], + updates: Dict[str, Optional[str]], ) -> AnalysisSchema: - """Apply LLM-emitted schema updates to the current schema.""" + """Apply LLM-emitted schema updates to the current schema. + + Values of None clear the field back to EMPTY (supports mid-conversation + corrections when the user changes their mind). + """ if not updates: return current @@ -475,18 +524,27 @@ def _apply_schema_updates( } for key, value in updates.items(): - if key not in valid_fields or not value: + if key not in valid_fields: + continue + if value is None: + setattr(schema, key, SchemaField()) + continue + if not value: continue field = SchemaField(value=str(value), status=FieldStatus.FILLED) - # For assembly, try to look up extra detail (accession) + # For assembly, try to extract accession from the value string, + # falling back to a catalog search by name if the regex misses. if key == "assembly": acc_match = re.search(r"(GC[AF]_\d{9}\.\d+)", str(value)) if acc_match: field.detail = acc_match.group(1) + else: + field.detail = self._find_assembly_accession(str(value), schema) - # For workflow, try to look up the trs_id + # For workflow, try to match iwcId in the value string, falling + # back to a case-insensitive name match against the catalog. if key == "workflow": workflow_value = str(value) for cat in self.catalog.workflows_by_category: @@ -497,7 +555,65 @@ def _apply_schema_updates( break if field.detail: break + if not field.detail: + field.detail = self._find_workflow_trs_id(workflow_value) setattr(schema, key, field) return schema + + def _find_assembly_accession( + self, value: str, schema: AnalysisSchema + ) -> Optional[str]: + """Fallback: search the catalog for an assembly matching the LLM value.""" + val_lower = value.lower() + # If we already know the organism, narrow the search + tax_id = None + if schema.organism.detail: + tax_id = schema.organism.detail + elif schema.organism.status == FieldStatus.FILLED and schema.organism.value: + for org in self.catalog.organisms: + species = (org.get("taxonomicLevelSpecies") or "").lower() + if species and species in schema.organism.value.lower(): + tax_id = str(org.get("ncbiTaxonomyId")) + break + + candidates: list[dict] = [] + for org in self.catalog.organisms: + org_tax = str(org.get("ncbiTaxonomyId", "")) + if tax_id and org_tax != tax_id: + continue + for g in org.get("genomes", []): + strain = (g.get("strainName") or "").strip().lower() + if strain and strain in val_lower: + candidates.append(g) + + # If no strain match, don't guess from species alone -- too ambiguous + # when an organism has many assemblies. + if not candidates: + logger.warning("Assembly fallback found no match for '%s'", value) + return None + + # Prefer reference assemblies when multiple candidates match + for c in candidates: + if c.get("isRef") == "Yes": + logger.info( + "Assembly fallback matched reference '%s'", c.get("accession") + ) + return c.get("accession") + + accession = candidates[0].get("accession") + logger.info("Assembly fallback matched '%s'", accession) + return accession + + def _find_workflow_trs_id(self, value: str) -> Optional[str]: + """Fallback: match a workflow by name when iwcId isn't in the value.""" + val_lower = value.lower() + for cat in self.catalog.workflows_by_category: + for wf in cat.get("workflows", []): + wf_name = (wf.get("workflowName") or "").lower() + if wf_name and (wf_name in val_lower or val_lower in wf_name): + logger.info("Workflow fallback matched name '%s'", wf_name) + return wf.get("trsId", wf.get("iwcId")) + logger.warning("Workflow fallback found no match for '%s'", value) + return None