diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 8c4943e89..975c46f42 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -1417,7 +1417,10 @@ async def add_episode_bulk( @handle_multiple_group_ids async def build_communities( - self, group_ids: list[str] | None = None, driver: GraphDriver | None = None + self, + group_ids: list[str] | None = None, + driver: GraphDriver | None = None, + sample_size: int | None = None, ) -> tuple[list[CommunityNode], list[CommunityEdge]]: """ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising @@ -1425,6 +1428,13 @@ async def build_communities( ---------- group_ids : list[str] | None Optional. Create communities only for the listed group_ids. If blank the entire graph will be used. + sample_size : int | None + Optional. If set, each community's LLM summary is built from only + the top-K most representative members (highest in-community + weighted degree, then longest summary). Dramatically reduces LLM + cost on large graphs — without sampling, summary cost grows with + total node count; with sampling it grows with the number of + communities. Recommended for graphs >10k nodes. """ if driver is None: driver = self.clients.driver @@ -1433,7 +1443,7 @@ async def build_communities( await remove_communities(driver) community_nodes, community_edges = await build_communities( - driver, self.llm_client, group_ids + driver, self.llm_client, group_ids, sample_size=sample_size ) await semaphore_gather( diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index 8c96bd79f..2db89356b 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -27,16 +27,75 @@ class Neighbor(BaseModel): edge_count: int +async def _build_group_projection( + driver: GraphDriver, group_id: str +) -> dict[str, list[Neighbor]]: + """Fetch the RELATES_TO projection for all entities in a group. + + Returns a mapping from each node's uuid to its list of in-group neighbors + with edge counts. Used by label propagation and by in-community degree + computations for sampling. + """ + projection: dict[str, list[Neighbor]] = {} + nodes = await EntityNode.get_by_group_ids(driver, [group_id]) + for node in nodes: + match_query = """ + MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id}) + """ + if driver.provider == GraphProvider.KUZU: + match_query = """ + MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id}) + """ + records, _, _ = await driver.execute_query( + match_query + + """ + WITH count(e) AS count, m.uuid AS uuid + RETURN + uuid, + count + """, + uuid=node.uuid, + group_id=group_id, + ) + + projection[node.uuid] = [ + Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records + ] + return projection + + async def get_community_clusters( - driver: GraphDriver, group_ids: list[str] | None -) -> list[list[EntityNode]]: + driver: GraphDriver, + group_ids: list[str] | None, + return_projection: bool = False, +) -> list[list[EntityNode]] | tuple[list[list[EntityNode]], dict[str, list[Neighbor]]]: + """Compute community clusters via label propagation. + + Args: + driver: Graph driver. + group_ids: Optional list of group ids to scope clustering. If None, + all groups are used. + return_projection: When True, also return the combined projection + (uuid → neighbors with edge counts) so callers can compute + in-community degrees without a second pass over the graph. + + Returns: + By default, just the list of clusters (each a list of EntityNode). + When return_projection=True, returns (clusters, projection) tuple. + """ if driver.graph_operations_interface: try: - return await driver.graph_operations_interface.get_community_clusters(driver, group_ids) + clusters = await driver.graph_operations_interface.get_community_clusters( + driver, group_ids + ) + if return_projection: + return clusters, {} + return clusters except NotImplementedError: pass community_clusters: list[list[EntityNode]] = [] + combined_projection: dict[str, list[Neighbor]] = {} if group_ids is None: group_id_values, _, _ = await driver.execute_query( @@ -51,31 +110,9 @@ async def get_community_clusters( group_ids = group_id_values[0]['group_ids'] if group_id_values else [] for group_id in group_ids: - projection: dict[str, list[Neighbor]] = {} - nodes = await EntityNode.get_by_group_ids(driver, [group_id]) - for node in nodes: - match_query = """ - MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id}) - """ - if driver.provider == GraphProvider.KUZU: - match_query = """ - MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id}) - """ - records, _, _ = await driver.execute_query( - match_query - + """ - WITH count(e) AS count, m.uuid AS uuid - RETURN - uuid, - count - """, - uuid=node.uuid, - group_id=group_id, - ) - - projection[node.uuid] = [ - Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records - ] + projection = await _build_group_projection(driver, group_id) + if return_projection: + combined_projection.update(projection) cluster_uuids = label_propagation(projection) @@ -87,6 +124,8 @@ async def get_community_clusters( ) ) + if return_projection: + return community_clusters, combined_projection return community_clusters @@ -171,10 +210,68 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s return description +def _select_representative_members( + community_cluster: list[EntityNode], + projection: dict[str, list[Neighbor]] | None, + sample_size: int, +) -> list[EntityNode]: + """Pick the top-K members most likely to characterize the community. + + Scoring key (descending): in-community weighted degree, then summary + length, then name for deterministic ties. In-community degree uses the + projection we already computed during clustering — no extra queries. + + When no projection is available (e.g. the graph_operations_interface + returned clusters directly), falls back to summary length only. + """ + if len(community_cluster) <= sample_size: + return community_cluster + + member_uuids = {m.uuid for m in community_cluster} + + def in_community_degree(entity: EntityNode) -> int: + if not projection: + return 0 + neighbors = projection.get(entity.uuid, []) + return sum(n.edge_count for n in neighbors if n.node_uuid in member_uuids) + + scored = sorted( + community_cluster, + key=lambda e: (in_community_degree(e), len(e.summary or ''), e.name), + reverse=True, + ) + return scored[:sample_size] + + async def build_community( - llm_client: LLMClient, community_cluster: list[EntityNode] + llm_client: LLMClient, + community_cluster: list[EntityNode], + *, + projection: dict[str, list[Neighbor]] | None = None, + sample_size: int | None = None, ) -> tuple[CommunityNode, list[CommunityEdge]]: - summaries = [entity.summary for entity in community_cluster] + """Build a community node from its member entities. + + Args: + llm_client: LLM used to summarize pairs and generate the final name. + community_cluster: Full list of member entities. + projection: Optional {uuid -> neighbors} projection from the clustering + step. Used to rank members by in-community weighted degree when + sampling. + sample_size: If set, only the top-K most representative members + participate in the binary summary merge. The community still + contains all members in its HAS_MEMBER edges — sampling only + affects which summaries are fed into the LLM pipeline. This cuts + LLM cost from O(N) per community to O(sample_size) and typically + improves quality because hub nodes carry the community's signal. + """ + summary_members = ( + _select_representative_members(community_cluster, projection, sample_size) + if sample_size is not None + else community_cluster + ) + + summaries = [entity.summary for entity in summary_members] length = len(summaries) while length > 1: odd_one_out: str | None = None @@ -196,8 +293,10 @@ async def build_community( summaries = new_summaries length = len(summaries) - summary = truncate_at_sentence(summaries[0], MAX_SUMMARY_CHARS) - name = await generate_summary_description(llm_client, summary) + summary = truncate_at_sentence(summaries[0], MAX_SUMMARY_CHARS) if summaries else '' + name = ( + await generate_summary_description(llm_client, summary) if summary else 'community' + ) now = utc_now() community_node = CommunityNode( name=name, @@ -208,7 +307,13 @@ async def build_community( ) community_edges = build_community_edges(community_cluster, community_node, now) - logger.debug(f'Built community {community_node.uuid} with {len(community_edges)} edges') + logger.debug( + 'Built community %s with %d member edges (summary from %d/%d members)', + community_node.uuid, + len(community_edges), + len(summary_members), + len(community_cluster), + ) return community_node, community_edges @@ -217,14 +322,35 @@ async def build_communities( driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None, + *, + sample_size: int | None = None, ) -> tuple[list[CommunityNode], list[CommunityEdge]]: - community_clusters = await get_community_clusters(driver, group_ids) + """Cluster entities into communities and build a summary node for each. + + Args: + driver: Graph driver. + llm_client: LLM client for community summarization. + group_ids: Scope clustering to these group ids (or all if None). + sample_size: If set, each community's summary is built from only + the top-K most representative members (by in-community weighted + degree, then summary length). Reduces LLM cost from O(total nodes) + to O(num_communities * sample_size). Recommended for graphs + >10k nodes. + """ + clusters_result = await get_community_clusters(driver, group_ids, return_projection=True) + assert isinstance(clusters_result, tuple) + community_clusters, projection = clusters_result semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY) async def limited_build_community(cluster): async with semaphore: - return await build_community(llm_client, cluster) + return await build_community( + llm_client, + cluster, + projection=projection, + sample_size=sample_size, + ) communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list( await semaphore_gather( diff --git a/tests/utils/maintenance/test_community_operations.py b/tests/utils/maintenance/test_community_operations.py new file mode 100644 index 000000000..1d1f0da4e --- /dev/null +++ b/tests/utils/maintenance/test_community_operations.py @@ -0,0 +1,150 @@ +"""Tests for community summary member sampling. + +The `sample_size` parameter on `build_community` (and `build_communities`) +limits the number of members whose summaries feed the binary-merge +summarization tree. This bounds LLM cost on large graphs: + +- Without sampling, summary cost grows as O(total_nodes) — every entity's + summary participates in the merge tree. +- With sampling, cost grows as O(num_communities * sample_size) — only the + top-K most representative members per community participate. + +These tests focus on the `_select_representative_members` helper that +implements the ranking. End-to-end tests of `build_communities` with a +real LLM are out of scope here — see the existing integration tests. +""" + +from __future__ import annotations + +from graphiti_core.nodes import EntityNode +from graphiti_core.utils.maintenance.community_operations import ( + Neighbor, + _select_representative_members, +) + + +def _make_entity(uuid: str, name: str = '', summary: str = '') -> EntityNode: + """Build a minimal EntityNode for sampling tests.""" + return EntityNode(uuid=uuid, name=name or uuid, group_id='g', summary=summary) + + +def test_returns_all_members_when_cluster_smaller_than_k(): + members = [_make_entity(f'e{i}') for i in range(5)] + sampled = _select_representative_members(members, projection=None, sample_size=10) + assert sampled == members + + +def test_returns_all_members_when_cluster_equal_to_k(): + members = [_make_entity(f'e{i}') for i in range(5)] + sampled = _select_representative_members(members, projection=None, sample_size=5) + assert sampled == members + + +def test_prefers_higher_in_community_degree(): + """A node with many in-community neighbors outranks isolated nodes.""" + # e0 is a hub: 3 weighted edges within the community. + # e1 has 1 weighted edge. + # e2..e4 have no in-community edges in this projection. + members = [_make_entity(f'e{i}') for i in range(5)] + projection: dict[str, list[Neighbor]] = { + 'e0': [ + Neighbor(node_uuid='e1', edge_count=5), + Neighbor(node_uuid='e2', edge_count=5), + Neighbor(node_uuid='e3', edge_count=5), + ], + 'e1': [Neighbor(node_uuid='e0', edge_count=5)], + 'e2': [Neighbor(node_uuid='e0', edge_count=5)], + 'e3': [Neighbor(node_uuid='e0', edge_count=5)], + 'e4': [], + } + sampled = _select_representative_members(members, projection, sample_size=2) + assert len(sampled) == 2 + # Hub must be picked first + assert sampled[0].uuid == 'e0' + + +def test_falls_back_to_summary_length_without_projection(): + """When no projection is available, longer summaries win.""" + members = [ + _make_entity('short', summary='x'), + _make_entity('medium', summary='x' * 50), + _make_entity('long', summary='x' * 200), + ] + sampled = _select_representative_members(members, projection=None, sample_size=2) + assert sampled[0].uuid == 'long' + assert sampled[1].uuid == 'medium' + + +def test_falls_back_to_summary_length_with_empty_projection(): + """An empty projection (e.g., from a graph_operations_interface that + does not expose projections) is treated like no projection at all.""" + members = [ + _make_entity('a', summary='short'), + _make_entity('b', summary='x' * 100), + ] + sampled = _select_representative_members(members, projection={}, sample_size=1) + assert sampled[0].uuid == 'b' + + +def test_deterministic_on_ties(): + """Same input produces the same partition across runs.""" + members = [_make_entity(f'e{i}') for i in range(5)] + projection: dict[str, list[Neighbor]] = { + 'e0': [Neighbor(node_uuid='e1', edge_count=1)], + 'e1': [ + Neighbor(node_uuid='e0', edge_count=1), + Neighbor(node_uuid='e2', edge_count=1), + ], + 'e2': [ + Neighbor(node_uuid='e1', edge_count=1), + Neighbor(node_uuid='e3', edge_count=1), + ], + 'e3': [ + Neighbor(node_uuid='e2', edge_count=1), + Neighbor(node_uuid='e4', edge_count=1), + ], + 'e4': [Neighbor(node_uuid='e3', edge_count=1)], + } + first = _select_representative_members(members, projection, sample_size=2) + second = _select_representative_members(members, projection, sample_size=2) + assert [m.uuid for m in first] == [m.uuid for m in second] + + +def test_only_counts_in_community_edges(): + """Edges to entities outside the community must be ignored. + + A node with many out-of-community connections but only a few in-community + edges should not outrank an in-community-focused node. + """ + members = [_make_entity('insider'), _make_entity('insider2')] + projection: dict[str, list[Neighbor]] = { + 'insider': [ + # Many heavy edges to entities NOT in the cluster + Neighbor(node_uuid='outsider_a', edge_count=100), + Neighbor(node_uuid='outsider_b', edge_count=100), + # One light edge inside + Neighbor(node_uuid='insider2', edge_count=1), + ], + 'insider2': [ + Neighbor(node_uuid='insider', edge_count=1), + ], + } + sampled = _select_representative_members(members, projection, sample_size=1) + # Both have in-community degree 1; tie-broken by name desc → 'insider2' wins + assert sampled[0].uuid == 'insider2' + + +def test_summary_length_breaks_degree_ties(): + """When two nodes have the same in-community degree, the one with the + richer summary wins (since richer summaries contribute more to the + binary merge).""" + members = [ + _make_entity('a', summary='x' * 10), + _make_entity('b', summary='x' * 200), + ] + projection: dict[str, list[Neighbor]] = { + 'a': [Neighbor(node_uuid='b', edge_count=1)], + 'b': [Neighbor(node_uuid='a', edge_count=1)], + } + sampled = _select_representative_members(members, projection, sample_size=1) + assert sampled[0].uuid == 'b'