diff --git a/memory/token_buffer.go b/memory/token_buffer.go index 81592ea8d..22229a2bc 100644 --- a/memory/token_buffer.go +++ b/memory/token_buffer.go @@ -90,6 +90,11 @@ func (tb *ConversationTokenBuffer) Clear(ctx context.Context) error { return tb.ConversationBuffer.Clear(ctx) } +// GetMemoryKey uses ConversationBuffer method for getting the memory key. +func (tb *ConversationTokenBuffer) GetMemoryKey(ctx context.Context) string { + return tb.ConversationBuffer.GetMemoryKey(ctx) +} + func (tb *ConversationTokenBuffer) getNumTokensFromMessages(ctx context.Context) (int, error) { messages, err := tb.ChatHistory.Messages(ctx) if err != nil { diff --git a/memory/token_buffer_test.go b/memory/token_buffer_test.go index d6bcc86ce..bd0682e56 100644 --- a/memory/token_buffer_test.go +++ b/memory/token_buffer_test.go @@ -108,3 +108,18 @@ func TestTokenBufferMemoryWithPreLoadedHistory(t *testing.T) { expected := map[string]any{"history": "Human: bar\nAI: foo"} assert.Equal(t, expected, result) } + +func TestTokenBufferMemoryGetMemoryKey(t *testing.T) { + t.Parallel() + ctx := context.Background() + + llm := newTestOpenAIClient(t) + + // Test with default memory key + m1 := NewConversationTokenBuffer(llm, 2000) + assert.Equal(t, "history", m1.GetMemoryKey(ctx)) + + // Test with custom memory key + m2 := NewConversationTokenBuffer(llm, 2000, WithMemoryKey("custom_key")) + assert.Equal(t, "custom_key", m2.GetMemoryKey(ctx)) +} diff --git a/memory/window_buffer.go b/memory/window_buffer.go index fe67baa35..7a2022d2a 100644 --- a/memory/window_buffer.go +++ b/memory/window_buffer.go @@ -100,3 +100,8 @@ func (wb *ConversationWindowBuffer) cutMessages(message []llms.ChatMessage) ([]l func (wb *ConversationWindowBuffer) Clear(ctx context.Context) error { return wb.ConversationBuffer.Clear(ctx) } + +// GetMemoryKey uses ConversationBuffer method for getting the memory key. +func (wb *ConversationWindowBuffer) GetMemoryKey(ctx context.Context) string { + return wb.ConversationBuffer.GetMemoryKey(ctx) +} diff --git a/memory/window_buffer_test.go b/memory/window_buffer_test.go index 07c2b7db0..7b6bb418a 100644 --- a/memory/window_buffer_test.go +++ b/memory/window_buffer_test.go @@ -101,6 +101,19 @@ func TestWindowBufferMemoryWithPreLoadedHistory(t *testing.T) { assert.Equal(t, expected, result) } +func TestWindowBufferMemoryGetMemoryKey(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Test with default memory key + m1 := NewConversationWindowBuffer(2) + assert.Equal(t, "history", m1.GetMemoryKey(ctx)) + + // Test with custom memory key + m2 := NewConversationWindowBuffer(2, WithMemoryKey("custom_key")) + assert.Equal(t, "custom_key", m2.GetMemoryKey(ctx)) +} + func TestConversationWindowBuffer_cutMessages(t *testing.T) { t.Parallel() type fields struct {