From d7c43c4734c0f00d5ba953fb4f8efebd36f5d8ef Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 1 Jul 2026 06:32:28 +0000 Subject: [PATCH 1/4] flash_attn_generic: drop _fadd/_fsub/_fmul/_fmax, write online softmax in FlyDSL types Remove the raw-arith FP helper wrappers and write the math directly with FlyDSL-typed ops: fx.Float32(a) +/-/* fx.Float32(b) and fx.Float32(a).maximumf(b), so values stay DSL-typed through the online softmax without .ir_value() round-trips. Drop the now-unused arith import. Base numeric/ArithValue classes are not modified. Note: DSL operators do not by themselves carry op-level fastmath (that is added separately, gated on the fast_fp_math hint), so on its own this is slightly slower than the original arith.*(fastmath) form. Numerics are bit-identical and tests pass. Co-Authored-By: Claude Opus 4.8 --- kernels/flash_attn_generic.py | 58 ++++++++++++++++------------------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/kernels/flash_attn_generic.py b/kernels/flash_attn_generic.py index dcd7024b1..0458f10bb 100644 --- a/kernels/flash_attn_generic.py +++ b/kernels/flash_attn_generic.py @@ -28,7 +28,7 @@ from flydsl._mlir import ir from flydsl._mlir.dialects import llvm from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, buffer_ops, const_expr, gpu, range_constexpr, rocdl +from flydsl.expr import buffer_ops, const_expr, gpu, range_constexpr, rocdl from flydsl.expr import math as fmath from flydsl.expr.typing import T from flydsl.expr.typing import Vector as Vec @@ -504,17 +504,12 @@ def flash_attn_generic_kernel( def _mfma(mfma_fn, a, b, c): return mfma_fn(v16f32_type, [a, b, c]) - def _fadd(a, b): - return arith.addf(_raw(a), _raw(b), fastmath=fm_fast) - - def _fsub(a, b): - return arith.subf(_raw(a), _raw(b), fastmath=fm_fast) - - def _fmul(a, b): - return arith.mulf(_raw(a), _raw(b), fastmath=fm_fast) - - def _fmax(a, b): - return arith.MaxNumFOp(_raw(a), _raw(b), fastmath=fm_fast).result + # FP math is written inline with FlyDSL-typed ops (no raw arith.* in the kernel): + # `fx.Float32(a) + / - / * fx.Float32(b)` and `fx.Float32(a).maximumf(b)`, each + # `.ir_value()`-unwrapped so the surrounding raw-MLIR call sites stay unchanged. + # NOTE: DSL operators don't carry op-level fastmath (only the function-level + # fast_fp_math attr applies), so this is a few % slower on large causal shapes than + # the old arith.*(fastmath=fast) form — accepted to avoid touching base numeric classes. def mfma_acc(a, b, c): if const_expr(dtype_str == "bf16"): @@ -1191,48 +1186,47 @@ def _k_idx_hi(ks): c_neg_inf, s_raw_hi[r] ) - local_max = s_raw_lo[0] + # Online softmax, all FlyDSL-typed (fx.Float32 flows through +/-/*/.maximumf/.exp2; + # raw inputs from Vec[]/fma/loop-carry are wrapped once at entry). reduction_peer + # returns fx.Float32; fma/ArithValue/Vec.from_elements all accept fx.Float32. + local_max = fx.Float32(s_raw_lo[0]) for r in range_constexpr(15): - local_max = _fmax(local_max, s_raw_lo[r + 1]) + local_max = local_max.maximumf(s_raw_lo[r + 1]) for r in range_constexpr(16): - local_max = _fmax(local_max, s_raw_hi[r]) + local_max = local_max.maximumf(s_raw_hi[r]) peer_max = reduction_peer(local_max) - row_max = _fmax(local_max, peer_max) - m_new_raw = _fmax(m_running, row_max) - - diff_m_raw = _fsub(m_running, m_new_raw) - diff_m_scaled = _fmul(diff_m_raw, c_sm_scale_log2e) - corr = ArithValue(diff_m_scaled).exp2(fastmath=fm_fast) + row_max = local_max.maximumf(peer_max) + m_run = fx.Float32(m_running) + m_new_raw = m_run.maximumf(row_max) - scaled_max = _fmul(c_sm_scale_log2e, m_new_raw) - neg_scaled_max = _fsub(c_zero_f, scaled_max) + corr = ((m_run - m_new_raw) * c_sm_scale_log2e).exp2(fastmath=fm_fast) + neg_scaled_max = c_zero_f - c_sm_scale_log2e * m_new_raw p_vals_lo = [] p_vals_hi = [] local_sum = c_zero_f for r in range_constexpr(16): diff_lo = fmath.fma(s_raw_lo[r], c_sm_scale_log2e, neg_scaled_max, fastmath=fm_fast) - p_lo = ArithValue(diff_lo).exp2(fastmath=fm_fast) + p_lo = fx.Float32(diff_lo).exp2(fastmath=fm_fast) p_vals_lo.append(p_lo) - local_sum = _fadd(local_sum, p_lo) + local_sum = local_sum + p_lo for r in range_constexpr(16): diff_hi = fmath.fma(s_raw_hi[r], c_sm_scale_log2e, neg_scaled_max, fastmath=fm_fast) - p_hi = ArithValue(diff_hi).exp2(fastmath=fm_fast) + p_hi = fx.Float32(diff_hi).exp2(fastmath=fm_fast) p_vals_hi.append(p_hi) - local_sum = _fadd(local_sum, p_hi) + local_sum = local_sum + p_hi peer_sum = reduction_peer(local_sum) - tile_sum = _fadd(local_sum, peer_sum) - l_corr = _fmul(corr, l_running) - l_new = _fadd(l_corr, tile_sum) + tile_sum = local_sum + peer_sum + l_new = corr * l_running + tile_sum # ==== Rescale O accumulators ==== corr_vec = Vec.from_elements([corr], fx.Float32).broadcast_to(16) if const_expr(not USE_HW_TR): - o_accs[0] = _fmul(Vec(o_accs[0]), corr_vec) + o_accs[0] = Vec(o_accs[0]) * corr_vec else: for dc in range_constexpr(D_CHUNKS): - o_accs[dc] = _fmul(Vec(o_accs[dc]), corr_vec) + o_accs[dc] = Vec(o_accs[dc]) * corr_vec if const_expr(ENABLE_PREFETCH_3BUF and (kv_sub + preload_k_count) < N_SUBTILES): next_k_sub = kv_sub + preload_k_count From 483ad0d9501710e98dffbb9c4d61a8745d087b19 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 1 Jul 2026 06:32:28 +0000 Subject: [PATCH 2/4] expr: gate DSL float-op fast-math on the fast_fp_math compile hint Make every DSL float op (Numeric/ArithValue +,-,*,/,//,%, unary neg, maximumf) attach op-level fastmath=fast only when compiled with the fast_fp_math hint (default off / strict IEEE), via _default_fastmath() reading CompilationContext.get_compile_hints(). This is the op-level analogue of -ffast-math and reuses the same fast_fp_math switch that already drives the function/target-level fast attr in the ROCm backend, so one flag turns on both. Verified: with fast_fp_math the DSL-operator kernel matches the original arith.*(fastmath) kernel; the strict default is slightly slower; numerics are bit-identical. Framework change -> must be synced into build-fly/python_packages at runtime. Co-Authored-By: Claude Opus 4.8 --- python/flydsl/expr/utils/arith.py | 32 +++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/python/flydsl/expr/utils/arith.py b/python/flydsl/expr/utils/arith.py index 5d3f78363..fc5f31608 100644 --- a/python/flydsl/expr/utils/arith.py +++ b/python/flydsl/expr/utils/arith.py @@ -144,6 +144,22 @@ def _coerce_other(self, other): } +def _default_fastmath(): + """Op-level fast-math for DSL float ops, gated on the ``fast_fp_math`` compile hint. + + Default OFF (FastMathFlags.none / strict IEEE). When a kernel is compiled with + ``fast_fp_math=True`` (e.g. ``flyc.compile[{"fast_fp_math": True}]``), every DSL + float op gets FastMathFlags.fast (reassoc,nnan,ninf,nsz,arcp,contract,afn) so the + backend can contract FMAs / reassociate — the op-level analogue of clang + ``-ffast-math``. The same hint also drives the function/target-level fast attr in + the ROCm backend, so one switch turns on both. + """ + from ...compiler.kernel_function import CompilationContext # lazy: avoid circular import + + hints = CompilationContext.get_compile_hints() + return arith.FastMathFlags.fast if hints.get("fast_fp_math") else arith.FastMathFlags.none + + @dsl_loc_tracing def _binary_op(self, other, op): other = _coerce_other(self, other) @@ -153,23 +169,23 @@ def _binary_op(self, other, op): if op in _ARITH_OPS: float_fn, int_fn = _ARITH_OPS[op] if self.is_float: - return float_fn(self, other) + return float_fn(self, other, fastmath=_default_fastmath()) return int_fn(self, other) if op == "div": if self.is_float: - return arith.divf(self, other) + return arith.divf(self, other, fastmath=_default_fastmath()) et = element_type(self.type) if isinstance(et, ir.IndexType): return arith.divui(self, other) fp_ty = T.f64() if et.width > 32 else T.f32() lhs = int_to_fp(self, self.signed, fp_ty) rhs = int_to_fp(other, other.signed, fp_ty) - return arith.divf(lhs, rhs) + return arith.divf(lhs, rhs, fastmath=_default_fastmath()) if op == "floordiv": if self.is_float: - q = arith.divf(self, other) + q = arith.divf(self, other, fastmath=_default_fastmath()) return math.floor(q) et = element_type(self.type) if isinstance(et, ir.IndexType): @@ -180,7 +196,7 @@ def _binary_op(self, other, op): if op == "mod": if self.is_float: - return arith.remf(self, other) + return arith.remf(self, other, fastmath=_default_fastmath()) et = element_type(self.type) if isinstance(et, ir.IndexType): return arith.remui(self, other) @@ -293,7 +309,7 @@ def _neg_op(self): if self.type == T.bool(): raise TypeError("negation is not supported for boolean type") if self.is_float: - return arith.negf(self) + return arith.negf(self, fastmath=_default_fastmath()) c0 = arith_const(0, self.type) return arith.subi(c0, self) @@ -425,8 +441,8 @@ def addf(self, other, *, fastmath=None): @dsl_loc_tracing def maximumf(self, other): - """Float maximum (NaN-propagating).""" - return arith.maximumf(self, _to_raw(other)) + """Float maximum (NaN-propagating), op-level fast-math gated on fast_fp_math hint.""" + return arith.maximumf(self, _to_raw(other), fastmath=_default_fastmath()) @dsl_loc_tracing def rsqrt(self, *, fastmath=None): From 5f41e321b3c695c9398569dfbe4962a009e5c583 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 1 Jul 2026 06:32:28 +0000 Subject: [PATCH 3/4] flash_attn_generic: pass compile hints via flyc.compile[...] so they enter the cache key The fmha compile hints were set only through the CompilationContext.compile_hints context manager, which populates the tracing/backend context but not JitFunction.compile_hints -- and the disk cache key is built from self.compile_hints. So compiling with fast_fp_math on vs off collided on the same cache entry (stale binary reused). Attach the hints to the jit function and compile via flyc.compile[_fmha_compile_hints](...), so the hints are part of the cache key and JitFunction.__call__ still auto-pushes them into the context during tracing/codegen. The manual `with` wrappers are removed. Verified: toggling fast_fp_math recompiles to distinct cache entries instead of reusing a stale binary; numerics are bit-identical. Co-Authored-By: Claude Opus 4.8 --- kernels/flash_attn_generic.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/kernels/flash_attn_generic.py b/kernels/flash_attn_generic.py index 0458f10bb..23fd9629d 100644 --- a/kernels/flash_attn_generic.py +++ b/kernels/flash_attn_generic.py @@ -1526,13 +1526,22 @@ def launch_flash_attn_generic( }, } + # Attach hints to the jit function itself so they are part of the disk cache key + # (jit_function: cache key includes self.compile_hints) and are auto-pushed into the + # compilation context during tracing/codegen. This is what makes fast_fp_math=True/False + # produce distinct cache entries instead of colliding. + launch_flash_attn_generic.compile_hints = { + **launch_flash_attn_generic.compile_hints, + **_fmha_compile_hints, + } + def _launch(*args, **kwargs): - with CompilationContext.compile_hints(_fmha_compile_hints): - return launch_flash_attn_generic(*args, **kwargs) + return launch_flash_attn_generic(*args, **kwargs) def _compile(Q, K, V, O, batch_size, seq_len, stream=None): # noqa: E741 - with CompilationContext.compile_hints(_fmha_compile_hints): - return flyc.compile(launch_flash_attn_generic, Q, K, V, O, batch_size, seq_len, fx.Stream(stream)) + return flyc.compile[_fmha_compile_hints]( + launch_flash_attn_generic, Q, K, V, O, batch_size, seq_len, fx.Stream(stream) + ) _launch.compile = _compile From e38970754676666fbb9122d70dc920dd86672d12 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 1 Jul 2026 09:35:24 +0000 Subject: [PATCH 4/4] expr/arith: add type-generic max()/min() Add a single type-generic max()/min() in expr/arith.py, modeled on cutlass.max/min. Accepts any DSL numeric type (Float32/Int32/Int64/unsigned ...) and Python scalars, any number of args (max(a,b), max(a,b,c), max([a,b,...]), max(a,[x,y])). Return type follows the operands' static types. Dispatch reuses the shared numeric coercion (as_numeric + _coerce_operands) and the resulting type: float -> maximumf / minimumf (NaN-propagating, matches cutlass) int signed -> maxsi / minsi int unsigned -> maxui / minui Float fastmath follows the fast_fp_math compile hint. The maximum-vs-maxnum choice is not exposed (matches cutlass). Registered in __all__ (fx.max/fx.min); decorated with dsl_loc_tracing like the other builders. Co-Authored-By: Claude Opus 4.8 --- python/flydsl/expr/arith.py | 63 +++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index 832ec28a0..d3a61cddc 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -31,6 +31,8 @@ "xori", "cmpi", "cmpf", + "max", + "min", ] # Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.) @@ -38,6 +40,7 @@ from .meta import dsl_loc_tracing from .utils.arith import ( # noqa: F401 ArithValue, + _default_fastmath, _to_raw, andi, constant, @@ -82,3 +85,63 @@ def cmpf(predicate, lhs, rhs, **kwargs): An ``i1`` comparison result. """ return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs) + + +# ── Type-generic max / min ────────────────────────────────────────────────── +# One entry point for any DSL numeric type (Float32/Int32/Int64/unsigned/...) and +# Python scalars, any number of args. Reuses the shared numeric coercion +# (as_numeric + _coerce_operands) and dispatches by the resulting type: +# float -> maximumf / minimumf (NaN-propagating, matches cutlass.max) +# int, signed -> maxsi / minsi +# int, unsigned -> maxui / minui +# The maximum-vs-maxnum choice is NOT exposed (matches cutlass). Return type +# follows the operands' static type. + + +def _minmax_pair(is_max, a, b): + from .numeric import _coerce_operands, as_numeric + + a, b, out_ty = _coerce_operands(as_numeric(a), as_numeric(b)) + lv, rv = a.ir_value(), b.ir_value() + if out_ty.is_float: + fn = _mlir_arith.maximumf if is_max else _mlir_arith.minimumf + res = fn(lv, rv, fastmath=_default_fastmath()) + elif out_ty.signed: + fn = _mlir_arith.maxsi if is_max else _mlir_arith.minsi + res = fn(lv, rv) + else: + fn = _mlir_arith.maxui if is_max else _mlir_arith.minui + res = fn(lv, rv) + return out_ty(res) + + +def _minmax(is_max, args): + flat = [] + for a in args: + if isinstance(a, (list, tuple)): + flat.extend(a) + else: + flat.append(a) + if not flat: + raise ValueError("max()/min() requires at least one argument") + acc = flat[0] + for x in flat[1:]: + acc = _minmax_pair(is_max, acc, x) + return acc + + +@dsl_loc_tracing +def max(*args): + """Type-generic maximum over any number of DSL numeric args (and Python scalars). + + Return type follows the operands' static types (not values). Accepts + ``max(a, b)``, ``max(a, b, c, ...)``, ``max([a, b, ...])``, ``max(a, [x, y])``. + Float uses NaN-propagating ``maximumf``; signed/unsigned int use ``maxsi``/``maxui``. + """ + return _minmax(True, args) + + +@dsl_loc_tracing +def min(*args): + """Type-generic minimum. See :func:`max`.""" + return _minmax(False, args)