Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -846,16 +846,16 @@ namespace detail {

template <class IntTuple>
std::pair<IntTuple, IntTuple> intTupleZip2ByImpl(const IntTupleBuilder<IntTuple> &builder,
IntTuple t, IntTupleAttr guide) {
IntTuple t, IntTupleAttr guide, int noneValue) {
using Collector = typename IntTupleBuilder<IntTuple>::ElemCollector;
if (guide.isLeaf()) {
assert(t.rank() == 2 && "intTupleZip2By expects rank-2 tuple at terminal");
return {builder.at(t, 0), builder.at(t, 1)};
}
// Canonicalize singleton guide wrappers so 1D profiles behave as leaf guides.
// This keeps zip2By robust after singleton unwrapping in product/divide type canonicalization.
if (guide.rank() == 1) {
return intTupleZip2ByImpl(builder, t, guide.at(0));
if (guide.rank() == 1 && t.rank() == 2) {
return intTupleZip2ByImpl(builder, t, guide.at(0), noneValue);
}
Collector firsts;
Collector seconds;
Expand All @@ -864,7 +864,14 @@ std::pair<IntTuple, IntTuple> intTupleZip2ByImpl(const IntTupleBuilder<IntTuple>
int32_t tRank = t.rank();
assert(tRank >= guideRank && "Mismatched ranks in intTupleZip2By");
for (int i = 0; i < guideRank; ++i) {
auto [first, second] = intTupleZip2ByImpl(builder, builder.at(t, i), guide.at(i));
if (guide.at(i).isLeafNone()) {
// i'th guide is None, implies view i'th mode s:d as (1,s):(0,d) for zip
// here first is either 1 or 0 depending on whether it's shape or stride
firsts.push_back(builder.materializeConstantLeaf(noneValue));
seconds.push_back(builder.at(t, i));
continue;
}
auto [first, second] = intTupleZip2ByImpl(builder, builder.at(t, i), guide.at(i), noneValue);
firsts.push_back(first);
seconds.push_back(second);
}
Expand All @@ -877,13 +884,14 @@ std::pair<IntTuple, IntTuple> intTupleZip2ByImpl(const IntTupleBuilder<IntTuple>
} // namespace detail

template <class IntTuple>
IntTuple intTupleZip2By(const IntTupleBuilder<IntTuple> &builder, IntTuple t, IntTupleAttr guide) {
IntTuple intTupleZip2By(const IntTupleBuilder<IntTuple> &builder, IntTuple t, IntTupleAttr guide,
int noneValue = 0) {
if (guide.isLeaf()) {
assert(t.rank() == 2 && "intTupleZip2By expects rank-2 tuple at terminal");
return t;
} else {
using Collector = typename IntTupleBuilder<IntTuple>::ElemCollector;
auto [first, second] = detail::intTupleZip2ByImpl(builder, t, guide);
auto [first, second] = detail::intTupleZip2ByImpl(builder, t, guide, noneValue);
Collector collector;
collector.push_back(first);
collector.push_back(second);
Expand Down
17 changes: 11 additions & 6 deletions include/flydsl/Dialect/Fly/Utils/LayoutUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,11 @@ typename LayoutBuilder<Layout>::IntTuple layoutCoshape(LayoutBuilder<Layout> &bu
IntTuple stride = builder.getStride(layout);
IntTuple one = builder.materializeConstantLeaf(1);

IntTuple m1Shapes =
intTupleTransformLeaf(builder, [&](IntTuple s) { return builder.sub(s, one); }, shape);
IntTuple m1Shapes = intTupleTransformLeaf(
builder, [&](IntTuple s) { return builder.sub(s, one); }, shape);
IntTuple coCoord = intTupleInnerProduct(builder, m1Shapes, stride);
return intTupleTransformLeaf(builder, [&](IntTuple c) { return builder.add(c, one); }, coCoord);
return intTupleTransformLeaf(
builder, [&](IntTuple c) { return builder.add(c, one); }, coCoord);
}

template <class Layout>
Expand Down Expand Up @@ -1365,11 +1366,15 @@ Layout layoutZippedDivide(LayoutBuilder<Layout> &builder, Layout layout, TileAtt

SmallVector<Attribute> guideElems;
for (int i = 0; i < divisorTile.rank(); ++i) {
guideElems.push_back(IntTupleAttr::getLeafNone(ctx));
if (!divisorTile.isNoneMode(i)) {
guideElems.push_back(IntTupleAttr::getLeafStatic(ctx, 1));
} else {
guideElems.push_back(IntTupleAttr::getLeafNone(ctx));
}
}
IntTupleAttr guide = IntTupleAttr::get(ArrayAttr::get(ctx, guideElems));
IntTuple retShape = intTupleZip2By(builder, builder.getShape(logicalDiv), guide);
IntTuple retStride = intTupleZip2By(builder, builder.getStride(logicalDiv), guide);
IntTuple retShape = intTupleZip2By(builder, builder.getShape(logicalDiv), guide, 1);
IntTuple retStride = intTupleZip2By(builder, builder.getStride(logicalDiv), guide, 0);
return builder.makeLayout(retShape, retStride);
}

Expand Down
23 changes: 15 additions & 8 deletions python/flydsl/expr/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,28 +910,28 @@ def left_inverse(layout):
@dsl_loc_tracing
def logical_divide(layout, divisor):
if not isinstance(divisor, ir.Value):
divisor = make_tile(*divisor)
divisor = _make_tile_or_layout(divisor)
return fly.logical_divide(layout, divisor)


@dsl_loc_tracing
def zipped_divide(layout, divisor):
if not isinstance(divisor, ir.Value):
divisor = make_tile(*divisor)
divisor = _make_tile_or_layout(divisor)
return fly.zipped_divide(layout, divisor)


@dsl_loc_tracing
def tiled_divide(layout, divisor):
if not isinstance(divisor, ir.Value):
divisor = make_tile(*divisor)
divisor = _make_tile_or_layout(divisor)
return fly.tiled_divide(layout, divisor)


@dsl_loc_tracing
def flat_divide(layout, divisor):
if not isinstance(divisor, ir.Value):
divisor = make_tile(*divisor)
divisor = _make_tile_or_layout(divisor)
return fly.flat_divide(layout, divisor)


Expand Down Expand Up @@ -1366,8 +1366,15 @@ def _resolve(m):
raise ValueError(f"make_tile: expected int, None, tuple, or Layout, got {type(m)}")

resolved = [_resolve(m) for m in args]
if len(resolved) == 1:
tile_type = TileType.get(resolved[0])
else:
tile_type = TileType.get(resolved)
tile_type = TileType.get(resolved)
return static(tile_type)


def _make_tile_or_layout(args):
if isinstance(args, ir.Value):
return args
if isinstance(args, int):
return make_layout(args, 1)
if isinstance(args, tuple):
return make_tile(*args)
raise ValueError(f"make_tile_or_layout: expected int, tuple, or Layout, got {type(args)}")
Comment thread
tingqli marked this conversation as resolved.
Outdated
46 changes: 46 additions & 0 deletions tests/unit/test_layout_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,52 @@ def build():
build()


def test_divide_bymode(frontend_only_jit):

@flyc.jit
def build():
a = fx.make_layout((64, 50, 80), (16000, 160, 1))
assert str(fx.logical_divide(a, 32)) == str(fx.make_layout((32, (2, 50, 80)), (16000, (512000, 160, 1))))
assert str(fx.zipped_divide(a, 32)) == str(fx.make_layout((32, (2, 50, 80)), (16000, (512000, 160, 1))))
assert str(fx.tiled_divide(a, 32)) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1)))
assert str(fx.flat_divide(a, 32)) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1)))

assert str(fx.logical_divide(a, (32,))) == str(fx.make_layout(((32, 2), 50, 80), ((16000, 512000), 160, 1)))
assert str(fx.zipped_divide(a, (32,))) == str(
fx.make_layout(((32,), (2, 50, 80)), ((16000,), (512000, 160, 1)))
)
assert str(fx.tiled_divide(a, (32,))) == str(fx.make_layout(((32,), 2, 50, 80), ((16000,), 512000, 160, 1)))
assert str(fx.flat_divide(a, (32,))) == str(fx.make_layout((32, 2, 50, 80), (16000, 512000, 160, 1)))

assert str(fx.logical_divide(a, (32, None, None))) == str(
fx.make_layout(((32, 2), 50, 80), ((16000, 512000), 160, 1))
)
assert str(fx.zipped_divide(a, (32, None, None))) == str(
fx.make_layout(((32, 1, 1), (2, 50, 80)), ((16000, 0, 0), (512000, 160, 1)))
)
assert str(fx.tiled_divide(a, (32, None, None))) == str(
fx.make_layout(((32, 1, 1), 2, 50, 80), ((16000, 0, 0), 512000, 160, 1))
)
assert str(fx.flat_divide(a, (32, None, None))) == str(
fx.make_layout((32, 1, 1, 2, 50, 80), (16000, 0, 0, 512000, 160, 1))
)

assert str(fx.logical_divide(a, (32, None, 40))) == str(
fx.make_layout(((32, 2), 50, (40, 2)), ((16000, 512000), 160, (1, 40)))
)
assert str(fx.zipped_divide(a, (32, None, 40))) == str(
fx.make_layout(((32, 1, 40), (2, 50, 2)), ((16000, 0, 1), (512000, 160, 40)))
)
assert str(fx.tiled_divide(a, (32, None, 40))) == str(
fx.make_layout(((32, 1, 40), 2, 50, 2), ((16000, 0, 1), 512000, 160, 40))
)
assert str(fx.flat_divide(a, (32, None, 40))) == str(
fx.make_layout((32, 1, 40, 2, 50, 2), (16000, 0, 1, 512000, 160, 40))
)

build()


# ==============================================================================
# 6. Product Operations (Cells 25, 27, 29)
# ==============================================================================
Expand Down
Loading