diff --git a/examples/cuda_tile/tile_example.spy b/examples/cuda_tile/tile_example.spy index ff6ab1c..650aeaf 100644 --- a/examples/cuda_tile/tile_example.spy +++ b/examples/cuda_tile/tile_example.spy @@ -96,6 +96,25 @@ def exported() -> None: store_view_tko(added, pviewA, unpack_blockIdx_x(blockidxs), token) + I1 = MLIR_Type("i1") + TileI1 = MLIR_Type("!cuda_tile.tile<{}>", I1) + # equal = MLIR_asm("cuda_tile.cmpi {comparison_predicate=#cuda_tile.comparison_predicate, signedness=#cuda_tile.signedness}", Tile128xI1, (Tile128xI32, Tile128xI32)) + # "cuda_tile.constant"() <{value = dense<[1, 2, 3, 4]> : !cuda_tile.tile<4xi32>}> : () -> !cuda_tile.tile<4xi32> + const_one = MLIR_asm("cuda_tile.constant {value = dense : !cuda_tile.tile}", bool, ()) + + if_op = MLIR_op("cuda_tile.if", bool, (TileI1,)) + + def export_ifelse(a: TilePtrF64) -> None: + tid = iota() + ptrs = offset(broadcast(reshape(a)), tid) + tko = load_ptr_tko(ptrs) # load pointers + v = unpack_0(tko) # MLIR details to get the tile data + + if const_one(): + printfn(v) # print the tile data + + return + _ = exported() diff --git a/nbcc/cutile_backend/backend.py b/nbcc/cutile_backend/backend.py index 2f4f892..f6201da 100644 --- a/nbcc/cutile_backend/backend.py +++ b/nbcc/cutile_backend/backend.py @@ -18,11 +18,9 @@ from cuda_tile._mlir.extras import types as _tile_types import cuda_tile._mlir.ir as ir # Context, Location, Module, Type - from nbcc.developer import TODO from nbcc.mlir_lowering import LowerStates - def entry( sym_name, function_type, @@ -87,7 +85,8 @@ def __init__(self, tu: TranslationUnit): self._i32 = ir.IntegerType.get_signless(32) self._i64 = ir.IntegerType.get_signless(64) self._boolean = ir.IntegerType.get_signless(1) - self._io_type = ir.IntegerType.get_signless(1) + self._io_bitype = ir.IntegerType.get_signless(1) + self._io_type = ir.Type.parse("!cuda_tile.tile") self._none_type = ir.IntegerType.get_signless(1) @classmethod @@ -165,7 +164,7 @@ def create_constant(self, value, type): return op def initialize_io(self): - return self.create_constant(0, self.io_type) + return self.create_constant(0, self._io_bitype) def create_none(self): return self.create_constant(0, self.none_type) @@ -275,13 +274,15 @@ def wrap(self, fqn, args): def _handle_builtins_i32(self, fqn: FQN, args: tuple): return (self.i32,) + @disp.case(by_typename("builtins::bool")) + def _handle_builtins_bool(self, fqn: FQN, args: tuple): + boolty = ir.Type.parse("!cuda_tile.tile") + return (boolty,) + @disp.case(by_typename("types::NoneType")) def _handle_none(self, fqn: FQN, args: tuple): return () - def get_ll_type(self, expr, mdmap) -> ir.Type | None: - """Get backend type for expression with metadata.""" - raise NotImplementedError def handle_builtin_op( self, op_name: str, args, state, lowering_instance=None @@ -312,17 +313,19 @@ def create_constant_boolean(self, value: bool): return self.create_constant(int(value), "i1") # Control flow methods - def create_if_op(self, condition, result_types, has_else=True): + def create_if_op(self, condition, result_types, operands, has_else=True, ): """Create if-else control flow operation - may not be supported in CuTile.""" - raise UnsupportedError( - "SCF if operations not supported in CuTile backend" - ) + from types import SimpleNamespace + ifop = _cuda_tile.IfOp(results_=result_types, condition=condition) + ns = SimpleNamespace() + ns.then_block = ifop.thenRegion.blocks.append() + ns.else_block = ifop.elseRegion.blocks.append() + ns.results = ifop.results_ + return ns def create_yield_op(self, operands): """Create yield operation - may not be supported in CuTile.""" - raise UnsupportedError( - "SCF yield operations not supported in CuTile backend" - ) + return _cuda_tile.yield_(operands) def create_while_op(self, result_types, init_args): """Create while loop operation - may not be supported in CuTile.""" diff --git a/nbcc/egraph/region_port_pruning.py b/nbcc/egraph/region_port_pruning.py new file mode 100644 index 0000000..ff6868e --- /dev/null +++ b/nbcc/egraph/region_port_pruning.py @@ -0,0 +1,545 @@ +# mypy: disable-error-code="empty-body" +"""EGraph-based optimization for eliminating redundant passthrough ports in +IfElse constructs. + +This module implements an optimization that identifies and eliminates +"passthrough ports" in IfElse regions. A passthrough port is an output port +whose value is directly connected to an input operand of the region (no +transformation). When both then/else branches of an IfElse have the same +passthrough ports, the optimization eliminates those ports and replaces their +usage with direct operand references. + +The optimization works by: 1. Analyzing both branches to identify passthrough +mappings (output index -> operand index) 2. Finding common passthroughs between +branches via set intersection 3. Creating pruned port lists that exclude the +common passthroughs 4. Redirecting usage of eliminated ports to reference the +original operands directly + +This reduces the output arity of IfElse constructs and simplifies the IR when +branches have redundant passthrough outputs. +""" +from __future__ import annotations + +from egglog import Expr, Set, function, i64 + + +class PassthroughAnalysisResult(Expr): + """Marker class for tracking passthrough analysis completion.""" + + def __init__(self): ... + + +class PassthroughMapping(Expr): + """Represents a mapping from an output port index to its source operand + index. + + Used to track which output ports are "passthroughs" - ports whose values are + directly forwarded from input operands without transformation. + + Args: + out_idx: Index of the output port + src_idx: Index of the input operand that this port forwards + """ + + def __init__(self, out_idx: i64, src_idx: i64): ... + + +@function +def record_passthrough_analysis( + mappings: Set[PassthroughMapping], +) -> PassthroughAnalysisResult: + """Records the completion of passthrough analysis for debugging purposes.""" + ... + + +def define_rule(): + from egglog import (Bool, Expr, PyObject, Set, String, Unit, Vec, delete, + function, i64, i64Like, method, panic, rewrite, rule, + ruleset, set_, subsume, union) + from sealir.eqsat.rvsdg_eqsat import Port, PortList, Region, Term, TermList + + class _PruningAction(Expr): + """Internal action marker for port pruning operations.""" + + ... + + @function(merge=lambda x, y: x | y) + def collect_passthrough_mappings( + region: Region, ports: Vec[Port], current_index: i64Like + ) -> Set[PassthroughMapping]: + """Recursively collects passthrough mappings for a region's output + ports. + + Analyzes each output port to determine if it directly forwards an input + operand. If so, creates a PassthroughMapping recording the output-to- + operand relationship. + + Args: + region: The region containing the ports + ports: Vector of output ports to analyze + current_index: Current position in the ports vector + + Returns: + Set of PassthroughMapping objects for passthrough ports found + """ + ... + + @function + def initiate_passthrough_elimination( + ifelse: Term, + then_region: Region, + then_ports: PortList, + else_region: Region, + else_ports: PortList, + elimination_mappings: Set[PassthroughMapping], + ) -> _PruningAction: + """Initiates the elimination of common passthrough ports from an IfElse + construct. + + Takes the intersection of passthrough mappings from both branches and + begins the process of pruning those common passthroughs from the output + port lists. + + Args: + ifelse: The IfElse term being optimized + then_region: The then branch region + then_ports: Output ports of the then branch + else_region: The else branch region + else_ports: Output ports of the else branch + elimination_mappings: Common passthrough mappings to eliminate + + Returns: Action marker to trigger port pruning rules + """ + ... + + class EliminationMask(Expr): + """Bitmap-style data structure tracking which output ports should be + eliminated. + + Stores a set of PassthroughMapping objects along with the total port + count, and provides efficient lookup to check if a specific port index + should be eliminated during pruning. + + Args: + mappings: Set of passthrough mappings indicating ports to + eliminate port_count: Total number of output ports in the original + list + """ + + def __init__( + self, mappings: Set[PassthroughMapping], port_count: i64 + ): ... + + @method(merge=lambda a, b: a | b) + def should_eliminate(self, port_index: i64Like) -> Bool: + """Returns True if the port at the given index should be + eliminated.""" + ... + + @function + def create_pruned_port_list( + region: Region, ports: PortList, elimination_mask: EliminationMask + ) -> PortList: + """Creates a new port list with eliminated passthrough ports removed. + + Args: + region: The region containing the ports + ports: Original list of output ports + elimination_mask: Mask indicating which ports to eliminate + Returns: + New PortList with passthrough ports removed + """ + ... + + @function + def _build_pruned_port_list( + region: Region, + ports: Vec[Port], + current_pos: i64, + elimination_mask: EliminationMask, + accumulated_ports: Vec[Port], + ) -> _PruningAction: + """Helper function that recursively builds the pruned port list. + + Iterates through the original ports vector, including only those ports + that are not marked for elimination in the mask. + + Args: + region: The region containing the ports + ports: Original vector of ports + current_pos: Current position being processed + elimination_mask: Mask indicating which ports to skip + accumulated_ports: Growing list of ports to keep + + Returns: + Action marker for rule scheduling + """ + ... + + @function + def redirect_port_usage_to_operands( + ifelse_term: Term, + elimination_mappings: Set[PassthroughMapping], + ) -> _PruningAction: + """Redirects usage of eliminated ports to reference the original + operands. + + For each eliminated passthrough port, replaces references to that port + with direct references to the corresponding input operand. + + Args: + ifelse_term: The IfElse term being optimized + elimination_mappings: Mappings from eliminated port indices to operand + indices + Returns: + Action marker to trigger usage redirection rules + """ + ... + + @function + def initialize_elimination_mask( + elimination_mask: EliminationMask, current_pos: i64 + ) -> _PruningAction: + """Initializes the elimination mask by setting all positions to False + initially. + + This function recursively processes each position in the mask, marking + positions as True only if they correspond to ports that should be + eliminated. + + Args: + elimination_mask: The mask being initialized + current_pos: Current position being processed + Returns: + Action marker for rule scheduling + """ + ... + + @ruleset + def detect_ifelse_passthrough_candidates( + cond_term: Term, + then_term: Term, + else_term: Term, + operand_terms: TermList, + then_region: Region, + then_out_ports: Vec[Port], + else_region: Region, + else_out_ports: Vec[Port], + ): + """ + Detects IfElse constructs and initiates passthrough analysis for both + branches. + + This ruleset identifies IfElse terms and triggers the collection of + passthrough mappings for both the then and else regions by starting the + recursive analysis at position 0 for each branch's output ports. + """ + yield rule( + Term.IfElse( + cond=cond_term, + then=then_term, + orelse=else_term, + operands=operand_terms, + ), + then_term + == Term.RegionEnd( + region=then_region, ports=PortList(ports=then_out_ports) + ), + else_term + == Term.RegionEnd( + region=else_region, ports=PortList(ports=else_out_ports) + ), + ).then( + set_( + collect_passthrough_mappings(then_region, then_out_ports, 0) + ).to(Set[PassthroughMapping].empty()), + set_( + collect_passthrough_mappings(else_region, else_out_ports, 0) + ).to(Set[PassthroughMapping].empty()), + ) + + @ruleset + def analyze_passthrough_mappings( + i: i64, + j: i64, + ports: Vec[Port], + region: Region, + wc_name: String, + mappings: Set[PassthroughMapping], + ): + """ + Analyzes individual ports to identify passthrough mappings and continues iteration. + """ + # 1. Identifies when an output port directly references a region input + yield rule( + mappings == collect_passthrough_mappings(region, ports, i), + ports[i] == Port(name=wc_name, term=region.get(j)), + ).then( + set_(collect_passthrough_mappings(region, ports, i)).to( + {PassthroughMapping(i, j)} + ), + ) + # 2. Advances the analysis to the next port position when current + # position is within bounds + yield rule( + mappings == collect_passthrough_mappings(region, ports, i), + i < ports.length(), + ).then( + set_(collect_passthrough_mappings(region, ports, i + 1)).to( + mappings + ) + ) + + @ruleset + def eliminate_common_passthroughs( + cond_term: Term, + then_term: Term, + else_term: Term, + operand_terms: TermList, + then_region: Region, + then_out_ports: Vec[Port], + else_region: Region, + else_out_ports: Vec[Port], + then_mappings: Set[PassthroughMapping], + else_mappings: Set[PassthroughMapping], + elimination_mappings: Set[PassthroughMapping], + ifelse_term: Term, + then_ports: PortList, + else_ports: PortList, + i: i64, + j: i64, + nelem: i64, + ): + """ + Performs the core elimination of common passthrough ports between branches. + + 3. Redirect port usage to point to original operands instead of eliminated ports + 4. Initialize and populate the elimination mask to track which ports should be removed + """ + # 1. Find the intersection of passthrough mappings from both branches + # and initiate elimination + yield rule( + ifelse_term + == Term.IfElse( + cond=cond_term, + then=then_term, + orelse=else_term, + operands=operand_terms, + ), + then_term + == Term.RegionEnd( + region=then_region, ports=PortList(ports=then_out_ports) + ), + else_term + == Term.RegionEnd( + region=else_region, ports=PortList(ports=else_out_ports) + ), + then_mappings + == collect_passthrough_mappings( + then_region, then_out_ports, then_out_ports.length() + ), + else_mappings + == collect_passthrough_mappings( + else_region, else_out_ports, else_out_ports.length() + ), + ).then( + # The intersection finds common passthrough ports between branches + union(PassthroughAnalysisResult()).with_( + record_passthrough_analysis(then_mappings) + ), # TODO: REMOVE ME + initiate_passthrough_elimination( + ifelse_term, + then_region, + PortList(ports=then_out_ports), + else_region, + PortList(ports=else_out_ports), + then_mappings & else_mappings, # Intersection + ), + ) + # 2. Create pruned port lists for both then and else branches, + # removing common passthrough ports + # 3. Redirect port usage to point to original operands instead of + # eliminated ports + yield rule( + _del1 := initiate_passthrough_elimination( + ifelse_term, + then_region, + then_ports, + else_region, + else_ports, + elimination_mappings, + ), + then_ports == PortList(ports=then_out_ports), + else_ports == PortList(ports=else_out_ports), + ).then( + union(then_ports).with_( + create_pruned_port_list( + then_region, + then_ports, + EliminationMask( + elimination_mappings, then_out_ports.length() + ), + ) + ), + union(else_ports).with_( + create_pruned_port_list( + else_region, + else_ports, + EliminationMask( + elimination_mappings, else_out_ports.length() + ), + ) + ), + redirect_port_usage_to_operands(ifelse_term, elimination_mappings), + delete(_del1), + ) + # 4. Initialize and populate the elimination mask to track which ports + # should be removed + yield rule( + EliminationMask(elimination_mappings, nelem), + elimination_mappings.contains(PassthroughMapping(i, j)), + ).then( + set_( + EliminationMask(elimination_mappings, nelem).should_eliminate( + i + ) + ).to(True) + ) + + yield rule( + EliminationMask(elimination_mappings, nelem), + ).then( + initialize_elimination_mask( + EliminationMask(elimination_mappings, nelem), 0 + ) + ) + yield rule( + initialize_elimination_mask( + EliminationMask(elimination_mappings, nelem), i + ), + i + 1 < nelem, + ).then( + initialize_elimination_mask( + EliminationMask(elimination_mappings, nelem), i + 1 + ) + ) + yield rule( + del1 := initialize_elimination_mask( + EliminationMask(elimination_mappings, nelem), i + ), + i < nelem, + ).then( + set_( + EliminationMask(elimination_mappings, nelem).should_eliminate( + i + ) + ).to(False), + delete(del1), + ) + + @ruleset + def construct_pruned_ports( + ports: Vec[Port], + new_ports: Vec[Port], + region: Region, + i: i64, + elimination_mask: EliminationMask, + ): + """ + Constructs the actual pruned port lists by iterating through original ports. + """ + + # 1. Initiate the pruning process by starting the recursive port list + # construction + yield rule( + create_pruned_port_list(region, PortList(ports), elimination_mask), + ).then( + _build_pruned_port_list( + region, ports, 0, elimination_mask, Vec[Port].empty() + ) + ) + # 2. Include ports that should NOT be eliminated (mask returns False) + yield rule( + del1 := _build_pruned_port_list( + region, ports, i, elimination_mask, new_ports + ), + ports[i], + i < ports.length(), + Bool(False) == (elimination_mask.should_eliminate(i)), + ).then( + _build_pruned_port_list( + region, + ports, + i + 1, + elimination_mask, + new_ports.push(ports[i]), + ), + delete(del1), + ) + # 3. Skip ports that SHOULD be eliminated (mask returns True) + yield rule( + del1 := _build_pruned_port_list( + region, ports, i, elimination_mask, new_ports + ), + ports[i], + i < ports.length(), + Bool(True) == (elimination_mask.should_eliminate(i)), + ).then( + _build_pruned_port_list( + region, ports, i + 1, elimination_mask, new_ports + ), + delete(del1), + ) + # 4. Finalize the pruned port list when all positions have been processed + yield rule( + del1 := create_pruned_port_list( + region, PortList(ports), elimination_mask + ), + del2 := _build_pruned_port_list( + region, ports, ports.length(), elimination_mask, new_ports + ), + ).then( + union(del1).with_(PortList(new_ports)), delete(del1), delete(del2) + ) + + @ruleset + def redirect_eliminated_port_usage( + ifelse: Term, + elimination_mappings: Set[PassthroughMapping], + i: i64, + j: i64, + operands: TermList, + wc_cond: Term, + wc_then: Term, + wc_orelse: Term, + ): + """ + Redirects references to eliminated ports to point directly to the + original operands. + + - For each eliminated port (found in elimination_mappings), replaces + references to that port with direct references to the corresponding + input operand + - This completes the optimization by removing the indirection through + eliminated ports + """ + yield rule( + redirect_port_usage_to_operands(ifelse, elimination_mappings), + ifelse.getPort(i), + elimination_mappings.contains(PassthroughMapping(i, j)), + ifelse + == Term.IfElse( + cond=wc_cond, then=wc_then, orelse=wc_orelse, operands=operands + ), + ).then(union(ifelse.getPort(i)).with_(operands[j])) + + schedule = ( + ( + detect_ifelse_passthrough_candidates | analyze_passthrough_mappings + ).saturate() + + eliminate_common_passthroughs.saturate() + + construct_pruned_ports.saturate() + + redirect_eliminated_port_usage.saturate() + ) + + return schedule diff --git a/nbcc/mlir_backend/backend.py b/nbcc/mlir_backend/backend.py index e84ae92..a8584c5 100644 --- a/nbcc/mlir_backend/backend.py +++ b/nbcc/mlir_backend/backend.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence, cast +from typing import Sequence import mlir.dialects.arith as arith import mlir.dialects.cf as cf @@ -9,13 +9,12 @@ from mlir.dialects import llvm from mlir.dialects.transform.interpreter import apply_named_sequence from mlir.ir import _GlobalDebug -from sealir import ase from sealir.dispatchtable import DispatchTableBuilder, dispatchtable from spy.fqn import FQN from nbcc.developer import TODO from nbcc.mlir_utils import decode_type_name, decode_asm_operation -from nbcc.mlir_lowering import BackendInterface, MDMap, LowerStates +from nbcc.mlir_lowering import BackendInterface, LowerStates from ..frontend import grammar as sg, TranslationUnit from .mlir_passes import PassManager @@ -147,7 +146,7 @@ def create_constant_boolean(self, value: bool): return arith.constant(self.boolean, value) # Control flow methods - def create_if_op(self, condition, result_types, has_else=True): + def create_if_op(self, condition, result_types, operands, has_else=True): from mlir.dialects import scf return scf.IfOp( @@ -648,13 +647,6 @@ def _handle_reduce_max_inner_keepdims(self, mlir_op: str, resty, args): assert bc.verify() return bc - def get_ll_type(self, expr: ase.SExpr, mdmap: MDMap) -> ir.Type: - mds = mdmap.lookup_typeinfo(expr) - if not mds: - return None - [ty] = mds - [llty] = self.lower_type(cast(sg.TypeExpr, ty.type_expr)) - return llty def make_module(self, module_name: str) -> ir.Module: with self.context: diff --git a/nbcc/mlir_lowering.py b/nbcc/mlir_lowering.py index 3d38a70..01d789c 100644 --- a/nbcc/mlir_lowering.py +++ b/nbcc/mlir_lowering.py @@ -141,9 +141,6 @@ def lower_type(self, ty) -> tuple["IRType", ...]: tuple of all component types. """ - @abstractmethod - def get_ll_type(self, expr, mdmap) -> "IRType | None": - """Get backend type for expression with metadata.""" @abstractmethod def handle_builtin_op( @@ -199,7 +196,8 @@ def create_constant_boolean(self, value: bool) -> Any: # Control flow methods @abstractmethod def create_if_op( - self, condition: Any, result_types: list, has_else: bool = True + self, condition: Any, result_types: list, operands: list, + has_else: bool = True, ) -> Any: """Create if-else control flow operation.""" @@ -293,6 +291,8 @@ def lower(self, root: rg.Func) -> Any: Lower RVSDG expressions to MLIR operations, handling control flow and data flow constructs. """ + from sealir.rvsdg import format_rvsdg + # print(format_rvsdg(root)) context = self.be.context self.loc = loc = self.be.Location.name( f"{self.__class__.__name__}.lower()", context=context @@ -498,6 +498,7 @@ def lower_expr(self, expr: SExpr, state: LowerStates): operand_vals.append((yield op)) result_tys: list[Any] = [] + packed_result_tys: list[tuple] = [] # determine result types assert isinstance(body, rg.RegionEnd) @@ -505,33 +506,31 @@ def lower_expr(self, expr: SExpr, state: LowerStates): for left_port, right_port in zip( body.ports, orelse.ports, strict=True ): - left_ty = self.get_port_type(left_port) - right_ty = self.get_port_type(right_port) - if left_ty is None: - ty = right_ty - elif right_ty is None: - ty = left_ty - else: - assert left_ty == right_ty, f"{left_ty} != {right_ty}" - ty = left_ty - result_tys.append(ty) + left_tys = self.get_port_type(left_port) + right_tys = self.get_port_type(right_port) + # Both branches should always return the same types + assert left_tys == right_tys, f"{left_tys} != {right_tys}" + result_tys.extend(left_tys) + packed_result_tys.append(left_tys) # Build the MLIR If-else if_op = self.be.create_if_op( - condition=condval, result_types=result_tys, has_else=True + condition=condval, result_types=result_tys, has_else=True, + operands=self.flatten_result_list(operand_vals), ) with state.push(operand_vals): # Make a detached module to temporarily house the blocks with self.be.InsertionPoint(if_op.then_block): value_body = yield body - self.be.create_yield_op([x for x in value_body]) + self.be.create_yield_op(self.flatten_result_list([x for x in value_body])) with self.be.InsertionPoint(if_op.else_block): value_else = yield orelse - self.be.create_yield_op([x for x in value_else]) + self.be.create_yield_op(self.flatten_result_list([x for x in value_else])) - return if_op.results + # repack + return self.repack_result_list(packed_result_tys, list(if_op.results)) case rg.Loop(body=rg.RegionEnd() as body, operands=operands): # Cast operands to expected type from pattern match @@ -575,18 +574,19 @@ def lower_expr(self, expr: SExpr, state: LowerStates): fqn=sg.FQN() as callee_fqn, io=io_val, args=args_vals ): mdmap = self.mdmap + io_val = (yield io_val) [callee_ti] = mdmap.lookup_typeinfo(callee_fqn) type_expr: sg.TypeExpr = cast(sg.TypeExpr, callee_ti.type_expr) - argtys: list[Any] = [] - for arg in args_vals: - [ti] = mdmap.lookup_typeinfo(arg) - arg_types = self.be.lower_type( - cast(sg.TypeExpr, ti.type_expr) - ) - argtys.extend(arg_types) + # argtys: list[Any] = [] + # for arg in args_vals: + # [ti] = mdmap.lookup_typeinfo(arg) + # arg_types = self.be.lower_type( + # cast(sg.TypeExpr, ti.type_expr) + # ) + # argtys.extend(arg_types) lowered_args = [] for arg in args_vals: @@ -611,7 +611,6 @@ def lower_expr(self, expr: SExpr, state: LowerStates): if owner is not None: assert owner.verify() return [io_val, res] - # self.declare_builtins(c_name, argtys, [resty]) elif callee_fqn_obj.namespace.fullname == "mlir::asm": result_types = list( self.be.lower_type( @@ -670,12 +669,18 @@ def _handle_mlir_asm(self, mlir_op: str, result_types, args): opname, _, attr = mlir_op.partition(" ") return self.be.create_mlir_asm(opname, attr, result_types, args) - def get_port_type(self, port) -> Any: + def get_ll_type(self, expr) -> tuple[Any, ...]: + """Get backend types for expression with metadata.""" + mds = self.mdmap.lookup_typeinfo(expr) + [ty] = mds + lltys = self.be.lower_type(cast(sg.TypeExpr, ty.type_expr)) + return lltys + + def get_port_type(self, port) -> tuple[Any, ...]: if port.name == internal_prefix("io"): - ty = self.be.io_type + return (self.be.io_type,) else: - ty = self.be.get_ll_type(port.value, self.mdmap) - return ty + return self.get_ll_type(port.value) def declare_builtins(self, sym_name, argtypes, restypes): if sym_name in self._declared: @@ -691,3 +696,40 @@ def declare_builtins(self, sym_name, argtypes, restypes): ) ) return ret + + def flatten_result_list(self, values): + flattened = [] + for val in values: + if type(val).__name__ == "OpResultList": + flattened.extend(val) + else: + flattened.append(val) + return flattened + + def repack_result_list(self, types, values): + i = 0 + buf = [] + for ty in types: + group = [] + for j in range(len(ty)): + group.append(values[i + j]) + if len(group) > 1: + buf.append(OpResultList(group)) + else: + buf.extend(group) + i += len(group) + return buf + + +class OpResultList: + """Fake OpResultList""" + def __init__(self, values): + assert len(values) > 1 + self._values = tuple(values) + + def __getitem__(self, idx): + return self._values[idx] + + def __len__(self): + return len(self._values) + diff --git a/nbcc/tests/test_egraph_region_outport_pruning.py b/nbcc/tests/test_egraph_region_outport_pruning.py new file mode 100644 index 0000000..ace2244 --- /dev/null +++ b/nbcc/tests/test_egraph_region_outport_pruning.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +from egglog import eq, expr_parts +from sealir import ase +from sealir.eqsat.py_eqsat import Py_Tuple +from sealir.eqsat.rvsdg_eqsat import Term, termlist +from sealir.rvsdg import format_rvsdg +from sealir.rvsdg import grammar as rg + +from nbcc.compiler import EGraph, egraph_conversion +from nbcc.egraph.region_port_pruning import define_rule + + +def make_example(*, then_ports, else_ports): + def make_ports(g, port_args): + portvalues = [] + for arg in port_args: + if isinstance(arg, int): + portvalues.append(g.write(rg.Unpack(val=rb, idx=arg))) + else: + portvalues.append(arg) + ports = [] + for i, v in enumerate(portvalues): + ports.append(g.write(rg.Port(name=f"p{i}", value=v))) + return tuple(ports) + + with ase.Tape() as tape: + g = rg.Grammar(tape=tape) + predicate = g.write(rg.PyInt(1)) + attrs = g.write(rg.Attrs(())) + + value_a = g.write(rg.PyInt(1)) + value_b = g.write(rg.PyInt(2)) + value_c = g.write(rg.PyInt(3)) + value_d = g.write(rg.PyInt(4)) + constants = [value_a, value_b, value_c, value_d] + then_ports = [a if b is None else b + for a, b, in zip(constants, then_ports, strict=True)] + else_ports = [a if b is None else b + for a, b, in zip(constants, else_ports, strict=True)] + ifelse = g.write( + rg.IfElse( + cond=predicate, + body=g.write( + rg.RegionEnd( + begin=( + rb := g.write( + rg.RegionBegin( + attrs=attrs, inports=("a", "b", "c") + ) + ) + ), + ports=make_ports(g, then_ports), + ), + ), + orelse=g.write( + rg.RegionEnd( + begin=( + rb := g.write( + rg.RegionBegin( + attrs=attrs, inports=("a", "b", "c") + ) + ) + ), + ports=make_ports(g, else_ports), + ) + ), + operands=(value_a, value_b, value_c), + ) + ) + + outs = [ + g.write(rg.Unpack(val=ifelse, idx=i)) + for i in range(len(ifelse.body.ports)) + ] + + root = g.write(rg.PyTuple(elems=tuple(outs))) + return root, ifelse + + +def _extra_rules_to_equate_termlist__getitem__(): + ###### extra + from egglog import Vec, i64, rewrite, ruleset + from sealir.eqsat.rvsdg_eqsat import Term, TermList + + @ruleset + def extras(termlist: TermList, terms: Vec[Term], i: i64): + yield rewrite(termlist[i]).to( + terms[i], + termlist == TermList(terms), + ) + + return extras + + +def _run_output_port_pruning_test(then_ports, else_ports, expected_terms_builder): + """Helper function to run output port pruning tests with given parameters. + + Args: + then_ports: Port configuration for the then branch + else_ports: Port configuration for the else branch + expected_terms_builder: Function that takes ifelse_enode and returns expected terms list + """ + root, ifelse = make_example( + then_ports=then_ports, + else_ports=else_ports, + ) + print(format_rvsdg(root)) + + memo = egraph_conversion(root) + egraph = EGraph() + root_enode = egraph.let("root", memo[root]) + + egraph.run(define_rule() + _extra_rules_to_equate_termlist__getitem__()) + # egraph.display(n_inline_leaves=1) + print(out := egraph.extract(root_enode)) + + ifelse_enode = memo[ifelse].term + + expected_terms = expected_terms_builder(ifelse_enode) + egraph.check( + eq(out).to( + Py_Tuple( + termlist(*expected_terms) + ) + ) + ) + + +def test_output_port_pruning_1(): + def build_expected_terms(ifelse_enode): + return [ + ifelse_enode.getPort(0), + ifelse_enode.getPort(1), + Term.LiteralI64(1), + ifelse_enode.getPort(3), + ] + + _run_output_port_pruning_test( + then_ports=[None, None, 0, 1], + else_ports=[None, None, 0, 2], + expected_terms_builder=build_expected_terms + ) + + +def test_output_port_pruning_2(): + def build_expected_terms(ifelse_enode): + return [ + ifelse_enode.getPort(0), + ifelse_enode.getPort(1), + Term.LiteralI64(1), + Term.LiteralI64(2), + ] + + _run_output_port_pruning_test( + then_ports=[None, None, 0, 1], + else_ports=[None, None, 0, 1], + expected_terms_builder=build_expected_terms + ) + + + +def test_output_port_pruning_3(): + def build_expected_terms(ifelse_enode): + return [ + Term.LiteralI64(3), + ifelse_enode.getPort(1), + Term.LiteralI64(1), + Term.LiteralI64(2), + ] + + _run_output_port_pruning_test( + then_ports=[2, None, 0, 1], + else_ports=[2, None, 0, 1], + expected_terms_builder=build_expected_terms + ) + + +def test_output_port_pruning_4(): + """Test with all constant values in both branches.""" + def build_expected_terms(ifelse_enode): + return [ + ifelse_enode.getPort(0), + ifelse_enode.getPort(1), + ifelse_enode.getPort(2), + ifelse_enode.getPort(3), + ] + + _run_output_port_pruning_test( + then_ports=[None, None, None, None], + else_ports=[None, None, None, None], + expected_terms_builder=build_expected_terms + ) + + +def test_output_port_pruning_5(): + """Test with mixed constants and port references, different in each branch.""" + def build_expected_terms(ifelse_enode): + return [ + Term.LiteralI64(3), + ifelse_enode.getPort(1), + ifelse_enode.getPort(2), + Term.LiteralI64(2), + ] + + _run_output_port_pruning_test( + then_ports=[2, None, None, 1], + else_ports=[2, None, None, 1], + expected_terms_builder=build_expected_terms + ) + + +def test_output_port_pruning_6(): + """Test with all port references, same pattern.""" + def build_expected_terms(ifelse_enode): + return [ + Term.LiteralI64(1), + Term.LiteralI64(2), + Term.LiteralI64(3), + ifelse_enode.getPort(3), + ] + + _run_output_port_pruning_test( + then_ports=[0, 1, 2, None], + else_ports=[0, 1, 2, None], + expected_terms_builder=build_expected_terms + ) + + +def test_output_port_pruning_7(): + """Test with alternating constants and port references.""" + def build_expected_terms(ifelse_enode): + return [ + ifelse_enode.getPort(0), + Term.LiteralI64(1), + ifelse_enode.getPort(2), + Term.LiteralI64(2), + ] + + _run_output_port_pruning_test( + then_ports=[None, 0, None, 1], + else_ports=[None, 0, None, 1], + expected_terms_builder=build_expected_terms + ) + + +def test_output_port_pruning_8(): + """Test with different port orderings between branches.""" + def build_expected_terms(ifelse_enode): + return [ + ifelse_enode.getPort(0), + ifelse_enode.getPort(1), + ifelse_enode.getPort(2), + ifelse_enode.getPort(3), + ] + + _run_output_port_pruning_test( + then_ports=[0, 1, 2, 0], + else_ports=[1, 0, 2, 1], + expected_terms_builder=build_expected_terms + )