Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -846,25 +846,27 @@ namespace detail {

template <class IntTuple>
std::pair<IntTuple, IntTuple> intTupleZip2ByImpl(const IntTupleBuilder<IntTuple> &builder,
IntTuple t, IntTupleAttr guide) {
IntTuple t, IntTupleAttr guide, int noneValue) {
using Collector = typename IntTupleBuilder<IntTuple>::ElemCollector;
if (guide.isLeaf()) {
assert(t.rank() == 2 && "intTupleZip2By expects rank-2 tuple at terminal");
return {builder.at(t, 0), builder.at(t, 1)};
}
// Canonicalize singleton guide wrappers so 1D profiles behave as leaf guides.
// This keeps zip2By robust after singleton unwrapping in product/divide type canonicalization.
if (guide.rank() == 1) {
return intTupleZip2ByImpl(builder, t, guide.at(0));
}
Collector firsts;
Collector seconds;

int32_t guideRank = guide.rank();
int32_t tRank = t.rank();
assert(tRank >= guideRank && "Mismatched ranks in intTupleZip2By");
for (int i = 0; i < guideRank; ++i) {
auto [first, second] = intTupleZip2ByImpl(builder, builder.at(t, i), guide.at(i));
if (guide.at(i).isLeafNone()) {
// i'th guide is None, implies view i'th mode s:d as (1,s):(0,d) for zip
// here first is either 1 or 0 depending on whether it's shape or stride
firsts.push_back(builder.materializeConstantLeaf(noneValue));
seconds.push_back(builder.at(t, i));
continue;
}
auto [first, second] = intTupleZip2ByImpl(builder, builder.at(t, i), guide.at(i), noneValue);
firsts.push_back(first);
seconds.push_back(second);
}
Expand All @@ -877,13 +879,14 @@ std::pair<IntTuple, IntTuple> intTupleZip2ByImpl(const IntTupleBuilder<IntTuple>
} // namespace detail

template <class IntTuple>
IntTuple intTupleZip2By(const IntTupleBuilder<IntTuple> &builder, IntTuple t, IntTupleAttr guide) {
IntTuple intTupleZip2By(const IntTupleBuilder<IntTuple> &builder, IntTuple t, IntTupleAttr guide,
int noneValue = 0) {
if (guide.isLeaf()) {
assert(t.rank() == 2 && "intTupleZip2By expects rank-2 tuple at terminal");
return t;
} else {
using Collector = typename IntTupleBuilder<IntTuple>::ElemCollector;
auto [first, second] = detail::intTupleZip2ByImpl(builder, t, guide);
auto [first, second] = detail::intTupleZip2ByImpl(builder, t, guide, noneValue);
Collector collector;
collector.push_back(first);
collector.push_back(second);
Expand Down
10 changes: 7 additions & 3 deletions include/flydsl/Dialect/Fly/Utils/LayoutUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1365,11 +1365,15 @@ Layout layoutZippedDivide(LayoutBuilder<Layout> &builder, Layout layout, TileAtt

SmallVector<Attribute> guideElems;
for (int i = 0; i < divisorTile.rank(); ++i) {
guideElems.push_back(IntTupleAttr::getLeafNone(ctx));
if (!divisorTile.isNoneMode(i)) {
guideElems.push_back(IntTupleAttr::getLeafStatic(ctx, 1));
} else {
guideElems.push_back(IntTupleAttr::getLeafNone(ctx));
}
}
IntTupleAttr guide = IntTupleAttr::get(ArrayAttr::get(ctx, guideElems));
IntTuple retShape = intTupleZip2By(builder, builder.getShape(logicalDiv), guide);
IntTuple retStride = intTupleZip2By(builder, builder.getStride(logicalDiv), guide);
IntTuple retShape = intTupleZip2By(builder, builder.getShape(logicalDiv), guide, 1);
IntTuple retStride = intTupleZip2By(builder, builder.getStride(logicalDiv), guide, 0);
return builder.makeLayout(retShape, retStride);
}

Expand Down
23 changes: 15 additions & 8 deletions python/flydsl/expr/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,28 +910,28 @@ def left_inverse(layout):
@dsl_loc_tracing
def logical_divide(layout, divisor):
if not isinstance(divisor, ir.Value):
divisor = make_tile(*divisor)
divisor = _make_tile_or_layout(divisor)
return fly.logical_divide(layout, divisor)


@dsl_loc_tracing
def zipped_divide(layout, divisor):
if not isinstance(divisor, ir.Value):
divisor = make_tile(*divisor)
divisor = _make_tile_or_layout(divisor)
return fly.zipped_divide(layout, divisor)


@dsl_loc_tracing
def tiled_divide(layout, divisor):
if not isinstance(divisor, ir.Value):
divisor = make_tile(*divisor)
divisor = _make_tile_or_layout(divisor)
return fly.tiled_divide(layout, divisor)


@dsl_loc_tracing
def flat_divide(layout, divisor):
if not isinstance(divisor, ir.Value):
divisor = make_tile(*divisor)
divisor = _make_tile_or_layout(divisor)
return fly.flat_divide(layout, divisor)


Expand Down Expand Up @@ -1366,8 +1366,15 @@ def _resolve(m):
raise ValueError(f"make_tile: expected int, None, tuple, or Layout, got {type(m)}")

resolved = [_resolve(m) for m in args]
if len(resolved) == 1:
tile_type = TileType.get(resolved[0])
else:
tile_type = TileType.get(resolved)
tile_type = TileType.get(resolved)
return static(tile_type)


def _make_tile_or_layout(arg):
if isinstance(arg, ir.Value):
return arg
if isinstance(arg, int):
return make_layout(arg, 1)
if isinstance(arg, (tuple, list)):
return make_tile(*arg)
raise ValueError(f"_make_tile_or_layout: expected int, tuple/list, or Layout, got {type(arg)}")
124 changes: 124 additions & 0 deletions tests/unit/test_layout_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,130 @@ def build():
build()


def generate_pycute_divide_tests():
"""generate pycute and flydsl divide tests for various shapes, strides, and tiles"""
import os

import pycute

import flydsl.compiler as flyc
import flydsl.expr as fx

os.environ.setdefault("FLYDSL_RUNTIME_ENABLE_CACHE", "0")

def is_same(a, b):
return a.shape.to_py_value() == b.shape and a.stride.to_py_value() == b.stride

def test_div(shape, stride, tile):
@flyc.jit
def test():
a0 = pycute.layout.Layout(shape, stride)
a1 = fx.make_layout(shape, stride)
print(f"input & tile: {a0} {tile}")

ref = pycute.logical_divide(a0, tile)
ret = fx.logical_divide(a1, tile)
assert is_same(ret, ref), f"fx.logical_divide(a, {tile}) = {ret}, expected {ref}"
print("\t logical_divide: ", ret, f"fx.make_layout({ref.shape},{ref.stride})")

ref = pycute.zipped_divide(a0, tile)
ret = fx.zipped_divide(a1, tile)
assert is_same(ret, ref), f"fx.zipped_divide(a, {tile}) = {ret}, expected {ref}"
print("\t zipped_divide: ", ret, f"fx.make_layout({ref.shape},{ref.stride})")

ref = pycute.tiled_divide(a0, tile)
ret = fx.tiled_divide(a1, tile)
assert is_same(ret, ref), f"fx.tiled_divide(a, {tile}) = {ret}, expected {ref}"
print("\t tiled_divide: ", ret, f"fx.make_layout({ref.shape},{ref.stride})")

print("\t flat_divide: ", fx.flat_divide(a1, tile))

test()

for tile in [32, (32,)]:
test_div(64, 1, tile)

for tile in [32, (32,), (32, None), (None, 25)]:
test_div((64, 50), (100, 1), tile)

for tile in [32, (32,), (32, None, None), (32, None, 40)]:
test_div((64, 50, 80), (16000, 160, 1), tile)


def test_divide_bymode(frontend_only_jit):
"""these reference answers are generated by generate_pycute_divide_tests()"""

@flyc.jit
def build():
a = fx.make_layout(64, 1)
assert str(fx.logical_divide(a, 32)) == str(fx.make_layout((32, 2), (1, 32)))
assert str(fx.zipped_divide(a, 32)) == str(fx.make_layout((32, 2), (1, 32)))
assert str(fx.tiled_divide(a, 32)) == str(fx.make_layout((32, 2), (1, 32)))
assert str(fx.flat_divide(a, 32)) == str(fx.make_layout((32, 2), (1, 32)))

a = fx.make_layout((64, 50), (100, 1))
assert str(fx.logical_divide(a, 32)) == str(fx.make_layout((32, (2, 50)), (100, (3200, 1))))
assert str(fx.zipped_divide(a, 32)) == str(fx.make_layout((32, (2, 50)), (100, (3200, 1))))
assert str(fx.tiled_divide(a, 32)) == str(fx.make_layout((32, 2, 50), (100, 3200, 1)))
assert str(fx.flat_divide(a, 32)) == str(fx.make_layout((32, 2, 50), (100, 3200, 1)))

assert str(fx.logical_divide(a, (32,))) == str(fx.make_layout(((32, 2), 50), ((100, 3200), 1)))
assert str(fx.zipped_divide(a, (32,))) == str(fx.make_layout(((32,), (2, 50)), ((100,), (3200, 1))))
assert str(fx.tiled_divide(a, (32,))) == str(fx.make_layout(((32,), 2, 50), ((100,), 3200, 1)))
assert str(fx.flat_divide(a, (32,))) == str(fx.make_layout((32, 2, 50), (100, 3200, 1)))

assert str(fx.logical_divide(a, (32, None))) == str(fx.make_layout(((32, 2), 50), ((100, 3200), 1)))
assert str(fx.zipped_divide(a, (32, None))) == str(fx.make_layout(((32, 1), (2, 50)), ((100, 0), (3200, 1))))
assert str(fx.tiled_divide(a, (32, None))) == str(fx.make_layout(((32, 1), 2, 50), ((100, 0), 3200, 1)))
assert str(fx.flat_divide(a, (32, None))) == str(fx.make_layout((32, 1, 2, 50), (100, 0, 3200, 1)))

assert str(fx.logical_divide(a, (None, 25))) == str(fx.make_layout((64, (25, 2)), (100, (1, 25))))
assert str(fx.zipped_divide(a, (None, 25))) == str(fx.make_layout(((1, 25), (64, 2)), ((0, 1), (100, 25))))
assert str(fx.tiled_divide(a, (None, 25))) == str(fx.make_layout(((1, 25), 64, 2), ((0, 1), 100, 25)))
assert str(fx.flat_divide(a, (None, 25))) == str(fx.make_layout((1, 25, 64, 2), (0, 1, 100, 25)))

a = fx.make_layout((64, 50, 80), (16000, 160, 1))
assert str(fx.logical_divide(a, 32)) == str(fx.make_layout((32, (2, 50, 80)), (16000, (512000, 160, 1))))
assert str(fx.zipped_divide(a, 32)) == str(fx.make_layout((32, (2, 50, 80)), (16000, (512000, 160, 1))))
assert str(fx.tiled_divide(a, 32)) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1)))
assert str(fx.flat_divide(a, 32)) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1)))

