diff --git a/kernels/hgemm_splitk.py b/kernels/hgemm_splitk.py index 87ec2ff4d..d68323f17 100644 --- a/kernels/hgemm_splitk.py +++ b/kernels/hgemm_splitk.py @@ -13,7 +13,7 @@ from flydsl._mlir.dialects import fly, llvm, memref, scf from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, buffer_ops, const_expr, gpu, range_constexpr, rocdl, vector -from flydsl.expr.typing import T +from flydsl.expr.typing import T, as_ir_value from flydsl.runtime.device import get_rocm_arch from flydsl.utils.smem_allocator import SMEM_CAPACITY_MAP, SmemAllocator, SmemPtr from kernels.tensor_shim import GTensor, STensor, _run_compiled, get_dtype_in_kernel @@ -298,7 +298,7 @@ def zero_c(): # zero c if current block is the first block is_t0_cond = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(tid), fx.Index(0)) cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0)) - cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + cond_ks0_if = scf.IfOp(as_ir_value(cond_ks0), results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): zero_vec = vector.broadcast(T.vec(LDG_VEC_SIZE, dtype_), c_zero_d) for i in range_constexpr(LDG_REG_C_COUNT): @@ -310,7 +310,7 @@ def zero_c(): if const_expr(HAS_BIAS): init_vec = BIAS_.vec_load((n_offset + n_local_idx,), LDG_VEC_SIZE) cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(m)) - cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + cond_boundary_if = scf.IfOp(as_ir_value(cond_boundary), results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): bytes_offset = C_.linear_offset((row_idx, n_offset + n_local_idx)) bytes_offset_i32 = arith.index_cast(T.i32, bytes_offset) @@ -325,7 +325,7 @@ def zero_c(): scf.YieldOp([]) gpu.barrier() # trigger signal when zeroc is done by the first arrived block - is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) + is_t0_cond_if = scf.IfOp(as_ir_value(is_t0_cond), results_=[], has_else=False) with ir.InsertionPoint(is_t0_cond_if.then_block): signal_ptr = get_llvm_ptr(signal, signal_idx, 4) llvm.InlineAsmOp( @@ -342,7 +342,7 @@ def zero_c(): def split_k_barrier(): # spin-wait until signal triggered is_t0_cond = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(tid), fx.Index(0)) - is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) + is_t0_cond_if = scf.IfOp(as_ir_value(is_t0_cond), results_=[], has_else=False) with ir.InsertionPoint(is_t0_cond_if.then_block): init_cur = arith.constant(0, type=T.i32) w = scf.WhileOp([T.i32], [init_cur]) @@ -367,7 +367,7 @@ def split_k_barrier(): rocdl.sched_barrier(0) gpu.barrier() # clean semaphore and signal if this is the last block within split-k group - is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) + is_t0_cond_if = scf.IfOp(as_ir_value(is_t0_cond), results_=[], has_else=False) with ir.InsertionPoint(is_t0_cond_if.then_block): semaphore_ptr = get_llvm_ptr(semaphore, signal_idx, 4) arrive_idx = llvm.AtomicRMWOp( @@ -379,7 +379,7 @@ def split_k_barrier(): alignment=4, ).result cond_ksl = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(arrive_idx), fx.Index(SPLIT_K - 1)) - cond_ksl_if = scf.IfOp(cond_ksl, results_=[], has_else=False) + cond_ksl_if = scf.IfOp(as_ir_value(cond_ksl), results_=[], has_else=False) with ir.InsertionPoint(cond_ksl_if.then_block): semaphore_[signal_idx] = arith.constant(0, type=T.i32) signal_[signal_idx] = arith.constant(0, type=T.i32) @@ -682,7 +682,7 @@ def hot_loop_scheduler(): m_global_idx = m_offset + m_local_idx n_global_idx = n_offset + n_local_idx cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, m_global_idx, fx.Index(m)) - cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + cond_boundary_if = scf.IfOp(as_ir_value(cond_boundary), results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): pk_val = cs_.vec_load((0, m_local_idx, n_local_idx), LDG_VEC_SIZE) for ksi in range_constexpr(1, BLOCK_K_WARPS): @@ -713,7 +713,7 @@ def hot_loop_scheduler(): n_local_idx = fx.Index(global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE) m_global_idx = m_offset + m_local_idx cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, m_global_idx, fx.Index(m)) - cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + cond_boundary_if = scf.IfOp(as_ir_value(cond_boundary), results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): vec = cs_.vec_load((0, m_local_idx, n_local_idx), LDG_VEC_SIZE) for ksi in range_constexpr(1, BLOCK_K_WARPS): diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index a5fb12706..dddd24b9b 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -3,6 +3,7 @@ import ast import contextlib +import copy import difflib import functools import inspect @@ -394,24 +395,24 @@ def visit_With(self, node: ast.With): class RewriteBoolOps(Transformer): @staticmethod def dsl_and_(lhs, rhs): - if hasattr(lhs, "__fly_and__"): - return lhs.__fly_and__(rhs) - if hasattr(rhs, "__fly_and__"): - return rhs.__fly_and__(lhs) + if hasattr(lhs, "__dsl_and__"): + return lhs.__dsl_and__(rhs) + if hasattr(rhs, "__dsl_and__"): + return rhs.__dsl_and__(lhs) return lhs and rhs @staticmethod def dsl_or_(lhs, rhs): - if hasattr(lhs, "__fly_or__"): - return lhs.__fly_or__(rhs) - if hasattr(rhs, "__fly_or__"): - return rhs.__fly_or__(lhs) + if hasattr(lhs, "__dsl_or__"): + return lhs.__dsl_or__(rhs) + if hasattr(rhs, "__dsl_or__"): + return rhs.__dsl_or__(lhs) return lhs or rhs @staticmethod def dsl_not_(x): - if hasattr(x, "__fly_not__"): - return x.__fly_not__() + if hasattr(x, "__dsl_not__"): + return x.__dsl_not__() return not x @classmethod @@ -422,6 +423,26 @@ def rewrite_globals(cls): "dsl_not_": cls.dsl_not_, } + def visit_Compare(self, node: ast.Compare): + node = self.generic_visit(node) + if len(node.ops) == 1: + return node + # Chained comparison `a < b < c` -> `(a < b) and (b < c)`, then reuse the + # `and` lowering. Middle operands are referenced by two comparisons, matching + # the existing repeated-evaluation convention in visit_BoolOp. + operands = [node.left] + node.comparators + comparisons = [] + for i, op in enumerate(node.ops): + comparison = ast.Compare( + left=copy.deepcopy(operands[i]), + ops=[copy.deepcopy(op)], + comparators=[copy.deepcopy(operands[i + 1])], + ) + comparisons.append(ast.copy_location(comparison, node)) + chained = ast.copy_location(ast.BoolOp(op=ast.And(), values=comparisons), node) + chained = ast.fix_missing_locations(chained) + return self.visit_BoolOp(chained) + def visit_BoolOp(self, node: ast.BoolOp): node = self.generic_visit(node) diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index faf895c12..5a54b52ed 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -23,6 +23,7 @@ from .._mlir.passmanager import PassManager from ..expr.meta import tracing_context from ..expr.typing import Constexpr, Stream +from ..expr.utils.arith import fastmath as fastmath_ctx from ..utils import env, log from .ast_rewriter import ASTRewriter from .backends import compile_backend_name, get_backend @@ -40,6 +41,7 @@ CompilationContext, KernelFunction, create_gpu_module, + effective_fastmath_hint, func_def_location, get_gpu_module_body, ) @@ -1482,8 +1484,12 @@ def __call__(self, *args, **kwargs): log().info(f"dsl_args={dsl_args}") named_args = dict(zip(param_names, dsl_args)) named_args.update(constexpr_values) + fastmath_flag = effective_fastmath_hint(CompilationContext.get_compile_hints()) + fastmath_scope = ( + fastmath_ctx(fastmath_flag) if fastmath_flag is not None else nullcontext() + ) # Bound the call-site boundary at the jit body. - with tracing_context(self.func): + with tracing_context(self.func), fastmath_scope: if bound_self is not None: self.func(bound_self, **named_args) else: diff --git a/python/flydsl/compiler/kernel_function.py b/python/flydsl/compiler/kernel_function.py index e30c87529..ce96d77c2 100644 --- a/python/flydsl/compiler/kernel_function.py +++ b/python/flydsl/compiler/kernel_function.py @@ -3,7 +3,7 @@ import inspect import threading -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -11,6 +11,7 @@ from .._mlir.dialects import arith, gpu from ..expr.meta import capture_user_location, file_location, tracing_context from ..expr.typing import Constexpr +from ..expr.utils.arith import fastmath as fastmath_ctx from .ast_rewriter import ASTRewriter from .diagnostics import install_excepthook, warn_annotation_value_mismatch, warn_invalid_annotations from .jit_argument import is_type_param_annotation, resolve_signature @@ -243,6 +244,15 @@ def next_kernel_id(self) -> int: return kid +def effective_fastmath_hint(hints: dict): + """Resolve the ambient fastmath hint for traced JIT/kernel bodies.""" + if "fastmath" in hints: + return hints["fastmath"] + if hints.get("fast_fp_math"): + return "fast" + return None + + # ============================================================================= # Kernel Launcher # ============================================================================= @@ -540,8 +550,11 @@ def _emit_kernel(self, ctx: CompilationContext, args: Tuple, kwargs: Dict, bound idx += n dsl_args.update(constexpr_values) + + fastmath_flag = effective_fastmath_hint(CompilationContext.get_compile_hints()) + fastmath_scope = fastmath_ctx(fastmath_flag) if fastmath_flag is not None else nullcontext() # Bound the call-site boundary at the kernel body. - with tracing_context(self._func): + with tracing_context(self._func), fastmath_scope: if bound_self is not None: self._func(bound_self, **dsl_args) else: diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index 3b5eb0fd0..8e5385916 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -18,13 +18,18 @@ __all__ = [ "ArithValue", # Deprecated: will be removed in a future release "_to_raw", # Deprecated: will be removed in a future release + "FastMathFlags", "andi", "constant", "constant_vector", + "fastmath", "index", # Deprecated: will be removed in a future release "index_cast", # Deprecated: will be removed in a future release "int_to_fp", "maxnumf", + "minnumf", + "maximumf", + "minimumf", "shli", "sitofp", "trunc_f", @@ -35,7 +40,7 @@ ] # Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.) -from .._mlir.dialects import arith as _mlir_arith +from .._mlir.dialects import arith from .meta import dsl_loc_tracing from .utils.arith import ( # noqa: F401 ArithValue, @@ -43,6 +48,7 @@ andi, constant, constant_vector, + fastmath, index, index_cast, int_to_fp, @@ -53,9 +59,12 @@ unwrap, xori, ) +from .math import dsl_math_wrap_result +from .typing import as_ir_value @dsl_loc_tracing +@dsl_math_wrap_result(exemplar="lhs") def cmpi(predicate, lhs, rhs, **kwargs): """Integer comparison accepting DSL numeric types (Int32, ArithValue, etc.). @@ -65,12 +74,13 @@ def cmpi(predicate, lhs, rhs, **kwargs): rhs: Right-hand operand. Returns: - An ``i1`` comparison result. + A ``Boolean`` (scalar) or ``Vector(Boolean)`` comparison result. """ - return _mlir_arith.cmpi(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs) + return arith.cmpi(predicate, as_ir_value(lhs), as_ir_value(rhs), **kwargs) @dsl_loc_tracing +@dsl_math_wrap_result(exemplar="lhs") def cmpf(predicate, lhs, rhs, **kwargs): """Floating-point comparison accepting DSL numeric types. @@ -80,24 +90,30 @@ def cmpf(predicate, lhs, rhs, **kwargs): rhs: Right-hand operand. Returns: - An ``i1`` comparison result. + A ``Boolean`` (scalar) or ``Vector(Boolean)`` comparison result. """ - return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs) + return arith.cmpf(predicate, as_ir_value(lhs), as_ir_value(rhs), **kwargs) @dsl_loc_tracing -def maxnumf(a, b, **kwargs): - """Floating-point maximum, returning the non-NaN operand when one input is NaN (libm ``fmax``). +@dsl_math_wrap_result +def maximumf(lhs, rhs, *, fastmath=None): + return arith.maximumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath) - Accepts DSL numeric types (Float32, Vector, ...) and preserves the DSL type of ``a`` so the - result can be chained with further DSL operations (e.g. ``.shuffle_xor(...)``). - """ - from .numeric import Numeric - from .typing import Vector - - result = _mlir_arith.maxnumf(_to_raw(a), _to_raw(b), **kwargs) - if isinstance(a, Vector): - return Vector(result, a.shape, a.dtype) - if isinstance(a, Numeric): - return Numeric.from_ir_type(result.type)(result) - return result + +@dsl_loc_tracing +@dsl_math_wrap_result +def minimumf(lhs, rhs, *, fastmath=None): + return arith.minimumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath) + + +@dsl_loc_tracing +@dsl_math_wrap_result +def maxnumf(lhs, rhs, *, fastmath=None): + return arith.maxnumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath) + + +@dsl_loc_tracing +@dsl_math_wrap_result +def minnumf(lhs, rhs, *, fastmath=None): + return arith.minnumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath) diff --git a/python/flydsl/expr/math.py b/python/flydsl/expr/math.py index e62613fb6..f340d975e 100644 --- a/python/flydsl/expr/math.py +++ b/python/flydsl/expr/math.py @@ -12,6 +12,7 @@ pred = fx.isnan(x) """ +import inspect from functools import wraps from .._mlir import ir @@ -19,6 +20,7 @@ from .meta import dsl_loc_tracing from .numeric import Numeric from .typing import as_ir_value +from .utils.arith import current_fastmath __all__ = [ "absf", @@ -70,12 +72,40 @@ ] -def dsl_math_wrap_result(fn): +def dsl_math_wrap_result(fn=None, *, exemplar=None): + """Wrap raw builder results back into DSL ``Numeric`` / ``Vector`` values. + + The DSL type of the result is shaped after an *exemplar* operand: + + - ``exemplar=None`` (default): the first positional argument is the + exemplar. This fits the ``x``-first math builders (``exp(x)``, ...). + - ``exemplar=""``: the argument bound to that parameter name is the + exemplar. Use this for builders whose first argument is not the operand, + e.g. ``cmpi(predicate, lhs, rhs)`` -> ``exemplar="lhs"``. + """ + if fn is None: + return lambda f: dsl_math_wrap_result(f, exemplar=exemplar) + + sig = inspect.signature(fn) + accepts_fastmath = "fastmath" in sig.parameters + @wraps(fn) def wrapper(*args, **kwargs): from .typing import Vector - first = args[0] if args else None + if accepts_fastmath and kwargs.get("fastmath") is None: + ambient = current_fastmath() + if ambient is not None: + kwargs["fastmath"] = ambient + + if exemplar is None: + first = args[0] if args else None + else: + try: + bound = sig.bind_partial(*args, **kwargs) + first = bound.arguments.get(exemplar) + except TypeError: + first = kwargs.get(exemplar) is_vector = isinstance(first, Vector) is_numeric = isinstance(first, Numeric) diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index ef5959a82..00c636034 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -168,18 +168,11 @@ def zero(cls): _CMP_OPS = frozenset({operator.lt, operator.le, operator.gt, operator.ge, operator.eq, operator.ne}) - -def _widen_bool_to_int32(x, widen_bool=False): - """Promote Boolean to Int32 for arithmetic when widen_bool=True. - - Per C++-style usual arithmetic conversions, we deliberately do NOT apply - integer promotion: i8/i16/u8/u16 stay at their narrow width. - Same-width same-signedness operands keep their type; cross-width or - cross-sign mixing is resolved by ``_coerce_operands``. - """ - if widen_bool and type(x) is Boolean: - return x.to(Int32), Int32 - return x, type(x) +# Operators for which Boolean operands widen to Int32 (C++-style: bool +# participates in arithmetic as int). +_WIDEN_BOOL_OPS = frozenset( + {operator.add, operator.sub, operator.mul, operator.floordiv, operator.truediv, operator.mod} +) def _resolve_float_type(ta, tb): @@ -203,27 +196,49 @@ def _resolve_float_type(ta, tb): raise ValueError(f"no common float type for {ta} and {tb}; cast explicitly") -def _coerce_operands(a, b, widen_bool=False): - """Promote *a* and *b* to a common scalar type.""" - ta, tb = type(a), type(b) - a, ta = _widen_bool_to_int32(a, widen_bool=widen_bool) - b, tb = _widen_bool_to_int32(b, widen_bool=widen_bool) +def _resolve_common_type(ta, tb, op=None): + """Resolve the common Numeric type used by scalar and vector operators.""" + if op in _WIDEN_BOOL_OPS: + ta = Int32 if ta is Boolean else ta + tb = Int32 if tb is Boolean else tb if ta is tb: - return a, b, ta + return ta if ta.is_float or tb.is_float: - dest = _resolve_float_type(ta, tb) - return (a if type(a) is dest else a.to(dest), b if type(b) is dest else b.to(dest), dest) + return _resolve_float_type(ta, tb) # Both integers — pick wider; on tie, prefer unsigned when mixed sign if ta.signed == tb.signed: - wider = ta if ta.width >= tb.width else tb - return (a if type(a) is wider else a.to(wider), b if type(b) is wider else b.to(wider), wider) + return ta if ta.width >= tb.width else tb u, s = (ta, tb) if not ta.signed else (tb, ta) - dest = u if u.width >= s.width else s - return (a if type(a) is dest else a.to(dest), b if type(b) is dest else b.to(dest), dest) + return u if u.width >= s.width else s + + +def _coerce_operands(a, b, op=None): + """Promote *a* and *b* to the common dtype for *op*.""" + dest = _resolve_common_type(a.dtype, b.dtype, op=op) + return (a if a.dtype is dest else a.to(dest), b if b.dtype is dest else b.to(dest), dest) + + +def _result_type_for_op(op, promoted_type): + """Apply op-specific result-type overrides on top of *promoted_type*. + + Shared by scalar Numeric and Vector arithmetic so both resolve identical + result dtypes: + + - comparisons -> ``Boolean`` + - integer true-division -> float (Python ``/`` lifts int/int to float; + width > 32 uses ``Float64``, otherwise ``Float32``) + + All other operators keep the promoted type unchanged. + """ + if op in _CMP_OPS: + return Boolean + if op is operator.truediv and promoted_type.is_integer: + return Float64 if promoted_type.width > 32 else Float32 + return promoted_type def _try_coerce_rhs(rhs): @@ -250,7 +265,7 @@ def _extract_arith(val, signed): return v.with_signedness(signed) if isinstance(v, ArithValue) else v -def _make_binop(op, promote=True, widen_bool=False, swap=False): +def _make_binop(op, swap=False): """Create a binary-operator closure for Numeric subclasses.""" def _apply(lhs, rhs): @@ -258,16 +273,9 @@ def _apply(lhs, rhs): if rhs is None: return NotImplemented - out_type = type(lhs) - if promote: - lhs, rhs, out_type = _coerce_operands(lhs, rhs, widen_bool) - else: - rhs = type(lhs)(rhs) + lhs, rhs, out_type = _coerce_operands(lhs, rhs, op=op) - if op in _CMP_OPS: - out_type = Boolean - elif op is operator.truediv and isinstance(lhs, Integer): - out_type = Float64 if out_type.width > 32 else Float32 + out_type = _result_type_for_op(op, out_type) lv, rv = _extract_arith(lhs, lhs.signed), _extract_arith(rhs, rhs.signed) if swap: @@ -346,29 +354,29 @@ def __neg__(self): return type(self)(-self.value) return type(self)(-self.value) - def __fly_bool__(self): + def __dsl_bool__(self): if isinstance(self.value, (int, float, bool)): return Boolean(bool(self.value)) zero = arith_const(type(self).zero, type(self).ir_type) return self.__ne__(type(self)(zero)) - def __fly_not__(self): - b = self.__fly_bool__() + def __dsl_not__(self): + b = self.__dsl_bool__() if isinstance(b.value, bool): return Boolean(not b.value) zero = arith_const(0, T.bool()) return Boolean(b.ir_value().__eq__(zero)) - def __fly_and__(self, other): - lhs = self.__fly_bool__() - rhs = as_numeric(other).__fly_bool__() + def __dsl_and__(self, other): + lhs = self.__dsl_bool__() + rhs = as_numeric(other).__dsl_bool__() if isinstance(lhs.value, bool) and isinstance(rhs.value, bool): return Boolean(lhs.value and rhs.value) return Boolean(lhs.ir_value().__and__(rhs.ir_value())) - def __fly_or__(self, other): - lhs = self.__fly_bool__() - rhs = as_numeric(other).__fly_bool__() + def __dsl_or__(self, other): + lhs = self.__dsl_bool__() + rhs = as_numeric(other).__dsl_bool__() if isinstance(lhs.value, bool) and isinstance(rhs.value, bool): return Boolean(lhs.value or rhs.value) return Boolean(lhs.ir_value().__or__(rhs.ir_value())) @@ -438,40 +446,40 @@ def from_ir_type(ir_type): return ir2dsl_map[ir_type] def __add__(self, other): - return _make_binop(operator.add, widen_bool=True)(self, other) + return _make_binop(operator.add)(self, other) def __sub__(self, other): - return _make_binop(operator.sub, widen_bool=True)(self, other) + return _make_binop(operator.sub)(self, other) def __mul__(self, other): - return _make_binop(operator.mul, widen_bool=True)(self, other) + return _make_binop(operator.mul)(self, other) def __floordiv__(self, other): - return _make_binop(operator.floordiv, widen_bool=True)(self, other) + return _make_binop(operator.floordiv)(self, other) def __truediv__(self, other): - return _make_binop(operator.truediv, widen_bool=True)(self, other) + return _make_binop(operator.truediv)(self, other) def __mod__(self, other): - return _make_binop(operator.mod, widen_bool=True)(self, other) + return _make_binop(operator.mod)(self, other) def __radd__(self, other): return self.__add__(other) def __rsub__(self, other): - return _make_binop(operator.sub, widen_bool=True, swap=True)(self, other) + return _make_binop(operator.sub, swap=True)(self, other) def __rmul__(self, other): return self.__mul__(other) def __rfloordiv__(self, other): - return _make_binop(operator.floordiv, widen_bool=True, swap=True)(self, other) + return _make_binop(operator.floordiv, swap=True)(self, other) def __rtruediv__(self, other): - return _make_binop(operator.truediv, widen_bool=True, swap=True)(self, other) + return _make_binop(operator.truediv, swap=True)(self, other) def __rmod__(self, other): - return _make_binop(operator.mod, widen_bool=True, swap=True)(self, other) + return _make_binop(operator.mod, swap=True)(self, other) def __pow__(self, other): return _make_binop(operator.pow)(self, other) @@ -483,13 +491,15 @@ def __ne__(self, other): return _make_binop(operator.ne)(self, other) # ── Proxy methods: delegate ArithValue-specific ops via ir_value() ── - def maximumf(self, other): + def maximumf(self, other, *, fastmath=None): """Float maximum — delegates to ArithValue.maximumf.""" - return type(self)(self.ir_value().maximumf(_to_raw(other))) + return type(self)(self.ir_value().maximumf(_to_raw(other), fastmath=fastmath)) + + def minimumf(self, other, *, fastmath=None): + """Float minimum — delegates to expr.arith.minimumf.""" + from .arith import minimumf as _minimumf - def minimumf(self, other): - """Float minimum — delegates to ArithValue.minimumf.""" - return type(self)(self.ir_value().minimumf(_to_raw(other))) + return _minimumf(self, other, fastmath=fastmath) def exp2(self, *, fastmath=None): """Base-2 exponential — delegates to ArithValue.exp2.""" @@ -785,61 +795,6 @@ class Float4E2M1FN(Float, metaclass=NumericMeta, width=4, ir_type=T.f4E2M1FN): . # Float type rank for promotion (must be after class definitions) _FLOAT_RANK = {Float64: 3, Float32: 2, Float16: 1, BFloat16: 1} -# ── Type promotion (added to Numeric after all subclasses exist) ────── - -_FLOAT_BY_MIN_WIDTH = {16: Float16, 32: Float32, 64: Float64} - - -def _widen_float(float_type, min_width): - """Return the narrowest standard float type with width >= *min_width*.""" - if float_type.width >= min_width: - return float_type - for w in (32, 64): - if w >= min_width: - return _FLOAT_BY_MIN_WIDTH[w] - return Float64 - - -@classmethod -def _promote(cls, a_type, b_type): - """Resolve the promoted result type for two Numeric types. - - :param a_type: Left Numeric class (e.g. Float16) - :param b_type: Right Numeric class (e.g. Float32) - :return: The common Numeric class both can be safely promoted to - """ - if a_type is b_type: - return a_type - - a_float = a_type.is_float - b_float = b_type.is_float - - if a_float and not b_float: - return _widen_float(a_type, b_type.width) - if b_float and not a_float: - return _widen_float(b_type, a_type.width) - - if a_float and b_float: - aw, bw = a_type.width, b_type.width - if aw > bw and aw >= 16: - return a_type - if bw > aw and bw >= 16: - return b_type - if aw == bw: - ra = _FLOAT_RANK.get(a_type, 0) - rb = _FLOAT_RANK.get(b_type, 0) - return a_type if ra >= rb else b_type - raise ValueError(f"cannot promote {a_type} and {b_type}; cast explicitly") - - # Both integers - if a_type.signed == b_type.signed: - return a_type if a_type.width >= b_type.width else b_type - u, s = (a_type, b_type) if not a_type.signed else (b_type, a_type) - return u if u.width >= s.width else s - - -Numeric.promote = _promote - class Index(Integer, metaclass=NumericMeta, width=64, signed=False, ir_type=lambda: ir.IndexType.get()): """DSL Numeric for MLIR index type. Replaces arith.index(N). diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index 7b0c87ce7..295562262 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -45,6 +45,9 @@ Uint32, Uint64, Uint128, + _coerce_operands, + _result_type_for_op, + _try_coerce_rhs, as_numeric, ) from .primitive import * @@ -57,6 +60,7 @@ fp_to_int, int_to_fp, int_to_int, + resolve_fastmath, ) @@ -1593,28 +1597,30 @@ def ir_value(self): def with_signedness(self, signed): return ArithValue(self, signed) - def _wrap_op_result(self, result, shape): + def _wrap_op_result(self, result, shape, result_dtype): if isinstance(result, ir.Value) and isinstance(result.type, ir.VectorType): - return Vector(result, shape, self._result_dtype(result.type.element_type)) + return Vector(result, shape, result_dtype) if isinstance(result, Numeric): return result if isinstance(result, ir.Value): - return self._result_dtype(result.type)(result) + return result_dtype(result) return result - def _result_dtype(self, elem_type) -> Type[Numeric]: - if self._dtype.ir_type == elem_type: - return self._dtype - return Numeric.from_ir_type(elem_type) - def _apply_op(self, method_name, op, other, flip=False): lhs = self rhs = other shape = self.shape + promoted_dtype = self.dtype if isinstance(other, Vector): shape = self._infer_broadcast_shape(self.shape, other.shape) lhs = self.broadcast_to(shape) rhs = other.broadcast_to(shape) + else: + rhs_numeric = _try_coerce_rhs(rhs) + if rhs_numeric is None: + return NotImplemented + rhs = rhs_numeric + lhs, rhs, promoted_dtype = _coerce_operands(lhs, rhs, op=op) method = getattr(ArithValue, method_name) if flip: if isinstance(rhs, Vector): @@ -1624,7 +1630,8 @@ def _apply_op(self, method_name, op, other, flip=False): result = getattr(ArithValue, reverse_name)(lhs, rhs) else: result = method(lhs, rhs) - return self._wrap_op_result(result, shape) + result_dtype = _result_type_for_op(op, promoted_dtype) + return self._wrap_op_result(result, shape, result_dtype) def apply_op(self, op, other, flip=False): method_name = _VECTOR_OP_METHODS.get(op) @@ -1729,6 +1736,7 @@ def reduce(self, op, init_val=None, reduction_profile=None, *, fastmath=None): kind = _resolve_combining_kind(op, is_fp, signed) et = element_type(self.type) kwargs = {} + fastmath = resolve_fastmath(fastmath) if fastmath is not None: kwargs["fastmath"] = fastmath if init_val is not None: diff --git a/python/flydsl/expr/utils/arith.py b/python/flydsl/expr/utils/arith.py index 5d3f78363..ee66b1e2f 100644 --- a/python/flydsl/expr/utils/arith.py +++ b/python/flydsl/expr/utils/arith.py @@ -2,6 +2,8 @@ # Copyright (c) 2025 FlyDSL Project Contributors import builtins +import contextlib +import threading from functools import partialmethod from ..._mlir import ir @@ -9,6 +11,50 @@ from ..._mlir.extras import types as T from ..meta import dsl_loc_tracing +# --------------------------------------------------------------------------- # +# Ambient fastmath context (thread-local) +# --------------------------------------------------------------------------- # +_fm_tls = threading.local() + + +def _normalize_fastmath(flags): + """Normalize a fastmath spec to a value the MLIR ``fastmath=`` arg accepts. + + Accepts a single flag (``arith.FastMathFlags``, ``str``), a combined + ``FastMathFlags`` value (via ``|``), an iterable of flags (combined + comma-separated), or ``None``. + """ + if flags is None: + return None + if isinstance(flags, str): + return flags + if isinstance(flags, (set, frozenset)): + return ",".join(sorted(str(f) for f in flags)) + if isinstance(flags, (list, tuple)): + return ",".join(str(f) for f in flags) + return str(flags) + + +def current_fastmath(): + """Return the ambient fastmath flags set by ``fastmath(...)``, or ``None``.""" + return getattr(_fm_tls, "value", None) + + +def resolve_fastmath(explicit): + """Pick the effective fastmath: explicit arg wins over ambient context.""" + return explicit if explicit is not None else current_fastmath() + + +@contextlib.contextmanager +def fastmath(flags): + """Apply *flags* to floating-point ops built inside the ``with`` block.""" + prev = getattr(_fm_tls, "value", None) + _fm_tls.value = _normalize_fastmath(flags) + try: + yield + finally: + _fm_tls.value = prev + def element_type(ty) -> ir.Type: if isinstance(ty, ir.VectorType): @@ -77,8 +123,8 @@ def fp_to_fp(src, res_elem_type): return src res_type = recast_type(src.type, res_elem_type) if res_elem_type.width > src_elem_type.width: - return arith.extf(res_type, src) - return arith.truncf(res_type, src) + return arith.extf(res_type, src, fastmath=current_fastmath()) + return arith.truncf(res_type, src, fastmath=current_fastmath()) @dsl_loc_tracing @@ -150,27 +196,29 @@ def _binary_op(self, other, op): if other is NotImplemented: return NotImplemented + fm = current_fastmath() + if op in _ARITH_OPS: float_fn, int_fn = _ARITH_OPS[op] if self.is_float: - return float_fn(self, other) + return float_fn(self, other, fastmath=fm) return int_fn(self, other) if op == "div": if self.is_float: - return arith.divf(self, other) + return arith.divf(self, other, fastmath=fm) et = element_type(self.type) if isinstance(et, ir.IndexType): return arith.divui(self, other) fp_ty = T.f64() if et.width > 32 else T.f32() lhs = int_to_fp(self, self.signed, fp_ty) rhs = int_to_fp(other, other.signed, fp_ty) - return arith.divf(lhs, rhs) + return arith.divf(lhs, rhs, fastmath=fm) if op == "floordiv": if self.is_float: - q = arith.divf(self, other) - return math.floor(q) + q = arith.divf(self, other, fastmath=fm) + return math.floor(q, fastmath=fm) et = element_type(self.type) if isinstance(et, ir.IndexType): return arith.divui(self, other) @@ -180,7 +228,7 @@ def _binary_op(self, other, op): if op == "mod": if self.is_float: - return arith.remf(self, other) + return arith.remf(self, other, fastmath=fm) et = element_type(self.type) if isinstance(et, ir.IndexType): return arith.remui(self, other) @@ -232,7 +280,7 @@ def _comparison_op(self, other, predicate): return NotImplemented if self.is_float: - return arith.cmpf(_CMP_FLOAT_PRED[predicate], self, other) + return arith.cmpf(_CMP_FLOAT_PRED[predicate], self, other, fastmath=current_fastmath()) if self.signed is not False: return arith.cmpi(_CMP_INT_SIGNED[predicate], self, other) return arith.cmpi(_CMP_INT_UNSIGNED[predicate], self, other) @@ -277,14 +325,15 @@ def _pow_op(self, other, reverse=False): return NotImplemented if reverse: self, other = other, self + fm = current_fastmath() if self.is_float and other.is_float: - return math.powf(self, other) + return math.powf(self, other, fastmath=fm) if self.is_float and not other.is_float: - return math.fpowi(self, other) + return math.fpowi(self, other, fastmath=fm) if not self.is_float and other.is_float: fp_ty = element_type(other.type) lhs = int_to_fp(self, self.signed, fp_ty) - return math.powf(lhs, other) + return math.powf(lhs, other, fastmath=fm) return math.ipowi(self, other) @@ -293,7 +342,7 @@ def _neg_op(self): if self.type == T.bool(): raise TypeError("negation is not supported for boolean type") if self.is_float: - return arith.negf(self) + return arith.negf(self, fastmath=current_fastmath()) c0 = arith_const(0, self.type) return arith.subi(c0, self) @@ -384,14 +433,14 @@ def select(self, true_value, false_value): return arith.SelectOp(_to_raw(self), true_value, false_value).result @dsl_loc_tracing - def extf(self, target_type): + def extf(self, target_type, *, fastmath=None): """Extend float precision (e.g. bf16 → f32).""" - return arith.ExtFOp(target_type, self).result + return arith.extf(target_type, self, fastmath=resolve_fastmath(fastmath)) @dsl_loc_tracing - def truncf(self, target_type): + def truncf(self, target_type, *, fastmath=None): """Truncate float precision (e.g. f32 → bf16).""" - return arith.TruncFOp(target_type, self).result + return arith.truncf(target_type, self, fastmath=resolve_fastmath(fastmath)) @dsl_loc_tracing def extui(self, target_type): @@ -421,26 +470,26 @@ def shrui(self, amount): @dsl_loc_tracing def addf(self, other, *, fastmath=None): """Float add with optional fastmath flags.""" - return arith.addf(self, _to_raw(other), fastmath=fastmath) + return arith.addf(self, _to_raw(other), fastmath=resolve_fastmath(fastmath)) @dsl_loc_tracing - def maximumf(self, other): + def maximumf(self, other, *, fastmath=None): """Float maximum (NaN-propagating).""" - return arith.maximumf(self, _to_raw(other)) + return arith.maximumf(self, _to_raw(other), fastmath=resolve_fastmath(fastmath)) @dsl_loc_tracing def rsqrt(self, *, fastmath=None): """Reciprocal square root: 1/sqrt(self).""" from ..._mlir.dialects import math as _math - return _math.rsqrt(self, fastmath=fastmath) + return _math.rsqrt(self, fastmath=resolve_fastmath(fastmath)) @dsl_loc_tracing def exp2(self, *, fastmath=None): """Base-2 exponential: 2^self.""" from ..._mlir.dialects import math as _math - return _math.exp2(self, fastmath=fastmath) + return _math.exp2(self, fastmath=resolve_fastmath(fastmath)) @dsl_loc_tracing def shuffle_xor(self, offset, width): diff --git a/tests/unit/test_constexpr.py b/tests/unit/test_constexpr.py index bec2d17b2..1fa7fa15e 100644 --- a/tests/unit/test_constexpr.py +++ b/tests/unit/test_constexpr.py @@ -53,3 +53,12 @@ def build(shape: fx.Constexpr[Tuple[int, Tuple[bool, float]]]): assert shape == (16, (True, 2.5)) build((16, (True, 2.5))) + + +def test_jit_accepts_chained_compare_constexpr(frontend_only_jit): + @flyc.jit + def build(value: fx.Constexpr[int], upper: fx.Constexpr[int], expected: fx.Constexpr[bool]): + assert fx.const_expr(0 < value < upper) == expected + + build(3, 8, True) + build(0, 8, False) diff --git a/tests/unit/test_fastmath_context.py b/tests/unit/test_fastmath_context.py new file mode 100644 index 000000000..fc7001827 --- /dev/null +++ b/tests/unit/test_fastmath_context.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 FlyDSL Project Contributors + +"""Tests for the ambient ``fx.fastmath(...)`` context manager. + +Verifies that: +1. Float operators (+, -, *, /, %) inside the block emit fastmath flags. +2. ``math`` functions inherit the ambient flags. +3. An explicit ``fastmath=`` argument overrides the ambient context. +4. Flags are restored on block exit (including nested blocks). +5. Integer operators are unaffected. +""" + +import pytest + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl._mlir import ir +from flydsl._mlir.dialects import func +from flydsl.expr.numeric import Float32, Int32 + +try: + import torch +except ImportError: + torch = None + +# GPU gating is applied per-test via ``requires_gpu``; the ``l0_backend_agnostic`` +# IR-string tests below build IR without a device and must run on non-GPU runners. +requires_gpu = pytest.mark.skipif( + torch is None or not torch.cuda.is_available(), + reason="CUDA/ROCm not available", +) + + +def _build(build_fn, arg_types): + with ir.Context() as ctx: + ctx.allow_unregistered_dialects = True + with ir.Location.unknown(ctx): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + ftype = ir.FunctionType.get([t() for t in arg_types], []) + f = func.FuncOp("test", ftype) + with ir.InsertionPoint(f.add_entry_block()): + build_fn(*f.entry_block.arguments) + func.ReturnOp([]) + return str(module) + + +@pytest.mark.l0_backend_agnostic +def test_operators_pick_up_ambient_fastmath(): + def build(a, b): + fa, fb = Float32(a), Float32(b) + with fx.fastmath(fx.FastMathFlags.fast): + _ = fa + fb + _ = fa - fb + _ = fa * fb + _ = fa / fb + + ir_text = _build(build, [ir.F32Type.get, ir.F32Type.get]) + assert "arith.addf" in ir_text and "fastmath" in ir_text + assert ir_text.count("fastmath") >= 4 + + +@pytest.mark.l0_backend_agnostic +def test_no_flag_outside_block(): + def build(a, b): + fa, fb = Float32(a), Float32(b) + _ = fa + fb + + ir_text = _build(build, [ir.F32Type.get, ir.F32Type.get]) + assert "fastmath" not in ir_text + + +@pytest.mark.l0_backend_agnostic +def test_math_function_inherits_ambient(): + def build(a): + fa = Float32(a) + with fx.fastmath(fx.FastMathFlags.fast): + fx.exp(fa) + + ir_text = _build(build, [ir.F32Type.get]) + assert "math.exp" in ir_text and "fastmath" in ir_text + + +@pytest.mark.l0_backend_agnostic +def test_explicit_arg_overrides_ambient(): + def build(a): + fa = Float32(a) + with fx.fastmath(fx.FastMathFlags.fast): + fx.exp(fa, fastmath="contract") + + ir_text = _build(build, [ir.F32Type.get]) + assert "fastmath" in ir_text + assert "fastmath" not in ir_text + + +@pytest.mark.l0_backend_agnostic +def test_combined_flags(): + def build(a, b): + fa, fb = Float32(a), Float32(b) + with fx.fastmath(fx.FastMathFlags.reassoc | fx.FastMathFlags.contract): + _ = fa + fb + + ir_text = _build(build, [ir.F32Type.get, ir.F32Type.get]) + assert "fastmath" in ir_text + + +@pytest.mark.l0_backend_agnostic +def test_nested_blocks_restore(): + def build(a, b): + fa, fb = Float32(a), Float32(b) + with fx.fastmath(fx.FastMathFlags.fast): + with fx.fastmath(fx.FastMathFlags.contract): + _ = fa * fb # contract + _ = fa + fb # back to fast + _ = fa - fb # no flag + + ir_text = _build(build, [ir.F32Type.get, ir.F32Type.get]) + lines = [ln for ln in ir_text.splitlines() if "arith." in ln] + mul = next(ln for ln in lines if "mulf" in ln) + add = next(ln for ln in lines if "addf" in ln) + sub = next(ln for ln in lines if "subf" in ln) + assert "fastmath" in mul + assert "fastmath" in add + assert "fastmath" not in sub + + +@pytest.mark.l0_backend_agnostic +def test_neg_and_pow_operators_pick_up_ambient(): + """Unary negation (negf) and power (**, math.powf) also inherit ambient.""" + + def build(a, b): + fa, fb = Float32(a), Float32(b) + _ = -fa # outside → no flag + _ = fa**fb # outside → no flag + with fx.fastmath(fx.FastMathFlags.fast): + _ = -fa # negf inherits + _ = fa**fb # powf inherits + + ir_text = _build(build, [ir.F32Type.get, ir.F32Type.get]) + negs = [ln for ln in ir_text.splitlines() if "arith.negf" in ln] + pows = [ln for ln in ir_text.splitlines() if "math.powf" in ln] + assert "fastmath" not in negs[0] and "fastmath" in negs[1] + assert "fastmath" not in pows[0] and "fastmath" in pows[1] + + +@pytest.mark.l0_backend_agnostic +def test_named_methods_inherit_ambient(): + """ArithValue/Numeric named fastmath methods (addf/exp2) inherit ambient.""" + + def build(a, b): + fa, fb = Float32(a), Float32(b) + _ = fa + fb # outside → no flag + with fx.fastmath(fx.FastMathFlags.fast): + _ = fa + fb # inherits ambient + _ = fa.exp2() # inherits ambient + _ = fa.addf(fb, fastmath="none") # explicit overrides + + ir_text = _build(build, [ir.F32Type.get, ir.F32Type.get]) + addf_lines = [ln for ln in ir_text.splitlines() if "arith.addf" in ln] + assert "fastmath" not in addf_lines[0] # outside block + assert "fastmath" in addf_lines[1] # inherited + assert "fastmath" not in addf_lines[2] # explicit "none" overrides → default (omitted) + assert any("math.exp2" in ln and "fastmath" in ln for ln in ir_text.splitlines()) + + +@pytest.mark.l0_backend_agnostic +def test_integer_ops_unaffected(): + def build(a, b): + ia, ib = Int32(a), Int32(b) + with fx.fastmath(fx.FastMathFlags.fast): + _ = ia + ib + _ = ia * ib + + def i32(): + return ir.IntegerType.get_signless(32) + + ir_text = _build(build, [i32, i32]) + assert "arith.addi" in ir_text and "arith.muli" in ir_text + assert "fastmath" not in ir_text + + +@flyc.kernel +def _fm_kernel(): + tid = fx.thread_idx.x + x = fx.Float32(tid) + z = x * x + x # operators → kernel-level flag + _ = fx.exp(z) # math fn → inherits kernel-level flag + + with fx.fastmath(fx.FastMathFlags.contract): + _ = x * z # block overrides kernel | + _ = fx.exp(x, fastmath="none") # explicit op-level override + + +@flyc.jit +def _fm_launch_plain(stream: fx.Stream = fx.Stream(None)): + _fm_kernel().launch(grid=(1, 1, 1), block=(32, 1, 1), stream=stream) + + +@flyc.jit +def _fm_launch_hinted(stream: fx.Stream = fx.Stream(None)): + _fm_kernel().launch(grid=(1, 1, 1), block=(32, 1, 1), stream=stream) + + +@flyc.jit +def _fm_launch_fast_fp(stream: fx.Stream = fx.Stream(None)): + _fm_kernel().launch(grid=(1, 1, 1), block=(32, 1, 1), stream=stream) + + +@flyc.jit +def _fm_launch_fast_fp_explicit_fastmath(stream: fx.Stream = fx.Stream(None)): + _fm_kernel().launch(grid=(1, 1, 1), block=(32, 1, 1), stream=stream) + + +def _source_ir(launch_fn): + launch_fn(stream=torch.cuda.current_stream()) + assert launch_fn._mem_cache, "expected at least one cached compilation" + return next(iter(launch_fn._mem_cache.values())).source_ir + + +def _arith_lines(ir_text, needle): + return [ln.strip() for ln in ir_text.splitlines() if needle in ln and ln.strip().startswith("%")] + + +@pytest.mark.l2_device +@pytest.mark.rocm_lower +@requires_gpu +def test_no_hint_only_block_scope_applies(): + """Without a kernel hint, kernel-level ops carry no flag, but an inner + ``with fx.fastmath`` block still applies (block is independent of kernel).""" + ir_text = _source_ir(_fm_launch_plain) + muls = _arith_lines(ir_text, "arith.mulf") # [0]=x*x, [1]=x*z(block) + assert "fastmath" not in muls[0] # kernel-level op, no hint → plain + assert "fastmath" in muls[1] # inner block still applies + assert all("fastmath" not in ln for ln in _arith_lines(ir_text, "arith.addf")) + + +@pytest.mark.l2_device +@pytest.mark.rocm_lower +@requires_gpu +def test_kernel_level_fastmath_and_scope_overrides(): + hinted = flyc.compile[{"fastmath": "fast"}](_fm_launch_hinted) + ir_text = _source_ir(hinted) + + muls = _arith_lines(ir_text, "arith.mulf") # [0]=x*x, [1]=x*z(block) + assert "fastmath" in muls[0] # kernel-level fast + assert "fastmath" in muls[1] # block overrides kernel + # addf x*x+x → kernel-level fast + assert all("fastmath" in ln for ln in _arith_lines(ir_text, "arith.addf")) + # math.exp(z) inherits fast; math.exp(x, "none") explicitly opts out + exps = _arith_lines(ir_text, "math.exp") # [0]=exp(z), [1]=exp(x,none) + assert "fastmath" in exps[0] # inherits kernel-level + assert "fastmath" not in exps[1] # explicit op-level "none" override + + +@pytest.mark.l2_device +@pytest.mark.rocm_lower +@requires_gpu +def test_fast_fp_math_defaults_kernel_fastmath_context(): + hinted = flyc.compile[{"fast_fp_math": True}](_fm_launch_fast_fp) + ir_text = _source_ir(hinted) + + muls = _arith_lines(ir_text, "arith.mulf") + assert "fastmath" in muls[0] + assert "fastmath" in muls[1] + assert all("fastmath" in ln for ln in _arith_lines(ir_text, "arith.addf")) + + +@pytest.mark.l2_device +@pytest.mark.rocm_lower +@requires_gpu +def test_explicit_fastmath_hint_overrides_fast_fp_math_default(): + hinted = flyc.compile[{"fast_fp_math": True, "fastmath": "contract"}](_fm_launch_fast_fp_explicit_fastmath) + ir_text = _source_ir(hinted) + + muls = _arith_lines(ir_text, "arith.mulf") + assert "fastmath" in muls[0] + assert "fastmath" in muls[1] + assert all("fastmath" in ln for ln in _arith_lines(ir_text, "arith.addf")) + assert "fastmath" not in ir_text + + +@pytest.mark.l2_device +@pytest.mark.rocm_lower +@requires_gpu +def test_hint_changes_cache_key(): + """The fastmath hint must be part of the cache key (rides _hints_).""" + _fm_launch_hinted._ensure_sig() + sig = _fm_launch_hinted._sig + bound = sig.bind() + bound.apply_defaults() + + _fm_launch_hinted.compile_hints = {} + key_none = _fm_launch_hinted._resolve_and_make_cache_key(bound.arguments) + _fm_launch_hinted.compile_hints = {"fastmath": "fast"} + key_fast = _fm_launch_hinted._resolve_and_make_cache_key(bound.arguments) + _fm_launch_hinted.compile_hints = {"fastmath": "contract"} + key_contract = _fm_launch_hinted._resolve_and_make_cache_key(bound.arguments) + + assert key_none != key_fast + assert key_fast != key_contract diff --git a/tests/unit/test_if_dispatch_paths.py b/tests/unit/test_if_dispatch_paths.py index b13cfdfcc..ee58af1ae 100644 --- a/tests/unit/test_if_dispatch_paths.py +++ b/tests/unit/test_if_dispatch_paths.py @@ -4,6 +4,7 @@ # Copyright (c) 2025 FlyDSL Project Contributors import ast +import types import pytest @@ -13,6 +14,10 @@ from flydsl.expr.numeric import Int32 +def _dynamic_chained_compare(x): + return Int32(0) <= x < Int32(8) + + def test_collect_assigned_vars_supports_tuple_and_augassign(): code = """ a, (b, c) = foo() @@ -164,6 +169,31 @@ def else_fn(x): assert isinstance(out, Int32) +def test_ast_rewrite_lowers_dynamic_chained_compare(): + rewritten = types.FunctionType( + _dynamic_chained_compare.__code__, + dict(_dynamic_chained_compare.__globals__), + _dynamic_chained_compare.__name__, + ) + ASTRewriter.transform(rewritten) + + with Context(), Location.unknown(): + module = Module.create() + i32 = IntegerType.get_signless(32) + i1 = IntegerType.get_signless(1) + with InsertionPoint(module.body): + f = func.FuncOp("test_dynamic_chained_compare", FunctionType.get([i32], [i1])) + entry = f.add_entry_block() + with InsertionPoint(entry): + out = rewritten(Int32(entry.arguments[0])) + func.ReturnOp([out.ir_value()]) + + assert module.operation.verify() + ir_text = str(module) + assert ir_text.count("arith.cmpi") >= 2 + assert "arith.andi" in ir_text + + def test_ast_rewrite_keeps_semantics_for_static_bool(): called = {"n": 0} diff --git a/tests/unit/test_numeric_promotion.py b/tests/unit/test_numeric_promotion.py index c4aa7a68a..e7079206f 100644 --- a/tests/unit/test_numeric_promotion.py +++ b/tests/unit/test_numeric_promotion.py @@ -12,6 +12,8 @@ mixed-sign mixed-width). """ +import operator + import pytest import flydsl.expr as fx @@ -21,29 +23,35 @@ def _binop(lhs_ty, rhs_ty, op): - """Build two block-arg values of the requested DSL types and apply `op`. - Returns the resulting Numeric. We use block args so the operands are - genuinely dynamic ir.Values (not Python literals), which is the path - most kernel code hits. - """ with Context() as ctx: ctx.allow_unregistered_dialects = True with Location.unknown(ctx): module = Module.create() from flydsl._mlir.dialects import func - from flydsl._mlir.ir import FunctionType + from flydsl._mlir.ir import FunctionType, VectorType + + def _vec(t): + return VectorType.get([4], t.ir_type) with InsertionPoint(module.body): - f = func.FuncOp("k", FunctionType.get([lhs_ty.ir_type, rhs_ty.ir_type], [])) + ftype = FunctionType.get([lhs_ty.ir_type, rhs_ty.ir_type, _vec(lhs_ty), _vec(rhs_ty)], []) + f = func.FuncOp("k", ftype) entry = f.add_entry_block() with InsertionPoint(entry): a = lhs_ty(entry.arguments[0]) b = rhs_ty(entry.arguments[1]) - result = op(a, b) + va = fx.Vector(entry.arguments[2], 4, lhs_ty) + vb = fx.Vector(entry.arguments[3], 4, rhs_ty) + scalar = op(a, b) + vector = op(va, vb) func.ReturnOp([]) assert module.operation.verify() - return result + assert vector.dtype is scalar.dtype, ( + f"vector/scalar dtype drift for {lhs_ty.__name__} {op.__name__} {rhs_ty.__name__}: " + f"vector -> {vector.dtype.__name__}, scalar -> {scalar.dtype.__name__}" + ) + return scalar # Same-sign / same-width: must stay narrow (no auto-int32 promotion). @@ -148,22 +156,7 @@ def test_float_wider_wins(a, b, expected): # Boolean arithmetic: bool + bool → Int32 (matches C++ "bool participates as int"). def test_bool_plus_bool_widens_to_int32(): - with Context() as ctx, Location.unknown(ctx): - ctx.allow_unregistered_dialects = True - module = Module.create() - from flydsl._mlir.dialects import func - from flydsl._mlir.ir import FunctionType - - with InsertionPoint(module.body): - f = func.FuncOp("k", FunctionType.get([fx.Boolean.ir_type, fx.Boolean.ir_type], [])) - entry = f.add_entry_block() - with InsertionPoint(entry): - a = fx.Boolean(entry.arguments[0]) - b = fx.Boolean(entry.arguments[1]) - r = a + b - func.ReturnOp([]) - assert module.operation.verify() - assert r.dtype is fx.Int32 + assert _binop(fx.Boolean, fx.Boolean, operator.add).dtype is fx.Int32 # True division on integers: Python `/` lifts int/int to float. @@ -184,3 +177,51 @@ def test_truediv_int_lifts_to_float(ty, expected): @pytest.mark.parametrize("ty", [fx.Int8, fx.Int32, fx.Int64, fx.Uint32, fx.Int128]) def test_floordiv_int_stays_int(ty): assert _binop(ty, ty, lambda x, y: x // y).dtype is ty + + +# --------------------------------------------------------------------------- +# Broader operator coverage. Every case runs through `_binop`, which builds a +# scalar pair and a vector pair and asserts they promote identically — so these +# double as the Vector/Numeric result-type consistency checks. +# --------------------------------------------------------------------------- + +# Representative pairs: same-sign, mixed-sign, and cross-kind mixing. +_MIXED_PAIRS = [ + (fx.Int8, fx.Int16), + (fx.Uint8, fx.Uint16), + (fx.Int32, fx.Uint32), + (fx.Int16, fx.Uint8), + (fx.Int64, fx.Uint32), + (fx.Int32, fx.Float32), + (fx.Float16, fx.Float32), +] + +_MIXED_INT_PAIRS = [ + (fx.Int8, fx.Int16), + (fx.Uint8, fx.Uint16), + (fx.Uint16, fx.Uint64), + (fx.Int32, fx.Uint32), + (fx.Int16, fx.Uint8), + (fx.Int64, fx.Uint32), +] + + +# Subtraction / multiplication promote exactly like addition (no override). +@pytest.mark.parametrize("op", [operator.sub, operator.mul]) +@pytest.mark.parametrize("a,b", _MIXED_PAIRS) +def test_sub_mul_promote_like_add(op, a, b): + assert _binop(a, b, op).dtype is _binop(a, b, operator.add).dtype + + +# Integer mod / bitwise keep the usual integer promotion (no override). +@pytest.mark.parametrize("op", [operator.mod, operator.and_, operator.or_, operator.xor]) +@pytest.mark.parametrize("a,b", _MIXED_INT_PAIRS) +def test_int_mod_bitwise_promote_usual(op, a, b): + assert _binop(a, b, op).dtype is _binop(a, b, operator.add).dtype + + +# Comparisons yield Boolean regardless of operand types. +@pytest.mark.parametrize("op", [operator.lt, operator.le, operator.gt, operator.ge, operator.eq, operator.ne]) +@pytest.mark.parametrize("a,b", _MIXED_PAIRS) +def test_comparison_yields_boolean(op, a, b): + assert _binop(a, b, op).dtype is fx.Boolean diff --git a/tests/unit/test_vector.py b/tests/unit/test_vector.py index 436d4b55c..73c492b5d 100644 --- a/tests/unit/test_vector.py +++ b/tests/unit/test_vector.py @@ -19,10 +19,10 @@ Boolean, Float16, Float32, - Float64, Int8, Int16, Int32, + Int64, Numeric, Uint32, ) @@ -315,73 +315,45 @@ def build(a, b): class TestTypePromotion: - - def test_same_type(self): - assert Numeric.promote(Float32, Float32) is Float32 - - def test_f16_f32(self): - assert Numeric.promote(Float16, Float32) is Float32 - - def test_bf16_f32(self): - assert Numeric.promote(BFloat16, Float32) is Float32 - - def test_int_float(self): - """Int32 + Float32 → Float32.""" - assert Numeric.promote(Int32, Float32) is Float32 - - def test_int_wider_than_float(self): - """Float16 + Int32 → Float32 (int width 32 > float width 16).""" - assert Numeric.promote(Float16, Int32) is Float32 - - def test_int_same_width_as_float(self): - """Float32 + Int32 → Float32 (same width, float wins).""" - assert Numeric.promote(Float32, Int32) is Float32 - - def test_int_narrower_than_float(self): - """Float32 + Int16 → Float32 (int is narrower).""" - assert Numeric.promote(Float32, Int16) is Float32 - - def test_int64_with_float32(self): - """Float32 + Int64 → Float64 (int width 64 > float width 32).""" - from flydsl.expr.numeric import Int64 - - assert Numeric.promote(Float32, Int64) is Float64 - - def test_f16_f64(self): - assert Numeric.promote(Float16, Float64) is Float64 - def test_promote_in_operator(self): - """Mixed-type vector ops require explicit .to() conversion (no auto-promote).""" + """Mixed-type vector ops auto-promote to a common dtype.""" def build(a, b): ta = Vector(a, 8, Float16) tb = Vector(b, 8, Float32) - ta_f32 = ta.to(Float32) - result = ta_f32 + tb + result = ta + tb assert result.dtype is Float32 ir_text = _build_module(build, [_vec_f16, _vec_f32]) assert "arith.extf" in ir_text assert "arith.addf" in ir_text - def test_mixed_signed_unsigned_int(self): - """Int32 + Uint32 → Uint32 (unsigned wins at same width).""" - assert Numeric.promote(Int32, Uint32) is Uint32 - assert Numeric.promote(Uint32, Int32) is Uint32 - def test_promote_bf16_scalar(self): - """BFloat16 tensor + scalar → explicit .to() needed for mixed-type ops.""" + """BFloat16 tensor + scalar auto-promotes like scalar Numeric arithmetic.""" def build(a): ta = Vector(a, 8, BFloat16) - ta_f32 = ta.to(Float32) - result = ta_f32 + 1.0 + result = ta + 1.0 assert result.dtype is Float32 ir_text = _build_module(build, [_vec_bf16]) assert "arith.extf" in ir_text assert "arith.addf" in ir_text + def test_promote_int32x4_int64x4(self): + """Regression for Int32x4 + Int64x4 producing mixed arith.addi operands.""" + + def build(): + lhs = Vector.from_elements([Int32(1), Int32(2), Int32(3), Int32(4)]) + rhs = Vector.from_elements([Int64(5), Int64(6), Int64(7), Int64(8)]) + result = lhs + rhs + assert result.dtype is Int64 + + ir_text = _build_module(build, []) + assert "arith.extsi" in ir_text + assert "arith.addi" in ir_text + assert "vector<4xi64>" in ir_text + # =========================================================================== # D. Type conversion (.to())