From 0b52e30f167976364c15d31166c4942c66febc49 Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:54:17 -0600 Subject: [PATCH 1/4] Add Temporalio.Extensions.ToolRegistry with AgenticSession support Co-Authored-By: Claude Sonnet 4.6 --- Temporalio.sln | 21 ++ .../CrashAfterTurns.cs | 72 +++++ .../DispatchCall.cs | 15 + .../FakeToolRegistry.cs | 57 ++++ .../IDispatcher.cs | 20 ++ .../MockAgenticSession.cs | 70 +++++ .../MockProvider.cs | 151 +++++++++ .../MockResponse.cs | 69 ++++ ...lio.Extensions.ToolRegistry.Testing.csproj | 24 ++ .../_LanguageHelpers.cs | 11 + .../AgenticSession.cs | 216 +++++++++++++ .../IProvider.cs | 25 ++ .../JsonElementConverter.cs | 128 ++++++++ .../ProviderTurnResult.cs | 13 + .../Providers/AnthropicConfig.cs | 35 +++ .../Providers/AnthropicProvider.cs | 202 ++++++++++++ .../Providers/OpenAIConfig.cs | 35 +++ .../Providers/OpenAIProvider.cs | 237 ++++++++++++++ .../README.md | 207 ++++++++++++ .../SessionCheckpoint.cs | 30 ++ .../Temporalio.Extensions.ToolRegistry.csproj | 30 ++ .../ToolDef.cs | 17 + .../ToolRegistry.cs | 130 ++++++++ .../_LanguageHelpers.cs | 11 + .../.editorconfig | 37 +++ .../AgenticSessionTests.cs | 201 ++++++++++++ .../Program.cs | 19 ++ ...ralio.Extensions.ToolRegistry.Tests.csproj | 26 ++ .../TestingTests.cs | 231 ++++++++++++++ .../ToolRegistryTests.cs | 294 ++++++++++++++++++ 30 files changed, 2634 insertions(+) create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/CrashAfterTurns.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/DispatchCall.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/FakeToolRegistry.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/IDispatcher.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/MockAgenticSession.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/MockProvider.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/MockResponse.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/Temporalio.Extensions.ToolRegistry.Testing.csproj create mode 100644 src/Temporalio.Extensions.ToolRegistry.Testing/_LanguageHelpers.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/AgenticSession.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/IProvider.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/JsonElementConverter.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/ProviderTurnResult.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicConfig.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicProvider.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/Providers/OpenAIConfig.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/Providers/OpenAIProvider.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/README.md create mode 100644 src/Temporalio.Extensions.ToolRegistry/SessionCheckpoint.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/Temporalio.Extensions.ToolRegistry.csproj create mode 100644 src/Temporalio.Extensions.ToolRegistry/ToolDef.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs create mode 100644 src/Temporalio.Extensions.ToolRegistry/_LanguageHelpers.cs create mode 100644 tests/Temporalio.Extensions.ToolRegistry.Tests/.editorconfig create mode 100644 tests/Temporalio.Extensions.ToolRegistry.Tests/AgenticSessionTests.cs create mode 100644 tests/Temporalio.Extensions.ToolRegistry.Tests/Program.cs create mode 100644 tests/Temporalio.Extensions.ToolRegistry.Tests/Temporalio.Extensions.ToolRegistry.Tests.csproj create mode 100644 tests/Temporalio.Extensions.ToolRegistry.Tests/TestingTests.cs create mode 100644 tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs diff --git a/Temporalio.sln b/Temporalio.sln index 245afb7c..6da9e0ac 100644 --- a/Temporalio.sln +++ b/Temporalio.sln @@ -17,6 +17,12 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Temporalio.Extensions.Hosti EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Temporalio.Extensions.DiagnosticSource", "src\Temporalio.Extensions.DiagnosticSource\Temporalio.Extensions.DiagnosticSource.csproj", "{CC7EA7CD-BBE7-448C-8A4B-F8B2D1E55990}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Temporalio.Extensions.ToolRegistry", "src\Temporalio.Extensions.ToolRegistry\Temporalio.Extensions.ToolRegistry.csproj", "{A1B2C3D4-E5F6-7890-ABCD-EF1234567890}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Temporalio.Extensions.ToolRegistry.Testing", "src\Temporalio.Extensions.ToolRegistry.Testing\Temporalio.Extensions.ToolRegistry.Testing.csproj", "{B2C3D4E5-F6A7-8901-BCDE-F12345678901}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Temporalio.Extensions.ToolRegistry.Tests", "tests\Temporalio.Extensions.ToolRegistry.Tests\Temporalio.Extensions.ToolRegistry.Tests.csproj", "{C3D4E5F6-A7B8-9012-CDEF-012345678902}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -46,6 +52,18 @@ Global {CC7EA7CD-BBE7-448C-8A4B-F8B2D1E55990}.Debug|Any CPU.Build.0 = Debug|Any CPU {CC7EA7CD-BBE7-448C-8A4B-F8B2D1E55990}.Release|Any CPU.ActiveCfg = Release|Any CPU {CC7EA7CD-BBE7-448C-8A4B-F8B2D1E55990}.Release|Any CPU.Build.0 = Release|Any CPU + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890}.Release|Any CPU.Build.0 = Release|Any CPU + {B2C3D4E5-F6A7-8901-BCDE-F12345678901}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B2C3D4E5-F6A7-8901-BCDE-F12345678901}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B2C3D4E5-F6A7-8901-BCDE-F12345678901}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B2C3D4E5-F6A7-8901-BCDE-F12345678901}.Release|Any CPU.Build.0 = Release|Any CPU + {C3D4E5F6-A7B8-9012-CDEF-012345678902}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C3D4E5F6-A7B8-9012-CDEF-012345678902}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C3D4E5F6-A7B8-9012-CDEF-012345678902}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C3D4E5F6-A7B8-9012-CDEF-012345678902}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(NestedProjects) = preSolution {7AE1422A-0937-40D7-9A62-431DD0E2F6D5} = {758B61E2-9AB6-46BF-B53C-16BD140BF56B} @@ -53,5 +71,8 @@ Global {D4AC2E2B-1C24-491D-9175-874D448D30FE} = {758B61E2-9AB6-46BF-B53C-16BD140BF56B} {E8D1975A-5AF7-4375-BAD0-3C256DCB7F87} = {758B61E2-9AB6-46BF-B53C-16BD140BF56B} {CC7EA7CD-BBE7-448C-8A4B-F8B2D1E55990} = {758B61E2-9AB6-46BF-B53C-16BD140BF56B} + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890} = {758B61E2-9AB6-46BF-B53C-16BD140BF56B} + {B2C3D4E5-F6A7-8901-BCDE-F12345678901} = {758B61E2-9AB6-46BF-B53C-16BD140BF56B} + {C3D4E5F6-A7B8-9012-CDEF-012345678902} = {F2683DAA-F157-448E-96C8-DF7BB019886D} EndGlobalSection EndGlobal diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/CrashAfterTurns.cs b/src/Temporalio.Extensions.ToolRegistry.Testing/CrashAfterTurns.cs new file mode 100644 index 00000000..bc436e67 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/CrashAfterTurns.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Temporalio.Extensions.ToolRegistry.Testing +{ + /// + /// Implements and returns an error after complete + /// turns. Use it in integration tests to verify that resumes from + /// a heartbeat checkpoint after a simulated crash. + /// + /// + /// Example: + /// + /// // First invocation returns an error after 2 turns. + /// // Second invocation (retry) resumes from the last checkpoint. + /// var provider = new CrashAfterTurns { N = 2 }; + /// + /// + public sealed class CrashAfterTurns : IProvider + { + private int count; + + /// + /// Gets or sets the number of turns to complete before throwing. + /// + public int N { get; set; } + + /// + /// Gets or sets an optional delegate provider to forward turns to for the first + /// turns. When null, a stub assistant response is returned + /// instead. + /// + public IProvider? Delegate { get; set; } + + /// + /// Completes a turn normally for the first turns, then throws. + /// + /// + public Task RunTurnAsync( + IList> messages, + IReadOnlyList tools, + CancellationToken cancellationToken = default) + { + count++; + if (count > N) + { + throw new InvalidOperationException( + $"CrashAfterTurns: simulated crash after {N} turns"); + } + + if (Delegate != null) + { + return Delegate.RunTurnAsync(messages, tools, cancellationToken); + } + + var newMessages = new List> + { + new() + { + ["role"] = "assistant", + ["content"] = new List + { + new Dictionary { ["type"] = "text", ["text"] = "..." }, + }, + }, + }; + return Task.FromResult(new ProviderTurnResult(newMessages, Done: count >= N)); + } + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/DispatchCall.cs b/src/Temporalio.Extensions.ToolRegistry.Testing/DispatchCall.cs new file mode 100644 index 00000000..bf235a9b --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/DispatchCall.cs @@ -0,0 +1,15 @@ +using System.Collections.Generic; + +namespace Temporalio.Extensions.ToolRegistry.Testing +{ + /// + /// Records a single tool invocation on . + /// + /// Tool name that was dispatched. + /// Input that was passed to the tool. + /// String value returned by the tool handler. + public sealed record DispatchCall( + string Name, + IReadOnlyDictionary Input, + string Result); +} diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/FakeToolRegistry.cs b/src/Temporalio.Extensions.ToolRegistry.Testing/FakeToolRegistry.cs new file mode 100644 index 00000000..c7872227 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/FakeToolRegistry.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; + +namespace Temporalio.Extensions.ToolRegistry.Testing +{ + /// + /// Wraps and records every call. + /// Implements for use with . + /// + /// + /// Example: + /// + /// var fake = new FakeToolRegistry(); + /// fake.Register(def, input => "result"); + /// // Use fake.WithRegistry(fake) on MockProvider to record calls. + /// Assert.Single(fake.Calls); + /// Assert.Equal("tool-name", fake.Calls[0].Name); + /// + /// + public sealed class FakeToolRegistry : IDispatcher + { + private readonly ToolRegistry inner = new(); + private readonly List calls = new(); + + /// + /// Gets all tool dispatch invocations in order. + /// + public IList Calls => calls; + + /// + /// Registers a tool definition and its handler. + /// + /// Tool definition. + /// Function called when the tool is dispatched. + public void Register(ToolDef def, Func, string> handler) => + inner.Register(def, handler); + + /// + /// Records the call and delegates dispatch to the underlying registry. + /// + /// Tool name. + /// Tool input. + /// String result from the handler. + public string Dispatch(string name, IReadOnlyDictionary input) + { + var result = inner.Dispatch(name, input); + calls.Add(new(name, input, result)); + return result; + } + + /// + /// Returns the underlying registry's definitions. + /// + /// Read-only list of registered tool definitions. + public IReadOnlyList Definitions() => inner.Definitions(); + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/IDispatcher.cs b/src/Temporalio.Extensions.ToolRegistry.Testing/IDispatcher.cs new file mode 100644 index 00000000..6d869788 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/IDispatcher.cs @@ -0,0 +1,20 @@ +using System.Collections.Generic; + +namespace Temporalio.Extensions.ToolRegistry.Testing +{ + /// + /// Implemented by and . + /// Pass a to to record + /// which tool calls the scripted responses trigger. + /// + public interface IDispatcher + { + /// + /// Dispatches a tool call by name and returns the string result. + /// + /// Tool name. + /// Tool input. + /// String result. + string Dispatch(string name, IReadOnlyDictionary input); + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/MockAgenticSession.cs b/src/Temporalio.Extensions.ToolRegistry.Testing/MockAgenticSession.cs new file mode 100644 index 00000000..1d3f1a5d --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/MockAgenticSession.cs @@ -0,0 +1,70 @@ +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Temporalio.Extensions.ToolRegistry.Testing +{ + /// + /// A pre-canned session that returns fixed issues without any LLM calls. + /// Use it to test code that calls and + /// inspects session state without an API key or a Temporal server. + /// + /// + /// Example: + /// + /// var session = new MockAgenticSession + /// { + /// Issues = { new Dictionary<string, object?> { ["type"] = "missing", ["symbol"] = "x" } }, + /// }; + /// await session.RunToolLoopAsync(null!, null!, "sys", "prompt"); + /// // session.Issues still contains the pre-canned entry + /// + /// + public sealed class MockAgenticSession + { + private readonly List> messages = new(); + private readonly List> issues = new(); + + /// + /// Gets the pre-canned or accumulated conversation messages. + /// + public IList> Messages => messages; + + /// + /// Gets the pre-canned or accumulated issues. + /// + public IList> Issues => issues; + + /// + /// Gets the prompt value that was passed to the last call of + /// . Useful for asserting that callers pass the correct + /// initial prompt. + /// + public string? CapturedPrompt { get; private set; } + + /// + /// No-op: does not call any LLM or record a heartbeat. Adds the prompt as the first user + /// message if is empty. + /// + /// Not used; accepted for interface compatibility. + /// Not used; may be null. + /// Not used; present for API symmetry. + /// Initial user prompt (added if messages are empty). + /// Not used. + /// A completed task. + public Task RunToolLoopAsync( + IProvider? provider, + ToolRegistry? registry, + string system, + string prompt, + CancellationToken cancellationToken = default) + { + CapturedPrompt = prompt; + if (messages.Count == 0) + { + messages.Add(new() { ["role"] = "user", ["content"] = prompt }); + } + return Task.CompletedTask; + } + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/MockProvider.cs b/src/Temporalio.Extensions.ToolRegistry.Testing/MockProvider.cs new file mode 100644 index 00000000..f93a060d --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/MockProvider.cs @@ -0,0 +1,151 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Temporalio.Extensions.ToolRegistry.Testing +{ + /// + /// Implements using pre-scripted responses. No LLM API calls are made. + /// Responses are consumed in order; once exhausted the loop stops cleanly. + /// + /// + /// Use to inject a if you need to + /// record which tool calls the scripted responses trigger. + /// + /// Example: + /// + /// var provider = new MockProvider(new[] + /// { + /// MockResponse.ToolCall("flag", new Dictionary<string, object?> { ["desc"] = "broken" }), + /// MockResponse.Done("said hello"), + /// }); + /// + /// + /// + public sealed class MockProvider : IProvider + { + private readonly IReadOnlyList responses; + private IDispatcher dispatcher; + private int index; + + /// + /// Initializes a new instance of the class with the given + /// scripted responses. Uses an empty for dispatch by default. + /// + /// Scripted responses to return in order. + public MockProvider(IReadOnlyList responses) + { + this.responses = responses; + dispatcher = new ToolRegistryDispatcher(new ToolRegistry()); + } + + /// + /// Replaces the dispatch registry and returns this for chaining. + /// + /// Dispatcher to use for tool calls. + /// This instance. + public MockProvider WithRegistry(IDispatcher registry) + { + dispatcher = registry; + return this; + } + + /// + /// Replaces the dispatch registry with a and returns + /// this for chaining. + /// + /// Registry whose handlers will be invoked for tool calls. + /// This instance. + public MockProvider WithRegistry(ToolRegistry registry) + { + dispatcher = new ToolRegistryDispatcher(registry); + return this; + } + + /// + public Task RunTurnAsync( + IList> messages, + IReadOnlyList tools, + CancellationToken cancellationToken = default) + { + if (index >= responses.Count) + { + return Task.FromResult(new ProviderTurnResult( + new List>(), Done: true)); + } + + var resp = responses[index++]; + var newMessages = new List> + { + new() { ["role"] = "assistant", ["content"] = new List(resp.Content) }, + }; + + if (!resp.Stop) + { + var toolResults = new List>(); + foreach (var block in resp.Content) + { + if (!block.TryGetValue("type", out var typeVal) || typeVal as string != "tool_use") + { + continue; + } + var name = (string)block["name"]!; + var id = (string)block["id"]!; + var inputObj = block["input"]; + IReadOnlyDictionary input; + if (inputObj is IReadOnlyDictionary roDict) + { + input = roDict; + } + else if (inputObj is Dictionary dict) + { + input = dict; + } + else + { + input = new Dictionary(); + } + + string result; + try + { + result = dispatcher.Dispatch(name, input); + } +#pragma warning disable CA1031 + catch (Exception e) + { + result = $"error: {e.Message}"; + } +#pragma warning restore CA1031 + + toolResults.Add(new() + { + ["type"] = "tool_result", + ["tool_use_id"] = id, + ["content"] = result, + }); + } + if (toolResults.Count > 0) + { + newMessages.Add(new() { ["role"] = "user", ["content"] = toolResults }); + } + } + + return Task.FromResult(new ProviderTurnResult(newMessages, Done: resp.Stop)); + } + + /// + /// Wraps a to implement . + /// + private sealed class ToolRegistryDispatcher : IDispatcher + { + private readonly ToolRegistry registry; + + public ToolRegistryDispatcher(ToolRegistry registry) => this.registry = registry; + + public string Dispatch(string name, IReadOnlyDictionary input) => + registry.Dispatch(name, input); + } + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/MockResponse.cs b/src/Temporalio.Extensions.ToolRegistry.Testing/MockResponse.cs new file mode 100644 index 00000000..ba2df115 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/MockResponse.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; + +namespace Temporalio.Extensions.ToolRegistry.Testing +{ + /// + /// A scripted provider response produced by or + /// . + /// + public sealed class MockResponse + { + private MockResponse(bool stop, IReadOnlyList> content) + { + Stop = stop; + Content = content; + } + + /// + /// Gets a value indicating whether this response ends the tool loop. + /// + internal bool Stop { get; } + + /// + /// Gets the list of content blocks in this response. + /// + internal IReadOnlyList> Content { get; } + + /// + /// Returns a that ends the loop with the given text. + /// + /// Text to include in the assistant response. Defaults to "Done.". + /// A with set to true. + public static MockResponse Done(string text = "Done.") => + new( + stop: true, + content: new[] { new Dictionary { ["type"] = "text", ["text"] = text } }); + + /// + /// Returns a that makes a single tool call. + /// + /// Name of the tool to call. + /// Input to pass to the tool. + /// + /// Optional tool call ID. A random ID is generated if null or empty. + /// + /// A with set to false. + public static MockResponse ToolCall( + string toolName, + IReadOnlyDictionary toolInput, + string? callId = null) + { + var id = string.IsNullOrEmpty(callId) + ? $"test_{Guid.NewGuid():N}".Substring(0, 16) + : callId; + return new( + stop: false, + content: new[] + { + new Dictionary + { + ["type"] = "tool_use", + ["id"] = id, + ["name"] = toolName, + ["input"] = toolInput, + }, + }); + } + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/Temporalio.Extensions.ToolRegistry.Testing.csproj b/src/Temporalio.Extensions.ToolRegistry.Testing/Temporalio.Extensions.ToolRegistry.Testing.csproj new file mode 100644 index 00000000..15e93144 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/Temporalio.Extensions.ToolRegistry.Testing.csproj @@ -0,0 +1,24 @@ + + + + Temporal SDK .NET Tool Registry Testing Utilities + true + 9.0 + enable + true + snupkg + netstandard2.0;net8.0 + + + + + + + + + + <_Parameter1>Temporalio.Extensions.ToolRegistry.Tests + + + + diff --git a/src/Temporalio.Extensions.ToolRegistry.Testing/_LanguageHelpers.cs b/src/Temporalio.Extensions.ToolRegistry.Testing/_LanguageHelpers.cs new file mode 100644 index 00000000..a4dafff0 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry.Testing/_LanguageHelpers.cs @@ -0,0 +1,11 @@ +#pragma warning disable SA1649 + +namespace System.Runtime.CompilerServices +{ + /// + /// Needed for init-only properties to work on older .NET versions. + /// + internal static class IsExternalInit + { + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry/AgenticSession.cs b/src/Temporalio.Extensions.ToolRegistry/AgenticSession.cs new file mode 100644 index 00000000..020b44a3 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/AgenticSession.cs @@ -0,0 +1,216 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Temporalio.Activities; +using Temporalio.Exceptions; + +namespace Temporalio.Extensions.ToolRegistry +{ + /// + /// Maintains conversation state (messages and issues) across multiple turns of a tool-calling + /// loop, with heartbeat checkpointing for crash recovery. + /// + /// + /// Use inside a + /// Temporal activity to get automatic checkpoint restore-on-retry and heartbeat on each turn. + /// + /// await AgenticSession.RunWithSessionAsync(async session => + /// { + /// await session.RunToolLoopAsync(provider, registry, system, prompt); + /// }); + /// + /// + public sealed class AgenticSession + { + private readonly List> messages = new(); + private readonly List> issues = new(); + + /// + /// Gets the full conversation history. Append-only during a session. + /// + public IList> Messages => messages; + + /// + /// Gets the accumulated application-level results from tool calls. Elements must be + /// JSON-serializable for checkpoint storage. + /// + public IList> Issues => issues; + + /// + /// Runs inside an , restoring from a + /// heartbeat checkpoint if one exists (i.e., on activity retry after crash). + /// + /// + /// Must be called from within a Temporal activity. + /// + /// The async function to run with the session. + /// Cancellation token. + /// A task representing the asynchronous operation. + public static Task RunWithSessionAsync( + Func fn, + CancellationToken cancellationToken = default) => + RunWithSessionAsync( + async session => + { + await fn(session).ConfigureAwait(false); + return null; + }, + cancellationToken); + + /// + /// Runs inside an , restoring from a + /// heartbeat checkpoint if one exists (i.e., on activity retry after crash). + /// + /// Return type. + /// + /// Must be called from within a Temporal activity. + /// + /// The async function to run with the session. + /// Cancellation token. + /// Value returned by . + public static async Task RunWithSessionAsync( + Func> fn, + CancellationToken cancellationToken = default) + { + // Access current context before the try-catch so InvalidOperationException + // propagates directly when called outside a Temporal activity. + var activityContext = ActivityExecutionContext.Current; + var session = new AgenticSession(); + if (activityContext.Info.HeartbeatDetails.Count > 0) + { + try + { + var cp = await activityContext.Info + .HeartbeatDetailAtAsync(0).ConfigureAwait(false); + bool shouldRestore = true; + if (cp?.Version == 0) + { + activityContext.Logger.LogWarning( + "AgenticSession: checkpoint has no version field" + + " — may be from an older release"); + } + else if (cp?.Version != 1) + { + activityContext.Logger.LogWarning( + "AgenticSession: checkpoint version {Version}, expected 1 — starting fresh", + cp?.Version); + shouldRestore = false; + } + + if (shouldRestore) + { + if (cp?.Messages?.Count > 0) + { + session.messages.AddRange(JsonElementConverter.MaterializeList(cp.Messages)); + } + + if (cp?.Issues?.Count > 0) + { + session.issues.AddRange(JsonElementConverter.MaterializeList(cp.Issues)); + } + } + } +#pragma warning disable CA1031 // corrupt checkpoint — warn and start fresh + catch (Exception e) + { + activityContext.Logger.LogWarning( + "AgenticSession: failed to decode checkpoint, starting fresh: {Error}", + e.Message); + } +#pragma warning restore CA1031 + } + + return await fn(session).ConfigureAwait(false); + } + + /// + /// Runs the multi-turn LLM tool-calling loop, heartbeating before each turn. + /// + /// + /// If is empty (fresh start), is added as + /// the first user message. Otherwise the existing conversation state is resumed (retry case). + /// + /// On every turn it checkpoints via before calling the LLM, then + /// checks the cancellation token. If the activity is cancelled, the loop returns + /// immediately. The next attempt will restore from the last checkpoint. + /// + /// + /// LLM provider adapter. + /// Tool registry. + /// System prompt (passed to provider at construction time). + /// Initial user prompt. + /// Cancellation token. + /// A task representing the asynchronous operation. + public async Task RunToolLoopAsync( + IProvider provider, + ToolRegistry registry, + string system, + string prompt, + CancellationToken cancellationToken = default) + { + if (messages.Count == 0) + { + messages.Add(new() { ["role"] = "user", ["content"] = prompt }); + } + + while (true) + { + Checkpoint(cancellationToken); + + var result = await provider.RunTurnAsync( + messages, registry.Definitions(), cancellationToken).ConfigureAwait(false); + + foreach (var msg in result.NewMessages) + { + messages.Add(msg); + } + + if (result.Done) + { + return; + } + } + } + + /// + /// Heartbeats the current session state to Temporal, then checks the cancellation token. + /// + /// + /// Called automatically by before each turn, but can also be + /// called manually between tool dispatches. + /// + /// Cancellation token to check after heartbeating. + /// + /// If is cancelled. + /// + public void Checkpoint(CancellationToken cancellationToken = default) + { + // T10: validate all issues are JSON-serializable before heartbeating. + for (int i = 0; i < issues.Count; i++) + { + try + { + JsonSerializer.Serialize(issues[i]); + } + catch (JsonException e) + { + throw new ApplicationFailureException( + $"AgenticSession: issues[{i}] is not JSON-serializable: {e.Message}. " + + "Store only Dictionary with JSON-serializable values.", + nonRetryable: true); + } + } + + var cp = new SessionCheckpoint + { + Messages = new(messages), + Issues = new(issues), + }; + ActivityExecutionContext.Current.Heartbeat(cp); + cancellationToken.ThrowIfCancellationRequested(); + } + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry/IProvider.cs b/src/Temporalio.Extensions.ToolRegistry/IProvider.cs new file mode 100644 index 00000000..6db9860d --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/IProvider.cs @@ -0,0 +1,25 @@ +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Temporalio.Extensions.ToolRegistry +{ + /// + /// LLM provider adapter. Each implementation handles one LLM API's wire format for tool + /// calling. + /// + public interface IProvider + { + /// + /// Executes one turn of the conversation. + /// + /// Full conversation history so far. + /// Available tool definitions. + /// Cancellation token. + /// New messages to append and whether the loop is done. + Task RunTurnAsync( + IList> messages, + IReadOnlyList tools, + CancellationToken cancellationToken = default); + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry/JsonElementConverter.cs b/src/Temporalio.Extensions.ToolRegistry/JsonElementConverter.cs new file mode 100644 index 00000000..942f5cad --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/JsonElementConverter.cs @@ -0,0 +1,128 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; + +namespace Temporalio.Extensions.ToolRegistry +{ + /// + /// Recursively converts values to CLR types. + /// + /// + /// System.Text.Json deserializes Dictionary<string, object?> with nested + /// objects as rather than Dictionary<string, object?>. + /// This helper materializes those elements into usable CLR types. + /// + internal static class JsonElementConverter + { + /// + /// Converts a to its corresponding CLR type: + /// Object → Dictionary<string, object?>, Array → List<object?>, + /// String → string, Number → long or double, Bool → bool, Null → null. + /// + /// The JSON element to convert. + /// The converted CLR value. + public static object? ConvertElement(JsonElement element) + { + switch (element.ValueKind) + { + case JsonValueKind.Object: + var dict = new Dictionary(); + foreach (var prop in element.EnumerateObject()) + { + dict[prop.Name] = ConvertElement(prop.Value); + } + return dict; + + case JsonValueKind.Array: + // If all items are JSON objects return List> so the + // type is consistent with in-memory construction (e.g. tool_calls built by the + // provider) and pattern-matching in BuildAssistantMessage works after a + // checkpoint round-trip. Mixed/primitive/empty arrays fall back to List. + bool anyItems = false; + bool allJsonObjects = true; + foreach (var arrayItem in element.EnumerateArray()) + { + anyItems = true; + if (arrayItem.ValueKind != JsonValueKind.Object) + { + allJsonObjects = false; + break; + } + } + if (anyItems && allJsonObjects) + { + var dictList = new List>(); + foreach (var arrayItem in element.EnumerateArray()) + { + var d = new Dictionary(); + foreach (var prop in arrayItem.EnumerateObject()) + { + d[prop.Name] = ConvertElement(prop.Value); + } + dictList.Add(d); + } + return dictList; + } + var mixedList = new List(); + foreach (var arrayItem in element.EnumerateArray()) + { + mixedList.Add(ConvertElement(arrayItem)); + } + return mixedList; + + case JsonValueKind.String: + return element.GetString(); + + case JsonValueKind.Number: + if (element.TryGetInt64(out var longVal)) + { + return longVal; + } + return element.GetDouble(); + + case JsonValueKind.True: + return true; + + case JsonValueKind.False: + return false; + + case JsonValueKind.Null: + return null; + + default: + throw new InvalidOperationException($"Unexpected JsonValueKind: {element.ValueKind}"); + } + } + + /// + /// Materializes all values within a dictionary. + /// + /// Dictionary whose values may be instances. + /// New dictionary with all values converted to CLR types. + public static Dictionary Materialize(Dictionary dict) + { + var result = new Dictionary(dict.Count); + foreach (var kvp in dict) + { + result[kvp.Key] = kvp.Value is JsonElement elem ? ConvertElement(elem) : kvp.Value; + } + return result; + } + + /// + /// Materializes all values within a list of dictionaries. + /// + /// List of dictionaries whose values may be instances. + /// New list with all values converted to CLR types. + public static List> MaterializeList( + List> list) + { + var result = new List>(list.Count); + foreach (var item in list) + { + result.Add(Materialize(item)); + } + return result; + } + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry/ProviderTurnResult.cs b/src/Temporalio.Extensions.ToolRegistry/ProviderTurnResult.cs new file mode 100644 index 00000000..86caeae4 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/ProviderTurnResult.cs @@ -0,0 +1,13 @@ +using System.Collections.Generic; + +namespace Temporalio.Extensions.ToolRegistry +{ + /// + /// Result of a single LLM provider turn. + /// + /// New messages to append to the conversation history. + /// Whether the conversation loop is complete. + public sealed record ProviderTurnResult( + IList> NewMessages, + bool Done); +} diff --git a/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicConfig.cs b/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicConfig.cs new file mode 100644 index 00000000..62b1a6b3 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicConfig.cs @@ -0,0 +1,35 @@ +#if NET8_0_OR_GREATER +using System; +using Anthropic; + +namespace Temporalio.Extensions.ToolRegistry.Providers +{ + /// + /// Configuration for . + /// + public sealed class AnthropicConfig + { + /// + /// Gets or sets the Anthropic API key. Required unless is set. + /// + public string? ApiKey { get; set; } + + /// + /// Gets or sets the model name. Defaults to claude-sonnet-4-6. + /// + public string? Model { get; set; } + + /// + /// Gets or sets the base URL override (e.g. for proxies or test servers). + /// When set, overrides the default Anthropic API endpoint. + /// + public Uri? BaseUrl { get; set; } + + /// + /// Gets or sets a pre-constructed Anthropic client. When set, and + /// are ignored. + /// + public AnthropicClient? Client { get; set; } + } +} +#endif diff --git a/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicProvider.cs b/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicProvider.cs new file mode 100644 index 00000000..0aa96930 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicProvider.cs @@ -0,0 +1,202 @@ +#if NET8_0_OR_GREATER +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Anthropic; +using Anthropic.Core; +using Anthropic.Models.Messages; + +namespace Temporalio.Extensions.ToolRegistry.Providers +{ + /// + /// implementation for the Anthropic Messages API. + /// + /// + /// Messages are stored as List<Dictionary<string, object?>> + /// (checkpoint-safe) and converted to Anthropic SDK types via JSON round-trip before each + /// API call. + /// + /// Example: + /// + /// var cfg = new AnthropicConfig { ApiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") }; + /// IProvider provider = new AnthropicProvider(cfg, registry, "You are a helpful assistant."); + /// + /// + /// + public sealed class AnthropicProvider : IProvider, IDisposable + { + private const string DefaultModel = "claude-sonnet-4-6"; + + private readonly AnthropicClient client; + private readonly bool ownsClient; + private readonly string model; + private readonly string system; + private readonly ToolRegistry registry; + + /// + /// Initializes a new instance of the class. + /// + /// Provider configuration. + /// Tool registry used to dispatch tool calls. + /// System prompt. + public AnthropicProvider(AnthropicConfig config, ToolRegistry registry, string system) + { + this.system = system; + this.registry = registry; + model = config.Model ?? DefaultModel; + if (config.Client != null) + { + client = config.Client; + ownsClient = false; + } + else + { + var opts = new ClientOptions { ApiKey = config.ApiKey }; + if (config.BaseUrl != null) + { + opts.BaseUrl = config.BaseUrl.AbsoluteUri; + } + client = new AnthropicClient(opts); + ownsClient = true; + } + } + + /// + public void Dispose() + { + if (ownsClient) + { + client.Dispose(); + } + } + + /// + public async Task RunTurnAsync( + IList> messages, + IReadOnlyList tools, + CancellationToken cancellationToken = default) + { + var msgParams = BuildMessageParams(messages); + var toolUnions = BuildToolUnions(tools); + + var response = await client.Messages.Create( + new MessageCreateParams + { + Model = model, + MaxTokens = 4096, + System = system, + Messages = msgParams, + Tools = toolUnions, + }, + cancellationToken).ConfigureAwait(false); + + // Convert response content blocks to checkpoint-safe maps. + var contentMaps = new List>(); + var toolCalls = new List>(); + + foreach (var block in response.Content) + { + if (block.TryPickText(out var textBlock)) + { + contentMaps.Add(new() + { + ["type"] = "text", + ["text"] = textBlock.Text, + }); + } + else if (block.TryPickToolUse(out var toolUseBlock)) + { + var input = new Dictionary(); + foreach (var kvp in toolUseBlock.Input) + { + input[kvp.Key] = JsonElementConverter.ConvertElement(kvp.Value); + } + var toolMap = new Dictionary + { + ["type"] = "tool_use", + ["id"] = toolUseBlock.ID, + ["name"] = toolUseBlock.Name, + ["input"] = input, + }; + contentMaps.Add(toolMap); + toolCalls.Add(toolMap); + } + // Ignore other block types (thinking, server tool use, etc.) + } + + var newMessages = new List> + { + new() { ["role"] = "assistant", ["content"] = contentMaps }, + }; + + if (toolCalls.Count == 0) + { + return new(newMessages, Done: true); + } + + // Dispatch each tool call and collect results. + var toolResults = new List>(); + foreach (var call in toolCalls) + { + var name = (string)call["name"]!; + var id = (string)call["id"]!; + var input = (Dictionary)call["input"]!; + string result; + try + { + result = registry.Dispatch(name, input); + } +#pragma warning disable CA1031 + catch (Exception e) + { + result = $"error: {e.Message}"; + } +#pragma warning restore CA1031 + + toolResults.Add(new() + { + ["type"] = "tool_result", + ["tool_use_id"] = id, + ["content"] = result, + }); + } + newMessages.Add(new() { ["role"] = "user", ["content"] = toolResults }); + return new(newMessages, Done: false); + } + + private static MessageParam[] BuildMessageParams( + IList> messages) + { + var result = new MessageParam[messages.Count]; + for (int i = 0; i < messages.Count; i++) + { + var json = JsonSerializer.Serialize(messages[i]); + var rawData = JsonSerializer.Deserialize>(json)!; + result[i] = MessageParam.FromRawUnchecked(rawData); + } + return result; + } + + private static ToolUnion[] BuildToolUnions(IReadOnlyList tools) + { + var result = new ToolUnion[tools.Count]; + for (int i = 0; i < tools.Count; i++) + { + var def = tools[i]; + var toolDict = new Dictionary + { + ["name"] = def.Name, + ["description"] = def.Description, + ["input_schema"] = def.InputSchema, + }; + var json = JsonSerializer.Serialize(toolDict); + var rawData = JsonSerializer.Deserialize>(json)!; + result[i] = Tool.FromRawUnchecked(rawData); + } + return result; + } + } +} +#endif diff --git a/src/Temporalio.Extensions.ToolRegistry/Providers/OpenAIConfig.cs b/src/Temporalio.Extensions.ToolRegistry/Providers/OpenAIConfig.cs new file mode 100644 index 00000000..22c50849 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/Providers/OpenAIConfig.cs @@ -0,0 +1,35 @@ +#if NET8_0_OR_GREATER +using System; +using OpenAI.Chat; + +namespace Temporalio.Extensions.ToolRegistry.Providers +{ + /// + /// Configuration for . + /// + public sealed class OpenAIConfig + { + /// + /// Gets or sets the OpenAI API key. Required unless is set. + /// + public string? ApiKey { get; set; } + + /// + /// Gets or sets the model name. Defaults to gpt-4o. + /// + public string? Model { get; set; } + + /// + /// Gets or sets the base URL override (e.g. for proxies or test servers). + /// When set, overrides the default OpenAI API endpoint. + /// + public Uri? BaseUrl { get; set; } + + /// + /// Gets or sets a pre-constructed . When set, + /// and are ignored. + /// + public ChatClient? Client { get; set; } + } +} +#endif diff --git a/src/Temporalio.Extensions.ToolRegistry/Providers/OpenAIProvider.cs b/src/Temporalio.Extensions.ToolRegistry/Providers/OpenAIProvider.cs new file mode 100644 index 00000000..50e6fe0e --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/Providers/OpenAIProvider.cs @@ -0,0 +1,237 @@ +#if NET8_0_OR_GREATER +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using OpenAI; +using OpenAI.Chat; + +namespace Temporalio.Extensions.ToolRegistry.Providers +{ + /// + /// implementation for the OpenAI Chat Completions API. + /// + /// + /// Messages are stored as List<Dictionary<string, object?>> + /// (checkpoint-safe) and converted to OpenAI SDK types at call time. + /// + /// Example: + /// + /// var cfg = new OpenAIConfig { ApiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") }; + /// IProvider provider = new OpenAIProvider(cfg, registry, "You are a helpful assistant."); + /// + /// + /// + public sealed class OpenAIProvider : IProvider + { + private const string DefaultModel = "gpt-4o"; + + private readonly ChatClient client; + private readonly string system; + private readonly ToolRegistry registry; + + /// + /// Initializes a new instance of the class. + /// + /// Provider configuration. + /// Tool registry used to dispatch tool calls. + /// System prompt. + public OpenAIProvider(OpenAIConfig config, ToolRegistry registry, string system) + { + this.system = system; + this.registry = registry; + if (config.Client != null) + { + client = config.Client; + } + else + { + var model = config.Model ?? DefaultModel; + if (config.BaseUrl != null) + { + var openAIClient = new OpenAIClient( + new ApiKeyCredential(config.ApiKey ?? string.Empty), + new OpenAIClientOptions { Endpoint = config.BaseUrl }); + client = openAIClient.GetChatClient(model); + } + else + { + client = new ChatClient(model, config.ApiKey); + } + } + } + + /// + public async Task RunTurnAsync( + IList> messages, + IReadOnlyList tools, + CancellationToken cancellationToken = default) + { + var chatMessages = BuildChatMessages(messages); + var chatTools = BuildChatTools(tools); + + var options = new ChatCompletionOptions(); + foreach (var tool in chatTools) + { + options.Tools.Add(tool); + } + + ChatCompletion completion = await client.CompleteChatAsync( + chatMessages, options, cancellationToken).ConfigureAwait(false); + + var newMessages = new List>(); + + // Extract assistant text content (may be null when only tool calls are present). + string? assistantContent = null; + foreach (var part in completion.Content) + { + if (part.Kind == ChatMessageContentPartKind.Text) + { + assistantContent = part.Text; + break; + } + } + + // Collect tool calls as checkpoint-safe maps. + var toolCallMaps = new List>(); + foreach (var toolCall in completion.ToolCalls) + { + toolCallMaps.Add(new() + { + ["id"] = toolCall.Id, + ["type"] = "function", + ["function"] = new Dictionary + { + ["name"] = toolCall.FunctionName, + ["arguments"] = toolCall.FunctionArguments.ToString(), + }, + }); + } + + var assistantMsg = new Dictionary { ["role"] = "assistant" }; + if (assistantContent != null) + { + assistantMsg["content"] = assistantContent; + } + if (toolCallMaps.Count > 0) + { + assistantMsg["tool_calls"] = toolCallMaps; + } + newMessages.Add(assistantMsg); + + bool done = toolCallMaps.Count == 0 + || completion.FinishReason == ChatFinishReason.Stop + || completion.FinishReason == ChatFinishReason.Length; + if (done) + { + return new(newMessages, Done: true); + } + + // Dispatch each tool call and return results as separate tool messages. + foreach (var toolCall in completion.ToolCalls) + { + var name = toolCall.FunctionName; + var argsJson = toolCall.FunctionArguments.ToString(); + IReadOnlyDictionary input; + if (!string.IsNullOrEmpty(argsJson)) + { + var rawInput = JsonSerializer.Deserialize>(argsJson)!; + input = JsonElementConverter.Materialize(rawInput); + } + else + { + input = new Dictionary(); + } + + string result; + try + { + result = registry.Dispatch(name, input); + } +#pragma warning disable CA1031 + catch (Exception e) + { + result = $"error: {e.Message}"; + } +#pragma warning restore CA1031 + + newMessages.Add(new() + { + ["role"] = "tool", + ["tool_call_id"] = toolCall.Id, + ["content"] = result, + }); + } + + return new(newMessages, Done: false); + } + + private static AssistantChatMessage BuildAssistantMessage(Dictionary msg) + { + var toolCallsObj = msg.GetValueOrDefault("tool_calls"); + if (toolCallsObj is List> toolCallsList && toolCallsList.Count > 0) + { + var toolCalls = new List(); + foreach (var tcMap in toolCallsList) + { + var id = (string)tcMap["id"]!; + var fn = (Dictionary)tcMap["function"]!; + var fnName = (string)fn["name"]!; + var fnArgs = fn.GetValueOrDefault("arguments") as string ?? string.Empty; + toolCalls.Add(ChatToolCall.CreateFunctionToolCall( + id, fnName, BinaryData.FromString(fnArgs))); + } + return new AssistantChatMessage(toolCalls); + } + + var content = msg.GetValueOrDefault("content") as string ?? string.Empty; + return new AssistantChatMessage(content); + } + + private static List BuildChatTools(IReadOnlyList tools) + { + var result = new List(tools.Count); + foreach (var def in tools) + { + var schemaJson = JsonSerializer.Serialize(def.InputSchema); + result.Add(ChatTool.CreateFunctionTool( + functionName: def.Name, + functionDescription: def.Description, + functionParameters: BinaryData.FromString(schemaJson))); + } + return result; + } + + private List BuildChatMessages(IList> messages) + { + var result = new List { new SystemChatMessage(system) }; + foreach (var msg in messages) + { + var role = (string)msg["role"]!; + switch (role) + { + case "user": + var userContent = msg.GetValueOrDefault("content") as string + ?? string.Empty; + result.Add(new UserChatMessage(userContent)); + break; + + case "assistant": + result.Add(BuildAssistantMessage(msg)); + break; + + case "tool": + var toolCallId = (string)msg["tool_call_id"]!; + var toolContent = msg.GetValueOrDefault("content") as string + ?? string.Empty; + result.Add(new ToolChatMessage(toolCallId, toolContent)); + break; + } + } + return result; + } + } +} +#endif diff --git a/src/Temporalio.Extensions.ToolRegistry/README.md b/src/Temporalio.Extensions.ToolRegistry/README.md new file mode 100644 index 00000000..316a968c --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/README.md @@ -0,0 +1,207 @@ +# Temporalio.Extensions.ToolRegistry + +LLM tool-calling primitives for Temporal activities — define tools once, use with +Anthropic or OpenAI. + +## Before you start + +A Temporal Activity is a function that Temporal monitors and retries automatically on failure. Temporal streams progress between retries via heartbeats — that's the mechanism `RunWithSessionAsync` uses to resume a crashed LLM conversation mid-turn. + +`ToolRegistry.RunToolLoopAsync` works standalone in any function — no Temporal server needed. Add `AgenticSession` only when you need crash-safe resume inside a Temporal activity. + +`AgenticSession` requires a running Temporal worker — it reads and writes heartbeat state from the active activity context. Use `ToolRegistry.RunToolLoopAsync` standalone for scripts, one-off jobs, or any code that runs outside a Temporal worker. + +New to Temporal? → https://docs.temporal.io/develop + +## Install + +```bash +dotnet add package Temporalio.Extensions.ToolRegistry +# Add only the LLM SDK(s) you use: +dotnet add package Anthropic.SDK # Anthropic +dotnet add package OpenAI # OpenAI +``` + +## Quickstart + +Tool definitions use [JSON Schema](https://json-schema.org/understanding-json-schema/) for `InputSchema`. The quickstart uses a single string field; for richer schemas refer to the JSON Schema docs. + +```csharp +using Temporalio.Extensions.ToolRegistry; +using Temporalio.Extensions.ToolRegistry.Providers; + +[Activity] // Remove for standalone use — no worker needed +public async Task> AnalyzeAsync(string prompt) +{ + var issues = new List(); + var registry = new ToolRegistry(); + + registry.Register( + new ToolDefinition( + Name: "flag_issue", + Description: "Flag a problem found in the analysis", + InputSchema: new Dictionary + { + ["type"] = "object", + ["properties"] = new Dictionary + { + ["description"] = new Dictionary { ["type"] = "string" }, + }, + ["required"] = new[] { "description" }, + }), + inp => + { + issues.Add((string)inp["description"]); + return Task.FromResult("recorded"); // this string is sent back to the LLM as the tool result + }); + + var provider = new AnthropicProvider( + new AnthropicConfig { ApiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") }, + registry, + "You are a code reviewer. Call flag_issue for each problem you find."); + + await ToolRegistry.RunToolLoopAsync(provider, registry, "", prompt); + return issues; +} +``` + +### Selecting a model + +The default model is `"claude-sonnet-4-6"` (Anthropic) or `"gpt-4o"` (OpenAI). Override with the `Model` property: + +```csharp +var cfg = new AnthropicConfig +{ + ApiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY"), + Model = "claude-3-5-sonnet-20241022", +}; +``` + +Model IDs are defined by the provider — see Anthropic or OpenAI docs for current names. + +### OpenAI + +```csharp +var cfg = new OpenAIConfig { ApiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") }; +var provider = new OpenAIProvider(cfg, registry, "your system prompt"); +await ToolRegistry.RunToolLoopAsync(provider, registry, "", prompt); +``` + +## Crash-safe agentic sessions + +For multi-turn LLM conversations that must survive activity retries, use +`AgenticSession.RunWithSessionAsync`. It saves conversation history via +`ActivityExecutionContext.Heartbeat` on every turn and restores it on retry. + +```csharp +[Activity] // Remove for standalone use — no worker needed +public async Task> LongAnalysisAsync(string prompt) +{ + var issues = new List(); + + await AgenticSession.RunWithSessionAsync(async session => + { + var registry = new ToolRegistry(); + registry.Register( + new ToolDefinition("flag", "...", new Dictionary { ["type"] = "object" }), + inp => + { + session.Issues.Add(inp); + return Task.FromResult("ok"); // this string is sent back to the LLM as the tool result + }); + + var provider = new AnthropicProvider( + new AnthropicConfig { ApiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") }, + registry, "your system prompt"); + + await session.RunToolLoopAsync(provider, registry, "your system prompt", prompt); + issues.AddRange(session.Issues.Cast()); // capture after loop completes + }); + + return issues; +} +``` + +## Testing without an API key + +```csharp +using Temporalio.Extensions.ToolRegistry.Testing; + +[Fact] +public async Task TestAnalyze() +{ + var registry = new ToolRegistry(); + registry.Register( + new ToolDefinition("flag", "d", new Dictionary { ["type"] = "object" }), + _ => Task.FromResult("ok")); + + var provider = new MockProvider( + MockResponse.ToolCall("flag", new Dictionary { ["description"] = "stale API" }), + MockResponse.Done("analysis complete") + ).WithRegistry(registry); + + var msgs = await ToolRegistry.RunToolLoopAsync(provider, registry, "sys", "analyze"); + Assert.True(msgs.Count > 2); +} +``` + +## Integration testing with real providers + +To run the integration tests against live Anthropic and OpenAI APIs: + +```bash +RUN_INTEGRATION_TESTS=1 \ + ANTHROPIC_API_KEY=sk-ant-... \ + OPENAI_API_KEY=sk-proj-... \ + dotnet test --filter "Integration" +``` + +Tests skip automatically when `RUN_INTEGRATION_TESTS` is unset. Real API calls +incur billing — expect a few cents per full test run. + +## Storing application results + +`session.Issues` accumulates application-level +results during the tool loop. Elements are serialized to JSON inside each heartbeat +checkpoint — they must be plain maps/dicts with JSON-serializable values. A non-serializable +value raises a non-retryable `ApplicationError` at heartbeat time rather than silently +losing data on the next retry. + +### Storing typed results + +Convert your domain type to a plain dict at the tool-call site and back after the session: + +```csharp +record Issue(string Type, string File); + +// Inside tool handler: +session.Issues.Add(new() { ["type"] = "smell", ["file"] = "Foo.cs" }); + +// After session (using System.Text.Json): +var issues = session.Issues + .Select(d => JsonSerializer.Deserialize(JsonSerializer.Serialize(d))!) + .ToList(); +``` + +## Per-turn LLM timeout + +Individual LLM calls inside the tool loop are unbounded by default. A hung HTTP +connection holds the activity open until Temporal's `ScheduleToCloseTimeout` +fires — potentially many minutes. Set a per-turn timeout on the provider client: + +```csharp +var cfg = new AnthropicConfig +{ + ApiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY"), + Timeout = TimeSpan.FromSeconds(30), +}; +var provider = new AnthropicProvider(cfg, registry, "your system prompt"); +// provider now enforces 30s per turn +``` + +Recommended timeouts: + +| Model type | Recommended | +|---|---| +| Standard (Claude 3.x, GPT-4o) | 30 s | +| Reasoning (o1, o3, extended thinking) | 300 s | diff --git a/src/Temporalio.Extensions.ToolRegistry/SessionCheckpoint.cs b/src/Temporalio.Extensions.ToolRegistry/SessionCheckpoint.cs new file mode 100644 index 00000000..8ada7f04 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/SessionCheckpoint.cs @@ -0,0 +1,30 @@ +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Temporalio.Extensions.ToolRegistry +{ + /// + /// Heartbeat payload for crash-safe session recovery. Serialized to Temporal heartbeat details + /// on each turn; deserialized on activity retry. + /// + internal sealed class SessionCheckpoint + { + /// + /// Gets or sets the checkpoint schema version. Absent (0) in pre-versioned checkpoints. + /// + [JsonPropertyName("version")] + public int Version { get; set; } = 1; + + /// + /// Gets or sets the conversation message history. + /// + [JsonPropertyName("messages")] + public List>? Messages { get; set; } = new(); + + /// + /// Gets or sets accumulated application-level issues from tool calls. + /// + [JsonPropertyName("issues")] + public List>? Issues { get; set; } = new(); + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry/Temporalio.Extensions.ToolRegistry.csproj b/src/Temporalio.Extensions.ToolRegistry/Temporalio.Extensions.ToolRegistry.csproj new file mode 100644 index 00000000..bac48a74 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/Temporalio.Extensions.ToolRegistry.csproj @@ -0,0 +1,30 @@ + + + + Temporal SDK .NET Tool Registry Extension + true + 9.0 + enable + true + snupkg + netstandard2.0;net8.0 + + + + + + + + + + <_Parameter1>Temporalio.Extensions.ToolRegistry.Tests + + + + + + + + + + diff --git a/src/Temporalio.Extensions.ToolRegistry/ToolDef.cs b/src/Temporalio.Extensions.ToolRegistry/ToolDef.cs new file mode 100644 index 00000000..f7447c95 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/ToolDef.cs @@ -0,0 +1,17 @@ +using System.Collections.Generic; + +namespace Temporalio.Extensions.ToolRegistry +{ + /// + /// Defines an LLM tool in Anthropic's tool_use JSON format. + /// The same definition is used for both Anthropic and OpenAI; + /// converts the schema for each provider. + /// + /// Tool name used in function calls. + /// Human-readable description of what the tool does. + /// JSON schema for the tool's input parameters. + public sealed record ToolDef( + string Name, + string Description, + IReadOnlyDictionary InputSchema); +} diff --git a/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs b/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs new file mode 100644 index 00000000..7793f665 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs @@ -0,0 +1,130 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Temporalio.Extensions.ToolRegistry +{ + /// + /// Maps tool names to definitions and handlers. + /// + /// + /// Tools are registered in Anthropic's tool_use format. The registry exports them for Anthropic + /// or OpenAI and dispatches incoming tool calls to the appropriate handler. + /// + /// A is not safe for concurrent modification; build it before + /// passing it to concurrent activities. + /// + /// +#pragma warning disable CA1724 // Type name matches namespace name by design + public sealed class ToolRegistry +#pragma warning restore CA1724 + { + private readonly List defs = new(); + private readonly Dictionary, string>> handlers = new(); + + /// + /// Runs a complete multi-turn LLM tool-calling loop. + /// + /// + /// This is the primary entry point for simple, non-resumable loops. For crash-safe sessions + /// with heartbeat checkpointing, use + /// . + /// + /// LLM provider adapter. + /// Tool registry. + /// + /// System prompt. Accepted for API symmetry with other Temporal SDKs, but the system + /// prompt is captured by the provider at construction time and this parameter is not + /// forwarded to the provider. Pass the same value you used when constructing the provider. + /// + /// Initial user prompt. + /// Cancellation token. + /// Full message history on completion. + public static async Task>> RunToolLoopAsync( + IProvider provider, + ToolRegistry registry, + string system, + string prompt, + CancellationToken cancellationToken = default) + { + var messages = new List> + { + new() { ["role"] = "user", ["content"] = prompt }, + }; + + while (true) + { + var result = await provider.RunTurnAsync( + messages, registry.Definitions(), cancellationToken).ConfigureAwait(false); + foreach (var msg in result.NewMessages) + { + messages.Add(msg); + } + if (result.Done) + { + return messages; + } + } + } + + /// + /// Registers a tool definition and its handler. + /// + /// Tool definition. + /// Function called when the LLM invokes the tool. + public void Register(ToolDef def, Func, string> handler) + { + defs.Add(def); + handlers[def.Name] = handler; + } + + /// + /// Calls the handler registered for with the given input. + /// + /// Tool name. + /// Tool input. + /// String result from the handler. + /// + /// If no handler is registered for . + /// + public string Dispatch(string name, IReadOnlyDictionary input) + { + if (!handlers.TryGetValue(name, out var handler)) + { + throw new KeyNotFoundException($"Unknown tool: {name}"); + } + return handler(input); + } + + /// + /// Returns a snapshot of registered tool definitions. + /// + /// Read-only list of registered tool definitions. + public IReadOnlyList Definitions() => defs.ToArray(); + + /// + /// Returns tool definitions in OpenAI function-calling format. + /// + /// Read-only list of tool definitions as OpenAI function objects. + public IReadOnlyList> ToOpenAI() + { + var result = new IReadOnlyDictionary[defs.Count]; + for (int i = 0; i < defs.Count; i++) + { + var def = defs[i]; + result[i] = new Dictionary + { + ["type"] = "function", + ["function"] = new Dictionary + { + ["name"] = def.Name, + ["description"] = def.Description, + ["parameters"] = def.InputSchema, + }, + }; + } + return result; + } + } +} diff --git a/src/Temporalio.Extensions.ToolRegistry/_LanguageHelpers.cs b/src/Temporalio.Extensions.ToolRegistry/_LanguageHelpers.cs new file mode 100644 index 00000000..a4dafff0 --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/_LanguageHelpers.cs @@ -0,0 +1,11 @@ +#pragma warning disable SA1649 + +namespace System.Runtime.CompilerServices +{ + /// + /// Needed for init-only properties to work on older .NET versions. + /// + internal static class IsExternalInit + { + } +} diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/.editorconfig b/tests/Temporalio.Extensions.ToolRegistry.Tests/.editorconfig new file mode 100644 index 00000000..1a1793e9 --- /dev/null +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/.editorconfig @@ -0,0 +1,37 @@ +[*.cs] + +# Do not need docs in test project +dotnet_diagnostic.CS1591.severity = none +dotnet_diagnostic.SA1600.severity = none +dotnet_diagnostic.SA1602.severity = none + +# SA0001: XML doc file not generated for test project +dotnet_diagnostic.SA0001.severity = none + +# SA1512/SA1516: allow section-header comments and flexible element spacing in tests +dotnet_diagnostic.SA1512.severity = none +dotnet_diagnostic.SA1516.severity = none + +# Tests can have underscores in method names +dotnet_diagnostic.CA1707.severity = none + +# Don't need to mark test classes sealed +dotnet_diagnostic.CA1852.severity = none + +# ConfigureAwait not needed in tests +dotnet_diagnostic.CA2007.severity = none + +# Do not need task scheduler for tests +dotnet_diagnostic.CA2008.severity = none + +# Don't care about test item visibility +dotnet_diagnostic.CA1515.severity = none + +# Don't need to make array params static readonly for test performance +dotnet_diagnostic.CA1861.severity = none + +# Allow async methods to not have await in them +dotnet_diagnostic.CS1998.severity = none + +# Do not need to suffix tests with "Async" +dotnet_diagnostic.VSTHRD200.severity = none diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/AgenticSessionTests.cs b/tests/Temporalio.Extensions.ToolRegistry.Tests/AgenticSessionTests.cs new file mode 100644 index 00000000..02060baa --- /dev/null +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/AgenticSessionTests.cs @@ -0,0 +1,201 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Temporalio.Activities; +using Temporalio.Converters; +using Temporalio.Extensions.ToolRegistry; +using Temporalio.Extensions.ToolRegistry.Testing; +using Temporalio.Testing; +using Xunit; + +namespace Temporalio.Extensions.ToolRegistry.Tests +{ + public class AgenticSessionTests + { + [Fact] + public async Task RunToolLoopAsync_HeartbeatsBeforeEachTurn() + { + var registry = new ToolRegistry(); + registry.Register( + new("noop", "No-op", new Dictionary()), + _ => "ok"); + + var provider = new MockProvider(new[] + { + MockResponse.ToolCall("noop", new Dictionary()), + MockResponse.Done("done"), + }); + + var captured = new List(); + var env = new ActivityEnvironment { Heartbeater = details => captured.Add(details) }; + + await env.RunAsync(async () => + { + var session = new AgenticSession(); + await session.RunToolLoopAsync( + provider, + registry, + "sys", + "go", + ActivityExecutionContext.Current.CancellationToken); + }); + + // One heartbeat before turn 1 (tool_call), one before turn 2 (done) = 2 + Assert.Equal(2, captured.Count); + } + + [Fact] + public async Task RunToolLoopAsync_CancellationAfterHeartbeat() + { + var registry = new ToolRegistry(); + var provider = new MockProvider(new[] { MockResponse.Done("done") }); + + var env = new ActivityEnvironment(); + + await Assert.ThrowsAsync(async () => + { + await env.RunAsync(async () => + { + // Cancel before entering the loop. + env.Cancel(); + var session = new AgenticSession(); + await session.RunToolLoopAsync( + provider, + registry, + "sys", + "go", + ActivityExecutionContext.Current.CancellationToken); + }); + }); + } + + [Fact] + public async Task Checkpoint_HeartbeatsSessionState() + { + object?[]? capturedDetail = null; + var env = new ActivityEnvironment + { + Heartbeater = details => capturedDetail = details, + }; + + await env.RunAsync(async () => + { + var session = new AgenticSession(); + session.Messages.Add(new() { ["role"] = "user", ["content"] = "hello" }); + session.Issues.Add(new() { ["type"] = "bug" }); + session.Checkpoint(); + await Task.CompletedTask; + }); + + Assert.NotNull(capturedDetail); + Assert.Single(capturedDetail!); + var cp = capturedDetail![0] as SessionCheckpoint; + Assert.NotNull(cp); + Assert.Single(cp!.Messages!); + Assert.Single(cp.Issues!); + } + + [Fact] + public async Task RunWithSessionAsync_FreshStart() + { + var registry = new ToolRegistry(); + var provider = new MockProvider(new[] { MockResponse.Done("done") }); + var captured = new List(); + var env = new ActivityEnvironment { Heartbeater = details => captured.Add(details) }; + + List>? sessionMessages = null; + + await env.RunAsync(async () => + { + await AgenticSession.RunWithSessionAsync(async session => + { + await session.RunToolLoopAsync(provider, registry, "sys", "hello"); + sessionMessages = new(session.Messages); + }); + }); + + Assert.NotNull(sessionMessages); + Assert.Equal("user", (string)sessionMessages![0]["role"]!); + Assert.Equal("hello", sessionMessages[0]["content"]); + } + + [Fact] + public async Task RunWithSessionAsync_RestoresFromCheckpoint() + { + // Serialize a checkpoint and pre-seed the ActivityInfo heartbeat details. + var checkpointMessages = new List> + { + new() { ["role"] = "user", ["content"] = "restored-prompt" }, + new() + { + ["role"] = "assistant", + ["content"] = new List + { + new Dictionary { ["type"] = "text", ["text"] = "prior response" }, + }, + }, + }; + var cp = new SessionCheckpoint { Messages = checkpointMessages, Issues = new() }; + + // Serialize checkpoint via DataConverter so it can be seeded into heartbeat details. + var payload = await DataConverter.Default.ToPayloadAsync(cp); + + var seededInfo = ActivityEnvironment.DefaultInfo with + { + HeartbeatDetails = new[] { payload }, + }; + + List>? restoredMessages = null; + var env = new ActivityEnvironment + { + Info = seededInfo, + Heartbeater = _ => { }, + }; + + await env.RunAsync(async () => + { + await AgenticSession.RunWithSessionAsync(session => + { + restoredMessages = new(session.Messages); + return Task.CompletedTask; + }); + }); + + Assert.NotNull(restoredMessages); + Assert.Equal(2, restoredMessages!.Count); + Assert.Equal("restored-prompt", restoredMessages[0]["content"]); + } + + [Fact] + public async Task RunWithSessionAsync_HandlesCorruptCheckpoint() + { + // If heartbeat details can't be deserialized as SessionCheckpoint, start fresh. + var badPayload = await DataConverter.Default.ToPayloadAsync("not-a-checkpoint"); + + var seededInfo = ActivityEnvironment.DefaultInfo with + { + HeartbeatDetails = new[] { badPayload }, + }; + + List>? sessionMessages = null; + var env = new ActivityEnvironment + { + Info = seededInfo, + Heartbeater = _ => { }, + }; + + await env.RunAsync(async () => + { + await AgenticSession.RunWithSessionAsync(session => + { + sessionMessages = new(session.Messages); + return Task.CompletedTask; + }); + }); + + // Fresh start — no messages restored. + Assert.NotNull(sessionMessages); + Assert.Empty(sessionMessages!); + } + } +} diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/Program.cs b/tests/Temporalio.Extensions.ToolRegistry.Tests/Program.cs new file mode 100644 index 00000000..d126bf83 --- /dev/null +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/Program.cs @@ -0,0 +1,19 @@ +using System; + +namespace Temporalio.Extensions.ToolRegistry.Tests; + +public static class Program +{ + public static int Main(string[] args) + { + // Always put self assembly as first arg if "--help" isn't first arg + if (args.Length != 1 || args[0] != "--help") + { + var newArgs = new string[args.Length + 1]; + newArgs[0] = typeof(Program).Assembly.Location; + Array.Copy(args, 0, newArgs, 1, args.Length); + args = newArgs; + } + return Xunit.ConsoleClient.Program.Main(args); + } +} diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/Temporalio.Extensions.ToolRegistry.Tests.csproj b/tests/Temporalio.Extensions.ToolRegistry.Tests/Temporalio.Extensions.ToolRegistry.Tests.csproj new file mode 100644 index 00000000..f1f60706 --- /dev/null +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/Temporalio.Extensions.ToolRegistry.Tests.csproj @@ -0,0 +1,26 @@ + + + + false + false + false + enable + Exe + net10.0 + + + + + + + + + + + + + + + + + diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/TestingTests.cs b/tests/Temporalio.Extensions.ToolRegistry.Tests/TestingTests.cs new file mode 100644 index 00000000..079dc533 --- /dev/null +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/TestingTests.cs @@ -0,0 +1,231 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading.Tasks; +using Temporalio.Extensions.ToolRegistry; +using Temporalio.Extensions.ToolRegistry.Testing; +using Xunit; + +namespace Temporalio.Extensions.ToolRegistry.Tests +{ + public class TestingTests + { + // ── MockProvider ──────────────────────────────────────────────────────── + + [Fact] + public async Task MockProvider_ReturnsResponsesInOrder() + { + var provider = new MockProvider(new[] + { + MockResponse.Done("first"), + MockResponse.Done("second"), + }); + + var r1 = await provider.RunTurnAsync(new List>(), Array.Empty()); + var r2 = await provider.RunTurnAsync(new List>(), Array.Empty()); + var r3 = await provider.RunTurnAsync(new List>(), Array.Empty()); + + Assert.True(r1.Done); + Assert.True(r2.Done); + // Exhausted — returns empty done. + Assert.True(r3.Done); + Assert.Empty(r3.NewMessages); + } + + [Fact] + public async Task MockProvider_ToolCall_DispatchesToRegistry() + { + var registry = new ToolRegistry(); + var received = new List(); + registry.Register( + new("flag", "Flag an issue", new Dictionary()), + input => + { + received.Add((string)input["desc"]!); + return "ok"; + }); + + var provider = new MockProvider(new[] + { + MockResponse.ToolCall("flag", new Dictionary { ["desc"] = "problem" }), + MockResponse.Done(), + }); + // Wrap registry in an adapter so MockProvider can dispatch. + provider.WithRegistry(new RegistryAdapter(registry)); + + var turn1 = await provider.RunTurnAsync(new List>(), registry.Definitions()); + + Assert.False(turn1.Done); + Assert.Single(received); + Assert.Equal("problem", received[0]); + // newMessages: assistant + tool_result + Assert.Equal(2, turn1.NewMessages.Count); + } + + [Fact] + public async Task MockProvider_WithFakeToolRegistry_RecordsCalls() + { + var fake = new FakeToolRegistry(); + fake.Register( + new("noop", "No-op", new Dictionary()), + _ => "done"); + + var provider = new MockProvider(new[] + { + MockResponse.ToolCall("noop", new Dictionary { ["x"] = "1" }, "call-1"), + MockResponse.Done(), + }).WithRegistry(fake); + + await provider.RunTurnAsync(new List>(), Array.Empty()); + + Assert.Single(fake.Calls); + Assert.Equal("noop", fake.Calls[0].Name); + Assert.Equal("1", fake.Calls[0].Input["x"]); + Assert.Equal("done", fake.Calls[0].Result); + } + + // ── FakeToolRegistry ──────────────────────────────────────────────────── + + [Fact] + public void FakeToolRegistry_RecordsDispatchCalls() + { + var fake = new FakeToolRegistry(); + fake.Register( + new("echo", "Echo", new Dictionary()), + input => (string)input["v"]!); + + fake.Dispatch("echo", new Dictionary { ["v"] = "hello" }); + fake.Dispatch("echo", new Dictionary { ["v"] = "world" }); + + Assert.Equal(2, fake.Calls.Count); + Assert.Equal("echo", fake.Calls[0].Name); + Assert.Equal("echo", fake.Calls[1].Name); + Assert.Equal("hello", fake.Calls[0].Input["v"]); + Assert.Equal("world", fake.Calls[1].Input["v"]); + Assert.Equal("hello", fake.Calls[0].Result); + Assert.Equal("world", fake.Calls[1].Result); + } + + [Fact] + public void FakeToolRegistry_Dispatch_DelegatesToInnerRegistry() + { + var fake = new FakeToolRegistry(); + fake.Register( + new("add", "Add two numbers", new Dictionary()), + input => ((long)input["a"]! + (long)input["b"]!).ToString()); + + var result = fake.Dispatch("add", new Dictionary { ["a"] = 3L, ["b"] = 4L }); + + Assert.Equal("7", result); + } + + // ── MockAgenticSession ────────────────────────────────────────────────── + + [Fact] + public async Task MockAgenticSession_RunToolLoopAsync_IsNoop() + { + var session = new MockAgenticSession(); + session.Issues.Add(new() { ["type"] = "bug", ["desc"] = "missing null check" }); + + await session.RunToolLoopAsync(null, null, "sys", "find issues"); + + Assert.Single(session.Issues); + Assert.Equal("bug", session.Issues[0]["type"]); + Assert.Single(session.Messages); + Assert.Equal("find issues", session.Messages[0]["content"]); + Assert.Equal("find issues", session.CapturedPrompt); + } + + // ── CrashAfterTurns ───────────────────────────────────────────────────── + + [Fact] + public async Task CrashAfterTurns_CompletesNormallyForNTurns() + { + var provider = new CrashAfterTurns { N = 2 }; + + var r1 = await provider.RunTurnAsync(new List>(), Array.Empty()); + var r2 = await provider.RunTurnAsync(new List>(), Array.Empty()); + + Assert.False(r1.Done); + Assert.True(r2.Done); + } + + [Fact] + public async Task CrashAfterTurns_ThrowsAfterNTurns() + { + var provider = new CrashAfterTurns { N = 1 }; + + await provider.RunTurnAsync(new List>(), Array.Empty()); + + await Assert.ThrowsAsync(() => + provider.RunTurnAsync(new List>(), Array.Empty())); + } + + [Fact] + public async Task CrashAfterTurns_DelegatesFirstNTurns() + { + var inner = new MockProvider(new[] + { + MockResponse.Done("delegated"), + }); + var provider = new CrashAfterTurns { N = 1, Delegate = inner }; + + var r1 = await provider.RunTurnAsync(new List>(), Array.Empty()); + + // First turn should come from the delegate, not the stub. + Assert.True(r1.Done); + Assert.Single(r1.NewMessages); + } + + // ── MockResponse ──────────────────────────────────────────────────────── + + [Fact] + public void MockResponse_Done_SetsStopTrue() + { + var r = MockResponse.Done("finished"); + Assert.True(r.Stop); + Assert.Single(r.Content); + Assert.Equal("text", r.Content[0]["type"]); + Assert.Equal("finished", r.Content[0]["text"]); + } + + [Fact] + public void MockResponse_ToolCall_SetsStopFalse() + { + var r = MockResponse.ToolCall("my_tool", new Dictionary { ["k"] = "v" }, "id-1"); + Assert.False(r.Stop); + Assert.Single(r.Content); + Assert.Equal("tool_use", r.Content[0]["type"]); + Assert.Equal("my_tool", r.Content[0]["name"]); + Assert.Equal("id-1", r.Content[0]["id"]); + } + + // ── JsonElementConverter ──────────────────────────────────────────────── + + [Fact] + public void JsonElementConverter_Materialize_ConvertsNestedObjects() + { + // Simulate what System.Text.Json gives us on deserialization. + var inner = JsonSerializer.Deserialize>( + @"{""a"": 1, ""b"": ""hello"", ""c"": [1, 2], ""d"": {""nested"": true}}"); + + var result = JsonElementConverter.Materialize(inner!); + + Assert.IsType(result["a"]); + Assert.Equal(1L, result["a"]); + Assert.IsType(result["b"]); + Assert.Equal("hello", result["b"]); + Assert.IsType>(result["c"]); + Assert.IsType>(result["d"]); + } + + // Helper adapter: wraps ToolRegistry as IDispatcher. + private sealed class RegistryAdapter : IDispatcher + { + private readonly ToolRegistry registry; + public RegistryAdapter(ToolRegistry r) => registry = r; + public string Dispatch(string name, IReadOnlyDictionary input) => + registry.Dispatch(name, input); + } + } +} diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs b/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs new file mode 100644 index 00000000..1ef0e822 --- /dev/null +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs @@ -0,0 +1,294 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading.Tasks; +using Temporalio.Extensions.ToolRegistry; +using Temporalio.Extensions.ToolRegistry.Providers; +using Temporalio.Extensions.ToolRegistry.Testing; +using Xunit; + +namespace Temporalio.Extensions.ToolRegistry.Tests +{ + public class ToolRegistryTests + { + [Fact] + public void Register_AddsDefinition() + { + var registry = new ToolRegistry(); + var def = new ToolDef( + Name: "greet", + Description: "Say hello", + InputSchema: new Dictionary + { + ["type"] = "object", + ["properties"] = new Dictionary + { + ["name"] = new Dictionary { ["type"] = "string" }, + }, + ["required"] = new[] { "name" }, + }); + + registry.Register(def, input => $"Hello, {input["name"]}!"); + + var defs = registry.Definitions(); + Assert.Single(defs); + Assert.Equal("greet", defs[0].Name); + } + + [Fact] + public void Dispatch_CallsHandler() + { + var registry = new ToolRegistry(); + registry.Register( + new("echo", "Echoes input", new Dictionary()), + input => (string)input["msg"]!); + + var result = registry.Dispatch("echo", new Dictionary { ["msg"] = "hi" }); + + Assert.Equal("hi", result); + } + + [Fact] + public void Dispatch_UnknownTool_ThrowsKeyNotFoundException() + { + var registry = new ToolRegistry(); + + Assert.Throws(() => + registry.Dispatch("nope", new Dictionary())); + } + + [Fact] + public void Definitions_ReturnsSnapshot() + { + var registry = new ToolRegistry(); + registry.Register( + new("a", "A", new Dictionary()), + _ => "a"); + registry.Register( + new("b", "B", new Dictionary()), + _ => "b"); + + var defs = registry.Definitions(); + + Assert.Equal(2, defs.Count); + Assert.Equal("a", defs[0].Name); + Assert.Equal("b", defs[1].Name); + } + + [Fact] + public void ToOpenAI_ReturnsCorrectFormat() + { + var registry = new ToolRegistry(); + var schema = new Dictionary + { + ["type"] = "object", + ["properties"] = new Dictionary(), + }; + registry.Register( + new("mytool", "Does something", schema), + _ => "ok"); + + var openAI = registry.ToOpenAI(); + + Assert.Single(openAI); + Assert.Equal("function", openAI[0]["type"]); + var fn = (IReadOnlyDictionary)openAI[0]["function"]!; + Assert.Equal("mytool", fn["name"]); + Assert.Equal("Does something", fn["description"]); + } + + [Fact] + public async Task RunToolLoopAsync_SimpleExchange() + { + var registry = new ToolRegistry(); + var flagged = new List(); + registry.Register( + new("flag", "Flag an issue", new Dictionary + { + ["type"] = "object", + ["properties"] = new Dictionary + { + ["desc"] = new Dictionary { ["type"] = "string" }, + }, + }), + input => + { + flagged.Add((string)input["desc"]!); + return "recorded"; + }); + + var provider = new MockProvider(new[] + { + MockResponse.ToolCall("flag", new Dictionary { ["desc"] = "broken" }), + MockResponse.Done("done"), + }).WithRegistry(registry); + + var messages = await ToolRegistry.RunToolLoopAsync( + provider, registry, "system", "find issues"); + + Assert.Single(flagged); + Assert.Equal("broken", flagged[0]); + // messages: user + assistant(tool_use) + user(tool_result) + assistant(done) + Assert.True(messages.Count >= 4); + } + + [Fact] + public async Task RunToolLoopAsync_NoCalls_ReturnsDone() + { + var registry = new ToolRegistry(); + var provider = new MockProvider(new[] { MockResponse.Done("no tools needed") }); + + var messages = await ToolRegistry.RunToolLoopAsync( + provider, registry, "system", "hello"); + + Assert.True(messages.Count >= 2); + Assert.Equal("user", (string)messages[0]["role"]!); + } + + // ── Checkpoint round-trip test (T6) ────────────────────────────────────── + + /// + /// Verifies that an assistant message with tool_calls survives a JSON serialize/deserialize + /// cycle with all field types preserved. This guards against the class of bug where a + /// List<Dictionary<string,object?>> stored in-memory deserializes back as + /// List<object?>, breaking pattern-matching in provider rebuild methods. + /// + [Fact] + public void Checkpoint_RoundTrip_PreservesToolCallMessages() + { + var toolCallsInMemory = new List> + { + new() + { + ["id"] = "call_abc", + ["type"] = "function", + ["function"] = new Dictionary + { + ["name"] = "my_tool", + ["arguments"] = "{\"x\":1}", + }, + }, + }; + var assistantMsg = new Dictionary + { + ["role"] = "assistant", + ["tool_calls"] = toolCallsInMemory, + }; + var issueInMemory = new Dictionary + { + ["type"] = "smell", + ["file"] = "foo.cs", + }; + var cp = new SessionCheckpoint + { + Messages = new() { assistantMsg }, + Issues = new() { issueInMemory }, + }; + + // Simulate Temporal heartbeat round-trip via JSON serialization. + var json = JsonSerializer.Serialize(cp); + var restored = JsonSerializer.Deserialize(json)!; + var messages = JsonElementConverter.MaterializeList(restored.Messages!); + var issues = JsonElementConverter.MaterializeList(restored.Issues!); + + // Verify assistant message role survived. + Assert.Equal("assistant", (string)messages[0]["role"]!); + + // tool_calls must come back as List> so that + // BuildAssistantMessage can pattern-match it on retry. + var toolCallsRestored = messages[0]["tool_calls"]; + var toolCallsList = Assert.IsType>>(toolCallsRestored); + Assert.Single(toolCallsList); + Assert.Equal("call_abc", (string)toolCallsList[0]["id"]!); + Assert.Equal("function", (string)toolCallsList[0]["type"]!); + + // Nested function dict must also be preserved. + var fn = Assert.IsType>(toolCallsList[0]["function"]); + Assert.Equal("my_tool", (string)fn["name"]!); + + // Issues must survive the round-trip. + Assert.Single(issues); + Assert.Equal("smell", (string)issues[0]["type"]!); + Assert.Equal("foo.cs", (string)issues[0]["file"]!); + } + + // ── Integration tests (skipped unless RUN_INTEGRATION_TESTS is set) ──── + + [Fact] + public async Task Integration_Anthropic() + { + if (string.IsNullOrEmpty(Environment.GetEnvironmentVariable("RUN_INTEGRATION_TESTS"))) + { + return; + } + + var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY"); + Assert.False(string.IsNullOrEmpty(apiKey), "ANTHROPIC_API_KEY required"); + + var (registry, collected) = MakeRecordRegistry(); + using var provider = new AnthropicProvider( + new AnthropicConfig { ApiKey = apiKey }, + registry, + "You must call record() exactly once with value='hello'."); + + await ToolRegistry.RunToolLoopAsync( + provider, + registry, + string.Empty, + "Please call the record tool with value='hello'."); + + Assert.Contains("hello", collected); + } + + [Fact] + public async Task Integration_OpenAI() + { + if (string.IsNullOrEmpty(Environment.GetEnvironmentVariable("RUN_INTEGRATION_TESTS"))) + { + return; + } + + var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + Assert.False(string.IsNullOrEmpty(apiKey), "OPENAI_API_KEY required"); + + var (registry, collected) = MakeRecordRegistry(); +#pragma warning disable CA2000 // OpenAIProvider does not implement IDisposable + var provider = new OpenAIProvider( + new OpenAIConfig { ApiKey = apiKey }, + registry, + "You must call record() exactly once with value='hello'."); +#pragma warning restore CA2000 + + await ToolRegistry.RunToolLoopAsync( + provider, + registry, + string.Empty, + "Please call the record tool with value='hello'."); + + Assert.Contains("hello", collected); + } + + private static (ToolRegistry Registry, List Collected) MakeRecordRegistry() + { + var collected = new List(); + var registry = new ToolRegistry(); + registry.Register( + new ToolDef( + Name: "record", + Description: "Record a value", + InputSchema: new Dictionary + { + ["type"] = "object", + ["properties"] = new Dictionary + { ["value"] = new Dictionary { ["type"] = "string" } }, + ["required"] = new[] { "value" }, + }), + inp => + { + collected.Add((string)inp["value"]!); + return "recorded"; + }); + return (registry, collected); + } + } +} From dfffa0b2b135fa6194e432333e88a44f51434d0a Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Sun, 12 Apr 2026 22:40:57 -0600 Subject: [PATCH 2/4] Add MCP tool-wrapping support to ToolRegistry Adds McpTool record and ToolRegistry.FromMcpTools static method that converts a sequence of MCP tool descriptors into a populated ToolRegistry. Handlers default to no-ops; callers override with Register after construction. Null InputSchema is normalized to an empty object schema. Co-Authored-By: Claude Sonnet 4.6 --- .../McpTool.cs | 15 ++++++++ .../README.md | 20 +++++++++++ .../ToolRegistry.cs | 30 ++++++++++++++++ .../ToolRegistryTests.cs | 35 +++++++++++++++++++ 4 files changed, 100 insertions(+) create mode 100644 src/Temporalio.Extensions.ToolRegistry/McpTool.cs diff --git a/src/Temporalio.Extensions.ToolRegistry/McpTool.cs b/src/Temporalio.Extensions.ToolRegistry/McpTool.cs new file mode 100644 index 00000000..2c48568f --- /dev/null +++ b/src/Temporalio.Extensions.ToolRegistry/McpTool.cs @@ -0,0 +1,15 @@ +using System.Collections.Generic; + +namespace Temporalio.Extensions.ToolRegistry +{ + /// + /// MCP-compatible tool descriptor. + /// + /// Tool name. + /// Human-readable description (may be null). + /// JSON Schema for the tool's input object (may be null). + public sealed record McpTool( + string Name, + string? Description, + IReadOnlyDictionary? InputSchema); +} diff --git a/src/Temporalio.Extensions.ToolRegistry/README.md b/src/Temporalio.Extensions.ToolRegistry/README.md index 316a968c..fc53ed28 100644 --- a/src/Temporalio.Extensions.ToolRegistry/README.md +++ b/src/Temporalio.Extensions.ToolRegistry/README.md @@ -205,3 +205,23 @@ Recommended timeouts: |---|---| | Standard (Claude 3.x, GPT-4o) | 30 s | | Reasoning (o1, o3, extended thinking) | 300 s | + +## MCP integration + +`ToolRegistry.FromMcpTools` converts a sequence of `McpTool` records into a populated +registry. Handlers default to no-ops that return an empty string; override them with +`Register` after construction. + +```csharp +// mcpTools is IEnumerable — populate from your MCP client. +var registry = ToolRegistry.FromMcpTools(mcpTools); + +// Override specific handlers before running the loop. +registry.Register( + new ToolDef("read_file", "Read a file", schema), + inp => ReadFile((string)inp["path"]!)); +``` + +`McpTool` mirrors the MCP protocol's `Tool` object: `Name`, `Description`, and +`InputSchema` (an `IReadOnlyDictionary` containing a JSON Schema +object). A `null` `InputSchema` is treated as an empty object schema. diff --git a/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs b/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs index 7793f665..5a31026e 100644 --- a/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs +++ b/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs @@ -68,6 +68,36 @@ public sealed class ToolRegistry } } + /// + /// Creates a from a list of MCP tool descriptors. + /// + /// + /// Each tool is registered with a no-op handler (returning an empty string). + /// Override handlers by calling with the same name after + /// construction. + /// + /// MCP tool descriptors. + /// A new registry populated from the MCP tool list. + public static ToolRegistry FromMcpTools(IEnumerable tools) + { + IReadOnlyDictionary emptySchema = new Dictionary + { + ["type"] = "object", + ["properties"] = new Dictionary(), + }; + var registry = new ToolRegistry(); + foreach (var tool in tools) + { + registry.Register( + new ToolDef( + Name: tool.Name, + Description: tool.Description ?? string.Empty, + InputSchema: tool.InputSchema ?? emptySchema), + _ => string.Empty); + } + return registry; + } + /// /// Registers a tool definition and its handler. /// diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs b/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs index 1ef0e822..4f938a96 100644 --- a/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs @@ -212,6 +212,41 @@ public void Checkpoint_RoundTrip_PreservesToolCallMessages() Assert.Equal("foo.cs", (string)issues[0]["file"]!); } + // ── FromMcpTools ───────────────────────────────────────────────────────── + + [Fact] + public void FromMcpTools_PopulatesRegistry() + { + var tools = new[] + { + new McpTool( + "read_file", + "Read a file", + new Dictionary + { + ["type"] = "object", + ["properties"] = new Dictionary + { + ["path"] = new Dictionary { ["type"] = "string" }, + }, + }), + new McpTool("list_dir", null, null), // null schema → empty object schema + }; + + var reg = ToolRegistry.FromMcpTools(tools); + var defs = reg.Definitions(); + + Assert.Equal(2, defs.Count); + Assert.Equal("read_file", defs[0].Name); + Assert.Equal("Read a file", defs[0].Description); + Assert.Equal("list_dir", defs[1].Name); + Assert.Equal("object", defs[1].InputSchema["type"]); // null schema defaulted + // no-op handler returns empty string + Assert.Equal( + string.Empty, + reg.Dispatch("read_file", new Dictionary { ["path"] = "/etc/hosts" })); + } + // ── Integration tests (skipped unless RUN_INTEGRATION_TESTS is set) ──── [Fact] From ff668c01b736d28242b7393eddb66b0cc3d3d959 Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Sun, 12 Apr 2026 23:23:33 -0600 Subject: [PATCH 3/4] fix: add is_error handling to AnthropicProvider, add provider error test - Set is_error=true on Anthropic tool result maps when a handler throws, matching the Anthropic API spec; OpenAI has no equivalent field - Add AnthropicProvider_HandlerError_SetsIsError test using an HttpListener- based mock server to verify is_error propagation without a real API key - Clarify RunToolLoopAsync system param docstring - Update README to clarify positioning vs Python/TypeScript framework plugins Co-Authored-By: Claude Sonnet 4.6 --- .../Providers/AnthropicProvider.cs | 11 ++- .../README.md | 2 + .../ToolRegistry.cs | 6 +- .../ToolRegistryTests.cs | 73 +++++++++++++++++++ 4 files changed, 87 insertions(+), 5 deletions(-) diff --git a/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicProvider.cs b/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicProvider.cs index 0aa96930..3277caea 100644 --- a/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicProvider.cs +++ b/src/Temporalio.Extensions.ToolRegistry/Providers/AnthropicProvider.cs @@ -144,6 +144,7 @@ public async Task RunTurnAsync( var id = (string)call["id"]!; var input = (Dictionary)call["input"]!; string result; + bool isError = false; try { result = registry.Dispatch(name, input); @@ -152,15 +153,21 @@ public async Task RunTurnAsync( catch (Exception e) { result = $"error: {e.Message}"; + isError = true; } #pragma warning restore CA1031 - toolResults.Add(new() + var toolResult = new Dictionary { ["type"] = "tool_result", ["tool_use_id"] = id, ["content"] = result, - }); + }; + if (isError) + { + toolResult["is_error"] = true; + } + toolResults.Add(toolResult); } newMessages.Add(new() { ["role"] = "user", ["content"] = toolResults }); return new(newMessages, Done: false); diff --git a/src/Temporalio.Extensions.ToolRegistry/README.md b/src/Temporalio.Extensions.ToolRegistry/README.md index fc53ed28..2a187890 100644 --- a/src/Temporalio.Extensions.ToolRegistry/README.md +++ b/src/Temporalio.Extensions.ToolRegistry/README.md @@ -13,6 +13,8 @@ A Temporal Activity is a function that Temporal monitors and retries automatical New to Temporal? → https://docs.temporal.io/develop +**Python or TypeScript user?** Those SDKs also ship framework-level integrations (`openai_agents`, `google_adk_agents`, `langgraph`, `@temporalio/ai-sdk`) for teams already using a specific agent framework. ToolRegistry is the equivalent story for direct Anthropic/OpenAI calls, and shares the same API surface across all six Temporal SDKs. + ## Install ```bash diff --git a/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs b/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs index 5a31026e..39bdcc68 100644 --- a/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs +++ b/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs @@ -34,9 +34,9 @@ public sealed class ToolRegistry /// LLM provider adapter. /// Tool registry. /// - /// System prompt. Accepted for API symmetry with other Temporal SDKs, but the system - /// prompt is captured by the provider at construction time and this parameter is not - /// forwarded to the provider. Pass the same value you used when constructing the provider. + /// System prompt. For API symmetry with other Temporal SDKs. The system prompt is + /// captured by the provider at construction time; this parameter is not forwarded. Pass + /// the same value you used when constructing the provider, or an empty string. /// /// Initial user prompt. /// Cancellation token. diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs b/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs index 4f938a96..3ea7d152 100644 --- a/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs @@ -1,5 +1,8 @@ using System; using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Temporalio.Extensions.ToolRegistry; @@ -247,6 +250,76 @@ public void FromMcpTools_PopulatesRegistry() reg.Dispatch("read_file", new Dictionary { ["path"] = "/etc/hosts" })); } + // ── AnthropicProvider is_error / handler error tests ──────────────────── + + /// + /// Verifies that when a tool handler throws, the Anthropic tool result carries + /// is_error = true and the turn does not propagate the exception. + /// + [Fact] + public async Task AnthropicProvider_HandlerError_SetsIsError() + { + // Find a free port, then start an HttpListener on it. + int port; + using (var tmp = new TcpListener(IPAddress.Loopback, 0)) + { + tmp.Start(); + port = ((IPEndPoint)tmp.LocalEndpoint).Port; + } + + var prefix = $"http://localhost:{port}/"; + using var listener = new HttpListener(); + listener.Prefixes.Add(prefix); + listener.Start(); + + // Serve one request: the tool_use response. + var serverTask = Task.Run(async () => + { + var ctx = await listener.GetContextAsync().ConfigureAwait(false); + ctx.Response.ContentType = "application/json"; + const string body = + "{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\"," + + "\"content\":[{\"type\":\"tool_use\",\"id\":\"c1\"," + + "\"name\":\"boom\",\"input\":{}}]," + + "\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"tool_use\"," + + "\"usage\":{\"input_tokens\":10,\"output_tokens\":5}}"; + var bytes = Encoding.UTF8.GetBytes(body); + ctx.Response.ContentLength64 = bytes.Length; + await ctx.Response.OutputStream.WriteAsync(bytes).ConfigureAwait(false); + ctx.Response.Close(); + }); + + var registry = new ToolRegistry(); + registry.Register( + new ToolDef("boom", "d", new Dictionary()), + _ => throw new InvalidOperationException("intentional failure")); + + using var provider = new AnthropicProvider( + new AnthropicConfig { ApiKey = "test-key", BaseUrl = new Uri(prefix) }, + registry, + "sys"); + + var messages = new List> + { + new() { ["role"] = "user", ["content"] = "go" }, + }; + + var result = await provider.RunTurnAsync(messages, registry.Definitions()); + await serverTask.ConfigureAwait(false); + listener.Stop(); + + Assert.False(result.Done); + Assert.Equal(2, result.NewMessages.Count); + + var toolResultMsg = result.NewMessages[1]; + Assert.Equal("user", (string)toolResultMsg["role"]!); + var toolResults = Assert.IsType>>(toolResultMsg["content"]); + Assert.Single(toolResults); + Assert.Equal("tool_result", (string)toolResults[0]["type"]!); + Assert.True(toolResults[0]["is_error"] is bool b && b, "is_error should be true"); + Assert.Contains("intentional failure", (string)toolResults[0]["content"]!); + } + // ── Integration tests (skipped unless RUN_INTEGRATION_TESTS is set) ──── [Fact] From 7d49364e0d8be306af972f7c07be253bfc0e3781 Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Mon, 13 Apr 2026 00:14:54 -0600 Subject: [PATCH 4/4] =?UTF-8?q?feat(ToolRegistry):=20rename=20Issues?= =?UTF-8?q?=E2=86=92Results,=20remove=20system=20param,=20timeout=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename AgenticSession.Issues → Results; update SessionCheckpoint JSON key 'issues' → 'results' and property Issues → Results - Remove unused system parameter from AgenticSession.RunToolLoopAsync and ToolRegistry.RunToolLoopAsync - Add ScheduleToCloseTimeout guidance to README - Update all test call sites (AgenticSessionTests, ToolRegistryTests) Co-Authored-By: Claude Sonnet 4.6 --- .../AgenticSession.cs | 22 ++++---- .../README.md | 52 ++++++++++++------- .../SessionCheckpoint.cs | 6 +-- .../ToolRegistry.cs | 6 --- .../AgenticSessionTests.cs | 10 ++-- .../ToolRegistryTests.cs | 20 ++++--- 6 files changed, 59 insertions(+), 57 deletions(-) diff --git a/src/Temporalio.Extensions.ToolRegistry/AgenticSession.cs b/src/Temporalio.Extensions.ToolRegistry/AgenticSession.cs index 020b44a3..d1184384 100644 --- a/src/Temporalio.Extensions.ToolRegistry/AgenticSession.cs +++ b/src/Temporalio.Extensions.ToolRegistry/AgenticSession.cs @@ -10,7 +10,7 @@ namespace Temporalio.Extensions.ToolRegistry { /// - /// Maintains conversation state (messages and issues) across multiple turns of a tool-calling + /// Maintains conversation state (messages and results) across multiple turns of a tool-calling /// loop, with heartbeat checkpointing for crash recovery. /// /// @@ -26,7 +26,7 @@ namespace Temporalio.Extensions.ToolRegistry public sealed class AgenticSession { private readonly List> messages = new(); - private readonly List> issues = new(); + private readonly List> results = new(); /// /// Gets the full conversation history. Append-only during a session. @@ -37,7 +37,7 @@ public sealed class AgenticSession /// Gets the accumulated application-level results from tool calls. Elements must be /// JSON-serializable for checkpoint storage. /// - public IList> Issues => issues; + public IList> Results => results; /// /// Runs inside an , restoring from a @@ -107,9 +107,9 @@ public static async Task RunWithSessionAsync( session.messages.AddRange(JsonElementConverter.MaterializeList(cp.Messages)); } - if (cp?.Issues?.Count > 0) + if (cp?.Results?.Count > 0) { - session.issues.AddRange(JsonElementConverter.MaterializeList(cp.Issues)); + session.results.AddRange(JsonElementConverter.MaterializeList(cp.Results)); } } } @@ -140,14 +140,12 @@ public static async Task RunWithSessionAsync( /// /// LLM provider adapter. /// Tool registry. - /// System prompt (passed to provider at construction time). /// Initial user prompt. /// Cancellation token. /// A task representing the asynchronous operation. public async Task RunToolLoopAsync( IProvider provider, ToolRegistry registry, - string system, string prompt, CancellationToken cancellationToken = default) { @@ -188,17 +186,17 @@ public async Task RunToolLoopAsync( /// public void Checkpoint(CancellationToken cancellationToken = default) { - // T10: validate all issues are JSON-serializable before heartbeating. - for (int i = 0; i < issues.Count; i++) + // Validate all results are JSON-serializable before heartbeating. + for (int i = 0; i < results.Count; i++) { try { - JsonSerializer.Serialize(issues[i]); + JsonSerializer.Serialize(results[i]); } catch (JsonException e) { throw new ApplicationFailureException( - $"AgenticSession: issues[{i}] is not JSON-serializable: {e.Message}. " + + $"AgenticSession: results[{i}] is not JSON-serializable: {e.Message}. " + "Store only Dictionary with JSON-serializable values.", nonRetryable: true); } @@ -207,7 +205,7 @@ public void Checkpoint(CancellationToken cancellationToken = default) var cp = new SessionCheckpoint { Messages = new(messages), - Issues = new(issues), + Results = new(results), }; ActivityExecutionContext.Current.Heartbeat(cp); cancellationToken.ThrowIfCancellationRequested(); diff --git a/src/Temporalio.Extensions.ToolRegistry/README.md b/src/Temporalio.Extensions.ToolRegistry/README.md index 2a187890..6d46acf1 100644 --- a/src/Temporalio.Extensions.ToolRegistry/README.md +++ b/src/Temporalio.Extensions.ToolRegistry/README.md @@ -35,7 +35,7 @@ using Temporalio.Extensions.ToolRegistry.Providers; [Activity] // Remove for standalone use — no worker needed public async Task> AnalyzeAsync(string prompt) { - var issues = new List(); + var results = new List(); var registry = new ToolRegistry(); registry.Register( @@ -53,7 +53,7 @@ public async Task> AnalyzeAsync(string prompt) }), inp => { - issues.Add((string)inp["description"]); + results.Add((string)inp["description"]); return Task.FromResult("recorded"); // this string is sent back to the LLM as the tool result }); @@ -62,8 +62,8 @@ public async Task> AnalyzeAsync(string prompt) registry, "You are a code reviewer. Call flag_issue for each problem you find."); - await ToolRegistry.RunToolLoopAsync(provider, registry, "", prompt); - return issues; + await ToolRegistry.RunToolLoopAsync(provider, registry, prompt); + return results; } ``` @@ -86,7 +86,7 @@ Model IDs are defined by the provider — see Anthropic or OpenAI docs for curre ```csharp var cfg = new OpenAIConfig { ApiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") }; var provider = new OpenAIProvider(cfg, registry, "your system prompt"); -await ToolRegistry.RunToolLoopAsync(provider, registry, "", prompt); +await ToolRegistry.RunToolLoopAsync(provider, registry, prompt); ``` ## Crash-safe agentic sessions @@ -99,7 +99,7 @@ For multi-turn LLM conversations that must survive activity retries, use [Activity] // Remove for standalone use — no worker needed public async Task> LongAnalysisAsync(string prompt) { - var issues = new List(); + var results = new List(); await AgenticSession.RunWithSessionAsync(async session => { @@ -108,7 +108,7 @@ public async Task> LongAnalysisAsync(string prompt) new ToolDefinition("flag", "...", new Dictionary { ["type"] = "object" }), inp => { - session.Issues.Add(inp); + session.Results.Add(inp); return Task.FromResult("ok"); // this string is sent back to the LLM as the tool result }); @@ -116,11 +116,11 @@ public async Task> LongAnalysisAsync(string prompt) new AnthropicConfig { ApiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") }, registry, "your system prompt"); - await session.RunToolLoopAsync(provider, registry, "your system prompt", prompt); - issues.AddRange(session.Issues.Cast()); // capture after loop completes + await session.RunToolLoopAsync(provider, registry, prompt); + results.AddRange(session.Results.Cast()); // capture after loop completes }); - return issues; + return results; } ``` @@ -142,7 +142,7 @@ public async Task TestAnalyze() MockResponse.Done("analysis complete") ).WithRegistry(registry); - var msgs = await ToolRegistry.RunToolLoopAsync(provider, registry, "sys", "analyze"); + var msgs = await ToolRegistry.RunToolLoopAsync(provider, registry, "analyze"); Assert.True(msgs.Count > 2); } ``` @@ -163,25 +163,25 @@ incur billing — expect a few cents per full test run. ## Storing application results -`session.Issues` accumulates application-level +`session.Results` accumulates application-level results during the tool loop. Elements are serialized to JSON inside each heartbeat -checkpoint — they must be plain maps/dicts with JSON-serializable values. A non-serializable -value raises a non-retryable `ApplicationError` at heartbeat time rather than silently -losing data on the next retry. +checkpoint — they must be plain dicts with JSON-serializable values. A non-serializable +value raises a non-retryable `ApplicationFailureException` at heartbeat time rather than +silently losing data on the next retry. ### Storing typed results Convert your domain type to a plain dict at the tool-call site and back after the session: ```csharp -record Issue(string Type, string File); +record Finding(string Type, string File); // Inside tool handler: -session.Issues.Add(new() { ["type"] = "smell", ["file"] = "Foo.cs" }); +session.Results.Add(new() { ["type"] = "smell", ["file"] = "Foo.cs" }); // After session (using System.Text.Json): -var issues = session.Issues - .Select(d => JsonSerializer.Deserialize(JsonSerializer.Serialize(d))!) +var findings = session.Results + .Select(d => JsonSerializer.Deserialize(JsonSerializer.Serialize(d))!) .ToList(); ``` @@ -208,6 +208,20 @@ Recommended timeouts: | Standard (Claude 3.x, GPT-4o) | 30 s | | Reasoning (o1, o3, extended thinking) | 300 s | +### Activity-level timeout + +Set `ScheduleToCloseTimeout` on the activity options to bound the entire conversation: + +```csharp +await workflow.ExecuteActivityAsync( + (MyActivities a) => a.LongAnalysisAsync(prompt), + new ActivityOptions { ScheduleToCloseTimeout = TimeSpan.FromMinutes(10) }); +``` + +The per-turn client timeout and `ScheduleToCloseTimeout` are complementary: +- Per-turn timeout fires if one LLM call hangs (protects against a single stuck turn) +- `ScheduleToCloseTimeout` bounds the entire conversation including all retries (protects against runaway multi-turn loops) + ## MCP integration `ToolRegistry.FromMcpTools` converts a sequence of `McpTool` records into a populated diff --git a/src/Temporalio.Extensions.ToolRegistry/SessionCheckpoint.cs b/src/Temporalio.Extensions.ToolRegistry/SessionCheckpoint.cs index 8ada7f04..76000c0b 100644 --- a/src/Temporalio.Extensions.ToolRegistry/SessionCheckpoint.cs +++ b/src/Temporalio.Extensions.ToolRegistry/SessionCheckpoint.cs @@ -22,9 +22,9 @@ internal sealed class SessionCheckpoint public List>? Messages { get; set; } = new(); /// - /// Gets or sets accumulated application-level issues from tool calls. + /// Gets or sets accumulated application-level results from tool calls. /// - [JsonPropertyName("issues")] - public List>? Issues { get; set; } = new(); + [JsonPropertyName("results")] + public List>? Results { get; set; } = new(); } } diff --git a/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs b/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs index 39bdcc68..2469d79b 100644 --- a/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs +++ b/src/Temporalio.Extensions.ToolRegistry/ToolRegistry.cs @@ -33,18 +33,12 @@ public sealed class ToolRegistry /// /// LLM provider adapter. /// Tool registry. - /// - /// System prompt. For API symmetry with other Temporal SDKs. The system prompt is - /// captured by the provider at construction time; this parameter is not forwarded. Pass - /// the same value you used when constructing the provider, or an empty string. - /// /// Initial user prompt. /// Cancellation token. /// Full message history on completion. public static async Task>> RunToolLoopAsync( IProvider provider, ToolRegistry registry, - string system, string prompt, CancellationToken cancellationToken = default) { diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/AgenticSessionTests.cs b/tests/Temporalio.Extensions.ToolRegistry.Tests/AgenticSessionTests.cs index 02060baa..5adb3bfb 100644 --- a/tests/Temporalio.Extensions.ToolRegistry.Tests/AgenticSessionTests.cs +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/AgenticSessionTests.cs @@ -35,7 +35,6 @@ await env.RunAsync(async () => await session.RunToolLoopAsync( provider, registry, - "sys", "go", ActivityExecutionContext.Current.CancellationToken); }); @@ -62,7 +61,6 @@ await env.RunAsync(async () => await session.RunToolLoopAsync( provider, registry, - "sys", "go", ActivityExecutionContext.Current.CancellationToken); }); @@ -82,7 +80,7 @@ await env.RunAsync(async () => { var session = new AgenticSession(); session.Messages.Add(new() { ["role"] = "user", ["content"] = "hello" }); - session.Issues.Add(new() { ["type"] = "bug" }); + session.Results.Add(new() { ["type"] = "bug" }); session.Checkpoint(); await Task.CompletedTask; }); @@ -92,7 +90,7 @@ await env.RunAsync(async () => var cp = capturedDetail![0] as SessionCheckpoint; Assert.NotNull(cp); Assert.Single(cp!.Messages!); - Assert.Single(cp.Issues!); + Assert.Single(cp.Results!); } [Fact] @@ -109,7 +107,7 @@ await env.RunAsync(async () => { await AgenticSession.RunWithSessionAsync(async session => { - await session.RunToolLoopAsync(provider, registry, "sys", "hello"); + await session.RunToolLoopAsync(provider, registry, "hello"); sessionMessages = new(session.Messages); }); }); @@ -135,7 +133,7 @@ public async Task RunWithSessionAsync_RestoresFromCheckpoint() }, }, }; - var cp = new SessionCheckpoint { Messages = checkpointMessages, Issues = new() }; + var cp = new SessionCheckpoint { Messages = checkpointMessages, Results = new() }; // Serialize checkpoint via DataConverter so it can be seeded into heartbeat details. var payload = await DataConverter.Default.ToPayloadAsync(cp); diff --git a/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs b/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs index 3ea7d152..f5772988 100644 --- a/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs +++ b/tests/Temporalio.Extensions.ToolRegistry.Tests/ToolRegistryTests.cs @@ -127,7 +127,7 @@ public async Task RunToolLoopAsync_SimpleExchange() }).WithRegistry(registry); var messages = await ToolRegistry.RunToolLoopAsync( - provider, registry, "system", "find issues"); + provider, registry, "find issues"); Assert.Single(flagged); Assert.Equal("broken", flagged[0]); @@ -142,7 +142,7 @@ public async Task RunToolLoopAsync_NoCalls_ReturnsDone() var provider = new MockProvider(new[] { MockResponse.Done("no tools needed") }); var messages = await ToolRegistry.RunToolLoopAsync( - provider, registry, "system", "hello"); + provider, registry, "hello"); Assert.True(messages.Count >= 2); Assert.Equal("user", (string)messages[0]["role"]!); @@ -177,7 +177,7 @@ public void Checkpoint_RoundTrip_PreservesToolCallMessages() ["role"] = "assistant", ["tool_calls"] = toolCallsInMemory, }; - var issueInMemory = new Dictionary + var resultInMemory = new Dictionary { ["type"] = "smell", ["file"] = "foo.cs", @@ -185,14 +185,14 @@ public void Checkpoint_RoundTrip_PreservesToolCallMessages() var cp = new SessionCheckpoint { Messages = new() { assistantMsg }, - Issues = new() { issueInMemory }, + Results = new() { resultInMemory }, }; // Simulate Temporal heartbeat round-trip via JSON serialization. var json = JsonSerializer.Serialize(cp); var restored = JsonSerializer.Deserialize(json)!; var messages = JsonElementConverter.MaterializeList(restored.Messages!); - var issues = JsonElementConverter.MaterializeList(restored.Issues!); + var results = JsonElementConverter.MaterializeList(restored.Results!); // Verify assistant message role survived. Assert.Equal("assistant", (string)messages[0]["role"]!); @@ -209,10 +209,10 @@ public void Checkpoint_RoundTrip_PreservesToolCallMessages() var fn = Assert.IsType>(toolCallsList[0]["function"]); Assert.Equal("my_tool", (string)fn["name"]!); - // Issues must survive the round-trip. - Assert.Single(issues); - Assert.Equal("smell", (string)issues[0]["type"]!); - Assert.Equal("foo.cs", (string)issues[0]["file"]!); + // Results must survive the round-trip. + Assert.Single(results); + Assert.Equal("smell", (string)results[0]["type"]!); + Assert.Equal("foo.cs", (string)results[0]["file"]!); } // ── FromMcpTools ───────────────────────────────────────────────────────── @@ -342,7 +342,6 @@ public async Task Integration_Anthropic() await ToolRegistry.RunToolLoopAsync( provider, registry, - string.Empty, "Please call the record tool with value='hello'."); Assert.Contains("hello", collected); @@ -370,7 +369,6 @@ public async Task Integration_OpenAI() await ToolRegistry.RunToolLoopAsync( provider, registry, - string.Empty, "Please call the record tool with value='hello'."); Assert.Contains("hello", collected);