From 6fe20a8e0be8f5bdcdcbbd120282b0e0fccb99a6 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Fri, 3 Jul 2026 09:44:47 +0000 Subject: [PATCH] [Fix] Consistent promotion rules for Vector and Numeric --- python/flydsl/expr/numeric.py | 163 ++++++++++----------------- python/flydsl/expr/typing.py | 24 ++-- tests/unit/test_numeric_promotion.py | 91 +++++++++++---- tests/unit/test_vector.py | 66 ++++------- 4 files changed, 159 insertions(+), 185 deletions(-) diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index ef5959a82..16c091ee7 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -168,18 +168,11 @@ def zero(cls): _CMP_OPS = frozenset({operator.lt, operator.le, operator.gt, operator.ge, operator.eq, operator.ne}) - -def _widen_bool_to_int32(x, widen_bool=False): - """Promote Boolean to Int32 for arithmetic when widen_bool=True. - - Per C++-style usual arithmetic conversions, we deliberately do NOT apply - integer promotion: i8/i16/u8/u16 stay at their narrow width. - Same-width same-signedness operands keep their type; cross-width or - cross-sign mixing is resolved by ``_coerce_operands``. - """ - if widen_bool and type(x) is Boolean: - return x.to(Int32), Int32 - return x, type(x) +# Operators for which Boolean operands widen to Int32 (C++-style: bool +# participates in arithmetic as int). +_WIDEN_BOOL_OPS = frozenset( + {operator.add, operator.sub, operator.mul, operator.floordiv, operator.truediv, operator.mod} +) def _resolve_float_type(ta, tb): @@ -203,27 +196,49 @@ def _resolve_float_type(ta, tb): raise ValueError(f"no common float type for {ta} and {tb}; cast explicitly") -def _coerce_operands(a, b, widen_bool=False): - """Promote *a* and *b* to a common scalar type.""" - ta, tb = type(a), type(b) - a, ta = _widen_bool_to_int32(a, widen_bool=widen_bool) - b, tb = _widen_bool_to_int32(b, widen_bool=widen_bool) +def _resolve_common_type(ta, tb, op=None): + """Resolve the common Numeric type used by scalar and vector operators.""" + if op in _WIDEN_BOOL_OPS: + ta = Int32 if ta is Boolean else ta + tb = Int32 if tb is Boolean else tb if ta is tb: - return a, b, ta + return ta if ta.is_float or tb.is_float: - dest = _resolve_float_type(ta, tb) - return (a if type(a) is dest else a.to(dest), b if type(b) is dest else b.to(dest), dest) + return _resolve_float_type(ta, tb) # Both integers — pick wider; on tie, prefer unsigned when mixed sign if ta.signed == tb.signed: - wider = ta if ta.width >= tb.width else tb - return (a if type(a) is wider else a.to(wider), b if type(b) is wider else b.to(wider), wider) + return ta if ta.width >= tb.width else tb u, s = (ta, tb) if not ta.signed else (tb, ta) - dest = u if u.width >= s.width else s - return (a if type(a) is dest else a.to(dest), b if type(b) is dest else b.to(dest), dest) + return u if u.width >= s.width else s + + +def _coerce_operands(a, b, op=None): + """Promote *a* and *b* to the common dtype for *op*.""" + dest = _resolve_common_type(a.dtype, b.dtype, op=op) + return (a if a.dtype is dest else a.to(dest), b if b.dtype is dest else b.to(dest), dest) + + +def _result_type_for_op(op, promoted_type): + """Apply op-specific result-type overrides on top of *promoted_type*. + + Shared by scalar Numeric and Vector arithmetic so both resolve identical + result dtypes: + + - comparisons -> ``Boolean`` + - integer true-division -> float (Python ``/`` lifts int/int to float; + width > 32 uses ``Float64``, otherwise ``Float32``) + + All other operators keep the promoted type unchanged. + """ + if op in _CMP_OPS: + return Boolean + if op is operator.truediv and promoted_type.is_integer: + return Float64 if promoted_type.width > 32 else Float32 + return promoted_type def _try_coerce_rhs(rhs): @@ -250,7 +265,7 @@ def _extract_arith(val, signed): return v.with_signedness(signed) if isinstance(v, ArithValue) else v -def _make_binop(op, promote=True, widen_bool=False, swap=False): +def _make_binop(op, swap=False): """Create a binary-operator closure for Numeric subclasses.""" def _apply(lhs, rhs): @@ -258,16 +273,9 @@ def _apply(lhs, rhs): if rhs is None: return NotImplemented - out_type = type(lhs) - if promote: - lhs, rhs, out_type = _coerce_operands(lhs, rhs, widen_bool) - else: - rhs = type(lhs)(rhs) + lhs, rhs, out_type = _coerce_operands(lhs, rhs, op=op) - if op in _CMP_OPS: - out_type = Boolean - elif op is operator.truediv and isinstance(lhs, Integer): - out_type = Float64 if out_type.width > 32 else Float32 + out_type = _result_type_for_op(op, out_type) lv, rv = _extract_arith(lhs, lhs.signed), _extract_arith(rhs, rhs.signed) if swap: @@ -438,40 +446,40 @@ def from_ir_type(ir_type): return ir2dsl_map[ir_type] def __add__(self, other): - return _make_binop(operator.add, widen_bool=True)(self, other) + return _make_binop(operator.add)(self, other) def __sub__(self, other): - return _make_binop(operator.sub, widen_bool=True)(self, other) + return _make_binop(operator.sub)(self, other) def __mul__(self, other): - return _make_binop(operator.mul, widen_bool=True)(self, other) + return _make_binop(operator.mul)(self, other) def __floordiv__(self, other): - return _make_binop(operator.floordiv, widen_bool=True)(self, other) + return _make_binop(operator.floordiv)(self, other) def __truediv__(self, other): - return _make_binop(operator.truediv, widen_bool=True)(self, other) + return _make_binop(operator.truediv)(self, other) def __mod__(self, other): - return _make_binop(operator.mod, widen_bool=True)(self, other) + return _make_binop(operator.mod)(self, other) def __radd__(self, other): return self.__add__(other) def __rsub__(self, other): - return _make_binop(operator.sub, widen_bool=True, swap=True)(self, other) + return _make_binop(operator.sub, swap=True)(self, other) def __rmul__(self, other): return self.__mul__(other) def __rfloordiv__(self, other): - return _make_binop(operator.floordiv, widen_bool=True, swap=True)(self, other) + return _make_binop(operator.floordiv, swap=True)(self, other) def __rtruediv__(self, other): - return _make_binop(operator.truediv, widen_bool=True, swap=True)(self, other) + return _make_binop(operator.truediv, swap=True)(self, other) def __rmod__(self, other): - return _make_binop(operator.mod, widen_bool=True, swap=True)(self, other) + return _make_binop(operator.mod, swap=True)(self, other) def __pow__(self, other): return _make_binop(operator.pow)(self, other) @@ -483,13 +491,15 @@ def __ne__(self, other): return _make_binop(operator.ne)(self, other) # ── Proxy methods: delegate ArithValue-specific ops via ir_value() ── - def maximumf(self, other): + def maximumf(self, other, *, fastmath=None): """Float maximum — delegates to ArithValue.maximumf.""" - return type(self)(self.ir_value().maximumf(_to_raw(other))) + return type(self)(self.ir_value().maximumf(_to_raw(other), fastmath=fastmath)) + + def minimumf(self, other, *, fastmath=None): + """Float minimum — delegates to expr.arith.minimumf.""" + from .arith import minimumf as _minimumf - def minimumf(self, other): - """Float minimum — delegates to ArithValue.minimumf.""" - return type(self)(self.ir_value().minimumf(_to_raw(other))) + return _minimumf(self, other, fastmath=fastmath) def exp2(self, *, fastmath=None): """Base-2 exponential — delegates to ArithValue.exp2.""" @@ -785,61 +795,6 @@ class Float4E2M1FN(Float, metaclass=NumericMeta, width=4, ir_type=T.f4E2M1FN): . # Float type rank for promotion (must be after class definitions) _FLOAT_RANK = {Float64: 3, Float32: 2, Float16: 1, BFloat16: 1} -# ── Type promotion (added to Numeric after all subclasses exist) ────── - -_FLOAT_BY_MIN_WIDTH = {16: Float16, 32: Float32, 64: Float64} - - -def _widen_float(float_type, min_width): - """Return the narrowest standard float type with width >= *min_width*.""" - if float_type.width >= min_width: - return float_type - for w in (32, 64): - if w >= min_width: - return _FLOAT_BY_MIN_WIDTH[w] - return Float64 - - -@classmethod -def _promote(cls, a_type, b_type): - """Resolve the promoted result type for two Numeric types. - - :param a_type: Left Numeric class (e.g. Float16) - :param b_type: Right Numeric class (e.g. Float32) - :return: The common Numeric class both can be safely promoted to - """ - if a_type is b_type: - return a_type - - a_float = a_type.is_float - b_float = b_type.is_float - - if a_float and not b_float: - return _widen_float(a_type, b_type.width) - if b_float and not a_float: - return _widen_float(b_type, a_type.width) - - if a_float and b_float: - aw, bw = a_type.width, b_type.width - if aw > bw and aw >= 16: - return a_type - if bw > aw and bw >= 16: - return b_type - if aw == bw: - ra = _FLOAT_RANK.get(a_type, 0) - rb = _FLOAT_RANK.get(b_type, 0) - return a_type if ra >= rb else b_type - raise ValueError(f"cannot promote {a_type} and {b_type}; cast explicitly") - - # Both integers - if a_type.signed == b_type.signed: - return a_type if a_type.width >= b_type.width else b_type - u, s = (a_type, b_type) if not a_type.signed else (b_type, a_type) - return u if u.width >= s.width else s - - -Numeric.promote = _promote - class Index(Integer, metaclass=NumericMeta, width=64, signed=False, ir_type=lambda: ir.IndexType.get()): """DSL Numeric for MLIR index type. Replaces arith.index(N). diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index 7b0c87ce7..15fba5408 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -45,6 +45,9 @@ Uint32, Uint64, Uint128, + _coerce_operands, + _result_type_for_op, + _try_coerce_rhs, as_numeric, ) from .primitive import * @@ -1593,28 +1596,30 @@ def ir_value(self): def with_signedness(self, signed): return ArithValue(self, signed) - def _wrap_op_result(self, result, shape): + def _wrap_op_result(self, result, shape, result_dtype): if isinstance(result, ir.Value) and isinstance(result.type, ir.VectorType): - return Vector(result, shape, self._result_dtype(result.type.element_type)) + return Vector(result, shape, result_dtype) if isinstance(result, Numeric): return result if isinstance(result, ir.Value): - return self._result_dtype(result.type)(result) + return result_dtype(result) return result - def _result_dtype(self, elem_type) -> Type[Numeric]: - if self._dtype.ir_type == elem_type: - return self._dtype - return Numeric.from_ir_type(elem_type) - def _apply_op(self, method_name, op, other, flip=False): lhs = self rhs = other shape = self.shape + promoted_dtype = self.dtype if isinstance(other, Vector): shape = self._infer_broadcast_shape(self.shape, other.shape) lhs = self.broadcast_to(shape) rhs = other.broadcast_to(shape) + else: + rhs_numeric = _try_coerce_rhs(rhs) + if rhs_numeric is None: + return NotImplemented + rhs = rhs_numeric + lhs, rhs, promoted_dtype = _coerce_operands(lhs, rhs, op=op) method = getattr(ArithValue, method_name) if flip: if isinstance(rhs, Vector): @@ -1624,7 +1629,8 @@ def _apply_op(self, method_name, op, other, flip=False): result = getattr(ArithValue, reverse_name)(lhs, rhs) else: result = method(lhs, rhs) - return self._wrap_op_result(result, shape) + result_dtype = _result_type_for_op(op, promoted_dtype) + return self._wrap_op_result(result, shape, result_dtype) def apply_op(self, op, other, flip=False): method_name = _VECTOR_OP_METHODS.get(op) diff --git a/tests/unit/test_numeric_promotion.py b/tests/unit/test_numeric_promotion.py index c4aa7a68a..e7079206f 100644 --- a/tests/unit/test_numeric_promotion.py +++ b/tests/unit/test_numeric_promotion.py @@ -12,6 +12,8 @@ mixed-sign mixed-width). """ +import operator + import pytest import flydsl.expr as fx @@ -21,29 +23,35 @@ def _binop(lhs_ty, rhs_ty, op): - """Build two block-arg values of the requested DSL types and apply `op`. - Returns the resulting Numeric. We use block args so the operands are - genuinely dynamic ir.Values (not Python literals), which is the path - most kernel code hits. - """ with Context() as ctx: ctx.allow_unregistered_dialects = True with Location.unknown(ctx): module = Module.create() from flydsl._mlir.dialects import func - from flydsl._mlir.ir import FunctionType + from flydsl._mlir.ir import FunctionType, VectorType + + def _vec(t): + return VectorType.get([4], t.ir_type) with InsertionPoint(module.body): - f = func.FuncOp("k", FunctionType.get([lhs_ty.ir_type, rhs_ty.ir_type], [])) + ftype = FunctionType.get([lhs_ty.ir_type, rhs_ty.ir_type, _vec(lhs_ty), _vec(rhs_ty)], []) + f = func.FuncOp("k", ftype) entry = f.add_entry_block() with InsertionPoint(entry): a = lhs_ty(entry.arguments[0]) b = rhs_ty(entry.arguments[1]) - result = op(a, b) + va = fx.Vector(entry.arguments[2], 4, lhs_ty) + vb = fx.Vector(entry.arguments[3], 4, rhs_ty) + scalar = op(a, b) + vector = op(va, vb) func.ReturnOp([]) assert module.operation.verify() - return result + assert vector.dtype is scalar.dtype, ( + f"vector/scalar dtype drift for {lhs_ty.__name__} {op.__name__} {rhs_ty.__name__}: " + f"vector -> {vector.dtype.__name__}, scalar -> {scalar.dtype.__name__}" + ) + return scalar # Same-sign / same-width: must stay narrow (no auto-int32 promotion). @@ -148,22 +156,7 @@ def test_float_wider_wins(a, b, expected): # Boolean arithmetic: bool + bool → Int32 (matches C++ "bool participates as int"). def test_bool_plus_bool_widens_to_int32(): - with Context() as ctx, Location.unknown(ctx): - ctx.allow_unregistered_dialects = True - module = Module.create() - from flydsl._mlir.dialects import func - from flydsl._mlir.ir import FunctionType - - with InsertionPoint(module.body): - f = func.FuncOp("k", FunctionType.get([fx.Boolean.ir_type, fx.Boolean.ir_type], [])) - entry = f.add_entry_block() - with InsertionPoint(entry): - a = fx.Boolean(entry.arguments[0]) - b = fx.Boolean(entry.arguments[1]) - r = a + b - func.ReturnOp([]) - assert module.operation.verify() - assert r.dtype is fx.Int32 + assert _binop(fx.Boolean, fx.Boolean, operator.add).dtype is fx.Int32 # True division on integers: Python `/` lifts int/int to float. @@ -184,3 +177,51 @@ def test_truediv_int_lifts_to_float(ty, expected): @pytest.mark.parametrize("ty", [fx.Int8, fx.Int32, fx.Int64, fx.Uint32, fx.Int128]) def test_floordiv_int_stays_int(ty): assert _binop(ty, ty, lambda x, y: x // y).dtype is ty + + +# --------------------------------------------------------------------------- +# Broader operator coverage. Every case runs through `_binop`, which builds a +# scalar pair and a vector pair and asserts they promote identically — so these +# double as the Vector/Numeric result-type consistency checks. +# --------------------------------------------------------------------------- + +# Representative pairs: same-sign, mixed-sign, and cross-kind mixing. +_MIXED_PAIRS = [ + (fx.Int8, fx.Int16), + (fx.Uint8, fx.Uint16), + (fx.Int32, fx.Uint32), + (fx.Int16, fx.Uint8), + (fx.Int64, fx.Uint32), + (fx.Int32, fx.Float32), + (fx.Float16, fx.Float32), +] + +_MIXED_INT_PAIRS = [ + (fx.Int8, fx.Int16), + (fx.Uint8, fx.Uint16), + (fx.Uint16, fx.Uint64), + (fx.Int32, fx.Uint32), + (fx.Int16, fx.Uint8), + (fx.Int64, fx.Uint32), +] + + +# Subtraction / multiplication promote exactly like addition (no override). +@pytest.mark.parametrize("op", [operator.sub, operator.mul]) +@pytest.mark.parametrize("a,b", _MIXED_PAIRS) +def test_sub_mul_promote_like_add(op, a, b): + assert _binop(a, b, op).dtype is _binop(a, b, operator.add).dtype + + +# Integer mod / bitwise keep the usual integer promotion (no override). +@pytest.mark.parametrize("op", [operator.mod, operator.and_, operator.or_, operator.xor]) +@pytest.mark.parametrize("a,b", _MIXED_INT_PAIRS) +def test_int_mod_bitwise_promote_usual(op, a, b): + assert _binop(a, b, op).dtype is _binop(a, b, operator.add).dtype + + +# Comparisons yield Boolean regardless of operand types. +@pytest.mark.parametrize("op", [operator.lt, operator.le, operator.gt, operator.ge, operator.eq, operator.ne]) +@pytest.mark.parametrize("a,b", _MIXED_PAIRS) +def test_comparison_yields_boolean(op, a, b): + assert _binop(a, b, op).dtype is fx.Boolean diff --git a/tests/unit/test_vector.py b/tests/unit/test_vector.py index 436d4b55c..73c492b5d 100644 --- a/tests/unit/test_vector.py +++ b/tests/unit/test_vector.py @@ -19,10 +19,10 @@ Boolean, Float16, Float32, - Float64, Int8, Int16, Int32, + Int64, Numeric, Uint32, ) @@ -315,73 +315,45 @@ def build(a, b): class TestTypePromotion: - - def test_same_type(self): - assert Numeric.promote(Float32, Float32) is Float32 - - def test_f16_f32(self): - assert Numeric.promote(Float16, Float32) is Float32 - - def test_bf16_f32(self): - assert Numeric.promote(BFloat16, Float32) is Float32 - - def test_int_float(self): - """Int32 + Float32 → Float32.""" - assert Numeric.promote(Int32, Float32) is Float32 - - def test_int_wider_than_float(self): - """Float16 + Int32 → Float32 (int width 32 > float width 16).""" - assert Numeric.promote(Float16, Int32) is Float32 - - def test_int_same_width_as_float(self): - """Float32 + Int32 → Float32 (same width, float wins).""" - assert Numeric.promote(Float32, Int32) is Float32 - - def test_int_narrower_than_float(self): - """Float32 + Int16 → Float32 (int is narrower).""" - assert Numeric.promote(Float32, Int16) is Float32 - - def test_int64_with_float32(self): - """Float32 + Int64 → Float64 (int width 64 > float width 32).""" - from flydsl.expr.numeric import Int64 - - assert Numeric.promote(Float32, Int64) is Float64 - - def test_f16_f64(self): - assert Numeric.promote(Float16, Float64) is Float64 - def test_promote_in_operator(self): - """Mixed-type vector ops require explicit .to() conversion (no auto-promote).""" + """Mixed-type vector ops auto-promote to a common dtype.""" def build(a, b): ta = Vector(a, 8, Float16) tb = Vector(b, 8, Float32) - ta_f32 = ta.to(Float32) - result = ta_f32 + tb + result = ta + tb assert result.dtype is Float32 ir_text = _build_module(build, [_vec_f16, _vec_f32]) assert "arith.extf" in ir_text assert "arith.addf" in ir_text - def test_mixed_signed_unsigned_int(self): - """Int32 + Uint32 → Uint32 (unsigned wins at same width).""" - assert Numeric.promote(Int32, Uint32) is Uint32 - assert Numeric.promote(Uint32, Int32) is Uint32 - def test_promote_bf16_scalar(self): - """BFloat16 tensor + scalar → explicit .to() needed for mixed-type ops.""" + """BFloat16 tensor + scalar auto-promotes like scalar Numeric arithmetic.""" def build(a): ta = Vector(a, 8, BFloat16) - ta_f32 = ta.to(Float32) - result = ta_f32 + 1.0 + result = ta + 1.0 assert result.dtype is Float32 ir_text = _build_module(build, [_vec_bf16]) assert "arith.extf" in ir_text assert "arith.addf" in ir_text + def test_promote_int32x4_int64x4(self): + """Regression for Int32x4 + Int64x4 producing mixed arith.addi operands.""" + + def build(): + lhs = Vector.from_elements([Int32(1), Int32(2), Int32(3), Int32(4)]) + rhs = Vector.from_elements([Int64(5), Int64(6), Int64(7), Int64(8)]) + result = lhs + rhs + assert result.dtype is Int64 + + ir_text = _build_module(build, []) + assert "arith.extsi" in ir_text + assert "arith.addi" in ir_text + assert "vector<4xi64>" in ir_text + # =========================================================================== # D. Type conversion (.to())