diff --git a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h index 8cd358dfe..1ae22c01d 100644 --- a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h @@ -846,17 +846,12 @@ namespace detail { template std::pair intTupleZip2ByImpl(const IntTupleBuilder &builder, - IntTuple t, IntTupleAttr guide) { + IntTuple t, IntTupleAttr guide, int noneValue) { using Collector = typename IntTupleBuilder::ElemCollector; if (guide.isLeaf()) { assert(t.rank() == 2 && "intTupleZip2By expects rank-2 tuple at terminal"); return {builder.at(t, 0), builder.at(t, 1)}; } - // Canonicalize singleton guide wrappers so 1D profiles behave as leaf guides. - // This keeps zip2By robust after singleton unwrapping in product/divide type canonicalization. - if (guide.rank() == 1) { - return intTupleZip2ByImpl(builder, t, guide.at(0)); - } Collector firsts; Collector seconds; @@ -864,7 +859,14 @@ std::pair intTupleZip2ByImpl(const IntTupleBuilder int32_t tRank = t.rank(); assert(tRank >= guideRank && "Mismatched ranks in intTupleZip2By"); for (int i = 0; i < guideRank; ++i) { - auto [first, second] = intTupleZip2ByImpl(builder, builder.at(t, i), guide.at(i)); + if (guide.at(i).isLeafNone()) { + // i'th guide is None, implies view i'th mode s:d as (1,s):(0,d) for zip + // here first is either 1 or 0 depending on whether it's shape or stride + firsts.push_back(builder.materializeConstantLeaf(noneValue)); + seconds.push_back(builder.at(t, i)); + continue; + } + auto [first, second] = intTupleZip2ByImpl(builder, builder.at(t, i), guide.at(i), noneValue); firsts.push_back(first); seconds.push_back(second); } @@ -877,13 +879,14 @@ std::pair intTupleZip2ByImpl(const IntTupleBuilder } // namespace detail template -IntTuple intTupleZip2By(const IntTupleBuilder &builder, IntTuple t, IntTupleAttr guide) { +IntTuple intTupleZip2By(const IntTupleBuilder &builder, IntTuple t, IntTupleAttr guide, + int noneValue = 0) { if (guide.isLeaf()) { assert(t.rank() == 2 && "intTupleZip2By expects rank-2 tuple at terminal"); return t; } else { using Collector = typename IntTupleBuilder::ElemCollector; - auto [first, second] = detail::intTupleZip2ByImpl(builder, t, guide); + auto [first, second] = detail::intTupleZip2ByImpl(builder, t, guide, noneValue); Collector collector; collector.push_back(first); collector.push_back(second); diff --git a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h index d1148aa82..b04c5d9f5 100644 --- a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h @@ -1365,11 +1365,15 @@ Layout layoutZippedDivide(LayoutBuilder &builder, Layout layout, TileAtt SmallVector guideElems; for (int i = 0; i < divisorTile.rank(); ++i) { - guideElems.push_back(IntTupleAttr::getLeafNone(ctx)); + if (!divisorTile.isNoneMode(i)) { + guideElems.push_back(IntTupleAttr::getLeafStatic(ctx, 1)); + } else { + guideElems.push_back(IntTupleAttr::getLeafNone(ctx)); + } } IntTupleAttr guide = IntTupleAttr::get(ArrayAttr::get(ctx, guideElems)); - IntTuple retShape = intTupleZip2By(builder, builder.getShape(logicalDiv), guide); - IntTuple retStride = intTupleZip2By(builder, builder.getStride(logicalDiv), guide); + IntTuple retShape = intTupleZip2By(builder, builder.getShape(logicalDiv), guide, 1); + IntTuple retStride = intTupleZip2By(builder, builder.getStride(logicalDiv), guide, 0); return builder.makeLayout(retShape, retStride); } diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index a25b07c3d..329547971 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -910,28 +910,28 @@ def left_inverse(layout): @dsl_loc_tracing def logical_divide(layout, divisor): if not isinstance(divisor, ir.Value): - divisor = make_tile(*divisor) + divisor = _make_tile_or_layout(divisor) return fly.logical_divide(layout, divisor) @dsl_loc_tracing def zipped_divide(layout, divisor): if not isinstance(divisor, ir.Value): - divisor = make_tile(*divisor) + divisor = _make_tile_or_layout(divisor) return fly.zipped_divide(layout, divisor) @dsl_loc_tracing def tiled_divide(layout, divisor): if not isinstance(divisor, ir.Value): - divisor = make_tile(*divisor) + divisor = _make_tile_or_layout(divisor) return fly.tiled_divide(layout, divisor) @dsl_loc_tracing def flat_divide(layout, divisor): if not isinstance(divisor, ir.Value): - divisor = make_tile(*divisor) + divisor = _make_tile_or_layout(divisor) return fly.flat_divide(layout, divisor) @@ -1366,8 +1366,15 @@ def _resolve(m): raise ValueError(f"make_tile: expected int, None, tuple, or Layout, got {type(m)}") resolved = [_resolve(m) for m in args] - if len(resolved) == 1: - tile_type = TileType.get(resolved[0]) - else: - tile_type = TileType.get(resolved) + tile_type = TileType.get(resolved) return static(tile_type) + + +def _make_tile_or_layout(arg): + if isinstance(arg, ir.Value): + return arg + if isinstance(arg, int): + return make_layout(arg, 1) + if isinstance(arg, (tuple, list)): + return make_tile(*arg) + raise ValueError(f"_make_tile_or_layout: expected int, tuple/list, or Layout, got {type(arg)}") diff --git a/tests/unit/test_layout_algebra.py b/tests/unit/test_layout_algebra.py index df4cbbf06..4d13f7bfb 100644 --- a/tests/unit/test_layout_algebra.py +++ b/tests/unit/test_layout_algebra.py @@ -323,6 +323,130 @@ def build(): build() +def generate_pycute_divide_tests(): + """generate pycute and flydsl divide tests for various shapes, strides, and tiles""" + import os + + import pycute + + import flydsl.compiler as flyc + import flydsl.expr as fx + + os.environ.setdefault("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + + def is_same(a, b): + return a.shape.to_py_value() == b.shape and a.stride.to_py_value() == b.stride + + def test_div(shape, stride, tile): + @flyc.jit + def test(): + a0 = pycute.layout.Layout(shape, stride) + a1 = fx.make_layout(shape, stride) + print(f"input & tile: {a0} {tile}") + + ref = pycute.logical_divide(a0, tile) + ret = fx.logical_divide(a1, tile) + assert is_same(ret, ref), f"fx.logical_divide(a, {tile}) = {ret}, expected {ref}" + print("\t logical_divide: ", ret, f"fx.make_layout({ref.shape},{ref.stride})") + + ref = pycute.zipped_divide(a0, tile) + ret = fx.zipped_divide(a1, tile) + assert is_same(ret, ref), f"fx.zipped_divide(a, {tile}) = {ret}, expected {ref}" + print("\t zipped_divide: ", ret, f"fx.make_layout({ref.shape},{ref.stride})") + + ref = pycute.tiled_divide(a0, tile) + ret = fx.tiled_divide(a1, tile) + assert is_same(ret, ref), f"fx.tiled_divide(a, {tile}) = {ret}, expected {ref}" + print("\t tiled_divide: ", ret, f"fx.make_layout({ref.shape},{ref.stride})") + + print("\t flat_divide: ", fx.flat_divide(a1, tile)) + + test() + + for tile in [32, (32,)]: + test_div(64, 1, tile) + + for tile in [32, (32,), (32, None), (None, 25)]: + test_div((64, 50), (100, 1), tile) + + for tile in [32, (32,), (32, None, None), (32, None, 40)]: + test_div((64, 50, 80), (16000, 160, 1), tile) + + +def test_divide_bymode(frontend_only_jit): + """these reference answers are generated by generate_pycute_divide_tests()""" + + @flyc.jit + def build(): + a = fx.make_layout(64, 1) + assert str(fx.logical_divide(a, 32)) == str(fx.make_layout((32, 2), (1, 32))) + assert str(fx.zipped_divide(a, 32)) == str(fx.make_layout((32, 2), (1, 32))) + assert str(fx.tiled_divide(a, 32)) == str(fx.make_layout((32, 2), (1, 32))) + assert str(fx.flat_divide(a, 32)) == str(fx.make_layout((32, 2), (1, 32))) + + a = fx.make_layout((64, 50), (100, 1)) + assert str(fx.logical_divide(a, 32)) == str(fx.make_layout((32, (2, 50)), (100, (3200, 1)))) + assert str(fx.zipped_divide(a, 32)) == str(fx.make_layout((32, (2, 50)), (100, (3200, 1)))) + assert str(fx.tiled_divide(a, 32)) == str(fx.make_layout((32, 2, 50), (100, 3200, 1))) + assert str(fx.flat_divide(a, 32)) == str(fx.make_layout((32, 2, 50), (100, 3200, 1))) + + assert str(fx.logical_divide(a, (32,))) == str(fx.make_layout(((32, 2), 50), ((100, 3200), 1))) + assert str(fx.zipped_divide(a, (32,))) == str(fx.make_layout(((32,), (2, 50)), ((100,), (3200, 1)))) + assert str(fx.tiled_divide(a, (32,))) == str(fx.make_layout(((32,), 2, 50), ((100,), 3200, 1))) + assert str(fx.flat_divide(a, (32,))) == str(fx.make_layout((32, 2, 50), (100, 3200, 1))) + + assert str(fx.logical_divide(a, (32, None))) == str(fx.make_layout(((32, 2), 50), ((100, 3200), 1))) + assert str(fx.zipped_divide(a, (32, None))) == str(fx.make_layout(((32, 1), (2, 50)), ((100, 0), (3200, 1)))) + assert str(fx.tiled_divide(a, (32, None))) == str(fx.make_layout(((32, 1), 2, 50), ((100, 0), 3200, 1))) + assert str(fx.flat_divide(a, (32, None))) == str(fx.make_layout((32, 1, 2, 50), (100, 0, 3200, 1))) + + assert str(fx.logical_divide(a, (None, 25))) == str(fx.make_layout((64, (25, 2)), (100, (1, 25)))) + assert str(fx.zipped_divide(a, (None, 25))) == str(fx.make_layout(((1, 25), (64, 2)), ((0, 1), (100, 25)))) + assert str(fx.tiled_divide(a, (None, 25))) == str(fx.make_layout(((1, 25), 64, 2), ((0, 1), 100, 25))) + assert str(fx.flat_divide(a, (None, 25))) == str(fx.make_layout((1, 25, 64, 2), (0, 1, 100, 25))) + + a = fx.make_layout((64, 50, 80), (16000, 160, 1)) + assert str(fx.logical_divide(a, 32)) == str(fx.make_layout((32, (2, 50, 80)), (16000, (512000, 160, 1)))) + assert str(fx.zipped_divide(a, 32)) == str(fx.make_layout((32, (2, 50, 80)), (16000, (512000, 160, 1)))) + assert str(fx.tiled_divide(a, 32)) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1))) + assert str(fx.flat_divide(a, 32)) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1))) + + assert str(fx.logical_divide(a, (32,))) == str(fx.make_layout(((32, 2), 50, 80), ((16000, 512000), 160, 1))) + assert str(fx.zipped_divide(a, (32,))) == str( + fx.make_layout(((32,), (2, 50, 80)), ((16000,), (512000, 160, 1))) + ) + assert str(fx.tiled_divide(a, (32,))) == str(fx.make_layout(((32,), 2, 50, 80), ((16000,), 512000, 160, 1))) + assert str(fx.flat_divide(a, (32,))) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1))) + + assert str(fx.logical_divide(a, (32, None, None))) == str( + fx.make_layout(((32, 2), 50, 80), ((16000, 512000), 160, 1)) + ) + assert str(fx.zipped_divide(a, (32, None, None))) == str( + fx.make_layout(((32, 1, 1), (2, 50, 80)), ((16000, 0, 0), (512000, 160, 1))) + ) + assert str(fx.tiled_divide(a, (32, None, None))) == str( + fx.make_layout(((32, 1, 1), 2, 50, 80), ((16000, 0, 0), 512000, 160, 1)) + ) + assert str(fx.flat_divide(a, (32, None, None))) == str( + fx.make_layout((32, 1, 1, 2, 50, 80), (16000, 0, 0, 512000, 160, 1)) + ) + + assert str(fx.logical_divide(a, (32, None, 40))) == str( + fx.make_layout(((32, 2), 50, (40, 2)), ((16000, 512000), 160, (1, 40))) + ) + assert str(fx.zipped_divide(a, (32, None, 40))) == str( + fx.make_layout(((32, 1, 40), (2, 50, 2)), ((16000, 0, 1), (512000, 160, 40))) + ) + assert str(fx.tiled_divide(a, (32, None, 40))) == str( + fx.make_layout(((32, 1, 40), 2, 50, 2), ((16000, 0, 1), 512000, 160, 40)) + ) + assert str(fx.flat_divide(a, (32, None, 40))) == str( + fx.make_layout((32, 1, 40, 2, 50, 2), (16000, 0, 1, 512000, 160, 40)) + ) + + build() + + # ============================================================================== # 6. Product Operations (Cells 25, 27, 29) # ==============================================================================