diff --git a/src/IRMatch.h b/src/IRMatch.h index 848172d435d1..d7d9b1281362 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -83,10 +83,6 @@ struct MatcherState { const BaseExprNode *bindings[max_wild]; halide_scalar_value_t bound_const[max_wild]; - // values of the lanes field with special meaning. - static constexpr uint16_t signed_integer_overflow = 0x8000; - static constexpr uint16_t special_values_mask = 0x8000; // currently only one - halide_type_t bound_const_type[max_wild]; HALIDE_ALWAYS_INLINE @@ -146,23 +142,9 @@ struct bindings { constexpr static uint32_t mask = std::remove_reference::type::binds; }; -inline HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty) { - const uint16_t flags = ty.lanes & MatcherState::special_values_mask; - ty.lanes &= ~MatcherState::special_values_mask; - if (flags & MatcherState::signed_integer_overflow) { - return make_signed_integer_overflow(ty); - } - // unreachable - return Expr(); -} - HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty) { halide_type_t scalar_type = ty; - if (scalar_type.lanes & MatcherState::special_values_mask) { - return make_const_special_expr(scalar_type); - } - const int lanes = scalar_type.lanes; scalar_type.lanes = 1; @@ -273,9 +255,9 @@ struct WildConstInt { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { state.get_bound_const(i, val, ty); + return false; } }; @@ -326,9 +308,9 @@ struct WildConstUInt { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { state.get_bound_const(i, val, ty); + return false; } }; @@ -379,9 +361,9 @@ struct WildConstFloat { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { state.get_bound_const(i, val, ty); + return false; } }; @@ -437,9 +419,9 @@ struct WildConst { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { state.get_bound_const(i, val, ty); + return false; } }; @@ -540,8 +522,7 @@ struct IntLiteral { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { // Assume type is already correct switch (ty.code) { case halide_type_int: @@ -558,6 +539,7 @@ struct IntLiteral { // Unreachable ; } + return false; } }; @@ -609,13 +591,13 @@ inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) { } template -int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept; +int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t, bool &) noexcept; template -uint64_t constant_fold_bin_op(halide_type_t &, uint64_t, uint64_t) noexcept; +uint64_t constant_fold_bin_op(halide_type_t &, uint64_t, uint64_t, bool &) noexcept; template -double constant_fold_bin_op(halide_type_t &, double, double) noexcept; +double constant_fold_bin_op(halide_type_t &, double, double, bool &) noexcept; constexpr bool commutative(IRNodeType t) { return (t == IRNodeType::Add || @@ -665,47 +647,44 @@ struct BinOp { constexpr static bool foldable = A::foldable && B::foldable; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + bool overflow = false; halide_scalar_value_t val_a, val_b; if (std::is_same_v) { - b.make_folded_const(val_b, ty, state); + overflow |= b.make_folded_const(val_b, ty, state); if ((std::is_same_v && val_b.u.u64 == 0) || (std::is_same_v && val_b.u.u64 == 1)) { // Short circuit val = val_b; - return; + return overflow; } - const uint16_t l = ty.lanes; - a.make_folded_const(val_a, ty, state); - ty.lanes |= l; // Make sure the overflow bits are sticky + overflow |= a.make_folded_const(val_a, ty, state); } else { - a.make_folded_const(val_a, ty, state); + overflow |= a.make_folded_const(val_a, ty, state); if ((std::is_same_v && val_a.u.u64 == 0) || (std::is_same_v && val_a.u.u64 == 1)) { // Short circuit val = val_a; - return; + return overflow; } - const uint16_t l = ty.lanes; - b.make_folded_const(val_b, ty, state); - ty.lanes |= l; + overflow |= b.make_folded_const(val_b, ty, state); } switch (ty.code) { case halide_type_int: - val.u.i64 = constant_fold_bin_op(ty, val_a.u.i64, val_b.u.i64); + val.u.i64 = constant_fold_bin_op(ty, val_a.u.i64, val_b.u.i64, overflow); break; case halide_type_uint: - val.u.u64 = constant_fold_bin_op(ty, val_a.u.u64, val_b.u.u64); + val.u.u64 = constant_fold_bin_op(ty, val_a.u.u64, val_b.u.u64, overflow); break; case halide_type_float: case halide_type_bfloat: - val.u.f64 = constant_fold_bin_op(ty, val_a.u.f64, val_b.u.f64); + val.u.f64 = constant_fold_bin_op(ty, val_a.u.f64, val_b.u.f64, overflow); break; default: // unreachable ; } + return overflow; } HALIDE_ALWAYS_INLINE @@ -767,20 +746,16 @@ struct CmpOp { constexpr static bool foldable = A::foldable && B::foldable; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + bool overflow = false; halide_scalar_value_t val_a, val_b; // If one side is an untyped const, evaluate the other side first to get a type hint. if (std::is_same_v) { - b.make_folded_const(val_b, ty, state); - const uint16_t l = ty.lanes; - a.make_folded_const(val_a, ty, state); - ty.lanes |= l; + overflow |= b.make_folded_const(val_b, ty, state); + overflow |= a.make_folded_const(val_a, ty, state); } else { - a.make_folded_const(val_a, ty, state); - const uint16_t l = ty.lanes; - b.make_folded_const(val_b, ty, state); - ty.lanes |= l; + overflow |= a.make_folded_const(val_a, ty, state); + overflow |= b.make_folded_const(val_b, ty, state); } switch (ty.code) { @@ -800,6 +775,7 @@ struct CmpOp { } ty.code = halide_type_uint; ty.bits = 1; + return overflow; } HALIDE_ALWAYS_INLINE @@ -922,21 +898,21 @@ HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, } template<> -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { - t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0; +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { + overflow |= (t.bits >= 32) && add_would_overflow(t.bits, a, b); int dead_bits = 64 - t.bits; // Drop the high bits then sign-extend them back return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits; } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { uint64_t ones = (uint64_t)(-1); return (a + b) & (ones >> (64 - t.bits)); } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b, bool &overflow) noexcept { return a + b; } @@ -955,21 +931,21 @@ HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, } template<> -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { - t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0; +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { + overflow |= (t.bits >= 32) && sub_would_overflow(t.bits, a, b); // Drop the high bits then sign-extend them back int dead_bits = 64 - t.bits; return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits; } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { uint64_t ones = (uint64_t)(-1); return (a - b) & (ones >> (64 - t.bits)); } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b, bool &overflow) noexcept { return a - b; } @@ -988,21 +964,21 @@ HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, } template<> -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { - t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0; +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { + overflow |= (t.bits >= 32) && mul_would_overflow(t.bits, a, b); int dead_bits = 64 - t.bits; // Drop the high bits then sign-extend them back return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits; } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { uint64_t ones = (uint64_t)(-1); return (a * b) & (ones >> (64 - t.bits)); } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b, bool &overflow) noexcept { return a * b; } @@ -1019,17 +995,17 @@ HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, } template<> -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op
(halide_type_t &t, int64_t a, int64_t b) noexcept { +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op
(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { return div_imp(a, b); } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op
(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op
(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { return div_imp(a, b); } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op
(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op
(halide_type_t &t, double a, double b, bool &overflow) noexcept { return div_imp(a, b); } @@ -1048,17 +1024,17 @@ HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, } template<> -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { return mod_imp(a, b); } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { return mod_imp(a, b); } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b, bool &overflow) noexcept { return mod_imp(a, b); } @@ -1070,17 +1046,17 @@ HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { return std::min(a, b); } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { return std::min(a, b); } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b, bool &overflow) noexcept { return std::min(a, b); } @@ -1092,17 +1068,17 @@ HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { return std::max(a, b); } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { return std::max(a, b); } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b, bool &overflow) noexcept { return std::max(a, b); } @@ -1267,17 +1243,17 @@ HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||( } template<> -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { return (a | b) & 1; } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { return (a | b) & 1; } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b, bool &overflow) noexcept { // Unreachable, as it would be a type mismatch. return 0; } @@ -1293,17 +1269,17 @@ HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&& } template<> -HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { +HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b, bool &overflow) noexcept { return a & b & 1; } template<> -HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { +HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b, bool &overflow) noexcept { return a & b & 1; } template<> -HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { +HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b, bool &overflow) noexcept { // Unreachable return 0; } @@ -1462,18 +1438,19 @@ struct Intrin { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + bool overflow = false; halide_scalar_value_t arg1; // Assuming the args have the same type as the intrinsic is incorrect in // general. But for the intrinsics we can fold (just shifts), the LHS // has the same type as the intrinsic, and we can always treat the RHS // as a signed int, because we're using 64 bits for it. - std::get<0>(args).make_folded_const(val, ty, state); + overflow |= std::get<0>(args).make_folded_const(val, ty, state); halide_type_t signed_ty = ty; signed_ty.code = halide_type_int; // We can just directly get the second arg here, because we only want to // instantiate this method for shifts, which have two args. - std::get<1>(args).make_folded_const(arg1, signed_ty, state); + overflow |= std::get<1>(args).make_folded_const(arg1, signed_ty, state); if (intrin == Call::shift_left) { if (arg1.u.i64 < 0) { @@ -1502,6 +1479,7 @@ struct Intrin { } else { internal_error << "Folding not implemented for intrinsic: " << intrin; } + return overflow; } HALIDE_ALWAYS_INLINE @@ -1647,10 +1625,11 @@ struct NotOp { constexpr static bool foldable = A::foldable; template - HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { - a.make_folded_const(val, ty, state); + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + bool overflow = a.make_folded_const(val, ty, state); val.u.u64 = ~val.u.u64; val.u.u64 &= 1; + return overflow; } }; @@ -1771,16 +1750,17 @@ struct SelectOp { constexpr static bool foldable = C::foldable && T::foldable && F::foldable; template - HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + bool overflow = false; halide_scalar_value_t c_val, t_val, f_val; halide_type_t c_ty; - c.make_folded_const(c_val, c_ty, state); + overflow |= c.make_folded_const(c_val, c_ty, state); if ((c_val.u.u64 & 1) == 1) { - t.make_folded_const(val, ty, state); + overflow |= t.make_folded_const(val, ty, state); } else { - f.make_folded_const(val, ty, state); + overflow |= f.make_folded_const(val, ty, state); } - ty.lanes |= c_ty.lanes & MatcherState::special_values_mask; + return overflow; } }; @@ -1833,7 +1813,9 @@ struct BroadcastOp { Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t lanes_val; halide_type_t ty; - lanes.make_folded_const(lanes_val, ty, state); + bool overflow = false; + overflow |= lanes.make_folded_const(lanes_val, ty, state); + internal_assert(!overflow) << "Overflow occurred computing the lanes field of a Broadcast node"; int32_t l = (int32_t)lanes_val.u.i64; type_hint.lanes /= l; Expr val = a.make(state, type_hint); @@ -1847,13 +1829,15 @@ struct BroadcastOp { constexpr static bool foldable = false; template - HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + bool overflow = false; halide_scalar_value_t lanes_val; halide_type_t lanes_ty; - lanes.make_folded_const(lanes_val, lanes_ty, state); + overflow |= lanes.make_folded_const(lanes_val, lanes_ty, state); uint16_t l = (uint16_t)lanes_val.u.i64; - a.make_folded_const(val, ty, state); - ty.lanes = l | (ty.lanes & MatcherState::special_values_mask); + overflow |= a.make_folded_const(val, ty, state); + ty.lanes = l; + return overflow; } }; @@ -1909,7 +1893,10 @@ struct RampOp { Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t lanes_val; halide_type_t ty; - lanes.make_folded_const(lanes_val, ty, state); + bool overflow = false; + overflow |= lanes.make_folded_const(lanes_val, ty, state); + internal_assert(!overflow) + << "Overflow occurred computing the lanes field of a Ramp node in a rewriter rule."; int32_t l = (int32_t)lanes_val.u.i64; type_hint.lanes /= l; Expr ea, eb; @@ -1971,7 +1958,10 @@ struct VectorReduceOp { Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t lanes_val; halide_type_t ty; - lanes.make_folded_const(lanes_val, ty, state); + bool overflow = false; + overflow |= lanes.make_folded_const(lanes_val, ty, state); + internal_assert(!overflow) + << "Overflow occurred computing the lanes of a VectorReduce node in a rewriter rule."; int l = (int)lanes_val.u.i64; return VectorReduce::make(reduce_op, a.make(state, type_hint), l); } @@ -2065,14 +2055,14 @@ struct NegateOp { constexpr static bool foldable = A::foldable; template - HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { - a.make_folded_const(val, ty, state); + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + bool overflow = a.make_folded_const(val, ty, state); int dead_bits = 64 - ty.bits; switch (ty.code) { case halide_type_int: if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) { // Trying to negate the most negative signed int for a no-overflow type. - ty.lanes |= MatcherState::signed_integer_overflow; + overflow = true; } else { // Negate, drop the high bits, and then sign-extend them back val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits; @@ -2089,6 +2079,7 @@ struct NegateOp { // unreachable ; } + return overflow; } }; @@ -2235,12 +2226,16 @@ struct SliceOp { Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t base_val, stride_val, lanes_val; halide_type_t ty; - base.make_folded_const(base_val, ty, state); + bool overflow = false; + overflow |= base.make_folded_const(base_val, ty, state); int b = (int)base_val.u.i64; - stride.make_folded_const(stride_val, ty, state); + overflow |= stride.make_folded_const(stride_val, ty, state); int s = (int)stride_val.u.i64; - lanes.make_folded_const(lanes_val, ty, state); + overflow |= lanes.make_folded_const(lanes_val, ty, state); int l = (int)lanes_val.u.i64; + internal_assert(!overflow) + << "Overflow occurred computing the parameters of a slice operation in a rewriter rule."; + return Shuffle::make_slice(vec.make(state, type_hint), b, s, l); } @@ -2295,7 +2290,10 @@ struct TransposeOp { Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t factor_val; halide_type_t ty; - factor.make_folded_const(factor_val, ty, state); + bool overflow = false; + overflow |= factor.make_folded_const(factor_val, ty, state); + internal_assert(!overflow) + << "Overflow occurred computing the parameters of a transpose operation in a rewriter rule."; int f = (int)factor_val.u.i64; return Shuffle::make_transpose(vec.make(state, type_hint), f); } @@ -2336,7 +2334,11 @@ struct Fold { Expr make(MatcherState &state, halide_type_t type_hint) const noexcept { halide_scalar_value_t c; halide_type_t ty = type_hint; - a.make_folded_const(c, ty, state); + bool overflow = false; + overflow |= a.make_folded_const(c, ty, state); + if (overflow) { + return make_signed_integer_overflow(ty); + } // The result of the fold may have an underspecified type // (e.g. because it's from an int literal). Make the type code @@ -2358,8 +2360,8 @@ struct Fold { constexpr static bool foldable = A::foldable; template - HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { - a.make_folded_const(val, ty, state); + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + return a.make_folded_const(val, ty, state); } }; @@ -2391,12 +2393,16 @@ struct Overflows { constexpr static bool foldable = A::foldable; template - HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { - a.make_folded_const(val, ty, state); + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + // Run the inner fold and consume its overflow flag (the whole + // point of this predicate is to ask about it without letting it + // taint the surrounding evaluation). + const bool inner_overflow = a.make_folded_const(val, ty, state); + val.u.u64 = inner_overflow ? 1 : 0; ty.code = halide_type_uint; ty.bits = 64; - val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0; ty.lanes = 1; + return false; } }; @@ -2433,16 +2439,14 @@ struct Overflow { HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { - type_hint.lanes |= MatcherState::signed_integer_overflow; - return make_const_special_expr(type_hint); + return make_signed_integer_overflow(type_hint); } constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { val.u.u64 = 0; - ty.lanes |= MatcherState::signed_integer_overflow; + return true; } }; @@ -2469,7 +2473,7 @@ struct IsConst { constexpr static bool foldable = true; template - HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { Expr e = a.make(state, {}); ty.code = halide_type_uint; ty.bits = 64; @@ -2479,6 +2483,7 @@ struct IsConst { } else { val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0; } + return false; } }; @@ -2520,13 +2525,14 @@ struct CanProve { constexpr static bool foldable = true; // Includes a raw call to an inlined make method, so don't inline. - HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_NEVER_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { Expr condition = a.make(state, {}); condition = prover->mutate(condition, nullptr); val.u.u64 = is_const_one(condition); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = condition.type().lanes(); + return false; } }; @@ -2556,14 +2562,14 @@ struct IsFloat { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.is_float(); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); + return false; } }; @@ -2595,14 +2601,14 @@ struct IsInt { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); + return false; } }; @@ -2641,14 +2647,14 @@ struct IsUInt { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); + return false; } }; @@ -2685,14 +2691,14 @@ struct IsScalar { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.is_scalar(); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); + return false; } }; @@ -2722,10 +2728,9 @@ struct IsMaxValue { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. - a.make_folded_const(val, ty, state); + const bool overflow = a.make_folded_const(val, ty, state); const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int)); if (ty.code == halide_type_uint || ty.code == halide_type_int) { val.u.u64 = (val.u.u64 == max_bits); @@ -2734,6 +2739,7 @@ struct IsMaxValue { } ty.code = halide_type_uint; ty.bits = 1; + return overflow; } }; @@ -2763,10 +2769,9 @@ struct IsMinValue { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. - a.make_folded_const(val, ty, state); + const bool overflow = a.make_folded_const(val, ty, state); if (ty.code == halide_type_int) { const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1); val.u.u64 = (val.u.u64 == min_bits); @@ -2777,6 +2782,7 @@ struct IsMinValue { } ty.code = halide_type_uint; ty.bits = 1; + return overflow; } }; @@ -2806,14 +2812,14 @@ struct LanesOf { constexpr static bool foldable = true; - HALIDE_ALWAYS_INLINE - void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + [[nodiscard]] HALIDE_ALWAYS_INLINE bool make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.lanes(); ty.code = halide_type_uint; ty.bits = 32; ty.lanes = 1; + return false; } }; @@ -2858,6 +2864,7 @@ HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicat MatcherState state; Expr exprs[max_wild]; + bool overflow = false; // for the constant_fold_bin_op normalizations below for (int trials = 0; trials < 100; trials++) { // We want to test small constants more frequently than @@ -2870,16 +2877,16 @@ HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicat switch (wildcard_type.code) { case halide_type_uint: { // Normalize to the type's range by adding zero - uint64_t val = constant_fold_bin_op(wildcard_type, (uint64_t)rng() >> shift, 0); + uint64_t val = constant_fold_bin_op(wildcard_type, (uint64_t)rng() >> shift, (uint64_t)0, overflow); state.set_bound_const(i, val, wildcard_type); - val = constant_fold_bin_op(wildcard_type, (uint64_t)rng() >> shift, 0); + val = constant_fold_bin_op(wildcard_type, (uint64_t)rng() >> shift, (uint64_t)0, overflow); exprs[i] = make_const(wildcard_type, val); state.set_binding(i, *exprs[i].get()); } break; case halide_type_int: { - int64_t val = constant_fold_bin_op(wildcard_type, (int64_t)rng() >> shift, 0); + int64_t val = constant_fold_bin_op(wildcard_type, (int64_t)rng() >> shift, (int64_t)0, overflow); state.set_bound_const(i, val, wildcard_type); - val = constant_fold_bin_op(wildcard_type, (int64_t)rng() >> shift, 0); + val = constant_fold_bin_op(wildcard_type, (int64_t)rng() >> shift, (int64_t)0, overflow); exprs[i] = make_const(wildcard_type, val); } break; case halide_type_float: @@ -2903,12 +2910,9 @@ HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicat if (!evaluate_predicate(pred, state)) { continue; } - before.make_folded_const(val_before, type, state); - uint16_t lanes = type.lanes; - after.make_folded_const(val_after, type, state); - lanes |= type.lanes; - - if (lanes & MatcherState::special_values_mask) { + bool fold_overflow = before.make_folded_const(val_before, type, state); + fold_overflow |= after.make_folded_const(val_after, type, state); + if (fold_overflow) { continue; } @@ -2916,12 +2920,12 @@ HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicat switch (output_type.code) { case halide_type_uint: // Compare normalized representations - ok &= (constant_fold_bin_op(output_type, val_before.u.u64, 0) == - constant_fold_bin_op(output_type, val_after.u.u64, 0)); + ok &= (constant_fold_bin_op(output_type, val_before.u.u64, (uint64_t)0, overflow) == + constant_fold_bin_op(output_type, val_after.u.u64, (uint64_t)0, overflow)); break; case halide_type_int: - ok &= (constant_fold_bin_op(output_type, val_before.u.i64, 0) == - constant_fold_bin_op(output_type, val_after.u.i64, 0)); + ok &= (constant_fold_bin_op(output_type, val_before.u.i64, (int64_t)0, overflow) == + constant_fold_bin_op(output_type, val_after.u.i64, (int64_t)0, overflow)); break; case halide_type_float: case halide_type_bfloat: { @@ -2975,9 +2979,9 @@ template(); - p.make_folded_const(c, ty, state); - // Overflow counts as a failed predicate - return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0); + const bool overflow = p.make_folded_const(c, ty, state); + // Overflow counts as a failed predicate. + return (c.u.u64 != 0) && !overflow; } // #defines for testing