Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 39 additions & 36 deletions kernels/flash_attn_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from flydsl._mlir import ir
from flydsl._mlir.dialects import llvm
from flydsl.compiler.kernel_function import CompilationContext
from flydsl.expr import arith, buffer_ops, const_expr, gpu, range_constexpr, rocdl
from flydsl.expr import buffer_ops, const_expr, gpu, range_constexpr, rocdl
from flydsl.expr import math as fmath
from flydsl.expr.typing import T
from flydsl.expr.typing import Vector as Vec
Expand Down Expand Up @@ -504,17 +504,12 @@ def flash_attn_generic_kernel(
def _mfma(mfma_fn, a, b, c):
return mfma_fn(v16f32_type, [a, b, c])

def _fadd(a, b):
return arith.addf(_raw(a), _raw(b), fastmath=fm_fast)

def _fsub(a, b):
return arith.subf(_raw(a), _raw(b), fastmath=fm_fast)

def _fmul(a, b):
return arith.mulf(_raw(a), _raw(b), fastmath=fm_fast)

def _fmax(a, b):
return arith.MaxNumFOp(_raw(a), _raw(b), fastmath=fm_fast).result
# FP math is written inline with FlyDSL-typed ops (no raw arith.* in the kernel):
# `fx.Float32(a) + / - / * fx.Float32(b)` and `fx.Float32(a).maximumf(b)`, each
# `.ir_value()`-unwrapped so the surrounding raw-MLIR call sites stay unchanged.
# NOTE: DSL operators don't carry op-level fastmath (only the function-level
# fast_fp_math attr applies), so this is a few % slower on large causal shapes than
# the old arith.*(fastmath=fast) form — accepted to avoid touching base numeric classes.

def mfma_acc(a, b, c):
if const_expr(dtype_str == "bf16"):
Expand Down Expand Up @@ -1191,48 +1186,47 @@ def _k_idx_hi(ks):
c_neg_inf, s_raw_hi[r]
)

local_max = s_raw_lo[0]
# Online softmax, all FlyDSL-typed (fx.Float32 flows through +/-/*/.maximumf/.exp2;
# raw inputs from Vec[]/fma/loop-carry are wrapped once at entry). reduction_peer
# returns fx.Float32; fma/ArithValue/Vec.from_elements all accept fx.Float32.
local_max = fx.Float32(s_raw_lo[0])
for r in range_constexpr(15):
local_max = _fmax(local_max, s_raw_lo[r + 1])
local_max = local_max.maximumf(s_raw_lo[r + 1])
for r in range_constexpr(16):
local_max = _fmax(local_max, s_raw_hi[r])
local_max = local_max.maximumf(s_raw_hi[r])
peer_max = reduction_peer(local_max)
row_max = _fmax(local_max, peer_max)
m_new_raw = _fmax(m_running, row_max)

diff_m_raw = _fsub(m_running, m_new_raw)
diff_m_scaled = _fmul(diff_m_raw, c_sm_scale_log2e)
corr = ArithValue(diff_m_scaled).exp2(fastmath=fm_fast)
row_max = local_max.maximumf(peer_max)
m_run = fx.Float32(m_running)
m_new_raw = m_run.maximumf(row_max)

scaled_max = _fmul(c_sm_scale_log2e, m_new_raw)
neg_scaled_max = _fsub(c_zero_f, scaled_max)
corr = ((m_run - m_new_raw) * c_sm_scale_log2e).exp2(fastmath=fm_fast)
neg_scaled_max = c_zero_f - c_sm_scale_log2e * m_new_raw

p_vals_lo = []
p_vals_hi = []
local_sum = c_zero_f
for r in range_constexpr(16):
diff_lo = fmath.fma(s_raw_lo[r], c_sm_scale_log2e, neg_scaled_max, fastmath=fm_fast)
p_lo = ArithValue(diff_lo).exp2(fastmath=fm_fast)
p_lo = fx.Float32(diff_lo).exp2(fastmath=fm_fast)
p_vals_lo.append(p_lo)
local_sum = _fadd(local_sum, p_lo)
local_sum = local_sum + p_lo
for r in range_constexpr(16):
diff_hi = fmath.fma(s_raw_hi[r], c_sm_scale_log2e, neg_scaled_max, fastmath=fm_fast)
p_hi = ArithValue(diff_hi).exp2(fastmath=fm_fast)
p_hi = fx.Float32(diff_hi).exp2(fastmath=fm_fast)
p_vals_hi.append(p_hi)
local_sum = _fadd(local_sum, p_hi)
local_sum = local_sum + p_hi

