Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 59 additions & 104 deletions python/flydsl/expr/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,11 @@ def zero(cls):

_CMP_OPS = frozenset({operator.lt, operator.le, operator.gt, operator.ge, operator.eq, operator.ne})


def _widen_bool_to_int32(x, widen_bool=False):
"""Promote Boolean to Int32 for arithmetic when widen_bool=True.

Per C++-style usual arithmetic conversions, we deliberately do NOT apply
integer promotion: i8/i16/u8/u16 stay at their narrow width.
Same-width same-signedness operands keep their type; cross-width or
cross-sign mixing is resolved by ``_coerce_operands``.
"""
if widen_bool and type(x) is Boolean:
return x.to(Int32), Int32
return x, type(x)
# Operators for which Boolean operands widen to Int32 (C++-style: bool
# participates in arithmetic as int).
_WIDEN_BOOL_OPS = frozenset(
{operator.add, operator.sub, operator.mul, operator.floordiv, operator.truediv, operator.mod}
)


def _resolve_float_type(ta, tb):
Expand All @@ -203,27 +196,49 @@ def _resolve_float_type(ta, tb):
raise ValueError(f"no common float type for {ta} and {tb}; cast explicitly")


def _coerce_operands(a, b, widen_bool=False):
"""Promote *a* and *b* to a common scalar type."""
ta, tb = type(a), type(b)
a, ta = _widen_bool_to_int32(a, widen_bool=widen_bool)
b, tb = _widen_bool_to_int32(b, widen_bool=widen_bool)
def _resolve_common_type(ta, tb, op=None):
"""Resolve the common Numeric type used by scalar and vector operators."""
if op in _WIDEN_BOOL_OPS:
ta = Int32 if ta is Boolean else ta
tb = Int32 if tb is Boolean else tb

if ta is tb:
return a, b, ta
return ta

if ta.is_float or tb.is_float:
dest = _resolve_float_type(ta, tb)
return (a if type(a) is dest else a.to(dest), b if type(b) is dest else b.to(dest), dest)
return _resolve_float_type(ta, tb)

# Both integers — pick wider; on tie, prefer unsigned when mixed sign
if ta.signed == tb.signed:
wider = ta if ta.width >= tb.width else tb
return (a if type(a) is wider else a.to(wider), b if type(b) is wider else b.to(wider), wider)
return ta if ta.width >= tb.width else tb

u, s = (ta, tb) if not ta.signed else (tb, ta)
dest = u if u.width >= s.width else s
return (a if type(a) is dest else a.to(dest), b if type(b) is dest else b.to(dest), dest)
return u if u.width >= s.width else s


def _coerce_operands(a, b, op=None):
"""Promote *a* and *b* to the common dtype for *op*."""
dest = _resolve_common_type(a.dtype, b.dtype, op=op)
return (a if a.dtype is dest else a.to(dest), b if b.dtype is dest else b.to(dest), dest)


def _result_type_for_op(op, promoted_type):
"""Apply op-specific result-type overrides on top of *promoted_type*.

Shared by scalar Numeric and Vector arithmetic so both resolve identical
result dtypes:

- comparisons -> ``Boolean``
- integer true-division -> float (Python ``/`` lifts int/int to float;
width > 32 uses ``Float64``, otherwise ``Float32``)

All other operators keep the promoted type unchanged.
"""
if op in _CMP_OPS:
return Boolean
if op is operator.truediv and promoted_type.is_integer:
return Float64 if promoted_type.width > 32 else Float32
return promoted_type


def _try_coerce_rhs(rhs):
Expand All @@ -250,24 +265,17 @@ def _extract_arith(val, signed):
return v.with_signedness(signed) if isinstance(v, ArithValue) else v


def _make_binop(op, promote=True, widen_bool=False, swap=False):
def _make_binop(op, swap=False):
"""Create a binary-operator closure for Numeric subclasses."""

def _apply(lhs, rhs):
rhs = _try_coerce_rhs(rhs)
if rhs is None:
return NotImplemented

