diff --git a/.env.sample b/.env.sample index d75f47a..13f5555 100644 --- a/.env.sample +++ b/.env.sample @@ -4,3 +4,4 @@ PORT=3001 # Upstash Context7 API Key for documentation lookup # Required only if you want context7 plugin to query documentation libraries CONTEXT7_API_KEY=your-upstash-context7-api-key-here +GEMINI_API_KEY= \ No newline at end of file diff --git a/apps/api/mcp/plugins/context7/manifest.ts b/apps/api/mcp/plugins/context7/manifest.ts index 0c38ec6..3807966 100644 --- a/apps/api/mcp/plugins/context7/manifest.ts +++ b/apps/api/mcp/plugins/context7/manifest.ts @@ -1,3 +1,4 @@ +import "../../../../src/utils/env.js"; import { StdioMCPServer } from "../../stdio-server.js"; const context7Env: Record = {}; diff --git a/apps/api/package.json b/apps/api/package.json index f2d6b48..e05c4d5 100644 --- a/apps/api/package.json +++ b/apps/api/package.json @@ -9,7 +9,8 @@ "start": "node dist/index.js", "lint": "eslint . --max-warnings 0", "check-types": "tsc --noEmit", - "test": "vitest run" + "test": "vitest run", + "cli": "tsx src/agent/cli.ts" }, "dependencies": { "@modelcontextprotocol/sdk": "^1.4.0", diff --git a/apps/api/src/agent/agent.test.ts b/apps/api/src/agent/agent.test.ts new file mode 100644 index 0000000..823bfe1 --- /dev/null +++ b/apps/api/src/agent/agent.test.ts @@ -0,0 +1,473 @@ +import { vi, describe, it, expect, beforeEach, afterEach } from "vitest"; +import { runAgent } from "./loop.js"; +import { createMemory } from "./memory.js"; +import { llmClient } from "./llm.js"; + +// Mock @repo/db +vi.mock("@repo/db", () => { + const ApprovalStatus = { + PENDING: "PENDING", + APPROVED: "APPROVED", + REJECTED: "REJECTED", + }; + return { + ApprovalStatus, + db: { + approval: { + findUnique: vi.fn(), + }, + conversation: { + findUnique: vi.fn(), + update: vi.fn(), + upsert: vi.fn(), + }, + }, + }; +}); + +// Import mocked db +import { db } from "@repo/db"; + +// Mock decision engine +vi.mock("../policy/decision.js", () => { + return { + decide: vi.fn(), + }; +}); +import { decide } from "../policy/decision.js"; + +// Mock MCP bootstrapping +vi.mock("../../mcp/bootstrap.js", () => { + return { + mcpDiscovery: { + discoverTools: vi.fn(), + }, + mcpExecutor: { + execute: vi.fn(), + }, + }; +}); +import { mcpDiscovery, mcpExecutor } from "../../mcp/bootstrap.js"; + +describe("Agent Module & Execution Loop", () => { + const mockTool = { + name: "test_tool", + description: "A test tool description", + inputSchema: { + type: "object", + properties: { + arg1: { type: "string" }, + }, + required: ["arg1"], + }, + execute: vi.fn(), + }; + + beforeEach(() => { + vi.clearAllMocks(); + + vi.mocked(db.conversation.findUnique).mockResolvedValue({ + id: "conv-1", + tokens_used: 0, + budget_limit: 1000, + budget_reset_at: new Date(), + createdAt: new Date(), + } as any); + vi.mocked(db.conversation.update).mockResolvedValue({} as any); + vi.mocked(db.conversation.upsert).mockResolvedValue({} as any); + + // Default discovery stub returning the test tool + const mockToolsMap = new Map(); + mockToolsMap.set("test_tool", { + server: { name: "test_server" }, + tool: mockTool, + }); + vi.mocked(mcpDiscovery.discoverTools).mockResolvedValue(mockToolsMap); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + // 1) tool call - LLM requests a tool call + it("scenario 1: tool call gets evaluated and mapped properly in the loop", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "tool_call", + tool_name: "test_tool", + arguments: { arg1: "hello" }, + }) + ); + + vi.mocked(decide).mockResolvedValue({ + decision: "PENDING", + reason: "approval-uuid-1", + }); + + const result = await runAgent("Perform task", "conv-1"); + expect(result.status).toBe("PENDING"); + expect(result.approvalId).toBe("approval-uuid-1"); + expect(decide).toHaveBeenCalledWith( + expect.objectContaining({ + tool_name: "test_tool", + arguments: { arg1: "hello" }, + }), + { conversationId: "conv-1", token: expect.any(Number) } + ); + }); + + // 2) final answer - LLM returns a final answer + it("scenario 2: final answer stops execution and returns success", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "final_answer", + answer: "Task completed successfully.", + }) + ); + + const result = await runAgent("Perform task", "conv-1"); + expect(result.status).toBe("SUCCESS"); + expect(result.answer).toBe("Task completed successfully."); + expect(result.memory.messages).toContainEqual({ + role: "assistant", + content: "Task completed successfully.", + }); + }); + + // 3) approval pending - tool call requires approval, decide() returns PENDING + it("scenario 3: decision PENDING saves approvalId and returns PENDING status", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "tool_call", + tool_name: "test_tool", + arguments: { arg1: "value" }, + }) + ); + + vi.mocked(decide).mockResolvedValue({ + decision: "PENDING", + reason: "pending-approval-id", + }); + + const result = await runAgent("Start workflow", "conv-2"); + expect(result.status).toBe("PENDING"); + expect(result.approvalId).toBe("pending-approval-id"); + expect(result.memory.approvalId).toBe("pending-approval-id"); + }); + + // 4) denied tool - tool call is denied, decide() returns DENY + it("scenario 4: decision DENY stops execution and returns DENY status", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "tool_call", + tool_name: "test_tool", + arguments: { arg1: "forbidden" }, + }) + ); + + vi.mocked(decide).mockResolvedValue({ + decision: "DENY", + reason: "Tool execution blocked by policy", + }); + + const result = await runAgent("Run forbidden action", "conv-3"); + expect(result.status).toBe("DENY"); + expect(result.reason).toBe("Tool execution blocked by policy"); + }); + + // 5) successful execution - tool call is allowed and executes successfully + it("scenario 5: allowed tool call executes successfully, records result, and requests next step", async () => { + // 1st call: request tool + // 2nd call: return final answer + let callCount = 0; + vi.spyOn(llmClient, "callModel").mockImplementation(async () => { + callCount++; + if (callCount === 1) { + return JSON.stringify({ + type: "tool_call", + tool_name: "test_tool", + arguments: { arg1: "valid-input" }, + }); + } + return JSON.stringify({ + type: "final_answer", + answer: "Execution completed successfully.", + }); + }); + + vi.mocked(decide).mockResolvedValue({ + decision: "ALLOW", + }); + + vi.mocked(mcpExecutor.execute).mockResolvedValue("Success output"); + + const result = await runAgent("Run action", "conv-4"); + expect(result.status).toBe("SUCCESS"); + expect(result.answer).toBe("Execution completed successfully."); + expect(mcpExecutor.execute).toHaveBeenCalledWith( + "test_tool", + { arg1: "valid-input" }, + { conversationId: "conv-4", decision: "ALLOW" } + ); + expect(result.memory.toolResults).toContain("Success output"); + }); + + // 6) invalid llm output - LLM returns something that is not valid JSON or doesn't match expected schema + it("scenario 6: invalid argument type fails schema validation and throws error", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "tool_call", + tool_name: "test_tool", + arguments: { arg1: 12345 }, // arg1 must be string + }) + ); + + await expect(runAgent("Run action", "conv-5")).rejects.toThrow( + "Invalid arguments for tool test_tool" + ); + }); + + it("scenario 6b: unknown tool rejection", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "tool_call", + tool_name: "unknown_tool", + arguments: {}, + }) + ); + + await expect(runAgent("Run action", "conv-5")).rejects.toThrow( + "Unknown tool: unknown_tool" + ); + }); + + // 7) executor throws - MCP executor throws an error + it("scenario 7: executor exception throws an error and fails closed", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "tool_call", + tool_name: "test_tool", + arguments: { arg1: "trigger-fail" }, + }) + ); + + vi.mocked(decide).mockResolvedValue({ + decision: "ALLOW", + }); + + vi.mocked(mcpExecutor.execute).mockRejectedValue(new Error("Executor crash")); + + await expect(runAgent("Fail task", "conv-6")).rejects.toThrow( + "Tool execution failed: Executor crash" + ); + }); + + // 8) malformed json - LLM output is not valid JSON + it("scenario 8: malformed json from LLM throws error", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue("not-json-format"); + + await expect(runAgent("Fail task", "conv-7")).rejects.toThrow( + "Malformed JSON from LLM response" + ); + }); + + // 9) approval resumes execution - agent is resumed with an approvalId and continues + it("scenario 9: agent loop resumes from approval ID, skips nextStep for the first call, and proceeds", async () => { + // Mock db.approval.findUnique to return the original tool call parameters + vi.mocked(db.approval.findUnique).mockResolvedValue({ + id: "approval-999", + tool_name: "test_tool", + arguments: { arg1: "resumed-val" }, + status: "APPROVED" as any, + createdAt: new Date(), + updatedAt: new Date(), + }); + + // decision of ALLOW when decisionContext includes the approved approvalId + vi.mocked(decide).mockResolvedValue({ + decision: "ALLOW", + }); + + vi.mocked(mcpExecutor.execute).mockResolvedValue("Resumed execution success"); + + // The model is only called once after the executor finishes to retrieve the final answer + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "final_answer", + answer: "Completed resumed action.", + }) + ); + + const memory = createMemory(); + memory.addMessage("user", "Run step 1"); + // Resume agent with the approval ID + const result = await runAgent(null, "conv-8", { + memory, + approvalId: "approval-999", + }); + + expect(result.status).toBe("SUCCESS"); + expect(result.answer).toBe("Completed resumed action."); + expect(db.approval.findUnique).toHaveBeenCalledWith({ + where: { id: "approval-999" }, + }); + expect(decide).toHaveBeenCalledWith( + expect.objectContaining({ + tool_name: "test_tool", + arguments: { arg1: "resumed-val" }, + approvalId: "approval-999", + }), + { conversationId: "conv-8", token: 0 } + ); + expect(mcpExecutor.execute).toHaveBeenCalledWith( + "test_tool", + { arg1: "resumed-val" }, + { conversationId: "conv-8", decision: "ALLOW" } + ); + expect(result.memory.toolResults).toContain("Resumed execution success"); + }); + + // 10) iteration limit - LLM repeats tool calls excessively + it("scenario 10: agent loop terminates and throws error if iteration limit is exceeded", async () => { + // Return a tool call every time so the agent loops continuously + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "tool_call", + tool_name: "test_tool", + arguments: { arg1: "looping" }, + }) + ); + + vi.mocked(decide).mockResolvedValue({ + decision: "ALLOW", + }); + + vi.mocked(mcpExecutor.execute).mockResolvedValue("Executed ok"); + + await expect(runAgent("Loop forever", "conv-9")).rejects.toThrow( + "Agent loop iteration limit exceeded" + ); + }); + + // 11) budget reset logic - automatically resets if 3 minutes have passed since budget_reset_at + it("scenario 11: agent loop resets budget when elapsed time since budget_reset_at is > 3 minutes", async () => { + vi.spyOn(llmClient, "callModel").mockResolvedValue( + JSON.stringify({ + type: "final_answer", + answer: "Finished", + }) + ); + + // Mock upsert to return a conversation that was reset 4 minutes ago + const expiredResetAt = new Date(Date.now() - 4 * 60 * 1000); + vi.mocked(db.conversation.upsert).mockResolvedValue({ + id: "conv-10", + tokens_used: 15000, + budget_limit: 20000, + budget_reset_at: expiredResetAt, + createdAt: new Date(Date.now() - 10 * 60 * 1000), + } as any); + + await runAgent("Reset check", "conv-10"); + + expect(db.conversation.update).toHaveBeenCalledWith( + expect.objectContaining({ + where: { id: "conv-10" }, + data: expect.objectContaining({ + tokens_used: 0, + budget_reset_at: expect.any(Date), + }), + }) + ); + }); + + // 12) resume pending - approval status is PENDING + it("scenario 12: agent loop returns PENDING if resumed approval is in PENDING status", async () => { + // Mock db.approval.findUnique to return a PENDING record + vi.mocked(db.approval.findUnique).mockResolvedValue({ + id: "approval-998", + tool_name: "test_tool", + arguments: { arg1: "resumed-val" }, + status: "PENDING" as any, + createdAt: new Date(), + updatedAt: new Date(), + }); + + const memory = createMemory(); + const result = await runAgent(null, "conv-11", { + memory, + approvalId: "approval-998", + }); + + expect(result.status).toBe("PENDING"); + expect(result.approvalId).toBe("approval-998"); + }); + + // 13) resume rejected - approval status is REJECTED + it("scenario 13: agent loop denies execution if resumed approval is in REJECTED status", async () => { + // Mock db.approval.findUnique to return a REJECTED record + vi.mocked(db.approval.findUnique).mockResolvedValue({ + id: "approval-997", + tool_name: "test_tool", + arguments: { arg1: "resumed-val" }, + status: "REJECTED" as any, + createdAt: new Date(), + updatedAt: new Date(), + }); + + const memory = createMemory(); + const result = await runAgent(null, "conv-12", { + memory, + approvalId: "approval-997", + }); + + expect(result.status).toBe("DENY"); + expect(result.reason).toBe("Approval not approved"); + }); + + describe("Gemini API Client Timeout", () => { + afterEach(() => { + vi.unstubAllGlobals(); + delete process.env.GEMINI_TIMEOUT_MS; + }); + + it("should abort the fetch request if it exceeds the timeout limit", async () => { + vi.stubGlobal("fetch", async (url: string, init?: RequestInit) => { + expect(init?.signal).toBeDefined(); + await new Promise((_, reject) => { + if (init?.signal) { + init.signal.addEventListener("abort", () => { + reject(new DOMException("The operation was aborted.", "AbortError")); + }); + } + }); + throw new Error("Should have timed out and aborted"); + }); + + process.env.GEMINI_API_KEY = "dummy-key"; + process.env.GEMINI_TIMEOUT_MS = "50"; // 50ms timeout for test speed + + await expect(llmClient.callModel("Hello")).rejects.toThrow("The operation was aborted"); + }); + + it("should fall back to default timeout and execute normally if env timeout is invalid", async () => { + vi.stubGlobal("fetch", async (url: string, init?: RequestInit) => { + expect(init?.signal).toBeDefined(); + return { + ok: true, + json: async () => ({ + candidates: [{ + content: { parts: [{ text: "response-ok" }] } + }] + }) + } as any; + }); + + process.env.GEMINI_API_KEY = "dummy-key"; + process.env.GEMINI_TIMEOUT_MS = "invalid-value"; // invalid non-number value + + const res = await llmClient.callModel("Hello"); + expect(res).toBe("response-ok"); + }); + }); +}); diff --git a/apps/api/src/agent/cli.ts b/apps/api/src/agent/cli.ts new file mode 100644 index 0000000..290cf58 --- /dev/null +++ b/apps/api/src/agent/cli.ts @@ -0,0 +1,119 @@ +import readline from "readline"; +import crypto from "crypto"; + +const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, +}); + +const API_URL = "http://localhost:3001"; +let history: any[] = []; +const conversationId = crypto.randomUUID(); + +function askQuestion(query: string): Promise { + return new Promise((resolve) => rl.question(query, resolve)); +} + +async function sendRequest(payload: any) { + try { + const response = await fetch(`${API_URL}/agent/run`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + + if (!response.ok) { + const errorText = await response.text(); + console.error(`\nāŒ API Error (${response.status}):`, errorText); + return; + } + + const data: any = await response.json(); + history = data.history || []; + + if (data.status === "PENDING") { + if (!data.approvalId) { + console.error("\nāŒ Error: PENDING status response is missing approvalId"); + return; + } + // Find the last assistant message containing the tool call details + const lastMsg = history[history.length - 1]; + console.log(`\nāš ļø [PENDING APPROVAL] ${lastMsg ? lastMsg.content : "A tool execution requires human approval."}`); + + const answer = await askQuestion(`šŸ‘‰ Approve this action? (y/n): `); + const approved = answer.trim().toLowerCase() === "y" || answer.trim().toLowerCase() === "yes"; + + const action = approved ? "approve" : "reject"; + const approvalResponse = await fetch(`${API_URL}/policies/approvals/${data.approvalId}/${action}`, { + method: "POST" + }); + + if (!approvalResponse.ok) { + console.error(`\nāŒ Failed to ${action} approval:`, await approvalResponse.text()); + return; + } + + console.log(`\nāœ… Action ${approved ? "approved" : "rejected"}. Resuming agent loop...`); + + // Resume agent execution + await sendRequest({ + message: null, + conversationId, + approvalId: data.approvalId, + history, + }); + } else if (data.status === "DENY") { + console.log(`\n🚫 [DENIED] Execution blocked: ${data.reason || "Blocked by policy"}`); + } else if (data.status === "SUCCESS") { + // Print execution log of tools used during the run + const toolCalls = history.filter( + (msg) => msg.role === "assistant" && msg.content.includes("Call tool") + ); + if (toolCalls.length > 0) { + console.log("\nšŸ› ļø [TOOL EXECUTION TRACE]"); + for (let i = 0; i < history.length; i++) { + const msg = history[i]; + if (msg.role === "assistant" && msg.content.includes("Call tool")) { + console.log(` • ${msg.content}`); + const nextMsg = history[i + 1]; + if (nextMsg && nextMsg.role === "tool") { + console.log(` ↳ Output: ${nextMsg.content}`); + } + } + } + } + console.log(`\nšŸ¤– [AGENT] ${data.answer}`); + } + } catch (error: any) { + console.error("\nāŒ Error communicating with agent:", error.message || error); + } +} + +async function main() { + console.clear(); + console.log("=================================================="); + console.log("šŸ¤– Interactive Gate-Keeper Agent CLI Client"); + console.log(`šŸ“” Backend URL: ${API_URL}`); + console.log(`šŸ’¬ Session ID: ${conversationId}`); + console.log("==================================================\n"); + + while (true) { + const input = await askQuestion("\nšŸ‘¤ User: "); + if (!input || !input.trim()) continue; + + const trimmed = input.trim(); + if (trimmed === "exit" || trimmed === "quit") { + console.log("Goodbye!"); + rl.close(); + process.exit(0); + } + + await sendRequest({ + message: trimmed, + conversationId, + history, + }); + } +} + +main(); diff --git a/apps/api/src/agent/llm.ts b/apps/api/src/agent/llm.ts new file mode 100644 index 0000000..db4bd7f --- /dev/null +++ b/apps/api/src/agent/llm.ts @@ -0,0 +1,198 @@ +import "../utils/env.js"; +import { Memory, Tool, ToolCall, FinalAnswer, AgentStep } from "../../types.js"; + +export const llmClient = { + async callModel(prompt: string): Promise { + const apiKey = process.env.GEMINI_API_KEY; + if (!apiKey) { + throw new Error("GEMINI_API_KEY environment variable is not defined"); + } + + let timeoutMs = 30000; + if (process.env.GEMINI_TIMEOUT_MS) { + const parsed = parseInt(process.env.GEMINI_TIMEOUT_MS, 10); + if (!Number.isNaN(parsed) && parsed >= 1 && parsed <= 300000) { + timeoutMs = parsed; + } + } + + const response = await fetch( + "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent", + { + method: "POST", + headers: { + "Content-Type": "application/json", + "x-goog-api-key": apiKey, + }, + body: JSON.stringify({ + contents: [{ + parts: [{ text: prompt }] + }], + generationConfig: { + responseMimeType: "application/json" + } + }), + signal: AbortSignal.timeout(timeoutMs) + } + ); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`Gemini API request failed with status ${response.status}: ${errorText}`); + } + + const json: any = await response.json(); + const text = json.candidates?.[0]?.content?.parts?.[0]?.text; + if (!text) { + throw new Error("Invalid response received from Gemini API"); + } + + return text; + } +}; + +export function validateSchema(schema: any, data: any): boolean { + if (!schema) return true; + + if (schema.type === "object") { + if (typeof data !== "object" || data === null || Array.isArray(data)) { + return false; + } + + if (Array.isArray(schema.required)) { + for (const req of schema.required) { + if (!(req in data)) { + return false; + } + } + } + + if (schema.properties && typeof schema.properties === "object") { + for (const key of Object.keys(data)) { + const propSchema = schema.properties[key]; + if (propSchema) { + if (!validateSchema(propSchema, data[key])) { + return false; + } + } else if (schema.additionalProperties === false) { + return false; + } + } + } + return true; + } + + if (schema.type === "string") { + return typeof data === "string"; + } + if (schema.type === "number") { + return typeof data === "number" && !Number.isNaN(data); + } + if (schema.type === "integer") { + return typeof data === "number" && Number.isInteger(data); + } + if (schema.type === "boolean") { + return typeof data === "boolean"; + } + if (schema.type === "array") { + if (!Array.isArray(data)) return false; + if (schema.items) { + for (const item of data) { + if (!validateSchema(schema.items, item)) { + return false; + } + } + } + return true; + } + + return true; +} + +export async function nextStep(memory: Memory, tools: Tool[]): Promise<{ step: AgentStep; tokens: number }> { + const messagesContext = memory.messages + .map(msg => `${msg.role.toUpperCase()}: ${msg.content}`) + .join("\n"); + + const toolsContext = tools + .map(t => `- Name: ${t.name}\n Description: ${t.description}\n Schema: ${JSON.stringify(t.inputSchema)}`) + .join("\n\n"); + + const prompt = ` +You are an agent with access to the following tools: +${toolsContext} + +Conversation history: +${messagesContext} + +Output your next step as a single JSON object. Do not include any other text, markdown formatting, or code blocks. +If you need to call a tool, output: +{ + "type": "tool_call", + "tool_name": "name_of_tool", + "arguments": { ... } +} + +If you are done and have a final answer, output: +{ + "type": "final_answer", + "answer": "your final response" +} +`; + + const rawResponse = await llmClient.callModel(prompt); + + // Estimate token usage (standard 4 characters per token average) + const tokens = Math.ceil(prompt.length / 4) + Math.ceil(rawResponse.length / 4); + + let parsed: any; + try { + parsed = JSON.parse(rawResponse.trim()); + } catch (err) { + throw new Error("Malformed JSON from LLM response"); + } + + if (!parsed || typeof parsed !== "object") { + throw new Error("Invalid LLM output structure"); + } + + if (parsed.type === "tool_call") { + const { tool_name, arguments: args } = parsed; + if (typeof tool_name !== "string" || !args || typeof args !== "object" || Array.isArray(args)) { + throw new Error("Invalid LLM output structure for tool call"); + } + + const tool = tools.find(t => t.name === tool_name); + if (!tool) { + throw new Error(`Unknown tool: ${tool_name}`); + } + + if (!validateSchema(tool.inputSchema, args)) { + throw new Error(`Invalid arguments for tool ${tool_name}`); + } + + return { + step: { + type: "tool_call", + tool_name, + arguments: args + }, + tokens + }; + } else if (parsed.type === "final_answer") { + const { answer } = parsed; + if (typeof answer !== "string") { + throw new Error("Invalid LLM output structure for final answer"); + } + + return { + step: { + type: "final_answer", + answer + }, + tokens + }; + } else { + throw new Error("Invalid LLM output structure: missing or invalid type"); + } +} diff --git a/apps/api/src/agent/loop.ts b/apps/api/src/agent/loop.ts new file mode 100644 index 0000000..5c4ed9d --- /dev/null +++ b/apps/api/src/agent/loop.ts @@ -0,0 +1,206 @@ +import { db, ApprovalStatus } from "@repo/db"; +import { createMemory } from "./memory.js"; +import { nextStep } from "./llm.js"; +import { decide } from "../policy/decision.js"; +import { mcpDiscovery, mcpExecutor } from "../../mcp/bootstrap.js"; +import { logger } from "../../mcp/logger.js"; +import { Memory, AgentStep, AgentResult, Tool } from "../../types.js"; + +export async function runAgent( + userMessage: string | null, + conversationId: string, + options?: { + memory?: Memory; + approvalId?: string; + } +): Promise { + // Ensure conversation exists in DB before execution to prevent record missing errors during manual testing + let conversation = await db.conversation.upsert({ + where: { id: conversationId }, + update: {}, + create: { + id: conversationId, + budget_limit: 20000, + tokens_used: 0, + }, + }); + + // Automatically reset the token budget if 3 minutes (180,000 ms) have passed since budget_reset_at + const threeMinutes = 3 * 60 * 1000; + const elapsed = Date.now() - new Date(conversation.budget_reset_at).getTime(); + if (elapsed > threeMinutes) { + logger.info("Resetting conversation budget limit (3-minute window expired)", { conversation_id: conversationId }); + conversation = await db.conversation.update({ + where: { id: conversationId }, + data: { + tokens_used: 0, + budget_reset_at: new Date(), + }, + }); + } + + logger.info("Agent run started", { conversation_id: conversationId, is_resume: !!options?.approvalId }); + + const memory = options?.memory || createMemory(); + if (userMessage) { + memory.addMessage("user", userMessage); + } + + let activeApprovalId = options?.approvalId; + let accumulatedTokens = 0; + + // Retrieve tools list from MCP discovery map + const toolsMap = await mcpDiscovery.discoverTools(); + const tools: Tool[] = Array.from(toolsMap.values()).map(vt => vt.tool); + + // Helper to persist accumulated tokens to the database + const updateTokens = async () => { + if (accumulatedTokens > 0) { + logger.info("Persisting agent tokens", { conversation_id: conversationId, tokens: accumulatedTokens }); + await db.conversation.update({ + where: { id: conversationId }, + data: { + tokens_used: { + increment: accumulatedTokens, + }, + }, + }); + } + }; + + let iterations = 0; + try { + while (true) { + iterations++; + if (iterations > 30) { + throw new Error("Agent loop iteration limit exceeded"); + } + let step: AgentStep; + + if (activeApprovalId) { + // Fetch approval request to resume execution + const approval = await db.approval.findUnique({ + where: { id: activeApprovalId } + }); + if (!approval) { + logger.warn("Resumed approval record not found", { conversation_id: conversationId, approval_id: activeApprovalId }); + await updateTokens(); + return { + status: "DENY", + reason: "Approval not found", + memory + }; + } + + if (approval.status === ApprovalStatus.PENDING) { + logger.info("Resumed approval record is still pending", { conversation_id: conversationId, approval_id: activeApprovalId }); + await updateTokens(); + return { + status: "PENDING", + approvalId: activeApprovalId, + memory + }; + } + + if (approval.status !== ApprovalStatus.APPROVED) { + logger.warn("Resumed approval record is not approved", { conversation_id: conversationId, approval_id: activeApprovalId, status: approval.status }); + await updateTokens(); + return { + status: "DENY", + reason: "Approval not approved", + memory + }; + } + + step = { + type: "tool_call", + tool_name: approval.tool_name, + arguments: approval.arguments as Record + }; + } else { + // Consult the LLM to get the next step + const nextResult = await nextStep(memory, tools); + step = nextResult.step; + accumulatedTokens += nextResult.tokens; + } + + logger.info("Agent step generated", { conversation_id: conversationId, step_type: step.type }); + + if (step.type === "final_answer") { + memory.addMessage("assistant", step.answer); + await updateTokens(); + logger.info("Agent execution completed with final answer", { conversation_id: conversationId }); + return { + status: "SUCCESS", + answer: step.answer, + memory + }; + } + + // Record tool call to assistant messages + memory.addMessage("assistant", `Call tool ${step.tool_name} with arguments: ${JSON.stringify(step.arguments)}`); + + // Evaluate the tool execution policy using decide() + const decisionContext = { + tool_name: step.tool_name, + arguments: step.arguments, + approvalId: activeApprovalId + }; + + logger.info("Evaluating tool execution policy", { conversation_id: conversationId, tool_name: step.tool_name }); + const decisionResult = await decide(decisionContext, { conversationId, token: accumulatedTokens }); + logger.info("Policy decision evaluated", { conversation_id: conversationId, tool_name: step.tool_name, decision: decisionResult.decision }); + + if (decisionResult.decision === "DENY") { + await updateTokens(); + return { + status: "DENY", + reason: decisionResult.reason || "Tool execution denied", + memory + }; + } + + if (decisionResult.decision === "PENDING") { + const approvalId = decisionResult.reason || activeApprovalId; + if (approvalId) { + memory.setApproval(approvalId); + } + await updateTokens(); + return { + status: "PENDING", + approvalId, + memory + }; + } + + if (decisionResult.decision === "ALLOW") { + try { + logger.info("Executing approved tool call", { conversation_id: conversationId, tool_name: step.tool_name }); + const result = await mcpExecutor.execute(step.tool_name, step.arguments, { + conversationId, + decision: "ALLOW" + }); + + // Reset approval ID once execution has completed + activeApprovalId = undefined; + memory.clearApproval(); + + // Store results in memory history + memory.addToolResult(result); + memory.addMessage("tool", JSON.stringify(result)); + } catch (execError: any) { + throw new Error(`Tool execution failed: ${execError.message || execError}`); + } + } + } + } catch (error: any) { + logger.error("Agent execution failed with error", { conversation_id: conversationId, error_message: error.message || String(error) }); + try { + await updateTokens(); + } catch (updateErr) { + console.error("Failed to update tokens on failure:", updateErr); + } + // Fail-closed wrapper + throw error; + } +} diff --git a/apps/api/src/agent/memory.ts b/apps/api/src/agent/memory.ts new file mode 100644 index 0000000..559effe --- /dev/null +++ b/apps/api/src/agent/memory.ts @@ -0,0 +1,31 @@ +import { Message, Memory } from "../../types.js"; + +export function createMemory(): Memory { + const messages: Message[] = []; + const toolResults: unknown[] = []; + let approvalId: string | undefined = undefined; + + return { + get messages() { + return messages; + }, + get toolResults() { + return toolResults; + }, + get approvalId() { + return approvalId; + }, + addMessage(role: "user" | "assistant" | "tool" | "system", content: string) { + messages.push({ role, content }); + }, + addToolResult(result: unknown) { + toolResults.push(result); + }, + clearApproval() { + approvalId = undefined; + }, + setApproval(id: string | undefined) { + approvalId = id; + }, + }; +} diff --git a/apps/api/src/index.ts b/apps/api/src/index.ts index 186bd6e..4158321 100644 --- a/apps/api/src/index.ts +++ b/apps/api/src/index.ts @@ -1,12 +1,12 @@ +import "./utils/env.js"; import express from "express"; import cors from "cors"; -import dotenv from "dotenv"; import { formatDate } from "@repo/shared"; import { mcpDiscovery, mcpExecutor } from "../mcp/bootstrap.js"; import { AppError } from "../types.js"; import policiesRouter from "./policy/router.js"; - -dotenv.config(); +import { runAgent } from "./agent/loop.js"; +import { createMemory } from "./agent/memory.js"; const app = express(); const port = process.env.PORT || 3001; @@ -69,6 +69,76 @@ app.post("/mcp/execute", async (req, res) => { } }); +// Agent execution endpoint +app.post("/agent/run", async (req, res) => { + try { + const { message, conversationId, approvalId, history } = req.body; + + if (typeof conversationId !== "string" || conversationId.trim() === "") { + return res.status(400).json({ error: "conversationId must be a non-empty string" }); + } + + if (message !== undefined && message !== null && typeof message !== "string") { + return res.status(400).json({ error: "message must be a string or null" }); + } + + if (approvalId !== undefined && approvalId !== null) { + if (typeof approvalId !== "string" || approvalId.trim() === "") { + return res.status(400).json({ error: "approvalId must be a non-empty string" }); + } + } + + const hasMessage = typeof message === "string" && message.trim() !== ""; + const hasApproval = typeof approvalId === "string" && approvalId.trim() !== ""; + if (!hasMessage && !hasApproval) { + return res.status(400).json({ error: "Either message or approvalId must be provided" }); + } + + if (history !== undefined) { + if (!Array.isArray(history)) { + return res.status(400).json({ error: "history must be an array" }); + } + if (history.length > 100) { + return res.status(400).json({ error: "history size exceeds the limit of 100 items" }); + } + for (const msg of history) { + if (!msg || typeof msg !== "object") { + return res.status(400).json({ error: "Invalid history message format" }); + } + if (msg.role !== "user" && msg.role !== "assistant" && msg.role !== "tool") { + return res.status(400).json({ error: "Invalid message role in history" }); + } + if (typeof msg.content !== "string") { + return res.status(400).json({ error: "Invalid message content in history" }); + } + } + } + + const memory = createMemory(); + if (Array.isArray(history)) { + for (const msg of history) { + memory.addMessage(msg.role, msg.content); + } + } + + const result = await runAgent(message, conversationId, { + memory, + approvalId, + }); + + res.json({ + status: result.status, + answer: result.answer, + approvalId: result.approvalId, + reason: result.reason, + history: result.memory.messages, + }); + } catch (error: any) { + console.error("Agent execution failed:", error); + res.status(500).json({ error: "Internal server error" }); + } +}); + app.listen(port, () => { console.log(`API server is running at http://localhost:${port}`); }); diff --git a/apps/api/src/policy/policy.test.ts b/apps/api/src/policy/policy.test.ts index 3f39832..6548bbf 100644 --- a/apps/api/src/policy/policy.test.ts +++ b/apps/api/src/policy/policy.test.ts @@ -32,6 +32,7 @@ vi.mock("@repo/db", () => { findFirst: vi.fn(), create: vi.fn(), delete: vi.fn(), + updateMany: vi.fn(), }, }, }; @@ -118,6 +119,7 @@ describe("Policy Engine Rules & Orchestrator", () => { id: "conv-1", tokens_used: 100, budget_limit: 1000, + budget_reset_at: new Date(), createdAt: new Date(), }); @@ -131,6 +133,7 @@ describe("Policy Engine Rules & Orchestrator", () => { id: "conv-1", tokens_used: 950, budget_limit: 1000, + budget_reset_at: new Date(), createdAt: new Date(), }); @@ -202,6 +205,7 @@ describe("Policy Engine Rules & Orchestrator", () => { id: "conv-1", tokens_used: 10, budget_limit: 100, + budget_reset_at: new Date(), createdAt: new Date(), }); @@ -247,6 +251,7 @@ describe("Decision Orchestration (decide)", () => { id: "conv-1", tokens_used: 10, budget_limit: 100, + budget_reset_at: new Date(), createdAt: new Date(), }); @@ -271,6 +276,7 @@ describe("Decision Orchestration (decide)", () => { id: "conv-1", tokens_used: 10, budget_limit: 100, + budget_reset_at: new Date(), createdAt: new Date(), }); vi.mocked(db.approval.create).mockResolvedValue({ @@ -311,6 +317,7 @@ describe("Decision Orchestration (decide)", () => { id: "conv-1", tokens_used: 10, budget_limit: 100, + budget_reset_at: new Date(), createdAt: new Date(), }); vi.mocked(db.approval.findUnique).mockResolvedValue({ @@ -346,6 +353,7 @@ describe("Decision Orchestration (decide)", () => { id: "conv-1", tokens_used: 10, budget_limit: 100, + budget_reset_at: new Date(), createdAt: new Date(), }); vi.mocked(db.approval.findUnique).mockResolvedValue({ @@ -449,4 +457,112 @@ describe("Policy Engine REST Endpoints", () => { expect(res.json).toHaveBeenCalledWith({ error: "Policy already exists" }); }); }); + + describe("POST /policies/approvals/:id/approve", () => { + it("should atomically update status using updateMany and return id and status", async () => { + const approveHandler = getHandler("/policies/approvals/:id/approve", "POST"); + expect(approveHandler).toBeDefined(); + + vi.mocked(db.approval.updateMany).mockResolvedValue({ count: 1 }); + + const req = { params: { id: "app-123" } } as any as Request; + const res = mockResponse(); + + await approveHandler(req, res, () => {}); + + expect(db.approval.updateMany).toHaveBeenCalledWith({ + where: { id: "app-123", status: ApprovalStatus.PENDING }, + data: { status: ApprovalStatus.APPROVED }, + }); + expect(res.json).toHaveBeenCalledWith({ + id: "app-123", + status: ApprovalStatus.APPROVED, + }); + }); + + it("should return 404 if approval record does not exist", async () => { + const approveHandler = getHandler("/policies/approvals/:id/approve", "POST"); + vi.mocked(db.approval.updateMany).mockResolvedValue({ count: 0 }); + vi.mocked(db.approval.findUnique).mockResolvedValue(null); + + const req = { params: { id: "app-invalid" } } as any as Request; + const res = mockResponse(); + + await approveHandler(req, res, () => {}); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ error: "Approval not found" }); + }); + + it("should return 400 if approval status is not PENDING", async () => { + const approveHandler = getHandler("/policies/approvals/:id/approve", "POST"); + vi.mocked(db.approval.updateMany).mockResolvedValue({ count: 0 }); + vi.mocked(db.approval.findUnique).mockResolvedValue({ + id: "app-123", + status: ApprovalStatus.APPROVED, + } as any); + + const req = { params: { id: "app-123" } } as any as Request; + const res = mockResponse(); + + await approveHandler(req, res, () => {}); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ error: "Approval status is not PENDING" }); + }); + }); + + describe("POST /policies/approvals/:id/reject", () => { + it("should atomically update status using updateMany and return id and status for rejection", async () => { + const rejectHandler = getHandler("/policies/approvals/:id/reject", "POST"); + expect(rejectHandler).toBeDefined(); + + vi.mocked(db.approval.updateMany).mockResolvedValue({ count: 1 }); + + const req = { params: { id: "app-123" } } as any as Request; + const res = mockResponse(); + + await rejectHandler(req, res, () => {}); + + expect(db.approval.updateMany).toHaveBeenCalledWith({ + where: { id: "app-123", status: ApprovalStatus.PENDING }, + data: { status: ApprovalStatus.REJECTED }, + }); + expect(res.json).toHaveBeenCalledWith({ + id: "app-123", + status: ApprovalStatus.REJECTED, + }); + }); + + it("should return 404 if approval record does not exist on rejection", async () => { + const rejectHandler = getHandler("/policies/approvals/:id/reject", "POST"); + vi.mocked(db.approval.updateMany).mockResolvedValue({ count: 0 }); + vi.mocked(db.approval.findUnique).mockResolvedValue(null); + + const req = { params: { id: "app-invalid" } } as any as Request; + const res = mockResponse(); + + await rejectHandler(req, res, () => {}); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ error: "Approval not found" }); + }); + + it("should return 400 if approval status is not PENDING on rejection", async () => { + const rejectHandler = getHandler("/policies/approvals/:id/reject", "POST"); + vi.mocked(db.approval.updateMany).mockResolvedValue({ count: 0 }); + vi.mocked(db.approval.findUnique).mockResolvedValue({ + id: "app-123", + status: ApprovalStatus.REJECTED, + } as any); + + const req = { params: { id: "app-123" } } as any as Request; + const res = mockResponse(); + + await rejectHandler(req, res, () => {}); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ error: "Approval status is not PENDING" }); + }); + }); }); diff --git a/apps/api/src/policy/router.ts b/apps/api/src/policy/router.ts index a0b225c..d2a038b 100644 --- a/apps/api/src/policy/router.ts +++ b/apps/api/src/policy/router.ts @@ -1,5 +1,5 @@ import { Router, Request, Response } from "express"; -import { db, PolicyAction } from "@repo/db"; +import { db, PolicyAction, ApprovalStatus } from "@repo/db"; const router = Router(); @@ -186,4 +186,62 @@ router.delete( }, ); +async function handleApprovalStatusUpdate( + id: string, + targetStatus: ApprovalStatus, + res: Response +): Promise { + try { + const updateResult = await db.approval.updateMany({ + where: { + id, + status: ApprovalStatus.PENDING, + }, + data: { + status: targetStatus, + }, + }); + + if (updateResult.count === 0) { + const exists = await db.approval.findUnique({ where: { id } }); + if (!exists) { + res.status(404).json({ error: "Approval not found" }); + return; + } + res.status(400).json({ error: "Approval status is not PENDING" }); + return; + } + + res.json({ id, status: targetStatus }); + } catch (error) { + res.status(500).json({ error: "Internal server error" }); + } +} + +// POST /policies/approvals/:id/approve +router.post( + "/policies/approvals/:id/approve", + async (req: Request, res: Response): Promise => { + const { id } = req.params; + if (!id || !id.trim()) { + res.status(400).json({ error: "Missing or invalid id parameter" }); + return; + } + await handleApprovalStatusUpdate(id.trim(), ApprovalStatus.APPROVED, res); + } +); + +// POST /policies/approvals/:id/reject +router.post( + "/policies/approvals/:id/reject", + async (req: Request, res: Response): Promise => { + const { id } = req.params; + if (!id || !id.trim()) { + res.status(400).json({ error: "Missing or invalid id parameter" }); + return; + } + await handleApprovalStatusUpdate(id.trim(), ApprovalStatus.REJECTED, res); + } +); + export default router; diff --git a/apps/api/src/utils/env.ts b/apps/api/src/utils/env.ts new file mode 100644 index 0000000..cf47f7f --- /dev/null +++ b/apps/api/src/utils/env.ts @@ -0,0 +1,23 @@ +import dotenv from "dotenv"; +import path from "path"; +import fs from "fs"; + +// Find and load .env by searching up the directory tree +let currentDir = process.cwd(); +while (currentDir) { + const envPath = path.join(currentDir, ".env"); + if (fs.existsSync(envPath)) { + const result = dotenv.config({ path: envPath }); + if (result.error) { + throw new Error(`Failed to load/parse env file at ${envPath}: ${result.error.message}`); + } + break; + } + // Stop traversing if we hit the monorepo root (contains turbo.json) + if (fs.existsSync(path.join(currentDir, "turbo.json"))) { + break; + } + const parent = path.dirname(currentDir); + if (parent === currentDir) break; + currentDir = parent; +} diff --git a/apps/api/types.ts b/apps/api/types.ts index 57585b7..d3387b6 100644 --- a/apps/api/types.ts +++ b/apps/api/types.ts @@ -48,3 +48,39 @@ export interface ConversationRequest { conversationId: string; token: number; } + +export interface Message { + role: "user" | "assistant" | "tool" | "system"; + content: string; +} + +export interface Memory { + readonly messages: readonly Message[]; + readonly toolResults: readonly unknown[]; + approvalId?: string; + addMessage(role: "user" | "assistant" | "tool" | "system", content: string): void; + addToolResult(result: unknown): void; + clearApproval(): void; + setApproval(approvalId: string | undefined): void; +} + +export interface ToolCall { + type: "tool_call"; + tool_name: string; + arguments: Record; +} + +export interface FinalAnswer { + type: "final_answer"; + answer: string; +} + +export type AgentStep = ToolCall | FinalAnswer; + +export interface AgentResult { + status: "SUCCESS" | "PENDING" | "DENY"; + answer?: string; + approvalId?: string; + reason?: string; + memory: Memory; +} diff --git a/packages/db/prisma/dev.db b/packages/db/prisma/dev.db index a33164e..d511f8c 100644 Binary files a/packages/db/prisma/dev.db and b/packages/db/prisma/dev.db differ diff --git a/packages/db/prisma/migrations/20260625013912_add_budget_reset_at/migration.sql b/packages/db/prisma/migrations/20260625013912_add_budget_reset_at/migration.sql new file mode 100644 index 0000000..35d6b17 --- /dev/null +++ b/packages/db/prisma/migrations/20260625013912_add_budget_reset_at/migration.sql @@ -0,0 +1,15 @@ +-- RedefineTables +PRAGMA defer_foreign_keys=ON; +PRAGMA foreign_keys=OFF; +CREATE TABLE "new_Conversation" ( + "id" TEXT NOT NULL PRIMARY KEY, + "tokens_used" INTEGER NOT NULL DEFAULT 0, + "budget_limit" INTEGER NOT NULL, + "budget_reset_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); +INSERT INTO "new_Conversation" ("budget_limit", "createdAt", "id", "tokens_used") SELECT "budget_limit", "createdAt", "id", "tokens_used" FROM "Conversation"; +DROP TABLE "Conversation"; +ALTER TABLE "new_Conversation" RENAME TO "Conversation"; +PRAGMA foreign_keys=ON; +PRAGMA defer_foreign_keys=OFF; diff --git a/packages/db/prisma/schema.prisma b/packages/db/prisma/schema.prisma index 89ff426..131d760 100644 --- a/packages/db/prisma/schema.prisma +++ b/packages/db/prisma/schema.prisma @@ -68,11 +68,13 @@ model Policy { } model Conversation { - id String @id @default(uuid()) + id String @id @default(uuid()) - tokens_used Int @default(0) + tokens_used Int @default(0) - budget_limit Int + budget_limit Int - createdAt DateTime @default(now()) + budget_reset_at DateTime @default(now()) + + createdAt DateTime @default(now()) } \ No newline at end of file