assert str(fx.logical_divide(a, (32,))) == str(fx.make_layout(((32, 2), 50, 80), ((16000, 512000), 160, 1)))
assert str(fx.zipped_divide(a, (32,))) == str(
fx.make_layout(((32,), (2, 50, 80)), ((16000,), (512000, 160, 1)))
)
assert str(fx.tiled_divide(a, (32,))) == str(fx.make_layout(((32,), 2, 50, 80), ((16000,), 512000, 160, 1)))
assert str(fx.flat_divide(a, (32,))) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1)))

assert str(fx.logical_divide(a, (32, None, None))) == str(
fx.make_layout(((32, 2), 50, 80), ((16000, 512000), 160, 1))
)
assert str(fx.zipped_divide(a, (32, None, None))) == str(
fx.make_layout(((32, 1, 1), (2, 50, 80)), ((16000, 0, 0), (512000, 160, 1)))
)
assert str(fx.tiled_divide(a, (32, None, None))) == str(
fx.make_layout(((32, 1, 1), 2, 50, 80), ((16000, 0, 0), 512000, 160, 1))
)
assert str(fx.flat_divide(a, (32, None, None))) == str(
fx.make_layout((32, 1, 1, 2, 50, 80), (16000, 0, 0, 512000, 160, 1))
)

assert str(fx.logical_divide(a, (32, None, 40))) == str(
fx.make_layout(((32, 2), 50, (40, 2)), ((16000, 512000), 160, (1, 40)))
)
assert str(fx.zipped_divide(a, (32, None, 40))) == str(
fx.make_layout(((32, 1, 40), (2, 50, 2)), ((16000, 0, 1), (512000, 160, 40)))
)
assert str(fx.tiled_divide(a, (32, None, 40))) == str(
fx.make_layout(((32, 1, 40), 2, 50, 2), ((16000, 0, 1), 512000, 160, 40))
)
assert str(fx.flat_divide(a, (32, None, 40))) == str(
fx.make_layout((32, 1, 40, 2, 50, 2), (16000, 0, 1, 512000, 160, 40))
)

build()


# ==============================================================================
# 6. Product Operations (Cells 25, 27, 29)
# ==============================================================================
Expand Down
Loading