out_type = type(lhs)
if promote:
lhs, rhs, out_type = _coerce_operands(lhs, rhs, widen_bool)
else:
rhs = type(lhs)(rhs)
lhs, rhs, out_type = _coerce_operands(lhs, rhs, op=op)

if op in _CMP_OPS:
out_type = Boolean
elif op is operator.truediv and isinstance(lhs, Integer):
out_type = Float64 if out_type.width > 32 else Float32
out_type = _result_type_for_op(op, out_type)

lv, rv = _extract_arith(lhs, lhs.signed), _extract_arith(rhs, rhs.signed)
if swap:
Expand Down Expand Up @@ -438,40 +446,40 @@ def from_ir_type(ir_type):
return ir2dsl_map[ir_type]

def __add__(self, other):
return _make_binop(operator.add, widen_bool=True)(self, other)
return _make_binop(operator.add)(self, other)

def __sub__(self, other):
return _make_binop(operator.sub, widen_bool=True)(self, other)
return _make_binop(operator.sub)(self, other)

def __mul__(self, other):
return _make_binop(operator.mul, widen_bool=True)(self, other)
return _make_binop(operator.mul)(self, other)

def __floordiv__(self, other):
return _make_binop(operator.floordiv, widen_bool=True)(self, other)
return _make_binop(operator.floordiv)(self, other)

def __truediv__(self, other):
return _make_binop(operator.truediv, widen_bool=True)(self, other)
return _make_binop(operator.truediv)(self, other)

def __mod__(self, other):
return _make_binop(operator.mod, widen_bool=True)(self, other)
return _make_binop(operator.mod)(self, other)

def __radd__(self, other):
return self.__add__(other)

def __rsub__(self, other):
return _make_binop(operator.sub, widen_bool=True, swap=True)(self, other)
return _make_binop(operator.sub, swap=True)(self, other)

def __rmul__(self, other):
return self.__mul__(other)

def __rfloordiv__(self, other):
return _make_binop(operator.floordiv, widen_bool=True, swap=True)(self, other)
return _make_binop(operator.floordiv, swap=True)(self, other)

def __rtruediv__(self, other):
return _make_binop(operator.truediv, widen_bool=True, swap=True)(self, other)
return _make_binop(operator.truediv, swap=True)(self, other)

def __rmod__(self, other):
return _make_binop(operator.mod, widen_bool=True, swap=True)(self, other)
return _make_binop(operator.mod, swap=True)(self, other)

def __pow__(self, other):
return _make_binop(operator.pow)(self, other)
Expand All @@ -483,13 +491,15 @@ def __ne__(self, other):
return _make_binop(operator.ne)(self, other)

# ── Proxy methods: delegate ArithValue-specific ops via ir_value() ──
def maximumf(self, other):
def maximumf(self, other, *, fastmath=None):
"""Float maximum — delegates to ArithValue.maximumf."""
return type(self)(self.ir_value().maximumf(_to_raw(other)))
return type(self)(self.ir_value().maximumf(_to_raw(other), fastmath=fastmath))
Comment thread
sjfeng1999 marked this conversation as resolved.

def minimumf(self, other, *, fastmath=None):
"""Float minimum — delegates to expr.arith.minimumf."""
from .arith import minimumf as _minimumf

def minimumf(self, other):
"""Float minimum — delegates to ArithValue.minimumf."""
return type(self)(self.ir_value().minimumf(_to_raw(other)))
return _minimumf(self, other, fastmath=fastmath)

def exp2(self, *, fastmath=None):
"""Base-2 exponential — delegates to ArithValue.exp2."""
Expand Down Expand Up @@ -785,61 +795,6 @@ class Float4E2M1FN(Float, metaclass=NumericMeta, width=4, ir_type=T.f4E2M1FN): .
# Float type rank for promotion (must be after class definitions)
_FLOAT_RANK = {Float64: 3, Float32: 2, Float16: 1, BFloat16: 1}

# ── Type promotion (added to Numeric after all subclasses exist) ──────

