[Lang] Add qd.math.fma(...) single-rounding fused multiply-add#478
Open
duburcqa wants to merge 39 commits into
Open
[Lang] Add qd.math.fma(...) single-rounding fused multiply-add#478duburcqa wants to merge 39 commits into
duburcqa wants to merge 39 commits into
Claude / Claude Code Review
completed
Apr 21, 2026 in 14m 30s
Code review found 3 important issues
Found 6 candidates, confirmed 4. See review comments for details.
Details
| Severity | Count |
|---|---|
| 🔴 Important | 3 |
| 🟡 Nit | 0 |
| 🟣 Pre-existing | 0 |
| Severity | File:Line | Issue |
|---|---|---|
| 🔴 Important | quadrants/ir/stmt_op_types.h:61 |
qd.math.fma crashes autodiff: auto_diff.cpp asserts op_type == select |
| 🔴 Important | quadrants/transforms/scalarize.cpp:272-280 |
scalarize.cpp: precise flag silently dropped when scalarizing tensor-typed fma |
| 🔴 Important | quadrants/ir/statements.h:296 |
TernaryOpStmt QD_STMT_DEF_FIELDS missing op_type: CSE silently conflates fma and select |
Annotations
Check failure on line 61 in quadrants/ir/stmt_op_types.h
claude / Claude Code Review
qd.math.fma crashes autodiff: auto_diff.cpp asserts op_type == select
The PR adds `TernaryOpType::fma` to the enum but never updates `auto_diff.cpp`, which contains `QD_ASSERT(stmt->op_type == TernaryOpType::select)` in both the reverse-mode visitor (line 1326) and the forward-mode visitor (line 1916) — causing a hard process abort whenever an `fma` node is encountered during gradient lowering. Any kernel that calls `qd.math.fma(...)` inside `qd.ad.Tape()` or `forward_grad()` will crash with a failed assertion rather than computing gradients; additionally, `NonLin
Check failure on line 280 in quadrants/transforms/scalarize.cpp
claude / Claude Code Review
scalarize.cpp: precise flag silently dropped when scalarizing tensor-typed fma
The `precise` flag is silently dropped when scalarizing a tensor-typed `TernaryOpStmt` (fma): both loops in `scalarize.cpp` at lines 619 and 655 create scalar `TernaryOpStmt` nodes but never copy `stmt->precise` onto them. A `qd.precise(fma(a, b, c))` expression over a `TensorType` is therefore scalarized into N elements each with `precise=false`, causing the LLVM backend to emit `fmuladd` (optimizable to separate mul+add) instead of the strict `llvm.fma` intrinsic, silently breaking the single-
Check failure on line 296 in quadrants/ir/statements.h
claude / Claude Code Review
TernaryOpStmt QD_STMT_DEF_FIELDS missing op_type: CSE silently conflates fma and select
TernaryOpStmt's QD_STMT_DEF_FIELDS is updated to (ret_type, op1, op2, op3, precise) but op_type is still absent, while BinaryOpStmt correctly includes op_type in the same PR. CSE equality checks typeid (both fma and select are TernaryOpStmt), then field_manager.equal() (only ret_type and precise), then operand pointers — so fma(a, b, c) and select(a, b, c) with identical Stmt* operands and the same ret_type compare as equal and are silently collapsed into one, producing incorrect output with no
Loading