peer_sum = reduction_peer(local_sum)
tile_sum = _fadd(local_sum, peer_sum)
l_corr = _fmul(corr, l_running)
l_new = _fadd(l_corr, tile_sum)
tile_sum = local_sum + peer_sum
l_new = corr * l_running + tile_sum

# ==== Rescale O accumulators ====
corr_vec = Vec.from_elements([corr], fx.Float32).broadcast_to(16)
if const_expr(not USE_HW_TR):
o_accs[0] = _fmul(Vec(o_accs[0]), corr_vec)
o_accs[0] = Vec(o_accs[0]) * corr_vec
else:
for dc in range_constexpr(D_CHUNKS):
o_accs[dc] = _fmul(Vec(o_accs[dc]), corr_vec)
o_accs[dc] = Vec(o_accs[dc]) * corr_vec

if const_expr(ENABLE_PREFETCH_3BUF and (kv_sub + preload_k_count) < N_SUBTILES):
next_k_sub = kv_sub + preload_k_count
Expand Down Expand Up @@ -1532,13 +1526,22 @@ def launch_flash_attn_generic(
},
}

# Attach hints to the jit function itself so they are part of the disk cache key
# (jit_function: cache key includes self.compile_hints) and are auto-pushed into the
# compilation context during tracing/codegen. This is what makes fast_fp_math=True/False
# produce distinct cache entries instead of colliding.
launch_flash_attn_generic.compile_hints = {
**launch_flash_attn_generic.compile_hints,
**_fmha_compile_hints,
}

def _launch(*args, **kwargs):
with CompilationContext.compile_hints(_fmha_compile_hints):
return launch_flash_attn_generic(*args, **kwargs)
return launch_flash_attn_generic(*args, **kwargs)

def _compile(Q, K, V, O, batch_size, seq_len, stream=None): # noqa: E741
with CompilationContext.compile_hints(_fmha_compile_hints):
return flyc.compile(launch_flash_attn_generic, Q, K, V, O, batch_size, seq_len, fx.Stream(stream))
return flyc.compile[_fmha_compile_hints](
launch_flash_attn_generic, Q, K, V, O, batch_size, seq_len, fx.Stream(stream)
)

_launch.compile = _compile

Expand Down
63 changes: 63 additions & 0 deletions python/flydsl/expr/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@
"xori",
"cmpi",
"cmpf",
"max",
"min",
]

# Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.)
from .._mlir.dialects import arith as _mlir_arith
from .meta import dsl_loc_tracing
from .utils.arith import ( # noqa: F401
ArithValue,
_default_fastmath,
_to_raw,
andi,
constant,
Expand Down Expand Up @@ -82,3 +85,63 @@ def cmpf(predicate, lhs, rhs, **kwargs):
An ``i1`` comparison result.
"""
return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs)


# ── Type-generic max / min ──────────────────────────────────────────────────
# One entry point for any DSL numeric type (Float32/Int32/Int64/unsigned/...) and
# Python scalars, any number of args. Reuses the shared numeric coercion
# (as_numeric + _coerce_operands) and dispatches by the resulting type:
# float -> maximumf / minimumf (NaN-propagating, matches cutlass.max)
# int, signed -> maxsi / minsi
# int, unsigned -> maxui / minui
# The maximum-vs-maxnum choice is NOT exposed (matches cutlass). Return type
# follows the operands' static type.


def _minmax_pair(is_max, a, b):
from .numeric import _coerce_operands, as_numeric

a, b, out_ty = _coerce_operands(as_numeric(a), as_numeric(b))
lv, rv = a.ir_value(), b.ir_value()
if out_ty.is_float:
fn = _mlir_arith.maximumf if is_max else _mlir_arith.minimumf
res = fn(lv, rv, fastmath=_default_fastmath())
elif out_ty.signed:
fn = _mlir_arith.maxsi if is_max else _mlir_arith.minsi
res = fn(lv, rv)
else:
fn = _mlir_arith.maxui if is_max else _mlir_arith.minui
res = fn(lv, rv)
return out_ty(res)


def _minmax(is_max, args):
flat = []
for a in args:
if isinstance(a, (list, tuple)):
flat.extend(a)
else:
flat.append(a)
if not flat:
raise ValueError("max()/min() requires at least one argument")
acc = flat[0]
for x in flat[1:]:
acc = _minmax_pair(is_max, acc, x)
return acc


@dsl_loc_tracing
def max(*args):
"""Type-generic maximum over any number of DSL numeric args (and Python scalars).

Return type follows the operands' static types (not values). Accepts
``max(a, b)``, ``max(a, b, c, ...)``, ``max([a, b, ...])``, ``max(a, [x, y])``.
Float uses NaN-propagating ``maximumf``; signed/unsigned int use ``maxsi``/``maxui``.
"""
return _minmax(True, args)


@dsl_loc_tracing
def min(*args):
"""Type-generic minimum. See :func:`max`."""
return _minmax(False, args)
32 changes: 24 additions & 8 deletions python/flydsl/expr/utils/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ def _coerce_other(self, other):
}


def _default_fastmath():
"""Op-level fast-math for DSL float ops, gated on the ``fast_fp_math`` compile hint.

Default OFF (FastMathFlags.none / strict IEEE). When a kernel is compiled with
``fast_fp_math=True`` (e.g. ``flyc.compile[{"fast_fp_math": True}]``), every DSL
float op gets FastMathFlags.fast (reassoc,nnan,ninf,nsz,arcp,contract,afn) so the
backend can contract FMAs / reassociate — the op-level analogue of clang
``-ffast-math``. The same hint also drives the function/target-level fast attr in
the ROCm backend, so one switch turns on both.
"""
from ...compiler.kernel_function import CompilationContext # lazy: avoid circular import

hints = CompilationContext.get_compile_hints()
return arith.FastMathFlags.fast if hints.get("fast_fp_math") else arith.FastMathFlags.none


@dsl_loc_tracing
def _binary_op(self, other, op):
other = _coerce_other(self, other)
Expand All @@ -153,23 +169,23 @@ def _binary_op(self, other, op):
if op in _ARITH_OPS:
float_fn, int_fn = _ARITH_OPS[op]
if self.is_float:
return float_fn(self, other)
return float_fn(self, other, fastmath=_default_fastmath())
return int_fn(self, other)

if op == "div":
if self.is_float:
return arith.divf(self, other)
return arith.divf(self, other, fastmath=_default_fastmath())
et = element_type(self.type)
if isinstance(et, ir.IndexType):
return arith.divui(self, other)
fp_ty = T.f64() if et.width > 32 else T.f32()
lhs = int_to_fp(self, self.signed, fp_ty)
rhs = int_to_fp(other, other.signed, fp_ty)
return arith.divf(lhs, rhs)
return arith.divf(lhs, rhs, fastmath=_default_fastmath())

if op == "floordiv":
if self.is_float:
q = arith.divf(self, other)
q = arith.divf(self, other, fastmath=_default_fastmath())
return math.floor(q)
et = element_type(self.type)
if isinstance(et, ir.IndexType):
Expand All @@ -180,7 +196,7 @@ def _binary_op(self, other, op):

if op == "mod":
if self.is_float:
return arith.remf(self, other)
return arith.remf(self, other, fastmath=_default_fastmath())
et = element_type(self.type)
if isinstance(et, ir.IndexType):
return arith.remui(self, other)
Expand Down Expand Up @@ -293,7 +309,7 @@ def _neg_op(self):
if self.type == T.bool():
raise TypeError("negation is not supported for boolean type")
if self.is_float:
return arith.negf(self)
return arith.negf(self, fastmath=_default_fastmath())
c0 = arith_const(0, self.type)
return arith.subi(c0, self)

Expand Down Expand Up @@ -425,8 +441,8 @@ def addf(self, other, *, fastmath=None):

@dsl_loc_tracing
def maximumf(self, other):
"""Float maximum (NaN-propagating)."""
return arith.maximumf(self, _to_raw(other))
"""Float maximum (NaN-propagating), op-level fast-math gated on fast_fp_math hint."""
return arith.maximumf(self, _to_raw(other), fastmath=_default_fastmath())

@dsl_loc_tracing
def rsqrt(self, *, fastmath=None):
Expand Down