diff --git a/README.md b/README.md index 84aa4ac..efeb5e5 100644 --- a/README.md +++ b/README.md @@ -5,15 +5,15 @@ A Rust library for formatting PostgreSQL SQL and PL/pgSQL, powered by Supports 7 formatting styles based on popular SQL style guides: -| Style | Description | -|-------|-------------| -| **river** (default) | Keywords right-aligned to form a visual "river" ([sqlstyle.guide](https://www.sqlstyle.guide/)) | -| **mozilla** | Keywords left-aligned, content indented 4 spaces | -| **aweber** | River style with JOINs participating in keyword alignment | -| **dbt** | Lowercase keywords, blank lines between clauses | -| **gitlab** | 2-space indent, uppercase keywords | -| **kickstarter** | 2-space indent, compact JOIN...ON on same line | -| **mattmc3** | Lowercase river with leading commas | +| Style | Description | +| ------------------- | -------------------------------------------------- | +| [**aweber**](https://gist.github.com/gmr/2cceb85bb37be96bc96f05c5b8de9e1b) (default) | River style with JOINs participating in keyword alignment | +| [**dbt**](https://docs.getdbt.com/best-practices/how-we-style/2-how-we-style-our-sql) | Lowercase keywords, blank lines between clauses | +| [**gitlab**](https://handbook.gitlab.com/handbook/enterprise-data/platform/sql-style-guide/) | 2-space indent, uppercase keywords | +| [**kickstarter**](https://gist.github.com/fredbenenson/7bb92718e19138c20591) | 2-space indent, compact JOIN...ON on same line | +| [**mattmc3**](https://gist.github.com/mattmc3/38a85e6a4ca1093816c08d4815fbebfb) | Lowercase river with leading commas | +| [**mozilla**](https://docs.telemetry.mozilla.org/concepts/sql_style.html) | Keywords left-aligned, content indented 4 spaces | +| [**river**](https://www.sqlstyle.guide/) | Keywords right-aligned to form a visual "river" | ## Usage @@ -98,7 +98,8 @@ match format("SELECT * FORM broken", Style::River) { Given: `SELECT file_hash FROM file_system WHERE file_name = '.vimrc'` -**River** (default): +**River**: + ```sql SELECT file_hash FROM file_system @@ -106,6 +107,7 @@ SELECT file_hash ``` **Mozilla**: + ```sql SELECT file_hash FROM file_system @@ -114,6 +116,7 @@ WHERE ``` **dbt**: + ```sql select file_hash @@ -124,6 +127,7 @@ where ``` **mattmc3** (leading commas): + ```sql select file_hash from file_system diff --git a/examples/dump_tree.rs b/examples/dump_tree.rs new file mode 100644 index 0000000..1eb6304 --- /dev/null +++ b/examples/dump_tree.rs @@ -0,0 +1,35 @@ +use tree_sitter::Parser; +use tree_sitter_postgres::LANGUAGE; + +fn print_tree(node: tree_sitter::Node, source: &str, indent: usize) { + let kind = node.kind(); + let text = &source[node.byte_range()]; + let short = if text.len() > 80 { &text[..80] } else { text }; + let short = short.replace('\n', "\\n"); + let pad = " ".repeat(indent); + if node.is_named() { + println!("{pad}{kind}: {short:?}"); + } else { + println!("{pad}[{kind}]: {short:?}"); + } + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + print_tree(child, source, indent + 1); + } +} + +fn main() { + let path = std::env::args() + .nth(1) + .expect("usage: dump_tree "); + let sql = std::fs::read_to_string(&path).unwrap(); + let input = if sql.trim().ends_with(';') { + sql.trim().to_string() + } else { + format!("{};", sql.trim()) + }; + let mut parser = Parser::new(); + parser.set_language(&LANGUAGE.into()).unwrap(); + let tree = parser.parse(&input, None).unwrap(); + print_tree(tree.root_node(), &input, 0); +} diff --git a/examples/format_test.rs b/examples/format_test.rs new file mode 100644 index 0000000..0c5d24b --- /dev/null +++ b/examples/format_test.rs @@ -0,0 +1,16 @@ +use libpgfmt::{format, style::Style}; +fn main() { + let sql = std::fs::read_to_string(std::env::args().nth(1).unwrap()).unwrap(); + let style: Style = std::env::args() + .nth(2) + .unwrap_or("aweber".to_string()) + .parse() + .unwrap(); + match format(sql.trim(), style) { + Ok(f) => println!("{f}"), + Err(e) => { + eprintln!("Error: {e}"); + std::process::exit(1); + } + } +} diff --git a/src/formatter/expr.rs b/src/formatter/expr.rs index bb99136..ba154b4 100644 --- a/src/formatter/expr.rs +++ b/src/formatter/expr.rs @@ -164,7 +164,7 @@ impl<'a> Formatter<'a> { /// Format any expression node into inline SQL text. pub(crate) fn format_expr(&self, node: Node<'a>) -> String { match node.kind() { - "a_expr" => self.format_a_expr(node), + "a_expr" | "b_expr" => self.format_a_expr(node), "a_expr_prec" => self.format_a_expr_prec(node), "c_expr" => self.format_c_expr(node), "columnref" => self.format_columnref(node), @@ -183,7 +183,7 @@ impl<'a> Formatter<'a> { "type_function_name" => self.format_first_named_child(node), "ColId" => self.format_col_id(node), "ColLabel" => self.format_first_named_child(node), - "qualified_name" => self.format_qualified_name(node), + "qualified_name" | "any_name" => self.format_qualified_name(node), "indirection" => self.format_indirection(node), "indirection_el" => self.format_indirection_el(node), "attr_name" => self.format_first_named_child(node), @@ -198,6 +198,7 @@ impl<'a> Formatter<'a> { formatted.join(", ") } "func_arg_expr" => self.format_first_named_child(node), + "array_expr" => self.format_array_expr(node), "opt_alias_clause" | "alias_clause" => self.format_alias(node), "group_by_item" => self.format_first_named_child(node), "ERROR" => self.text(node).to_string(), @@ -215,13 +216,25 @@ impl<'a> Formatter<'a> { /// Format an a_expr node (the main expression type with operators). fn format_a_expr(&self, node: Node<'a>) -> String { - let mut parts = Vec::new(); + let mut parts: Vec = Vec::new(); let mut cursor = node.walk(); // Check if this a_expr contains an inline expr_list (e.g., IN (...)). // If so, skip unnamed parens since we format them with the expr_list. let has_expr_list = node.find_child("expr_list").is_some(); + let mut pending_cast = false; for child in node.children(&mut cursor) { if child.is_named() { + // After ::, the next named child is a Typename — append directly + // to the previous part without spaces. + if pending_cast { + pending_cast = false; + let typename = self.format_expr(child); + if let Some(last) = parts.last_mut() { + last.push_str("::"); + last.push_str(&typename); + } + continue; + } match child.kind() { "a_expr_prec" | "a_expr" | "c_expr" => { parts.push(self.format_expr(child)); @@ -269,6 +282,11 @@ impl<'a> Formatter<'a> { // Unnamed children are operators like =, <, >, !=, etc. let text = self.text(child).trim(); if !text.is_empty() { + // Typecast operator :: — defer and attach to next Typename. + if text == "::" { + pending_cast = true; + continue; + } // Skip parens that surround an expr_list (handled inline). if has_expr_list && (text == "(" || text == ")") { continue; @@ -325,14 +343,28 @@ impl<'a> Formatter<'a> { } fn format_a_expr_prec(&self, node: Node<'a>) -> String { - let mut parts = Vec::new(); + let mut parts: Vec = Vec::new(); let mut cursor = node.walk(); + let mut pending_cast = false; for child in node.children(&mut cursor) { if child.is_named() { + if pending_cast { + pending_cast = false; + let typename = self.format_expr(child); + if let Some(last) = parts.last_mut() { + last.push_str("::"); + last.push_str(&typename); + } + continue; + } parts.push(self.format_expr(child)); } else { let text = self.text(child).trim(); if !text.is_empty() { + if text == "::" { + pending_cast = true; + continue; + } let op = if text == "!=" { "<>" } else { text }; parts.push(op.to_string()); } @@ -363,14 +395,25 @@ impl<'a> Formatter<'a> { } "kw_exists" => self.kw("EXISTS"), "kw_row" => self.kw("ROW"), + "kw_array" => self.kw("ARRAY"), + "array_expr" => self.format_array_expr(child), _ if child.kind().starts_with("kw_") => self.format_keyword_node(child), _ => self.format_expr(child), }; - if paren_depth > 0 { - paren_parts.push(formatted); + // Merge ARRAY with following [...] bracket expression. + let target = if paren_depth > 0 { + &mut paren_parts } else { - parts.push(formatted); + &mut parts + }; + if formatted.starts_with('[') + && let Some(last) = target.last_mut() + && (*last == "ARRAY" || *last == "array") + { + last.push_str(&formatted); + continue; } + target.push(formatted); } else { let text = self.text(child).trim(); if text == "(" { @@ -384,7 +427,14 @@ impl<'a> Formatter<'a> { if paren_depth == 0 { // Close outermost paren group. let inner = paren_parts.join(" "); - parts.push(format!("({inner})")); + // Strip redundant parens around a single simple + // expression (column ref, literal) that doesn't + // contain operators or keywords. + if paren_parts.len() == 1 && !inner.contains(' ') && !inner.contains('\n') { + parts.push(inner); + } else { + parts.push(format!("({inner})")); + } paren_parts.clear(); } else { // Closing a nested paren. @@ -444,7 +494,18 @@ impl<'a> Formatter<'a> { let mut cursor = node.walk(); if let Some(child) = node.named_children(&mut cursor).next() { return match child.kind() { - "identifier" | "unreserved_keyword" => self.text(child).to_string(), + "identifier" => self.text(child).to_string(), + "unreserved_keyword" => { + // Special pseudo-variables like VALUE in domain CHECK + // constraints are conventionally lowercased regardless + // of keyword casing. + if let Some(kw) = child.named_children(&mut child.walk()).next() + && kw.kind() == "kw_value" + { + return "value".to_string(); + } + self.text(child).to_string() + } _ => self.format_expr(child), }; } @@ -506,7 +567,14 @@ impl<'a> Formatter<'a> { match node.kind() { "func_expr" => { if let Some(app) = node.find_child("func_application") { - return self.format_func(app); + let mut result = self.format_func(app); + // Check for OVER clause at the func_expr level + // (window functions like RANK() OVER (...)). + if let Some(over) = node.find_child("over_clause") { + result.push(' '); + result.push_str(&self.format_over_clause(over)); + } + return result; } // func_expr_common_subexpr or other variants. self.format_func_expr_common(node) @@ -568,7 +636,15 @@ impl<'a> Formatter<'a> { args }; - let mut result = format!("{cased_name}({inner})"); + // ANY, ALL, SOME are special SQL constructs that conventionally + // have a space before the opening paren. + let lower = cased_name.to_lowercase(); + let space = if lower == "any" || lower == "all" || lower == "some" { + " " + } else { + "" + }; + let mut result = format!("{cased_name}{space}({inner})"); if let Some(over) = over_clause { result.push(' '); @@ -579,10 +655,68 @@ impl<'a> Formatter<'a> { } fn format_func_expr_common(&self, node: Node<'a>) -> String { - // Handle COALESCE, GREATEST, LEAST, NULLIF, CURRENT_TIMESTAMP, etc. - let mut cursor = node.walk(); + // Check for func_expr_common_subexpr children. + let subexpr = node.find_child("func_expr_common_subexpr").unwrap_or(node); + + // CAST(expr AS type) → expr::type + // Parenthesize when the formatted operand contains spaces that indicate + // a compound expression (e.g. "a + b"), because the :: typecast operator + // has higher precedence than arithmetic operators in PostgreSQL. + // Simple expressions (column refs, literals, function calls, already- + // parenthesized expressions) do not need extra parens. + if subexpr.has_child("kw_cast") + && let Some(expr) = subexpr.find_child_any(&["a_expr", "c_expr"]) + && let Some(typename) = subexpr.find_child("Typename") + { + let formatted = self.format_expr(expr); + // Parenthesized expressions, function calls like fn(...), and + // expressions without spaces (simple identifiers/literals) are safe. + // Anything else (e.g. "a + b", "x IS NOT NULL") needs wrapping. + let needs_parens = + formatted.contains(' ') && !formatted.starts_with('(') && !formatted.contains('('); + return if needs_parens { + format!("({formatted})::{}", self.format_typename(typename)) + } else { + format!("{formatted}::{}", self.format_typename(typename)) + }; + } + + // Handle COALESCE, GREATEST, LEAST, NULLIF, etc. + // These are function-like: KEYWORD(args) + if let Some(expr_list) = subexpr.find_child("expr_list") { + let items = flatten_list(expr_list, "expr_list"); + let mut formatted: Vec = items.iter().map(|i| self.format_expr(*i)).collect(); + // Merge decimal fragments split by tree-sitter ERROR nodes + // (e.g., "0" + ".00" → "0.00"). + let mut i = 0; + while i + 1 < formatted.len() { + if formatted[i].chars().all(|c| c.is_ascii_digit()) + && formatted[i + 1].starts_with('.') + && formatted[i + 1][1..].chars().all(|c| c.is_ascii_digit()) + { + let merged = format!("{}{}", formatted[i], formatted[i + 1]); + formatted[i] = merged; + formatted.remove(i + 1); + } else { + i += 1; + } + } + // Find the keyword name. + let mut name = String::new(); + let mut cursor = subexpr.walk(); + for child in subexpr.named_children(&mut cursor) { + if child.kind().starts_with("kw_") { + name = self.kw(self.text(child)); + break; + } + } + return format!("{name}({})", formatted.join(", ")); + } + + // Other forms (CURRENT_TIMESTAMP, etc.). + let mut cursor = subexpr.walk(); let mut parts = Vec::new(); - for child in node.children(&mut cursor) { + for child in subexpr.children(&mut cursor) { if child.is_named() { match child.kind() { "func_application" => return self.format_func(child), @@ -617,11 +751,14 @@ impl<'a> Formatter<'a> { fn format_over_clause(&self, node: Node<'a>) -> String { let mut parts = vec![self.kw("OVER")]; - parts.push("(".to_string()); + parts.push(" (".to_string()); + + // The window_specification contains the actual partition/order clauses. + let spec = node.find_child("window_specification").unwrap_or(node); let mut inner = Vec::new(); - let mut cursor = node.walk(); - for child in node.named_children(&mut cursor) { + let mut cursor = spec.walk(); + for child in spec.named_children(&mut cursor) { match child.kind() { "opt_partition_clause" => { inner.push(self.format_partition_clause(child)); @@ -689,33 +826,85 @@ impl<'a> Formatter<'a> { } fn format_case_expr(&self, node: Node<'a>) -> String { - let mut parts = vec![self.kw("CASE")]; + let case_kw = self.kw("CASE"); + let end_kw = self.kw("END"); + let mut case_arg: Option = None; + let mut when_clauses: Vec = Vec::new(); + let mut else_parts: Option = None; + let mut cursor = node.walk(); for child in node.named_children(&mut cursor) { match child.kind() { "kw_case" | "kw_end" => {} "case_arg" => { if let Some(expr) = child.find_child_any(&["a_expr", "c_expr"]) { - parts.push(self.format_expr(expr)); + case_arg = Some(self.format_expr(expr)); } } "when_clause_list" => { let clauses = flatten_list(child, "when_clause_list"); for clause in clauses { - parts.push(self.format_when_clause(clause)); + when_clauses.push(self.format_when_clause(clause)); } } "case_default" => { if let Some(expr) = child.find_child_any(&["a_expr", "c_expr", "a_expr_prec"]) { - parts.push(self.kw("ELSE")); - parts.push(self.format_expr(expr)); + else_parts = + Some(format!("{} {}", self.kw("ELSE"), self.format_expr(expr))); } } _ => {} } } - parts.push(self.kw("END")); - parts.join(" ") + + // Build the single-line version first. + let mut inline_parts = vec![case_kw.clone()]; + if let Some(ref arg) = case_arg { + inline_parts.push(arg.clone()); + } + for wc in &when_clauses { + inline_parts.push(wc.clone()); + } + if let Some(ref ep) = else_parts { + inline_parts.push(ep.clone()); + } + inline_parts.push(end_kw.clone()); + let single_line = inline_parts.join(" "); + + // Wrap CASE when the style wraps CASE+ELSE and there's an ELSE clause, + // or when the single-line version is excessively long. + let should_wrap = if self.config.wrap_case_else && else_parts.is_some() { + true + } else { + single_line.len() > 120 + }; + if !should_wrap { + return single_line; + } + + // Multi-line: align WHEN/ELSE under the first WHEN, END indented 1 space. + // First line: "CASE [arg] WHEN ..." + // Continuation: " WHEN ..." (indent = len("CASE ") + len(arg + " ") if present) + let prefix = match &case_arg { + Some(arg) => format!("{case_kw} {arg} "), + None => format!("{case_kw} "), + }; + let when_indent = " ".repeat(prefix.len()); + let end_indent = " "; + + let mut lines = Vec::new(); + for (i, wc) in when_clauses.iter().enumerate() { + if i == 0 { + lines.push(format!("{prefix}{wc}")); + } else { + lines.push(format!("{when_indent}{wc}")); + } + } + if let Some(ep) = &else_parts { + lines.push(format!("{when_indent}{ep}")); + } + lines.push(format!("{end_indent}{end_kw}")); + lines.join("\n") } fn format_when_clause(&self, node: Node<'a>) -> String { @@ -734,6 +923,16 @@ impl<'a> Formatter<'a> { parts.join(" ") } + fn format_array_expr(&self, node: Node<'a>) -> String { + // array_expr: [ expr_list ] + if let Some(expr_list) = node.find_child("expr_list") { + let items = flatten_list(expr_list, "expr_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + return format!("[{}]", formatted.join(", ")); + } + self.text(node).to_string() + } + fn format_in_expr(&self, node: Node<'a>) -> String { let mut cursor = node.walk(); for child in node.named_children(&mut cursor) { @@ -861,13 +1060,14 @@ impl<'a> Formatter<'a> { let mut base = String::new(); let mut modifiers = String::new(); let mut extra_keywords = Vec::new(); + let mut timezone_keywords = Vec::new(); let mut cursor = node.walk(); for child in node.children(&mut cursor) { if child.is_named() { match child.kind() { "kw_integer" | "kw_int" | "kw_smallint" | "kw_bigint" | "kw_real" - | "kw_boolean" | "kw_float" | "kw_decimal" => { + | "kw_boolean" | "kw_float" | "kw_decimal" | "kw_numeric" => { base = self.map_type_name(&self.text(child).to_lowercase()); } "kw_double" => base = "DOUBLE".to_string(), @@ -879,16 +1079,31 @@ impl<'a> Formatter<'a> { } } "kw_varying" => extra_keywords.push("VARYING".to_string()), - "kw_with" => extra_keywords.push(self.kw("WITH")), - "kw_without" => extra_keywords.push(self.kw("WITHOUT")), + "opt_timezone" => { + // WITH/WITHOUT TIME ZONE — must appear after modifiers + // so TIMESTAMP(6) WITH TIME ZONE is correct, not + // TIMESTAMP WITH TIME ZONE(6). + let mut tz_cursor = child.walk(); + for tz_child in child.named_children(&mut tz_cursor) { + match tz_child.kind() { + "kw_with" => timezone_keywords.push(self.kw("WITH")), + "kw_without" => timezone_keywords.push(self.kw("WITHOUT")), + "kw_time" => timezone_keywords.push(self.kw("TIME")), + "kw_zone" => timezone_keywords.push(self.kw("ZONE")), + _ => {} + } + } + } + "kw_with" => timezone_keywords.push(self.kw("WITH")), + "kw_without" => timezone_keywords.push(self.kw("WITHOUT")), "kw_time" => { if base.is_empty() { base = self.kw("TIME"); } else { - extra_keywords.push(self.kw("TIME")); + timezone_keywords.push(self.kw("TIME")); } } - "kw_zone" => extra_keywords.push(self.kw("ZONE")), + "kw_zone" => timezone_keywords.push(self.kw("ZONE")), "kw_timestamp" => base = self.kw("TIMESTAMP"), "type_function_name" | "unreserved_keyword" => { let name = self.format_first_named_child(child); @@ -932,12 +1147,23 @@ impl<'a> Formatter<'a> { }; if !extra_keywords.is_empty() { - result.push(' '); + if !result.is_empty() { + result.push(' '); + } result.push_str(&extra_keywords.join(" ")); } if !modifiers.is_empty() { result.push_str(&modifiers); } + // Timezone qualifiers (WITH/WITHOUT TIME ZONE) must follow modifiers + // so that TIMESTAMP(6) WITH TIME ZONE is produced, not + // TIMESTAMP WITH TIME ZONE(6). + if !timezone_keywords.is_empty() { + if !result.is_empty() { + result.push(' '); + } + result.push_str(&timezone_keywords.join(" ")); + } result } @@ -997,6 +1223,7 @@ impl<'a> Formatter<'a> { match child.kind() { "ColId" => parts.push(self.format_col_id(child)), "indirection" => parts.push(self.format_indirection(child)), + "attrs" => parts.push(self.format_attrs(child)), "attr_name" => { // Schema-qualified: schema.name parts.push(format!(".{}", self.format_expr(child))); @@ -1029,12 +1256,22 @@ impl<'a> Formatter<'a> { { return self.format_alias(ac); } + let mut has_as = false; let mut parts = Vec::new(); let mut cursor = node.walk(); for child in node.named_children(&mut cursor) { match child.kind() { - "kw_as" => parts.push(self.kw("AS")), - "ColId" => parts.push(self.format_col_id(child)), + "kw_as" => { + has_as = true; + parts.push(self.kw("AS")); + } + "ColId" => { + // Always add AS keyword for bare aliases. + if !has_as { + parts.push(self.kw("AS")); + } + parts.push(self.format_col_id(child)); + } _ => parts.push(self.format_expr(child)), } } diff --git a/src/formatter/mod.rs b/src/formatter/mod.rs index 1b81e0c..70173fc 100644 --- a/src/formatter/mod.rs +++ b/src/formatter/mod.rs @@ -33,6 +33,8 @@ pub(crate) struct StyleConfig { pub blank_lines_in_ctes: bool, /// Strip INNER keyword from INNER JOIN (mattmc3: use plain JOIN). pub strip_inner_join: bool, + /// Wrap CASE expressions when ELSE is present (AWeber, mattmc3). + pub wrap_case_else: bool, } impl StyleConfig { @@ -50,6 +52,7 @@ impl StyleConfig { join_on_same_line: false, blank_lines_in_ctes: false, strip_inner_join: false, + wrap_case_else: false, }, Style::Mozilla => Self { upper_keywords: true, @@ -63,6 +66,7 @@ impl StyleConfig { join_on_same_line: false, blank_lines_in_ctes: false, strip_inner_join: false, + wrap_case_else: false, }, Style::Aweber => Self { upper_keywords: true, @@ -76,6 +80,7 @@ impl StyleConfig { join_on_same_line: false, blank_lines_in_ctes: false, strip_inner_join: false, + wrap_case_else: true, }, Style::Dbt => Self { upper_keywords: false, @@ -89,6 +94,7 @@ impl StyleConfig { join_on_same_line: false, blank_lines_in_ctes: false, strip_inner_join: false, + wrap_case_else: false, }, Style::Gitlab => Self { upper_keywords: true, @@ -102,6 +108,7 @@ impl StyleConfig { join_on_same_line: false, blank_lines_in_ctes: true, strip_inner_join: false, + wrap_case_else: false, }, Style::Kickstarter => Self { upper_keywords: true, @@ -115,6 +122,7 @@ impl StyleConfig { join_on_same_line: true, blank_lines_in_ctes: false, strip_inner_join: false, + wrap_case_else: false, }, Style::Mattmc3 => Self { upper_keywords: false, @@ -128,6 +136,7 @@ impl StyleConfig { join_on_same_line: false, blank_lines_in_ctes: false, strip_inner_join: true, + wrap_case_else: true, }, } } diff --git a/src/formatter/select.rs b/src/formatter/select.rs index 77a0fa4..a02a4a4 100644 --- a/src/formatter/select.rs +++ b/src/formatter/select.rs @@ -40,12 +40,20 @@ impl<'a> Formatter<'a> { /// Format a select_no_parens node. pub(crate) fn format_select_no_parens(&self, node: Node<'a>) -> String { + self.format_select_no_parens_with_min_width(node, 0) + } + + fn format_select_no_parens_with_min_width( + &self, + node: Node<'a>, + min_river_width: usize, + ) -> String { let clauses = self.collect_select_clauses(node); if clauses.values_clause.is_some() { return self.format_values_only(&clauses); } if self.config.river { - self.format_select_river(&clauses) + self.format_select_river_with_min_width(&clauses, min_river_width) } else { self.format_select_left_aligned(&clauses) } @@ -247,28 +255,59 @@ impl<'a> Formatter<'a> { // ── River-style SELECT ────────────────────────────────────────────── fn format_select_river(&self, clauses: &SelectClauses<'a>) -> String { + self.format_select_river_with_min_width(clauses, 0) + } + + fn format_select_river_with_min_width( + &self, + clauses: &SelectClauses<'a>, + min_width: usize, + ) -> String { let mut lines = Vec::new(); // Calculate river width from all keywords that will appear. let keywords = self.collect_river_keywords(clauses); - let river_width = keywords.iter().map(|k| k.len()).max().unwrap_or(6); - - // WITH clause. + // Don't apply min_width to set operations (UNION/INTERSECT/EXCEPT) + // as they format their own halves independently. + let effective_min = if clauses.set_op.is_some() { + 0 + } else { + min_width + }; + let river_width = keywords + .iter() + .map(|k| k.len()) + .max() + .unwrap_or(6) + .max(effective_min); + + // WITH clause. When inside a CTE body (min_width > 0), river-align + // the WITH keyword; at the top level, WITH starts at column 0. if let Some(with) = clauses.with_clause { - lines.push(self.format_with_clause_river(with, river_width)); + lines.push(self.format_with_clause_river_inner(with, river_width, min_width > 0)); } // SELECT [DISTINCT] targets. - let select_kw = if clauses.distinct.is_some() { + // For river alignment, use just SELECT as the keyword and prepend + // DISTINCT/DISTINCT ON to the target content so that river width + // is based on SELECT alone. + let select_kw = self.kw("SELECT"); + let distinct_prefix = if clauses.distinct.is_some() { let distinct_text = clauses .distinct .map(|d| self.format_distinct(d)) .unwrap_or_else(|| self.kw("DISTINCT")); - format!("{} {}", self.kw("SELECT"), distinct_text) + Some(distinct_text) } else { - self.kw("SELECT") + None }; - self.append_river_targets(&select_kw, &clauses.targets, river_width, &mut lines); + self.append_river_targets_with_prefix( + &select_kw, + distinct_prefix.as_deref(), + &clauses.targets, + river_width, + &mut lines, + ); // FROM clause with JOINs. if let Some(from) = clauses.from { @@ -459,20 +498,30 @@ impl<'a> Formatter<'a> { } } - /// Append target list items in river style. - fn append_river_targets( + /// Append target list items in river style with an optional prefix + /// (e.g. DISTINCT ON (...)) that goes before the first target but after + /// the river keyword. + fn append_river_targets_with_prefix( &self, select_kw: &str, + prefix: Option<&str>, targets: &[Node<'a>], width: usize, lines: &mut Vec, ) { + let prepend = |s: &str| -> String { + match prefix { + Some(p) => format!("{p} {s}"), + None => s.to_string(), + } + }; + if targets.is_empty() { - lines.push(self.river_line(select_kw, "*", width)); + lines.push(self.river_line(select_kw, &prepend("*"), width)); return; } - let first = self.format_target_el(targets[0]); + let first = prepend(&self.format_target_el(targets[0])); if targets.len() == 1 { lines.push(self.river_line(select_kw, &first, width)); return; @@ -1417,14 +1466,28 @@ impl<'a> Formatter<'a> { // ── WITH / CTE formatting ─────────────────────────────────────────── - fn format_with_clause_river(&self, node: Node<'a>, river_width: usize) -> String { + fn format_with_clause_river_inner( + &self, + node: Node<'a>, + river_width: usize, + river_align_with: bool, + ) -> String { let mut lines = Vec::new(); if let Some(cte_list) = node.find_child("cte_list") { let ctes = flatten_list(cte_list, "cte_list"); + // When WITH is river-aligned, river_line handles continuation + // indentation, so CTE bodies use their own width (pass 0). + // When WITH is at column 0 (top-level), CTE bodies inherit + // the outer river_width as a minimum. + let body_min_width = if river_align_with { 0 } else { river_width }; for (i, cte) in ctes.iter().enumerate() { - let cte_text = self.format_cte_river(*cte, river_width); + let cte_text = self.format_cte_river(*cte, body_min_width); if i == 0 { - lines.push(format!("{} {cte_text}", self.kw("WITH"))); + if river_align_with { + lines.push(self.river_line(&self.kw("WITH"), &cte_text, river_width)); + } else { + lines.push(format!("{} {cte_text}", self.kw("WITH"))); + } } else { lines.push(cte_text); } @@ -1433,23 +1496,24 @@ impl<'a> Formatter<'a> { lines.join(",\n") } - fn format_cte_river(&self, node: Node<'a>, _river_width: usize) -> String { + fn format_cte_river(&self, node: Node<'a>, river_width: usize) -> String { let name = node .find_child("name") .map(|n| self.format_expr(n)) .unwrap_or_default(); - let body = self.format_cte_body(node); + let body = self.format_cte_body(node, river_width); format!("{name} {} (\n{body}\n)", self.kw("AS")) } /// Extract and format the body of a CTE, handling SELECT, INSERT, UPDATE, /// DELETE, and any other PreparableStmt type. - fn format_cte_body(&self, node: Node<'a>) -> String { + fn format_cte_body(&self, node: Node<'a>, min_river_width: usize) -> String { if let Some(prep) = node.find_child("PreparableStmt") { if let Some(select) = prep.find_child("SelectStmt") { - return self.format_select_stmt(select); + let snp = select.find_child("select_no_parens").unwrap_or(select); + return self.format_select_no_parens_with_min_width(snp, min_river_width); } if let Some(insert) = prep.find_child("InsertStmt") { return self.format_insert_stmt(insert); @@ -1485,7 +1549,7 @@ impl<'a> Formatter<'a> { .map(|n| self.format_expr(n)) .unwrap_or_default(); - let body = self.format_cte_body(*cte); + let body = self.format_cte_body(*cte, 0); let indented_body = body .lines() diff --git a/src/formatter/stmt.rs b/src/formatter/stmt.rs index 3197a0f..0ffe5c1 100644 --- a/src/formatter/stmt.rs +++ b/src/formatter/stmt.rs @@ -386,10 +386,11 @@ impl<'a> Formatter<'a> { lines.push(")".to_string()); // WITH clause for storage parameters. + // OptWith already contains the WITH keyword, so just normalize. if let Some(with) = node.find_child("OptWith") { - let text = self.text(with); - if !text.trim().is_empty() { - lines.push(format!("{} {}", self.kw("WITH"), text.trim())); + let text = normalize_whitespace(self.text(with)); + if !text.is_empty() { + lines.push(text); } } @@ -411,8 +412,8 @@ impl<'a> Formatter<'a> { .unwrap_or_default(); let mut constraint_parts = Vec::new(); if let Some(qual_list) = col.find_child("ColQualList") { - let mut cursor = qual_list.walk(); - for child in qual_list.named_children(&mut cursor) { + let constraints = flatten_list(qual_list, "ColQualList"); + for child in constraints { if child.kind() == "ColConstraint" { constraint_parts.push(self.format_col_constraint(child)); } @@ -490,8 +491,8 @@ impl<'a> Formatter<'a> { // Column constraints. if let Some(qual_list) = node.find_child("ColQualList") { - let mut cursor = qual_list.walk(); - for child in qual_list.named_children(&mut cursor) { + let constraints = flatten_list(qual_list, "ColQualList"); + for child in constraints { if child.kind() == "ColConstraint" { parts.push(self.format_col_constraint(child)); } @@ -610,7 +611,7 @@ impl<'a> Formatter<'a> { .or_else(|| node.find_child("view_name")) .map(|n| self.format_qualified_name(n)) .unwrap_or_default(); - prefix = format!("{prefix} {name} {} ", self.kw("AS")); + prefix = format!("{prefix} {name} {}", self.kw("AS")); // The SELECT body. if let Some(select) = node.find_child("SelectStmt") { @@ -763,7 +764,24 @@ impl<'a> Formatter<'a> { } "func_as" => { // AS $$ ... $$ - parts.push(format!("{}\n{}", self.kw("AS"), self.text(child))); + // Preserve original line breaks in the body — collapsing + // newlines to spaces would break line-comment (--) semantics. + // For single-line bodies, normalize whitespace (safe since + // there are no newlines that could affect -- comments). + // For multi-line bodies, re-indent each line to preserve + // the original structure. + let text = self.text(child).trim(); + if let Some((tag, body)) = parse_dollar_quoted(text) { + if body.contains('\n') { + let body = reindent_body(body, " "); + parts.push(format!("{} {tag}\n{body}\n{tag}", self.kw("AS"))); + } else { + let body = normalize_whitespace(body); + parts.push(format!("{} {tag}\n {body}\n{tag}", self.kw("AS"))); + } + } else { + parts.push(format!("{} {text}", self.kw("AS"))); + } } _ => {} } @@ -807,9 +825,136 @@ impl<'a> Formatter<'a> { // ── CREATE FOREIGN TABLE ──────────────────────────────────────────── fn format_create_foreign_table_stmt(&self, node: Node<'a>) -> String { - // Similar to CREATE TABLE but with SERVER and OPTIONS. - let text = self.text(node); - normalize_whitespace(text) + let table_name = node + .find_child("qualified_name") + .map(|n| self.format_qualified_name(n)) + .unwrap_or_default(); + + let mut lines = Vec::new(); + lines.push(format!( + "{} {} {} {table_name} (", + self.kw("CREATE"), + self.kw("FOREIGN"), + self.kw("TABLE") + )); + + // Column definitions (same as CREATE TABLE). + if let Some(elem_list) = node + .find_child("OptTableElementList") + .and_then(|n| n.find_child("TableElementList")) + { + let elements = flatten_list(elem_list, "TableElementList"); + let indent = self.config.indent; + + if self.config.river { + // Classify all elements, keeping their original order. + let classified: Vec<_> = elements + .iter() + .map(|e| self.classify_table_element(*e)) + .collect(); + + // Compute column-alignment widths from Column elements only. + let max_name_len = classified + .iter() + .filter_map(|e| { + if let TableElementKind::Column(n, _, _) = e { + Some(n.len()) + } else { + None + } + }) + .max() + .unwrap_or(0); + let max_type_len = classified + .iter() + .filter_map(|e| { + if let TableElementKind::Column(_, t, _) = e { + Some(t.len()) + } else { + None + } + }) + .max() + .unwrap_or(0); + + let total = classified.len(); + for (i, elem) in classified.iter().enumerate() { + let comma = if i < total - 1 { "," } else { "" }; + match elem { + TableElementKind::Column(name, typename, constraints) => { + let padded_name = format!("{:width$}", name, width = max_name_len); + let padded_type = format!("{:width$}", typename, width = max_type_len); + let mut item = format!("{padded_name} {padded_type}"); + if !constraints.is_empty() { + item = format!("{item} {constraints}"); + } + lines.push(format!("{indent}{item}{comma}")); + } + TableElementKind::PrimaryKey(text) + | TableElementKind::Constraint(_, text) => { + let text = match elem { + TableElementKind::Constraint(Some(name), body) => { + format!("{} {name} {body}", self.kw("CONSTRAINT")) + } + _ => text.clone(), + }; + lines.push(format!("{indent}{text}{comma}")); + } + } + } + } else { + let formatted: Vec<_> = elements + .iter() + .map(|e| self.format_table_element(*e)) + .collect(); + for (i, elem) in formatted.iter().enumerate() { + let comma = if i < formatted.len() - 1 { "," } else { "" }; + lines.push(format!("{indent}{elem}{comma}")); + } + } + } + + lines.push(")".to_string()); + + // SERVER name. + if let Some(server_name) = node.find_child("name") { + lines.push(format!( + "{} {}", + self.kw("SERVER"), + self.format_expr(server_name) + )); + } + + // OPTIONS (...). + if let Some(opts) = node.find_child("create_generic_options") { + self.format_generic_options(opts, &mut lines); + } + + lines.join("\n") + } + + fn format_generic_options(&self, node: Node<'a>, lines: &mut Vec) { + if let Some(opt_list) = node.find_child("generic_option_list") { + let items = flatten_list(opt_list, "generic_option_list"); + let indent = self.config.indent; + + lines.push(format!("{} (", self.kw("OPTIONS"))); + for (i, item) in items.iter().enumerate() { + let formatted = self.format_generic_option(*item); + let comma = if i < items.len() - 1 { "," } else { "" }; + lines.push(format!("{indent}{formatted}{comma}")); + } + lines.push(")".to_string()); + } + } + + fn format_generic_option(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + parts.push(self.format_expr(child)); + } + parts.join(" ") } // ── Helpers ───────────────────────────────────────────────────────── @@ -874,6 +1019,60 @@ impl<'a> Formatter<'a> { // format_where_river and format_where_left_aligned are defined in select.rs } +/// Parse a dollar-quoted string into (tag, body). +/// E.g., `$$ body $$` → Some(("$$", " body ")) +/// E.g., `$fn$ body $fn$` → Some(("$fn$", " body ")) +fn parse_dollar_quoted(s: &str) -> Option<(&str, &str)> { + if !s.starts_with('$') { + return None; + } + // Find the end of the opening tag. + let tag_end = s[1..].find('$')? + 2; // +1 for the inner offset, +1 for the closing $ + let tag = &s[..tag_end]; + let rest = &s[tag_end..]; + // Find the closing tag. + let body_end = rest.rfind(tag)?; + let body = &rest[..body_end]; + Some((tag, body)) +} + +/// Re-indent a multi-line body (e.g., a dollar-quoted function body) so that +/// each non-empty line starts with the given `indent` prefix. Strips leading +/// and trailing blank lines, and removes the common leading whitespace from all +/// lines before applying the new indent. +fn reindent_body(s: &str, indent: &str) -> String { + let lines: Vec<&str> = s.lines().collect(); + // Skip leading/trailing empty lines. + let start = lines.iter().position(|l| !l.trim().is_empty()).unwrap_or(0); + let end = lines + .iter() + .rposition(|l| !l.trim().is_empty()) + .map(|i| i + 1) + .unwrap_or(lines.len()); + let body_lines = &lines[start..end]; + if body_lines.is_empty() { + return String::new(); + } + // Determine the minimum leading whitespace across non-empty lines. + let min_indent = body_lines + .iter() + .filter(|l| !l.trim().is_empty()) + .map(|l| l.len() - l.trim_start().len()) + .min() + .unwrap_or(0); + body_lines + .iter() + .map(|line| { + if line.trim().is_empty() { + String::new() + } else { + format!("{indent}{}", &line[min_indent..]) + } + }) + .collect::>() + .join("\n") +} + /// Collapse runs of whitespace to single spaces, but preserve whitespace /// inside single-quoted strings, double-quoted identifiers, and dollar-quoted /// strings so that literal content is not altered. diff --git a/src/lib.rs b/src/lib.rs index eed316e..9b2f468 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,9 +17,9 @@ //! //! The [`Style`] enum provides 7 formatting variants: //! -//! - **River** (default) — keywords right-aligned to form a visual river +//! - **River** — keywords right-aligned to form a visual river //! - **Mozilla** — keywords left-aligned, content indented 4 spaces -//! - **Aweber** — river with JOINs in keyword alignment +//! - **AWeber** (default) — river with JOINs in keyword alignment //! - **Dbt** — lowercase keywords, blank lines between clauses //! - **Gitlab** — 2-space indent, uppercase keywords //! - **Kickstarter** — 2-space indent, compact JOINs @@ -68,7 +68,7 @@ use tree_sitter_postgres::{LANGUAGE, LANGUAGE_PLPGSQL}; /// ``` /// use libpgfmt::{format, style::Style}; /// -/// // River style (default) +/// // River style /// let result = format("SELECT id FROM users WHERE active = TRUE", Style::River).unwrap(); /// assert_eq!(result, "SELECT id\n FROM users\n WHERE active = TRUE;"); /// diff --git a/src/style.rs b/src/style.rs index e349ae9..f2f2076 100644 --- a/src/style.rs +++ b/src/style.rs @@ -4,7 +4,7 @@ use std::str::FromStr; /// SQL formatting style. /// /// Each variant implements a different layout strategy for SQL statements. -/// Use [`Style::default()`] for the river style, or parse from a string: +/// Use [`Style::default()`] for the [AWeber style](https://gist.github.com/gmr/2cceb85bb37be96bc96f05c5b8de9e1b), or parse from a string: /// /// ``` /// use libpgfmt::style::Style; @@ -20,11 +20,11 @@ use std::str::FromStr; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum Style { /// Simon Holywell's river style — keywords right-aligned to form a visual river. - #[default] River, /// Mozilla style — keywords left-aligned, content indented 4 spaces. Mozilla, /// AWeber style — river style with JOINs participating in keyword alignment. + #[default] Aweber, /// dbt style — Mozilla-like with lowercase keywords and blank lines between clauses. Dbt, diff --git a/tests/all_fixtures.rs b/tests/all_fixtures.rs index 852b5e4..cbb1c8b 100644 --- a/tests/all_fixtures.rs +++ b/tests/all_fixtures.rs @@ -29,20 +29,6 @@ fn run_fixture(style: Style, style_name: &str, name: &str) { } } -/// Known fixtures that don't match expected output yet due to grammar -/// limitations or incomplete formatting support. These parse successfully -/// but produce different output than the pgfmt reference. -const KNOWN_FAILING: &[&str] = &[ - "river/create_domain", - "river/create_foreign_table", - "river/create_function", - "river/create_matview", - "river/create_table_with", - "river/create_view_cte", - "aweber/select_case_join", - "aweber/select_cte_nested", -]; - /// Discover all .sql files in each style directory and run them. #[test] fn all_fixture_pairs() { @@ -81,7 +67,6 @@ fn all_fixture_pairs() { continue; } let fixture_key = format!("{style_name}/{stem}"); - let is_known_failing = KNOWN_FAILING.contains(&fixture_key.as_str()); total += 1; let result = std::panic::catch_unwind(|| { run_fixture(*style, style_name, &stem); @@ -89,30 +74,22 @@ fn all_fixture_pairs() { match result { Ok(()) => { passed += 1; - if is_known_failing { - eprintln!("UNEXPECTED PASS {fixture_key}: remove from KNOWN_FAILING"); - } } Err(e) => { - if is_known_failing { - eprintln!("EXPECTED FAIL {fixture_key}"); - passed += 1; // Don't count as failure. + let msg = if let Some(s) = e.downcast_ref::() { + s.clone() + } else if let Some(s) = e.downcast_ref::<&str>() { + s.to_string() + } else { + "unknown panic".to_string() + }; + let short = if msg.chars().count() > 200 { + let truncated: String = msg.chars().take(200).collect(); + format!("{truncated}...") } else { - let msg = if let Some(s) = e.downcast_ref::() { - s.clone() - } else if let Some(s) = e.downcast_ref::<&str>() { - s.to_string() - } else { - "unknown panic".to_string() - }; - let short = if msg.chars().count() > 200 { - let truncated: String = msg.chars().take(200).collect(); - format!("{truncated}...") - } else { - msg - }; - failures.push(format!("{fixture_key}: {short}")); - } + msg + }; + failures.push(format!("{fixture_key}: {short}")); } } } diff --git a/tests/fixtures/river/create_matview.expected b/tests/fixtures/river/create_matview.expected index a419701..7f1a7df 100644 --- a/tests/fixtures/river/create_matview.expected +++ b/tests/fixtures/river/create_matview.expected @@ -11,19 +11,15 @@ CREATE MATERIALIZED VIEW report.service_subscription_info AS FROM public.subscriptions AS sp JOIN public.plans AS p USING (plan_id) - - INNER JOIN public.plan_details AS pd - USING (plan_id) - - INNER JOIN public.pricing_fees AS pf - USING (plan_detail_id) - - INNER JOIN public.features AS f - USING (feature_id) - - INNER JOIN public.billing_terms AS bt - USING (billing_term_id) - WHERE sp.cancelled_at IS NULL OR sp.cancelled_at > CURRENT_TIMESTAMP + JOIN public.plan_details AS pd + USING (plan_id) + JOIN public.pricing_fees AS pf + USING (plan_detail_id) + JOIN public.features AS f + USING (feature_id) + JOIN public.billing_terms AS bt + USING (billing_term_id) + WHERE (sp.cancelled_at IS NULL OR sp.cancelled_at > CURRENT_TIMESTAMP) AND f.feature_name = 'notifications'::TEXT ORDER BY sp.cancelled_at DESC, sp.started_at WITH NO DATA;