Skip to content

flash_attn_generic: replace raw arith.* FP ops with FlyDSL-typed fast…#764

Open
xudoyuan wants to merge 4 commits into
mainfrom
cleanup/flash-attn-generic-arith-to-flydsl
Open

flash_attn_generic: replace raw arith.* FP ops with FlyDSL-typed fast…#764
xudoyuan wants to merge 4 commits into
mainfrom
cleanup/flash-attn-generic-arith-to-flydsl

Conversation

@xudoyuan

@xudoyuan xudoyuan commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

…math proxies

The QK/softmax helpers used raw MLIR ops (arith.addf/subf/mulf/MaxNumFOp) purely to attach op-level fastmath, because the DSL float types only exposed addf(fastmath=) and maximumf (no fastmath). This dropped to the MLIR layer inside the kernel, which we want to avoid.

Add the missing fastmath-capable proxies to the DSL so kernels can stay FlyDSL-typed without losing op-level fastmath:

  • ArithValue: subf / mulf / maxnumf (python/flydsl/expr/utils/arith.py)
  • Numeric proxies: subf / mulf / maxnumf (python/flydsl/expr/numeric.py)

Convert flash_attn_generic _fadd/_fsub/_fmul/_fmax to use Float32.addf/subf/mulf/maxnumf(fastmath=fm_fast); maxnumf preserves the original NaN-non-propagating maxnum semantics. Drop the now-unused arith import. Helpers still return a raw MLIR value so downstream call sites are unchanged.

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Comment thread python/flydsl/expr/numeric.py Outdated
@xudoyuan xudoyuan linked an issue Jun 30, 2026 that may be closed by this pull request
@xudoyuan xudoyuan force-pushed the cleanup/flash-attn-generic-arith-to-flydsl branch 4 times, most recently from eb5bf5a to 70f447c Compare June 30, 2026 06:13
@xudoyuan xudoyuan requested a review from sjfeng1999 July 1, 2026 06:09
@xudoyuan xudoyuan marked this pull request as ready for review July 1, 2026 06:10
xudoyuan and others added 3 commits July 1, 2026 06:32
…x in FlyDSL types

Remove the raw-arith FP helper wrappers and write the math directly with FlyDSL-typed
ops: fx.Float32(a) +/-/* fx.Float32(b) and fx.Float32(a).maximumf(b), so values stay
DSL-typed through the online softmax without .ir_value() round-trips. Drop the now-unused
arith import. Base numeric/ArithValue classes are not modified.

Note: DSL operators do not by themselves carry op-level fastmath (that is added separately,
gated on the fast_fp_math hint), so on its own this is slightly slower than the original
arith.*(fastmath) form. Numerics are bit-identical and tests pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Make every DSL float op (Numeric/ArithValue +,-,*,/,//,%, unary neg, maximumf) attach
op-level fastmath=fast only when compiled with the fast_fp_math hint (default off / strict
IEEE), via _default_fastmath() reading CompilationContext.get_compile_hints(). This is the
op-level analogue of -ffast-math and reuses the same fast_fp_math switch that already drives
the function/target-level fast attr in the ROCm backend, so one flag turns on both.

Verified: with fast_fp_math the DSL-operator kernel matches the original arith.*(fastmath)
kernel; the strict default is slightly slower; numerics are bit-identical. Framework change
-> must be synced into build-fly/python_packages at runtime.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…enter the cache key

The fmha compile hints were set only through the CompilationContext.compile_hints context
manager, which populates the tracing/backend context but not JitFunction.compile_hints --
and the disk cache key is built from self.compile_hints. So compiling with fast_fp_math on
vs off collided on the same cache entry (stale binary reused).

Attach the hints to the jit function and compile via flyc.compile[_fmha_compile_hints](...),
so the hints are part of the cache key and JitFunction.__call__ still auto-pushes them into
the context during tracing/codegen. The manual `with` wrappers are removed.

Verified: toggling fast_fp_math recompiles to distinct cache entries instead of reusing a
stale binary; numerics are bit-identical.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@xudoyuan xudoyuan force-pushed the cleanup/flash-attn-generic-arith-to-flydsl branch from 1ee2c83 to 5f41e32 Compare July 1, 2026 06:32
# All FP operations use aggressive fast-math (no NaN/Inf checks, reassociation).
# The unsafe_fp_math/fast_fp_math builder params control LLVM-level attributes only.
fm_fast = fx.arith.FastMathFlags.fast
v4f16_type = Vec.make_type(4, elem_dtype)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update to Float16x4

Add a single type-generic max()/min() in expr/arith.py, modeled on
cutlass.max/min. Accepts any DSL numeric type (Float32/Int32/Int64/unsigned
...) and Python scalars, any number of args (max(a,b), max(a,b,c),
max([a,b,...]), max(a,[x,y])). Return type follows the operands' static types.

Dispatch reuses the shared numeric coercion (as_numeric + _coerce_operands)
and the resulting type:
  float         -> maximumf / minimumf  (NaN-propagating, matches cutlass)
  int signed    -> maxsi / minsi
  int unsigned  -> maxui / minui
Float fastmath follows the fast_fp_math compile hint. The maximum-vs-maxnum
choice is not exposed (matches cutlass). Registered in __all__ (fx.max/fx.min);
decorated with dsl_loc_tracing like the other builders.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@xudoyuan xudoyuan force-pushed the cleanup/flash-attn-generic-arith-to-flydsl branch from f70ba8a to e389707 Compare July 1, 2026 11:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Type closure use case refactoring

2 participants