Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,41 @@ def validate_group_ids(group_ids: list[str] | None) -> bool:
return True


def normalize_group_ids(group_ids: str | list[str] | None) -> list[str] | None:
"""
Normalize group_ids parameter to list[str] | None.

Accepts str | list[str] | None and returns list[str] | None.
- Returns None if input is None
- Converts string to single-element list
- Validates each group_id using validate_group_id
- Returns the normalized list

Args:
group_ids: The group_ids to normalize, can be string, list of strings, or None

Returns:
Normalized list of group_ids or None

Raises:
GroupIdValidationError: If any group_id contains invalid characters
"""
if group_ids is None:
return None

# Convert string to single-element list
if isinstance(group_ids, str):
normalized = [group_ids]
else:
normalized = group_ids

# Validate each group_id
for group_id in normalized:
validate_group_id(group_id)

return normalized


def validate_node_labels(node_labels: list[str] | None) -> bool:
"""Validate that node labels are safe to interpolate into Cypher label expressions."""

Expand Down
56 changes: 37 additions & 19 deletions mcp_server/src/graphiti_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.edges import EntityEdge
from graphiti_core.helpers import normalize_group_ids
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
Expand Down Expand Up @@ -407,15 +408,16 @@ async def add_memory(
@mcp.tool()
async def search_nodes(
query: str,
group_ids: list[str] | None = None,
group_ids: str | list[str] | None = None,
max_nodes: int = 10,
entity_types: list[str] | None = None,
) -> NodeSearchResponse | ErrorResponse:
"""Search for nodes in the graph memory.

Args:
query: The search query
group_ids: Optional list of group IDs to filter results
group_ids: Optional group ID or list of group IDs to filter results.
Can be a string (single group ID) or list of strings.
max_nodes: Maximum number of nodes to return (default: 10)
entity_types: Optional list of entity type names to filter by
"""
Expand All @@ -427,10 +429,13 @@ async def search_nodes(
try:
client = await graphiti_service.get_client()

# Use the provided group_ids or fall back to the default from config if none provided
# Normalize group_ids parameter (handles str, list[str], or None)
normalized_group_ids = normalize_group_ids(group_ids)

# Use the normalized group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
normalized_group_ids
if normalized_group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
Expand Down Expand Up @@ -487,15 +492,16 @@ async def search_nodes(
@mcp.tool()
async def search_memory_facts(
query: str,
group_ids: list[str] | None = None,
group_ids: str | list[str] | None = None,
max_facts: int = 10,
center_node_uuid: str | None = None,
) -> FactSearchResponse | ErrorResponse:
"""Search the graph memory for relevant facts.

Args:
query: The search query
group_ids: Optional list of group IDs to filter results
group_ids: Optional group ID or list of group IDs to filter results.
Can be a string (single group ID) or list of strings.
max_facts: Maximum number of facts to return (default: 10)
center_node_uuid: Optional UUID of a node to center the search around
"""
Expand All @@ -511,10 +517,13 @@ async def search_memory_facts(

client = await graphiti_service.get_client()

# Use the provided group_ids or fall back to the default from config if none provided
# Normalize group_ids parameter (handles str, list[str], or None)
normalized_group_ids = normalize_group_ids(group_ids)

# Use the normalized group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
normalized_group_ids
if normalized_group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
Expand Down Expand Up @@ -619,13 +628,14 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:

@mcp.tool()
async def get_episodes(
group_ids: list[str] | None = None,
group_ids: str | list[str] | None = None,
max_episodes: int = 10,
) -> EpisodeSearchResponse | ErrorResponse:
"""Get episodes from the graph memory.

Args:
group_ids: Optional list of group IDs to filter results
group_ids: Optional group ID or list of group IDs to filter results.
Can be a string (single group ID) or list of strings.
max_episodes: Maximum number of episodes to return (default: 10)
"""
global graphiti_service
Expand All @@ -636,10 +646,13 @@ async def get_episodes(
try:
client = await graphiti_service.get_client()

# Use the provided group_ids or fall back to the default from config if none provided
# Normalize group_ids parameter (handles str, list[str], or None)
normalized_group_ids = normalize_group_ids(group_ids)

# Use the normalized group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
normalized_group_ids
if normalized_group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
Expand Down Expand Up @@ -686,11 +699,13 @@ async def get_episodes(


@mcp.tool()
async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse:
async def clear_graph(group_ids: str | list[str] | None = None) -> SuccessResponse | ErrorResponse:
"""Clear all data from the graph for specified group IDs.

Args:
group_ids: Optional list of group IDs to clear. If not provided, clears the default group.
group_ids: Optional group ID or list of group IDs to clear.
Can be a string (single group ID) or list of strings.
If not provided, clears the default group.
"""
global graphiti_service

Expand All @@ -700,9 +715,12 @@ async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | E
try:
client = await graphiti_service.get_client()

# Use the provided group_ids or fall back to the default from config if none provided
# Normalize group_ids parameter (handles str, list[str], or None)
normalized_group_ids = normalize_group_ids(group_ids)

# Use the normalized group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids or [config.graphiti.group_id] if config.graphiti.group_id else []
normalized_group_ids or [config.graphiti.group_id] if config.graphiti.group_id else []
)

if not effective_group_ids:
Expand Down
Loading
Loading