Skip to content

[Lang] Add qd.math.fma(...) single-rounding fused multiply-add#478

Open
duburcqa wants to merge 39 commits into
duburcqa/qd_precisefrom
duburcqa/qd_math_fma
Open

[Lang] Add qd.math.fma(...) single-rounding fused multiply-add#478
duburcqa wants to merge 39 commits into
duburcqa/qd_precisefrom
duburcqa/qd_math_fma

[Lang] Add qd.math.fma(...) single-rounding fused multiply-add

b66c9d0
Select commit
Loading
Failed to load commit list.
Claude / Claude Code Review completed Apr 21, 2026 in 14m 30s

Code review found 3 important issues

Found 6 candidates, confirmed 4. See review comments for details.

Details

Severity Count
🔴 Important 3
🟡 Nit 0
🟣 Pre-existing 0
Severity File:Line Issue
🔴 Important quadrants/ir/stmt_op_types.h:61 qd.math.fma crashes autodiff: auto_diff.cpp asserts op_type == select
🔴 Important quadrants/transforms/scalarize.cpp:272-280 scalarize.cpp: precise flag silently dropped when scalarizing tensor-typed fma
🔴 Important quadrants/ir/statements.h:296 TernaryOpStmt QD_STMT_DEF_FIELDS missing op_type: CSE silently conflates fma and select

Annotations

Check failure on line 61 in quadrants/ir/stmt_op_types.h

See this annotation in the file changed.

@claude claude / Claude Code Review

qd.math.fma crashes autodiff: auto_diff.cpp asserts op_type == select

The PR adds `TernaryOpType::fma` to the enum but never updates `auto_diff.cpp`, which contains `QD_ASSERT(stmt->op_type == TernaryOpType::select)` in both the reverse-mode visitor (line 1326) and the forward-mode visitor (line 1916) — causing a hard process abort whenever an `fma` node is encountered during gradient lowering. Any kernel that calls `qd.math.fma(...)` inside `qd.ad.Tape()` or `forward_grad()` will crash with a failed assertion rather than computing gradients; additionally, `NonLin

Check failure on line 280 in quadrants/transforms/scalarize.cpp

See this annotation in the file changed.

@claude claude / Claude Code Review

scalarize.cpp: precise flag silently dropped when scalarizing tensor-typed fma

The `precise` flag is silently dropped when scalarizing a tensor-typed `TernaryOpStmt` (fma): both loops in `scalarize.cpp` at lines 619 and 655 create scalar `TernaryOpStmt` nodes but never copy `stmt->precise` onto them. A `qd.precise(fma(a, b, c))` expression over a `TensorType` is therefore scalarized into N elements each with `precise=false`, causing the LLVM backend to emit `fmuladd` (optimizable to separate mul+add) instead of the strict `llvm.fma` intrinsic, silently breaking the single-

Check failure on line 296 in quadrants/ir/statements.h

See this annotation in the file changed.

@claude claude / Claude Code Review

TernaryOpStmt QD_STMT_DEF_FIELDS missing op_type: CSE silently conflates fma and select

TernaryOpStmt's QD_STMT_DEF_FIELDS is updated to (ret_type, op1, op2, op3, precise) but op_type is still absent, while BinaryOpStmt correctly includes op_type in the same PR. CSE equality checks typeid (both fma and select are TernaryOpStmt), then field_manager.equal() (only ret_type and precise), then operand pointers — so fma(a, b, c) and select(a, b, c) with identical Stmt* operands and the same ret_type compare as equal and are silently collapsed into one, producing incorrect output with no