diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 440ae127..09174d5b 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -47,6 +47,14 @@ _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory, ) +from meshmode.transform_metadata import ( + FaceMassOperatorTag, + MassInverseOperatorTag, + TensorProductDOFAxisTag, + TensorProductMassInverseOperatorTag, + TensorProductOperatorAxisTag, + TensorProductOperatorTag +) from loopy.translation_unit import for_each_kernel from loopy.tools import memoize_on_disk @@ -901,6 +909,8 @@ def fuse_same_discretization_entity_loops(knl): # transforming it. orig_knl = knl + import time + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, "iface", False, @@ -915,6 +925,7 @@ def fuse_same_discretization_entity_loops(knl): "idof", False, orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, "idim", False, @@ -924,15 +935,29 @@ def fuse_same_discretization_entity_loops(knl): "iface", True, orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, "idof", True, orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, "idim", True, orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, + TensorProductDOFAxisTag, + "idof_tp", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, + TensorProductDOFAxisTag, + "idof_tp", + True, + orig_knl) + return knl @@ -1028,7 +1053,10 @@ def _get_iel_to_idofs(kernel): for iname in kernel.all_inames() if (kernel .inames[iname] - .tags_of_type(DiscretizationDOFAxisTag)) + .tags_of_type(DiscretizationDOFAxisTag) or + kernel + .inames[iname] + .tags_of_type(TensorProductDOFAxisTag)) } iface_inames = {iname for iname in kernel.all_inames() @@ -1058,6 +1086,9 @@ def _get_iel_to_idofs(kernel): raise NotImplementedError(f"The loop {insn.within_inames}" " does not appear as a singly nested" " loop.") + + # {{{ loop (simplicial) + elif ((len(insn.within_inames) == 2) and (len(insn.within_inames & iel_inames) == 1) and (len(insn.within_inames & idof_inames) == 1)): @@ -1068,9 +1099,41 @@ def _get_iel_to_idofs(kernel): for dof_insn in kernel.iname_to_insns()[idof]): pass else: + for dof_insn in kernel.iname_to_insns()[idof]: + if iel not in kernel.id_to_insn[dof_insn].within_inames: + print(f"_get_iel_to_idofs: {str(kernel.id_to_insn[dof_insn])=}") raise NotImplementedError("The loop " f"'{insn.within_inames}' has the idof-loop" " that's not nested within the iel-loop.") + # }}} + + # {{{ loop (tensor product) + + elif ((len(insn.within_inames) > 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) > 1)): + + iel, = insn.within_inames & iel_inames + for idof in insn.within_inames & idof_inames: + iel_to_idofs[iel].add(idof) + + if all((iel in kernel.id_to_insn[dof_insn].within_inames) + for dof_insn in kernel.iname_to_insns()[idof]): + pass + else: + for dof_insn in kernel.iname_to_insns()[idof]: + if iel not in kernel.id_to_insn[dof_insn].within_inames: + print("_get_iel_to_idofs: " + f"{str(kernel.id_to_insn[dof_insn])=}") + raise NotImplementedError("The loop " + f"'{insn.within_inames}' has the " + "idof-loop that's not nested " + "within the iel-loop.") + + # }}} + + # {{{ loop + elif ((len(insn.within_inames) > 2) and (len(insn.within_inames & iel_inames) == 1) and (len(insn.within_inames & idof_inames) == 1) @@ -1086,7 +1149,11 @@ def _get_iel_to_idofs(kernel): else: raise NotImplementedError("Could not fit into " " loop nest pattern.") + + # }}} + else: + print(f"_get_iel_to_idofs: {str(insn)=}") raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'" " into known set of loop-nest patterns.") @@ -1126,8 +1193,10 @@ class EinsumTag(UniqueTag): def _prepare_kernel_for_parallelization(kernel): + from meshmode.transform_metadata import TensorProductDOFAxisTag discr_tag_to_prefix = {DiscretizationElementAxisTag: "iel", DiscretizationDOFAxisTag: "idof", + TensorProductDOFAxisTag: "idof_tp", DiscretizationDimAxisTag: "idim", DiscretizationAmbientDimAxisTag: "idim", DiscretizationTopologicalDimAxisTag: "idim", @@ -1252,11 +1321,13 @@ def _combine_einsum_domains(knl): from pytools.persistent_dict import WriteOncePersistentDict -from pytato.analysis import PytatoKeyBuilder +from pytato.analysis import PytatoKeyBuilder, get_num_nodes class FusionContractorArrayContext( SingleGridWorkBalancingPytatoArrayContext): + t_units = [] + def __init__( self, queue: "cl.CommandQueue", allocator=None, *, use_memory_pool: Optional[bool] = None, @@ -1282,6 +1353,8 @@ def __init__( def transform_dag(self, dag): import pytato as pt + initial_node_count = get_num_nodes(dag) + # {{{ Remove FEMEinsumTags that might have been propagated # TODO: Is this too hacky? @@ -1355,6 +1428,28 @@ def tag_indices_as_non_negative(ary): with ProcessLogger(logger, "transform_dag.deduplicate_data_wrappers"): dag = pt.transform.deduplicate_data_wrappers(dag) + # {{{ freeze and thaw tensor product operators + + # FIXME: this is a hack + def thaw_freeze_tp_operators(expr): + if isinstance(expr, pt.Einsum) and \ + expr.tags_of_type(TensorProductOperatorTag): + ref_mass_inv, stiff_t = expr.args + data = self.to_numpy(ref_mass_inv) @ self.to_numpy(stiff_t) + axis_tags = (TensorProductOperatorAxisTag(),) + return (self.from_numpy(data).copy( + axes=( + pt.Axis(tags=frozenset(axis_tags)), + pt.Axis(tags=frozenset(axis_tags)) + ) + ).tagged(TensorProductOperatorTag()) + .tagged(pt.tags.PrefixNamed("diff_op"))) + return expr + + dag = pt.transform.map_and_copy(dag, thaw_freeze_tp_operators) + + # }}} + # {{{ get rid of copies for different views of a cl-array def eliminate_reshapes_of_data_wrappers(ary): @@ -1379,12 +1474,24 @@ def materialize_face_mass_input_and_output(expr): expr, "ifj,fej,fej->ei")): mat, jac, vec = expr.args - return (pt.einsum("ifj,fej,fej->ei", - mat, - jac, - vec.tagged(pt.tags.ImplStored())) - .tagged((pt.tags.ImplStored(), - pt.tags.PrefixNamed("face_mass")))) + if mat.tags_of_type(FaceMassOperatorTag): + return (pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + .tagged((pt.tags.ImplStored(), + pt.tags.PrefixNamed("face_mass_result")))) + elif (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ifj,fej->ei")): + mat, vec = expr.args + if mat.tags_of_type(FaceMassOperatorTag): + return (pt.einsum("ifj,fej->ei", + mat, + vec.tagged(pt.tags.ImplStored())) + .tagged((pt.tags.ImplStored(), + pt.tags.PrefixNamed("face_mass_result")))) else: return expr @@ -1398,17 +1505,41 @@ def materialize_face_mass_input_and_output(expr): # {{{ materialize inverse mass inputs def materialize_inverse_mass_inputs(expr): + def is_tp_einsum(expr): + if pt.analysis.is_einsum_similar_to_subscript( + expr, "il,eljk->eijk"): + return True + elif pt.analysis.is_einsum_similar_to_subscript( + expr, "jl,eilk->eijk"): + return True + elif pt.analysis.is_einsum_similar_to_subscript( + expr, "kl,eijl->eijk"): + return True + return False + if (isinstance(expr, pt.Einsum) and pt.analysis.is_einsum_similar_to_subscript( expr, - "ei,ij,ej->ei")): - arg1, arg2, arg3 = expr.args - if not arg3.tags_of_type(pt.tags.PrefixNamed): - arg3 = arg3.tagged(pt.tags.PrefixNamed("mass_inv_inp")) - if not arg3.tags_of_type(pt.tags.ImplStored): - arg3 = arg3.tagged(pt.tags.ImplStored()) - - return expr.copy(args=(arg1, arg2, arg3)) + "ij,ej->ei")): + mat, vec = expr.args + if mat.tags_of_type(MassInverseOperatorTag): + if not vec.tags_of_type(pt.tags.PrefixNamed): + vec = vec.tagged(pt.tags.PrefixNamed("input_vec")) + if not vec.tags_of_type(pt.tags.ImplStored): + vec = vec.tagged(pt.tags.ImplStored()) + + return expr.copy(args=(mat, vec)) + + elif (isinstance(expr, pt.Einsum) and is_tp_einsum(expr)): + mat, vec = expr.args + if mat.tags_of_type(TensorProductMassInverseOperatorTag): + if not vec.tags_of_type(pt.tags.PrefixNamed): + vec = vec.tagged(pt.tags.PrefixNamed("input_vec_tp")) + if not vec.tags_of_type(pt.tags.ImplStored): + vec = vec.tagged(pt.tags.ImplStored()) + + return expr.copy(args=(mat, vec)) + else: return expr @@ -1638,6 +1769,15 @@ def _untag_impl_stored(expr): # }}} + final_node_count = get_num_nodes(dag) + with ProcessLogger(logger, "final node count"): + logger.info( + "Final DAG size: %d nodes, started with %d nodes, %s %d nodes", + final_node_count, initial_node_count, + ("added" if initial_node_count < final_node_count else + "removed"), abs(initial_node_count - final_node_count) + ) + return dag def transform_loopy_program(self, t_unit): @@ -1660,6 +1800,7 @@ def transform_loopy_program(self, t_unit): # from loopy.transform.instruction import simplify_indices # t_unit = simplify_indices(t_unit) + knl = t_unit.default_entrypoint logger.info(f"Transforming kernel '{knl.name}' with {len(knl.instructions)} statements.") @@ -1820,7 +1961,8 @@ def transform_loopy_program(self, t_unit): t_unit = t_unit.with_kernel(knl) del knl - if False and t_unit.default_entrypoint.tags_of_type(FromArrayContextCompile): + if False and t_unit.default_entrypoint.tags_of_type( + FromArrayContextCompile): # FIXME: Enable this branch, WIP for now and hence disabled it. from loopy.match import ObjTagged import feinsum as fnsm @@ -1863,17 +2005,40 @@ def transform_loopy_program(self, t_unit): knl = t_unit.default_entrypoint for iel, idofs in sorted(iel_to_idofs.items()): if idofs: - nunit_dofs = {knl.get_constant_iname_length(idof) - for idof in idofs} - idof, = idofs + if len(idofs) == 1: + nunit_dofs = { + knl.get_constant_iname_length(idof) + for idof in idofs + } + l_one, l_zero = _get_group_size_for_dof_array_loop( + nunit_dofs) + + idof, = idofs + + knl = lp.split_iname(knl, iel, l_one, + inner_tag="l.1", + outer_tag="g.0") + knl = lp.split_iname(knl, idof, l_zero, + inner_tag="for", + outer_tag="l.0") - l_one_size, l_zero_size = _get_group_size_for_dof_array_loop( - nunit_dofs) + else: + def idof_tp_sort_key(idof): + tag, = knl.inames[idof].tags_of_type( + TensorProductDOFAxisTag + ) + return tag.iaxis + + inames_to_tags = {iel: "g.0"} + inames_to_tags.update({ + idof: f"l.{i}" + for i, idof in enumerate( + sorted(idofs, key=idof_tp_sort_key)[:-1] + ) + }) + + knl = lp.tag_inames(knl, inames_to_tags) - knl = lp.split_iname(knl, iel, l_one_size, - inner_tag="l.1", outer_tag="g.0") - knl = lp.split_iname(knl, idof, l_zero_size, - inner_tag="l.0", outer_tag="unr") else: knl = lp.split_iname(knl, iel, 32, outer_tag="g.0", inner_tag="l.0") @@ -1886,5 +2051,4 @@ def transform_loopy_program(self, t_unit): return t_unit - # vim: foldmethod=marker diff --git a/meshmode/transform_metadata.py b/meshmode/transform_metadata.py index 148ea577..e9fcaced 100644 --- a/meshmode/transform_metadata.py +++ b/meshmode/transform_metadata.py @@ -36,6 +36,7 @@ """ from pytools.tag import Tag, UniqueTag, tag_dataclass +from pytato.transform.metadata import AxisIgnoredForPropagationTag class FirstAxisIsElementsTag(Tag): @@ -131,3 +132,103 @@ class DiscretizationDOFPickListAxisTag(DiscretizationEntityAxisTag): DOF pick lists. See :mod:`meshmode.discretization.connection.direct` for details. """ + +# {{{ tensor-product and operator metadata + +class OperatorTag(Tag): + """ + Used to signify that an array is an operator. + """ + + +class FaceMassOperatorTag(OperatorTag): + """ + Used to signify than an array is a face mass operator. + """ + + +class MassOperatorTag(OperatorTag): + """ + Used to signify that an array is a mass operator. + """ + + +class MassInverseOperatorTag(OperatorTag): + """ + Used to signify that an array is an inverse mass operator. + """ + + +class DifferentiationOperatorTag(OperatorTag): + """ + Used to signify that an array is a *strong* differentiation operator. + """ + + +class StiffnessOperatorTag(OperatorTag): + """ + Used to signify that an array is a *weak* differentiation operator. + """ + + +@tag_dataclass +class TensorProductDOFAxisTag(DiscretizationEntityAxisTag): + """ + Signify an axis as containing the DOFs of a tensor product discretization. + `iaxis` is later interpreted to determine the relative update speed (i.e. + the stride) of each axis. + """ + iaxis: int + + +class TensorProductOperatorAxisTag( + DiscretizationEntityAxisTag, + AxisIgnoredForPropagationTag + ): + """ + Signify an axis is part of a 1D operator applied to a tensor product + discretization. No tags will be propagated to or along axes containing this + tag. + """ + pass + + +class TensorProductOperatorTag(Tag): + """ + Used to tag an operator as one that acts on DOFs from a tensor-product + discretization. Used to make decisions about how to handle prefetching and + precomputing these operators. + """ + pass + + +class TensorProductMassOperatorTag(TensorProductOperatorTag): + """ + Tag an operator as being a reference mass operator. Used to realize an + algebraic simplification of redundant mass-times-mass-inverse operations + when using a tensor product discretization. + """ + pass + + +class TensorProductMassInverseOperatorTag(TensorProductOperatorTag): + """ + See `TensorProductMassOperatorTag`. + """ + pass + + +class TensorProductDifferentiationOperatorTag(OperatorTag): + """ + See `DifferentiationOperatorTag`. + """ + + +class TensorProductStiffnessOperatorTag(TensorProductOperatorTag): + """ + Similar to `TensorProductMassOperatorTag`. Used to implement an + associativity DAG transformation. + """ + pass + +# }}}