diff --git a/Cargo.lock b/Cargo.lock index 17f43cd917..8cdf88cd36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6569,6 +6569,7 @@ dependencies = [ "arrow-array", "async-graphql", "async-graphql-poem", + "async-graphql-value", "base64 0.22.1", "base64-compat", "bigdecimal", diff --git a/Cargo.toml b/Cargo.toml index a12139ed49..87ee16d5c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,7 @@ raphtory-itertools = { version = "0.18.0", path = "raphtory-itertools" } clam-core = { version = "0.18.0", path = "clam-core" } optd-core = { version = "0.18.0", path = "optd/optd/core" } async-graphql = { version = "7.2.1", features = ["dynamic-schema"] } +async-graphql-value = "7.2.1" bincode = { version = "2", features = ["serde"] } async-graphql-poem = "7.2.1" dynamic-graphql = "0.10.1" diff --git a/python/tests/test_base_install/test_graphql/test_server_flags.py b/python/tests/test_base_install/test_graphql/test_server_flags.py index 7019bf5688..3ca466981b 100644 --- a/python/tests/test_base_install/test_graphql/test_server_flags.py +++ b/python/tests/test_base_install/test_graphql/test_server_flags.py @@ -262,15 +262,18 @@ def test_disable_lists_all_resolvers(): def test_disable_lists_page_still_works(): - """Even with `disable_lists=True`, `page` queries still succeed.""" + """Even with `disable_lists=True`, every `page` resolver still succeeds.""" work_dir = tempfile.mkdtemp() with GraphServer(work_dir, disable_lists=True).start(): client = RaphtoryClient(SERVER_URL) make_graph(client) - result = client.query( - '{ graph(path: "g") { nodes { page(limit: 10) { name } } } }' - ) - assert len(result["graph"]["nodes"]["page"]) == 3 + for name, query in PAGE_QUERIES: + try: + client.query(query) + except Exception as e: + raise AssertionError( + f"{name} unexpectedly failed while lists are disabled: {e}" + ) def test_max_page_size_all_resolvers(): @@ -289,15 +292,18 @@ def test_max_page_size_all_resolvers(): def test_max_page_size_under_cap_works(): - """Pages at or below max_page_size still succeed.""" + """With max_page_size=51, the same PAGE_QUERIES (all using limit=50) all succeed.""" work_dir = tempfile.mkdtemp() - with GraphServer(work_dir, max_page_size=2).start(): + with GraphServer(work_dir, max_page_size=51).start(): client = RaphtoryClient(SERVER_URL) make_graph(client) - result = client.query( - '{ graph(path: "g") { nodes { page(limit: 2) { name } } } }' - ) - assert len(result["graph"]["nodes"]["page"]) == 2 + for name, query in PAGE_QUERIES: + try: + client.query(query) + except Exception as e: + raise AssertionError( + f"{name} unexpectedly failed under max_page_size=51: {e}" + ) def test_disable_batching(): diff --git a/raphtory-graphql/Cargo.toml b/raphtory-graphql/Cargo.toml index 081ea82f76..bedfaba452 100644 --- a/raphtory-graphql/Cargo.toml +++ b/raphtory-graphql/Cargo.toml @@ -31,6 +31,7 @@ once_cell = { workspace = true } poem = { workspace = true } tokio = { workspace = true } async-graphql = { workspace = true, features = ["apollo_tracing"] } +async-graphql-value = { workspace = true } dynamic-graphql = { workspace = true } async-graphql-poem = { workspace = true } futures-util = { workspace = true } diff --git a/raphtory-graphql/src/collection_guard.rs b/raphtory-graphql/src/collection_guard.rs new file mode 100644 index 0000000000..d5dd9cc6e1 --- /dev/null +++ b/raphtory-graphql/src/collection_guard.rs @@ -0,0 +1,313 @@ +use crate::config::concurrency_config::ConcurrencyConfig; +use async_graphql::{ + async_trait, + extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery}, + parser::types::{ExecutableDocument, Field, Selection, SelectionSet, VariableDefinition}, + Name, Positioned, ServerError, ServerResult, Variables, +}; +use async_graphql_value::{ConstValue, Value}; +use std::{collections::HashSet, sync::Arc}; + +const LIST_DISABLED_ERROR: &str = + "Bulk list endpoints are disabled on this server. Use `page` instead."; + +/// Enforces `concurrency.disable_lists` and `concurrency.max_page_size` at parse time +/// by walking the `ExecutableDocument` and rejecting any `list`/`listRev` field (when +/// lists are disabled) or any `page`/`pageRev` field whose `limit` argument exceeds +/// the configured maximum. +pub struct CollectionGuard { + disable_lists: bool, + max_page_size: Option, +} + +impl CollectionGuard { + /// Returns `None` when neither guard is active — avoids installing an extension + /// that would only no-op. + pub fn from_config(config: &ConcurrencyConfig) -> Option { + if !config.disable_lists && config.max_page_size.is_none() { + return None; + } + Some(Self { + disable_lists: config.disable_lists, + max_page_size: config.max_page_size, + }) + } +} + +impl ExtensionFactory for CollectionGuard { + fn create(&self) -> Arc { + Arc::new(CollectionGuardExtension { + disable_lists: self.disable_lists, + max_page_size: self.max_page_size, + }) + } +} + +struct CollectionGuardExtension { + disable_lists: bool, + max_page_size: Option, +} + +#[async_trait::async_trait] +impl Extension for CollectionGuardExtension { + async fn parse_query( + &self, + ctx: &ExtensionContext<'_>, + query: &str, + variables: &Variables, + next: NextParseQuery<'_>, + ) -> ServerResult { + let doc = next.run(ctx, query, variables).await?; + for (_, op) in doc.operations.iter() { + let resolver = VariableResolver::new(&op.node.variable_definitions, variables); + let mut visited = HashSet::new(); + self.walk(&op.node.selection_set.node, &doc, &resolver, &mut visited)?; + } + Ok(doc) + } +} + +impl CollectionGuardExtension { + fn walk<'a>( + &self, + set: &'a SelectionSet, + doc: &'a ExecutableDocument, + resolver: &VariableResolver<'_>, + visited: &mut HashSet<&'a str>, + ) -> ServerResult<()> { + for item in &set.items { + match &item.node { + Selection::Field(field) => { + let field_node = &field.node; + let name = field_node.name.node.as_str(); + let pos = field.pos; + match name { + "list" | "listRev" if self.disable_lists => { + return Err(ServerError::new(LIST_DISABLED_ERROR, Some(pos))); + } + "page" | "pageRev" => { + if let Some(max) = self.max_page_size { + if let Some(limit) = field_limit(field_node, resolver) { + if limit > max { + return Err(ServerError::new( + format!( + "page limit {limit} exceeds the maximum allowed page size {max}" + ), + Some(pos), + )); + } + } + } + } + _ => {} + } + self.walk(&field_node.selection_set.node, doc, resolver, visited)?; + } + Selection::InlineFragment(frag) => { + self.walk(&frag.node.selection_set.node, doc, resolver, visited)?; + } + Selection::FragmentSpread(spread) => { + let fragment_name = spread.node.fragment_name.node.as_str(); + if !visited.insert(fragment_name) { + continue; + } + if let Some(def) = doc.fragments.get(&spread.node.fragment_name.node) { + self.walk(&def.node.selection_set.node, doc, resolver, visited)?; + } + } + } + } + Ok(()) + } +} + +fn field_limit(field: &Field, resolver: &VariableResolver<'_>) -> Option { + let (_, value) = field + .arguments + .iter() + .find(|(n, _)| n.node.as_str() == "limit")?; + match &value.node { + Value::Number(n) => n.as_u64().map(|v| v as usize), + Value::Variable(name) => match resolver.resolve(name)? { + ConstValue::Number(n) => n.as_u64().map(|v| v as usize), + _ => None, + }, + _ => None, + } +} + +/// Resolves a variable by name, falling back to the operation's declared default value +/// when the client omitted it. Stays scoped to a single operation because defaults are +/// per-operation. +struct VariableResolver<'a> { + variables: &'a Variables, + defaults: Vec<(&'a Name, &'a ConstValue)>, +} + +impl<'a> VariableResolver<'a> { + fn new(definitions: &'a [Positioned], variables: &'a Variables) -> Self { + let defaults = definitions + .iter() + .filter_map(|def| { + def.node + .default_value + .as_ref() + .map(|v| (&def.node.name.node, &v.node)) + }) + .collect(); + Self { + variables, + defaults, + } + } + + fn resolve(&self, name: &Name) -> Option<&ConstValue> { + if let Some(value) = self.variables.get(name) { + return Some(value); + } + self.defaults + .iter() + .find_map(|(n, v)| (*n == name).then_some(*v)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_graphql::parser::parse_query; + + fn run( + disable_lists: bool, + max_page_size: Option, + query: &str, + variables: Variables, + ) -> Result<(), String> { + let ext = CollectionGuardExtension { + disable_lists, + max_page_size, + }; + let doc = parse_query(query).map_err(|e| e.to_string())?; + for (_, op) in doc.operations.iter() { + let resolver = VariableResolver::new(&op.node.variable_definitions, &variables); + let mut visited = HashSet::new(); + ext.walk(&op.node.selection_set.node, &doc, &resolver, &mut visited) + .map_err(|e| e.message)?; + } + Ok(()) + } + + #[test] + fn rejects_list_when_disabled() { + let err = run(true, None, "{ foo { list { bar } } }", Variables::default()).unwrap_err(); + assert!(err.contains("Bulk list endpoints are disabled")); + } + + #[test] + fn rejects_list_rev_when_disabled() { + let err = run( + true, + None, + "{ foo { listRev { bar } } }", + Variables::default(), + ) + .unwrap_err(); + assert!(err.contains("Bulk list endpoints are disabled")); + } + + #[test] + fn allows_list_when_not_disabled() { + run( + false, + None, + "{ foo { list { bar } } }", + Variables::default(), + ) + .unwrap(); + } + + #[test] + fn rejects_page_over_max() { + let err = run( + false, + Some(10), + "{ foo { page(limit: 50) { bar } } }", + Variables::default(), + ) + .unwrap_err(); + assert!(err.contains("page limit 50 exceeds the maximum allowed page size 10")); + } + + #[test] + fn allows_page_under_max() { + run( + false, + Some(50), + "{ foo { page(limit: 10) { bar } } }", + Variables::default(), + ) + .unwrap(); + } + + #[test] + fn resolves_limit_from_provided_variable() { + let vars = Variables::from_json(serde_json::json!({ "n": 100 })); + let err = run( + false, + Some(10), + "query ($n: Int!) { foo { page(limit: $n) { bar } } }", + vars, + ) + .unwrap_err(); + assert!(err.contains("page limit 100 exceeds")); + } + + #[test] + fn resolves_limit_from_variable_default() { + let err = run( + false, + Some(10), + "query ($n: Int = 100) { foo { page(limit: $n) { bar } } }", + Variables::default(), + ) + .unwrap_err(); + assert!(err.contains("page limit 100 exceeds")); + } + + #[test] + fn walks_into_inline_fragments() { + let err = run( + true, + None, + "{ foo { ... on Foo { list { bar } } } }", + Variables::default(), + ) + .unwrap_err(); + assert!(err.contains("Bulk list endpoints are disabled")); + } + + #[test] + fn walks_into_fragment_spreads() { + let err = run( + true, + None, + "fragment F on Foo { list { bar } } { foo { ...F } }", + Variables::default(), + ) + .unwrap_err(); + assert!(err.contains("Bulk list endpoints are disabled")); + } + + #[test] + fn handles_cyclic_fragments_without_looping() { + // Cycle is spec-invalid but arrives here before async-graphql's validation; the + // visited set must prevent infinite recursion. + let err = run( + true, + None, + "fragment A on T { ...B list { x } } fragment B on T { ...A } { root { ...A } }", + Variables::default(), + ) + .unwrap_err(); + assert!(err.contains("Bulk list endpoints are disabled")); + } +} diff --git a/raphtory-graphql/src/lib.rs b/raphtory-graphql/src/lib.rs index e4b50b87bb..ecc6691d8a 100644 --- a/raphtory-graphql/src/lib.rs +++ b/raphtory-graphql/src/lib.rs @@ -5,6 +5,7 @@ use std::sync::Arc; mod auth; pub mod client; +mod collection_guard; pub mod data; mod graph; pub mod model; diff --git a/raphtory-graphql/src/model/graph/windowset.rs b/raphtory-graphql/src/model/graph/windowset.rs index da1fd6311e..32b1501029 100644 --- a/raphtory-graphql/src/model/graph/windowset.rs +++ b/raphtory-graphql/src/model/graph/windowset.rs @@ -1,17 +1,11 @@ use crate::{ model::graph::{ - collection::{check_list_allowed, check_page_limit}, - edge::GqlEdge, - edges::GqlEdges, - graph::GqlGraph, - node::GqlNode, - nodes::GqlNodes, + edge::GqlEdge, edges::GqlEdges, graph::GqlGraph, node::GqlNode, nodes::GqlNodes, path_from_node::GqlPathFromNode, }, paths::ExistingGraphFolder, rayon::blocking_compute, }; -use async_graphql::Context; use dynamic_graphql::{ResolvedObject, ResolvedObjectFields}; use raphtory::db::{ api::{ @@ -48,14 +42,12 @@ impl GqlGraphWindowSet { /// will be returned. async fn page( &self, - ctx: &Context<'_>, limit: usize, offset: Option, page_index: Option, - ) -> async_graphql::Result> { - check_page_limit(ctx, limit)?; + ) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || { + blocking_compute(move || { let start = page_index.unwrap_or(0) * limit + offset.unwrap_or(0); self_clone .ws @@ -65,20 +57,19 @@ impl GqlGraphWindowSet { .map(|g| GqlGraph::new(self_clone.path.clone(), g)) .collect() }) - .await) + .await } - async fn list(&self, ctx: &Context<'_>) -> async_graphql::Result> { - check_list_allowed(ctx)?; + async fn list(&self) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || { + blocking_compute(move || { self_clone .ws .clone() .map(|g| GqlGraph::new(self_clone.path.clone(), g)) .collect() }) - .await) + .await } } @@ -107,14 +98,12 @@ impl GqlNodeWindowSet { /// will be returned. async fn page( &self, - ctx: &Context<'_>, limit: usize, offset: Option, page_index: Option, - ) -> async_graphql::Result> { - check_page_limit(ctx, limit)?; + ) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || { + blocking_compute(move || { let start = page_index.unwrap_or(0) * limit + offset.unwrap_or(0); self_clone .ws @@ -124,13 +113,12 @@ impl GqlNodeWindowSet { .map(|n| n.into()) .collect() }) - .await) + .await } - async fn list(&self, ctx: &Context<'_>) -> async_graphql::Result> { - check_list_allowed(ctx)?; + async fn list(&self) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || self_clone.ws.clone().map(|n| n.into()).collect()).await) + blocking_compute(move || self_clone.ws.clone().map(|n| n.into()).collect()).await } } @@ -161,14 +149,12 @@ impl GqlNodesWindowSet { /// will be returned. async fn page( &self, - ctx: &Context<'_>, limit: usize, offset: Option, page_index: Option, - ) -> async_graphql::Result> { - check_page_limit(ctx, limit)?; + ) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || { + blocking_compute(move || { let start = page_index.unwrap_or(0) * limit + offset.unwrap_or(0); self_clone .ws @@ -178,16 +164,12 @@ impl GqlNodesWindowSet { .map(|n| GqlNodes::new(n)) .collect() }) - .await) + .await } - async fn list(&self, ctx: &Context<'_>) -> async_graphql::Result> { - check_list_allowed(ctx)?; + async fn list(&self) -> Vec { let self_clone = self.clone(); - Ok( - blocking_compute(move || self_clone.ws.clone().map(|n| GqlNodes::new(n)).collect()) - .await, - ) + blocking_compute(move || self_clone.ws.clone().map(|n| GqlNodes::new(n)).collect()).await } } @@ -216,14 +198,12 @@ impl GqlPathFromNodeWindowSet { /// will be returned. async fn page( &self, - ctx: &Context<'_>, limit: usize, offset: Option, page_index: Option, - ) -> async_graphql::Result> { - check_page_limit(ctx, limit)?; + ) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || { + blocking_compute(move || { let start = page_index.unwrap_or(0) * limit + offset.unwrap_or(0); self_clone .ws @@ -233,20 +213,19 @@ impl GqlPathFromNodeWindowSet { .map(|n| GqlPathFromNode::new(n)) .collect() }) - .await) + .await } - async fn list(&self, ctx: &Context<'_>) -> async_graphql::Result> { - check_list_allowed(ctx)?; + async fn list(&self) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || { + blocking_compute(move || { self_clone .ws .clone() .map(|n| GqlPathFromNode::new(n)) .collect() }) - .await) + .await } } @@ -275,14 +254,12 @@ impl GqlEdgeWindowSet { /// will be returned. async fn page( &self, - ctx: &Context<'_>, limit: usize, offset: Option, page_index: Option, - ) -> async_graphql::Result> { - check_page_limit(ctx, limit)?; + ) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || { + blocking_compute(move || { let start = page_index.unwrap_or(0) * limit + offset.unwrap_or(0); self_clone .ws @@ -292,13 +269,12 @@ impl GqlEdgeWindowSet { .map(|e| e.into()) .collect() }) - .await) + .await } - async fn list(&self, ctx: &Context<'_>) -> async_graphql::Result> { - check_list_allowed(ctx)?; + async fn list(&self) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || self_clone.ws.clone().map(|e| e.into()).collect()).await) + blocking_compute(move || self_clone.ws.clone().map(|e| e.into()).collect()).await } } @@ -327,14 +303,12 @@ impl GqlEdgesWindowSet { /// will be returned. async fn page( &self, - ctx: &Context<'_>, limit: usize, offset: Option, page_index: Option, - ) -> async_graphql::Result> { - check_page_limit(ctx, limit)?; + ) -> Vec { let self_clone = self.clone(); - Ok(blocking_compute(move || { + blocking_compute(move || { let start = page_index.unwrap_or(0) * limit + offset.unwrap_or(0); self_clone .ws @@ -344,15 +318,11 @@ impl GqlEdgesWindowSet { .map(|e| GqlEdges::new(e)) .collect() }) - .await) + .await } - async fn list(&self, ctx: &Context<'_>) -> async_graphql::Result> { - check_list_allowed(ctx)?; + async fn list(&self) -> Vec { let self_clone = self.clone(); - Ok( - blocking_compute(move || self_clone.ws.clone().map(|e| GqlEdges::new(e)).collect()) - .await, - ) + blocking_compute(move || self_clone.ws.clone().map(|e| GqlEdges::new(e)).collect()).await } } diff --git a/raphtory-graphql/src/server.rs b/raphtory-graphql/src/server.rs index 17b65732be..cddd660ecd 100644 --- a/raphtory-graphql/src/server.rs +++ b/raphtory-graphql/src/server.rs @@ -1,5 +1,6 @@ use crate::{ auth::{AuthenticatedGraphQL, MutationAuth}, + collection_guard::CollectionGuard, config::app_config::{load_config, AppConfig}, data::Data, model::{ @@ -250,8 +251,10 @@ impl GraphServer { let schema_cfg = &self.config.schema; let mut schema_builder = App::create_schema() .data(self.data.clone()) - .data(self.config.concurrency.clone()) .extension(MutationAuth); + if let Some(guard) = CollectionGuard::from_config(&self.config.concurrency) { + schema_builder = schema_builder.extension(guard); + } if let Some(depth) = schema_cfg.max_query_depth { schema_builder = schema_builder.limit_depth(depth); }