_FLOAT_BY_MIN_WIDTH = {16: Float16, 32: Float32, 64: Float64}


def _widen_float(float_type, min_width):
"""Return the narrowest standard float type with width >= *min_width*."""
if float_type.width >= min_width:
return float_type
for w in (32, 64):
if w >= min_width:
return _FLOAT_BY_MIN_WIDTH[w]
return Float64


@classmethod
def _promote(cls, a_type, b_type):
"""Resolve the promoted result type for two Numeric types.

:param a_type: Left Numeric class (e.g. Float16)
:param b_type: Right Numeric class (e.g. Float32)
:return: The common Numeric class both can be safely promoted to
"""
if a_type is b_type:
return a_type

a_float = a_type.is_float
b_float = b_type.is_float

if a_float and not b_float:
return _widen_float(a_type, b_type.width)
if b_float and not a_float:
return _widen_float(b_type, a_type.width)

if a_float and b_float:
aw, bw = a_type.width, b_type.width
if aw > bw and aw >= 16:
return a_type
if bw > aw and bw >= 16:
return b_type
if aw == bw:
ra = _FLOAT_RANK.get(a_type, 0)
rb = _FLOAT_RANK.get(b_type, 0)
return a_type if ra >= rb else b_type
raise ValueError(f"cannot promote {a_type} and {b_type}; cast explicitly")

# Both integers
if a_type.signed == b_type.signed:
return a_type if a_type.width >= b_type.width else b_type
u, s = (a_type, b_type) if not a_type.signed else (b_type, a_type)
return u if u.width >= s.width else s


Numeric.promote = _promote


class Index(Integer, metaclass=NumericMeta, width=64, signed=False, ir_type=lambda: ir.IndexType.get()):
"""DSL Numeric for MLIR index type. Replaces arith.index(N).
Expand Down
24 changes: 15 additions & 9 deletions python/flydsl/expr/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
Uint32,
Uint64,
Uint128,
_coerce_operands,
_result_type_for_op,
_try_coerce_rhs,
as_numeric,
)
from .primitive import *
Expand Down Expand Up @@ -1593,28 +1596,30 @@ def ir_value(self):
def with_signedness(self, signed):
return ArithValue(self, signed)

def _wrap_op_result(self, result, shape):
def _wrap_op_result(self, result, shape, result_dtype):
if isinstance(result, ir.Value) and isinstance(result.type, ir.VectorType):
return Vector(result, shape, self._result_dtype(result.type.element_type))
return Vector(result, shape, result_dtype)
if isinstance(result, Numeric):
return result
if isinstance(result, ir.Value):
return self._result_dtype(result.type)(result)
return result_dtype(result)
return result

def _result_dtype(self, elem_type) -> Type[Numeric]:
if self._dtype.ir_type == elem_type:
return self._dtype
return Numeric.from_ir_type(elem_type)

def _apply_op(self, method_name, op, other, flip=False):
lhs = self
rhs = other
shape = self.shape
promoted_dtype = self.dtype
if isinstance(other, Vector):
shape = self._infer_broadcast_shape(self.shape, other.shape)
lhs = self.broadcast_to(shape)
rhs = other.broadcast_to(shape)
else:
rhs_numeric = _try_coerce_rhs(rhs)
if rhs_numeric is None:
return NotImplemented
rhs = rhs_numeric
lhs, rhs, promoted_dtype = _coerce_operands(lhs, rhs, op=op)
method = getattr(ArithValue, method_name)
if flip:
if isinstance(rhs, Vector):
Expand All @@ -1624,7 +1629,8 @@ def _apply_op(self, method_name, op, other, flip=False):
result = getattr(ArithValue, reverse_name)(lhs, rhs)
else:
result = method(lhs, rhs)
return self._wrap_op_result(result, shape)
result_dtype = _result_type_for_op(op, promoted_dtype)
return self._wrap_op_result(result, shape, result_dtype)

def apply_op(self, op, other, flip=False):
method_name = _VECTOR_OP_METHODS.get(op)
Expand Down
Loading
Loading