diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 19fdf2bcb..e33c8e7c0 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -171,6 +171,41 @@ def validate_group_ids(group_ids: list[str] | None) -> bool: return True +def normalize_group_ids(group_ids: str | list[str] | None) -> list[str] | None: + """ + Normalize group_ids parameter to list[str] | None. + + Accepts str | list[str] | None and returns list[str] | None. + - Returns None if input is None + - Converts string to single-element list + - Validates each group_id using validate_group_id + - Returns the normalized list + + Args: + group_ids: The group_ids to normalize, can be string, list of strings, or None + + Returns: + Normalized list of group_ids or None + + Raises: + GroupIdValidationError: If any group_id contains invalid characters + """ + if group_ids is None: + return None + + # Convert string to single-element list + if isinstance(group_ids, str): + normalized = [group_ids] + else: + normalized = group_ids + + # Validate each group_id + for group_id in normalized: + validate_group_id(group_id) + + return normalized + + def validate_node_labels(node_labels: list[str] | None) -> bool: """Validate that node labels are safe to interpolate into Cypher label expressions.""" diff --git a/mcp_server/src/graphiti_mcp_server.py b/mcp_server/src/graphiti_mcp_server.py index 833bc5d93..f6b23afaf 100644 --- a/mcp_server/src/graphiti_mcp_server.py +++ b/mcp_server/src/graphiti_mcp_server.py @@ -14,6 +14,7 @@ from dotenv import load_dotenv from graphiti_core import Graphiti from graphiti_core.edges import EntityEdge +from graphiti_core.helpers import normalize_group_ids from graphiti_core.nodes import EpisodeType, EpisodicNode from graphiti_core.search.search_filters import SearchFilters from graphiti_core.utils.maintenance.graph_data_operations import clear_data @@ -407,7 +408,7 @@ async def add_memory( @mcp.tool() async def search_nodes( query: str, - group_ids: list[str] | None = None, + group_ids: str | list[str] | None = None, max_nodes: int = 10, entity_types: list[str] | None = None, ) -> NodeSearchResponse | ErrorResponse: @@ -415,7 +416,8 @@ async def search_nodes( Args: query: The search query - group_ids: Optional list of group IDs to filter results + group_ids: Optional group ID or list of group IDs to filter results. + Can be a string (single group ID) or list of strings. max_nodes: Maximum number of nodes to return (default: 10) entity_types: Optional list of entity type names to filter by """ @@ -427,10 +429,13 @@ async def search_nodes( try: client = await graphiti_service.get_client() - # Use the provided group_ids or fall back to the default from config if none provided + # Normalize group_ids parameter (handles str, list[str], or None) + normalized_group_ids = normalize_group_ids(group_ids) + + # Use the normalized group_ids or fall back to the default from config if none provided effective_group_ids = ( - group_ids - if group_ids is not None + normalized_group_ids + if normalized_group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] @@ -487,7 +492,7 @@ async def search_nodes( @mcp.tool() async def search_memory_facts( query: str, - group_ids: list[str] | None = None, + group_ids: str | list[str] | None = None, max_facts: int = 10, center_node_uuid: str | None = None, ) -> FactSearchResponse | ErrorResponse: @@ -495,7 +500,8 @@ async def search_memory_facts( Args: query: The search query - group_ids: Optional list of group IDs to filter results + group_ids: Optional group ID or list of group IDs to filter results. + Can be a string (single group ID) or list of strings. max_facts: Maximum number of facts to return (default: 10) center_node_uuid: Optional UUID of a node to center the search around """ @@ -511,10 +517,13 @@ async def search_memory_facts( client = await graphiti_service.get_client() - # Use the provided group_ids or fall back to the default from config if none provided + # Normalize group_ids parameter (handles str, list[str], or None) + normalized_group_ids = normalize_group_ids(group_ids) + + # Use the normalized group_ids or fall back to the default from config if none provided effective_group_ids = ( - group_ids - if group_ids is not None + normalized_group_ids + if normalized_group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] @@ -619,13 +628,14 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse: @mcp.tool() async def get_episodes( - group_ids: list[str] | None = None, + group_ids: str | list[str] | None = None, max_episodes: int = 10, ) -> EpisodeSearchResponse | ErrorResponse: """Get episodes from the graph memory. Args: - group_ids: Optional list of group IDs to filter results + group_ids: Optional group ID or list of group IDs to filter results. + Can be a string (single group ID) or list of strings. max_episodes: Maximum number of episodes to return (default: 10) """ global graphiti_service @@ -636,10 +646,13 @@ async def get_episodes( try: client = await graphiti_service.get_client() - # Use the provided group_ids or fall back to the default from config if none provided + # Normalize group_ids parameter (handles str, list[str], or None) + normalized_group_ids = normalize_group_ids(group_ids) + + # Use the normalized group_ids or fall back to the default from config if none provided effective_group_ids = ( - group_ids - if group_ids is not None + normalized_group_ids + if normalized_group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] @@ -686,11 +699,13 @@ async def get_episodes( @mcp.tool() -async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse: +async def clear_graph(group_ids: str | list[str] | None = None) -> SuccessResponse | ErrorResponse: """Clear all data from the graph for specified group IDs. Args: - group_ids: Optional list of group IDs to clear. If not provided, clears the default group. + group_ids: Optional group ID or list of group IDs to clear. + Can be a string (single group ID) or list of strings. + If not provided, clears the default group. """ global graphiti_service @@ -700,9 +715,12 @@ async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | E try: client = await graphiti_service.get_client() - # Use the provided group_ids or fall back to the default from config if none provided + # Normalize group_ids parameter (handles str, list[str], or None) + normalized_group_ids = normalize_group_ids(group_ids) + + # Use the normalized group_ids or fall back to the default from config if none provided effective_group_ids = ( - group_ids or [config.graphiti.group_id] if config.graphiti.group_id else [] + normalized_group_ids or [config.graphiti.group_id] if config.graphiti.group_id else [] ) if not effective_group_ids: diff --git a/mcp_server/tests/test_comprehensive_integration.py b/mcp_server/tests/test_comprehensive_integration.py index da1ed7e67..e95f6ffaf 100644 --- a/mcp_server/tests/test_comprehensive_integration.py +++ b/mcp_server/tests/test_comprehensive_integration.py @@ -340,6 +340,80 @@ async def test_hybrid_search(self): assert metric.success + @pytest.mark.asyncio + async def test_scalar_group_ids_acceptance(self): + """Test that search tools accept both scalar string and list inputs for group_ids.""" + async with GraphitiTestClient() as client: + # First add some test data + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Test Data for Group IDs', + 'episode_body': 'This is test data for verifying scalar group_ids acceptance.', + 'source': 'text', + 'source_description': 'test', + 'group_id': client.test_group_id, + }, + ) + + # Wait for processing + await client.wait_for_episode_processing() + + # Test 1: search_memory_facts with scalar string group_ids + result1, metric1 = await client.call_tool_with_metrics( + 'search_memory_facts', + { + 'query': 'test data', + 'group_ids': client.test_group_id, # Scalar string + 'max_facts': 5, + }, + ) + assert metric1.success, f"search_memory_facts failed with scalar group_ids: {metric1.details}" + + # Test 2: search_memory_facts with list group_ids + result2, metric2 = await client.call_tool_with_metrics( + 'search_memory_facts', + { + 'query': 'test data', + 'group_ids': [client.test_group_id], # List + 'max_facts': 5, + }, + ) + assert metric2.success, f"search_memory_facts failed with list group_ids: {metric2.details}" + + # Test 3: search_memory_facts with None group_ids + result3, metric3 = await client.call_tool_with_metrics( + 'search_memory_facts', + { + 'query': 'test data', + 'group_ids': None, # None + 'max_facts': 5, + }, + ) + assert metric3.success, f"search_memory_facts failed with None group_ids: {metric3.details}" + + # Test 4: search_memory_nodes with scalar string group_ids + result4, metric4 = await client.call_tool_with_metrics( + 'search_memory_nodes', + { + 'query': 'test data', + 'group_ids': client.test_group_id, # Scalar string + 'max_nodes': 5, + }, + ) + assert metric4.success, f"search_memory_nodes failed with scalar group_ids: {metric4.details}" + + # Test 5: search_memory_nodes with list group_ids + result5, metric5 = await client.call_tool_with_metrics( + 'search_memory_nodes', + { + 'query': 'test data', + 'group_ids': [client.test_group_id], # List + 'max_nodes': 5, + }, + ) + assert metric5.success, f"search_memory_nodes failed with list group_ids: {metric5.details}" + class TestEpisodeManagement: """Test episode lifecycle operations.""" @@ -372,6 +446,57 @@ async def test_get_episodes_pagination(self): episodes = json.loads(result) if isinstance(result, str) else result assert len(episodes.get('episodes', [])) <= 3 + @pytest.mark.asyncio + async def test_get_episodes_scalar_group_ids(self): + """Test that get_episodes accepts both scalar string and list inputs for group_ids.""" + async with GraphitiTestClient() as client: + # Add test episodes + for i in range(3): + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': f'Episode for Group IDs Test {i}', + 'episode_body': f'This is episode {i} for testing scalar group_ids.', + 'source': 'text', + 'source_description': 'test', + 'group_id': client.test_group_id, + }, + ) + + await client.wait_for_episode_processing(expected_count=3) + + # Test 1: get_episodes with scalar string group_ids + result1, metric1 = await client.call_tool_with_metrics( + 'get_episodes', + { + 'group_ids': client.test_group_id, # Scalar string + 'max_episodes': 5, + }, + ) + assert metric1.success, f"get_episodes failed with scalar group_ids: {metric1.details}" + + # Test 2: get_episodes with list group_ids + result2, metric2 = await client.call_tool_with_metrics( + 'get_episodes', + { + 'group_ids': [client.test_group_id], # List + 'max_episodes': 5, + }, + ) + assert metric2.success, f"get_episodes failed with list group_ids: {metric2.details}" + + # Test 3: get_episodes with None group_ids + result3, metric3 = await client.call_tool_with_metrics( + 'get_episodes', + { + 'group_ids': None, # None + 'max_episodes': 5, + }, + ) + # Note: get_episodes with None group_ids might return empty results or use default + # We just verify it doesn't crash + assert metric3.success, f"get_episodes failed with None group_ids: {metric3.details}" + @pytest.mark.asyncio async def test_delete_episode(self): """Test deleting specific episodes.""" @@ -437,6 +562,68 @@ async def test_get_entity_edge(self): # Note: This test assumes edges are created between entities # Actual edge retrieval would require valid edge UUIDs + @pytest.mark.asyncio + async def test_clear_graph_scalar_group_ids(self): + """Test that clear_graph accepts both scalar string and list inputs for group_ids.""" + async with GraphitiTestClient() as client: + # First add some test data to clear + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Data to be cleared', + 'episode_body': 'This data will be cleared by the clear_graph test.', + 'source': 'text', + 'source_description': 'test', + 'group_id': client.test_group_id, + }, + ) + + await client.wait_for_episode_processing() + + # Test 1: clear_graph with scalar string group_ids + result1, metric1 = await client.call_tool_with_metrics( + 'clear_graph', + { + 'group_ids': client.test_group_id, # Scalar string + }, + ) + assert metric1.success, f"clear_graph failed with scalar group_ids: {metric1.details}" + + # Test 2: clear_graph with list group_ids + # First add more data since we just cleared it + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'More data to be cleared', + 'episode_body': 'This is more data for the list group_ids test.', + 'source': 'text', + 'source_description': 'test', + 'group_id': client.test_group_id, + }, + ) + + await client.wait_for_episode_processing() + + result2, metric2 = await client.call_tool_with_metrics( + 'clear_graph', + { + 'group_ids': [client.test_group_id], # List + }, + ) + assert metric2.success, f"clear_graph failed with list group_ids: {metric2.details}" + + # Test 3: clear_graph with None group_ids + # This should clear the default group or return an error + result3, metric3 = await client.call_tool_with_metrics( + 'clear_graph', + { + 'group_ids': None, # None + }, + ) + # Note: clear_graph with None group_ids might clear default group or return error + # We just verify it doesn't crash + assert metric3.success, f"clear_graph failed with None group_ids: {metric3.details}" + @pytest.mark.asyncio async def test_delete_entity_edge(self): """Test deleting entity edges.""" @@ -460,6 +647,58 @@ async def test_invalid_tool_arguments(self): assert not metric.success assert 'error' in str(metric.details).lower() + @pytest.mark.asyncio + async def test_invalid_group_ids_validation(self): + """Test that invalid group_ids still raise appropriate validation errors.""" + async with GraphitiTestClient() as client: + # Test invalid group_ids with special characters (scalar) + result1, metric1 = await client.call_tool_with_metrics( + 'search_memory_facts', + { + 'query': 'test', + 'group_ids': 'invalid@group!id', # Invalid scalar with special chars + 'max_facts': 5, + }, + ) + # This should fail with validation error + assert not metric1.success, "search_memory_facts should fail with invalid scalar group_ids" + assert 'invalid' in str(metric1.details).lower() or 'validation' in str(metric1.details).lower() + + # Test invalid group_ids with special characters (in list) + result2, metric2 = await client.call_tool_with_metrics( + 'search_memory_nodes', + { + 'query': 'test', + 'group_ids': ['valid_group', 'invalid@group!id'], # List with one invalid + 'max_nodes': 5, + }, + ) + # This should also fail with validation error + assert not metric2.success, "search_memory_nodes should fail with invalid group_ids in list" + assert 'invalid' in str(metric2.details).lower() or 'validation' in str(metric2.details).lower() + + # Test empty string group_ids (scalar) - should be valid based on validate_group_id + result3, metric3 = await client.call_tool_with_metrics( + 'get_episodes', + { + 'group_ids': '', # Empty string - should be valid + 'max_episodes': 5, + }, + ) + # Empty string should be valid (validate_group_id returns True for empty string) + # We just verify it doesn't crash with validation error + + # Test invalid group_ids with clear_graph + result4, metric4 = await client.call_tool_with_metrics( + 'clear_graph', + { + 'group_ids': 'group#with#hash', # Invalid scalar + }, + ) + # This should fail with validation error + assert not metric4.success, "clear_graph should fail with invalid scalar group_ids" + assert 'invalid' in str(metric4.details).lower() or 'validation' in str(metric4.details).lower() + @pytest.mark.asyncio async def test_timeout_handling(self): """Test timeout handling for long operations.""" diff --git a/server/graph_service/dto/retrieve.py b/server/graph_service/dto/retrieve.py index b75c48c9f..f5bb15fb5 100644 --- a/server/graph_service/dto/retrieve.py +++ b/server/graph_service/dto/retrieve.py @@ -1,16 +1,36 @@ from datetime import datetime, timezone -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from graph_service.dto.common import Message class SearchQuery(BaseModel): - group_ids: list[str] | None = Field( - None, description='The group ids for the memories to search' + group_ids: str | list[str] | None = Field( + None, description='The group id or list of group ids for the memories to search. Can be a string (single group ID) or list of strings.' ) query: str max_facts: int = Field(default=10, description='The maximum number of facts to retrieve') + + @field_validator('group_ids') + @classmethod + def normalize_group_ids(cls, v: str | list[str] | None) -> list[str] | None: + """Normalize group_ids parameter to list[str] | None. + + Accepts str | list[str] | None and returns list[str] | None. + - Returns None if input is None + - Converts string to single-element list + - Returns list as-is + """ + if v is None: + return None + + # Convert string to single-element list + if isinstance(v, str): + return [v] + + # Already a list, return as-is + return v class FactResult(BaseModel):