flash_attn_generic: replace raw arith.* FP ops with FlyDSL-typed fast…#764
Open
xudoyuan wants to merge 4 commits into
Open
flash_attn_generic: replace raw arith.* FP ops with FlyDSL-typed fast…#764xudoyuan wants to merge 4 commits into
xudoyuan wants to merge 4 commits into
Conversation
sjfeng1999
reviewed
Jun 29, 2026
eb5bf5a to
70f447c
Compare
…x in FlyDSL types Remove the raw-arith FP helper wrappers and write the math directly with FlyDSL-typed ops: fx.Float32(a) +/-/* fx.Float32(b) and fx.Float32(a).maximumf(b), so values stay DSL-typed through the online softmax without .ir_value() round-trips. Drop the now-unused arith import. Base numeric/ArithValue classes are not modified. Note: DSL operators do not by themselves carry op-level fastmath (that is added separately, gated on the fast_fp_math hint), so on its own this is slightly slower than the original arith.*(fastmath) form. Numerics are bit-identical and tests pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Make every DSL float op (Numeric/ArithValue +,-,*,/,//,%, unary neg, maximumf) attach op-level fastmath=fast only when compiled with the fast_fp_math hint (default off / strict IEEE), via _default_fastmath() reading CompilationContext.get_compile_hints(). This is the op-level analogue of -ffast-math and reuses the same fast_fp_math switch that already drives the function/target-level fast attr in the ROCm backend, so one flag turns on both. Verified: with fast_fp_math the DSL-operator kernel matches the original arith.*(fastmath) kernel; the strict default is slightly slower; numerics are bit-identical. Framework change -> must be synced into build-fly/python_packages at runtime. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…enter the cache key The fmha compile hints were set only through the CompilationContext.compile_hints context manager, which populates the tracing/backend context but not JitFunction.compile_hints -- and the disk cache key is built from self.compile_hints. So compiling with fast_fp_math on vs off collided on the same cache entry (stale binary reused). Attach the hints to the jit function and compile via flyc.compile[_fmha_compile_hints](...), so the hints are part of the cache key and JitFunction.__call__ still auto-pushes them into the context during tracing/codegen. The manual `with` wrappers are removed. Verified: toggling fast_fp_math recompiles to distinct cache entries instead of reusing a stale binary; numerics are bit-identical. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1ee2c83 to
5f41e32
Compare
sjfeng1999
reviewed
Jul 1, 2026
| # All FP operations use aggressive fast-math (no NaN/Inf checks, reassociation). | ||
| # The unsafe_fp_math/fast_fp_math builder params control LLVM-level attributes only. | ||
| fm_fast = fx.arith.FastMathFlags.fast | ||
| v4f16_type = Vec.make_type(4, elem_dtype) |
Add a single type-generic max()/min() in expr/arith.py, modeled on cutlass.max/min. Accepts any DSL numeric type (Float32/Int32/Int64/unsigned ...) and Python scalars, any number of args (max(a,b), max(a,b,c), max([a,b,...]), max(a,[x,y])). Return type follows the operands' static types. Dispatch reuses the shared numeric coercion (as_numeric + _coerce_operands) and the resulting type: float -> maximumf / minimumf (NaN-propagating, matches cutlass) int signed -> maxsi / minsi int unsigned -> maxui / minui Float fastmath follows the fast_fp_math compile hint. The maximum-vs-maxnum choice is not exposed (matches cutlass). Registered in __all__ (fx.max/fx.min); decorated with dsl_loc_tracing like the other builders. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
f70ba8a to
e389707
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
…math proxies
The QK/softmax helpers used raw MLIR ops (arith.addf/subf/mulf/MaxNumFOp) purely to attach op-level fastmath, because the DSL float types only exposed addf(fastmath=) and maximumf (no fastmath). This dropped to the MLIR layer inside the kernel, which we want to avoid.
Add the missing fastmath-capable proxies to the DSL so kernels can stay FlyDSL-typed without losing op-level fastmath:
Convert flash_attn_generic _fadd/_fsub/_fmul/_fmax to use Float32.addf/subf/mulf/maxnumf(fastmath=fm_fast); maxnumf preserves the original NaN-non-propagating maxnum semantics. Drop the now-unused
arithimport. Helpers still return a raw MLIR value so downstream call sites are unchanged.Motivation
Technical Details
Test Plan
Test Result
Submission Checklist