Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,14 +1417,24 @@ async def add_episode_bulk(

@handle_multiple_group_ids
async def build_communities(
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
self,
group_ids: list[str] | None = None,
driver: GraphDriver | None = None,
sample_size: int | None = None,
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
"""
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
the content of these communities.
----------
group_ids : list[str] | None
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
sample_size : int | None
Optional. If set, each community's LLM summary is built from only
the top-K most representative members (highest in-community
weighted degree, then longest summary). Dramatically reduces LLM
cost on large graphs — without sampling, summary cost grows with
total node count; with sampling it grows with the number of
communities. Recommended for graphs >10k nodes.
"""
if driver is None:
driver = self.clients.driver
Expand All @@ -1433,7 +1443,7 @@ async def build_communities(
await remove_communities(driver)

community_nodes, community_edges = await build_communities(
driver, self.llm_client, group_ids
driver, self.llm_client, group_ids, sample_size=sample_size
)

await semaphore_gather(
Expand Down
196 changes: 161 additions & 35 deletions graphiti_core/utils/maintenance/community_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,75 @@ class Neighbor(BaseModel):
edge_count: int


async def _build_group_projection(
driver: GraphDriver, group_id: str
) -> dict[str, list[Neighbor]]:
"""Fetch the RELATES_TO projection for all entities in a group.

Returns a mapping from each node's uuid to its list of in-group neighbors
with edge counts. Used by label propagation and by in-community degree
computations for sampling.
"""
projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes:
match_query = """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WITH count(e) AS count, m.uuid AS uuid
RETURN
uuid,
count
""",
uuid=node.uuid,
group_id=group_id,
)

projection[node.uuid] = [
Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
]
return projection


async def get_community_clusters(
driver: GraphDriver, group_ids: list[str] | None
) -> list[list[EntityNode]]:
driver: GraphDriver,
group_ids: list[str] | None,
return_projection: bool = False,
) -> list[list[EntityNode]] | tuple[list[list[EntityNode]], dict[str, list[Neighbor]]]:
"""Compute community clusters via label propagation.

Args:
driver: Graph driver.
group_ids: Optional list of group ids to scope clustering. If None,
all groups are used.
return_projection: When True, also return the combined projection
(uuid → neighbors with edge counts) so callers can compute
in-community degrees without a second pass over the graph.

Returns:
By default, just the list of clusters (each a list of EntityNode).
When return_projection=True, returns (clusters, projection) tuple.
"""
if driver.graph_operations_interface:
try:
return await driver.graph_operations_interface.get_community_clusters(driver, group_ids)
clusters = await driver.graph_operations_interface.get_community_clusters(
driver, group_ids
)
if return_projection:
return clusters, {}
return clusters
except NotImplementedError:
pass

community_clusters: list[list[EntityNode]] = []
combined_projection: dict[str, list[Neighbor]] = {}

if group_ids is None:
group_id_values, _, _ = await driver.execute_query(
Expand All @@ -51,31 +110,9 @@ async def get_community_clusters(
group_ids = group_id_values[0]['group_ids'] if group_id_values else []

for group_id in group_ids:
projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes:
match_query = """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WITH count(e) AS count, m.uuid AS uuid
RETURN
uuid,
count
""",
uuid=node.uuid,
group_id=group_id,
)

projection[node.uuid] = [
Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
]
projection = await _build_group_projection(driver, group_id)
if return_projection:
combined_projection.update(projection)

cluster_uuids = label_propagation(projection)

Expand All @@ -87,6 +124,8 @@ async def get_community_clusters(
)
)

if return_projection:
return community_clusters, combined_projection
return community_clusters


Expand Down Expand Up @@ -171,10 +210,68 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
return description


def _select_representative_members(
community_cluster: list[EntityNode],
projection: dict[str, list[Neighbor]] | None,
sample_size: int,
) -> list[EntityNode]:
"""Pick the top-K members most likely to characterize the community.

Scoring key (descending): in-community weighted degree, then summary
length, then name for deterministic ties. In-community degree uses the
projection we already computed during clustering — no extra queries.

When no projection is available (e.g. the graph_operations_interface
returned clusters directly), falls back to summary length only.
"""
if len(community_cluster) <= sample_size:
return community_cluster

member_uuids = {m.uuid for m in community_cluster}

def in_community_degree(entity: EntityNode) -> int:
if not projection:
return 0
neighbors = projection.get(entity.uuid, [])
return sum(n.edge_count for n in neighbors if n.node_uuid in member_uuids)

scored = sorted(
community_cluster,
key=lambda e: (in_community_degree(e), len(e.summary or ''), e.name),
reverse=True,
)
return scored[:sample_size]


async def build_community(
llm_client: LLMClient, community_cluster: list[EntityNode]
llm_client: LLMClient,
community_cluster: list[EntityNode],
*,
projection: dict[str, list[Neighbor]] | None = None,
sample_size: int | None = None,
) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster]
"""Build a community node from its member entities.

Args:
llm_client: LLM used to summarize pairs and generate the final name.
community_cluster: Full list of member entities.
projection: Optional {uuid -> neighbors} projection from the clustering
step. Used to rank members by in-community weighted degree when
sampling.
sample_size: If set, only the top-K most representative members
participate in the binary summary merge. The community still
contains all members in its HAS_MEMBER edges — sampling only
affects which summaries are fed into the LLM pipeline. This cuts
LLM cost from O(N) per community to O(sample_size) and typically
improves quality because hub nodes carry the community's signal.
"""
summary_members = (
_select_representative_members(community_cluster, projection, sample_size)
if sample_size is not None
else community_cluster
)

summaries = [entity.summary for entity in summary_members]
length = len(summaries)
while length > 1:
odd_one_out: str | None = None
Expand All @@ -196,8 +293,10 @@ async def build_community(
summaries = new_summaries
length = len(summaries)

summary = truncate_at_sentence(summaries[0], MAX_SUMMARY_CHARS)
name = await generate_summary_description(llm_client, summary)
summary = truncate_at_sentence(summaries[0], MAX_SUMMARY_CHARS) if summaries else ''
name = (
await generate_summary_description(llm_client, summary) if summary else 'community'
)
now = utc_now()
community_node = CommunityNode(
name=name,
Expand All @@ -208,7 +307,13 @@ async def build_community(
)
community_edges = build_community_edges(community_cluster, community_node, now)

logger.debug(f'Built community {community_node.uuid} with {len(community_edges)} edges')
logger.debug(
'Built community %s with %d member edges (summary from %d/%d members)',
community_node.uuid,
len(community_edges),
len(summary_members),
len(community_cluster),
)

return community_node, community_edges

Expand All @@ -217,14 +322,35 @@ async def build_communities(
driver: GraphDriver,
llm_client: LLMClient,
group_ids: list[str] | None,
*,
sample_size: int | None = None,
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community_clusters = await get_community_clusters(driver, group_ids)
"""Cluster entities into communities and build a summary node for each.

Args:
driver: Graph driver.
llm_client: LLM client for community summarization.
group_ids: Scope clustering to these group ids (or all if None).
sample_size: If set, each community's summary is built from only
the top-K most representative members (by in-community weighted
degree, then summary length). Reduces LLM cost from O(total nodes)
to O(num_communities * sample_size). Recommended for graphs
>10k nodes.
"""
clusters_result = await get_community_clusters(driver, group_ids, return_projection=True)
assert isinstance(clusters_result, tuple)
community_clusters, projection = clusters_result

semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)

async def limited_build_community(cluster):
async with semaphore:
return await build_community(llm_client, cluster)
return await build_community(
llm_client,
cluster,
projection=projection,
sample_size=sample_size,
)

communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
await semaphore_gather(
Expand Down
Loading
Loading