Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 191 additions & 27 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory,
)
from meshmode.transform_metadata import (
FaceMassOperatorTag,
MassInverseOperatorTag,
TensorProductDOFAxisTag,
TensorProductMassInverseOperatorTag,
TensorProductOperatorAxisTag,
TensorProductOperatorTag
)
from loopy.translation_unit import for_each_kernel

from loopy.tools import memoize_on_disk
Expand Down Expand Up @@ -901,6 +909,8 @@ def fuse_same_discretization_entity_loops(knl):
# transforming it.
orig_knl = knl

import time

knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag,
"iface",
False,
Expand All @@ -915,6 +925,7 @@ def fuse_same_discretization_entity_loops(knl):
"idof",
False,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag,
"idim",
False,
Expand All @@ -924,15 +935,29 @@ def fuse_same_discretization_entity_loops(knl):
"iface",
True,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag,
"idof",
True,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag,
"idim",
True,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl,
TensorProductDOFAxisTag,
"idof_tp",
False,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl,
TensorProductDOFAxisTag,
"idof_tp",
True,
orig_knl)

return knl


Expand Down Expand Up @@ -1028,7 +1053,10 @@ def _get_iel_to_idofs(kernel):
for iname in kernel.all_inames()
if (kernel
.inames[iname]
.tags_of_type(DiscretizationDOFAxisTag))
.tags_of_type(DiscretizationDOFAxisTag) or
kernel
.inames[iname]
.tags_of_type(TensorProductDOFAxisTag))
}
iface_inames = {iname
for iname in kernel.all_inames()
Expand Down Expand Up @@ -1058,6 +1086,9 @@ def _get_iel_to_idofs(kernel):
raise NotImplementedError(f"The <iel> loop {insn.within_inames}"
" does not appear as a singly nested"
" loop.")

# {{{ <iel, idof> loop (simplicial)

elif ((len(insn.within_inames) == 2)
and (len(insn.within_inames & iel_inames) == 1)
and (len(insn.within_inames & idof_inames) == 1)):
Expand All @@ -1068,9 +1099,41 @@ def _get_iel_to_idofs(kernel):
for dof_insn in kernel.iname_to_insns()[idof]):
pass
else:
for dof_insn in kernel.iname_to_insns()[idof]:
if iel not in kernel.id_to_insn[dof_insn].within_inames:
print(f"_get_iel_to_idofs: {str(kernel.id_to_insn[dof_insn])=}")
raise NotImplementedError("The <iel,idof> loop "
f"'{insn.within_inames}' has the idof-loop"
" that's not nested within the iel-loop.")
# }}}

# {{{ <iel, idof, ...> loop (tensor product)

elif ((len(insn.within_inames) > 2)
and (len(insn.within_inames & iel_inames) == 1)
and (len(insn.within_inames & idof_inames) > 1)):

iel, = insn.within_inames & iel_inames
for idof in insn.within_inames & idof_inames:
iel_to_idofs[iel].add(idof)

if all((iel in kernel.id_to_insn[dof_insn].within_inames)
for dof_insn in kernel.iname_to_insns()[idof]):
pass
else:
for dof_insn in kernel.iname_to_insns()[idof]:
if iel not in kernel.id_to_insn[dof_insn].within_inames:
print("_get_iel_to_idofs: "
f"{str(kernel.id_to_insn[dof_insn])=}")
raise NotImplementedError("The <iel,idof> loop "
f"'{insn.within_inames}' has the "
"idof-loop that's not nested "
"within the iel-loop.")

# }}}

# {{{ <iel, idof, iface> loop

elif ((len(insn.within_inames) > 2)
and (len(insn.within_inames & iel_inames) == 1)
and (len(insn.within_inames & idof_inames) == 1)
Expand All @@ -1086,7 +1149,11 @@ def _get_iel_to_idofs(kernel):
else:
raise NotImplementedError("Could not fit into <iel,idof,iface>"
" loop nest pattern.")

# }}}

else:
print(f"_get_iel_to_idofs: {str(insn)=}")
raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'"
" into known set of loop-nest patterns.")

Expand Down Expand Up @@ -1126,8 +1193,10 @@ class EinsumTag(UniqueTag):


def _prepare_kernel_for_parallelization(kernel):
from meshmode.transform_metadata import TensorProductDOFAxisTag
discr_tag_to_prefix = {DiscretizationElementAxisTag: "iel",
DiscretizationDOFAxisTag: "idof",
TensorProductDOFAxisTag: "idof_tp",
DiscretizationDimAxisTag: "idim",
DiscretizationAmbientDimAxisTag: "idim",
DiscretizationTopologicalDimAxisTag: "idim",
Expand Down Expand Up @@ -1252,11 +1321,13 @@ def _combine_einsum_domains(knl):


from pytools.persistent_dict import WriteOncePersistentDict
from pytato.analysis import PytatoKeyBuilder
from pytato.analysis import PytatoKeyBuilder, get_num_nodes

class FusionContractorArrayContext(
SingleGridWorkBalancingPytatoArrayContext):

t_units = []

def __init__(
self, queue: "cl.CommandQueue", allocator=None, *,
use_memory_pool: Optional[bool] = None,
Expand All @@ -1282,6 +1353,8 @@ def __init__(
def transform_dag(self, dag):
import pytato as pt

initial_node_count = get_num_nodes(dag)

# {{{ Remove FEMEinsumTags that might have been propagated

# TODO: Is this too hacky?
Expand Down Expand Up @@ -1355,6 +1428,28 @@ def tag_indices_as_non_negative(ary):
with ProcessLogger(logger, "transform_dag.deduplicate_data_wrappers"):
dag = pt.transform.deduplicate_data_wrappers(dag)

# {{{ freeze and thaw tensor product operators

# FIXME: this is a hack
def thaw_freeze_tp_operators(expr):
if isinstance(expr, pt.Einsum) and \
expr.tags_of_type(TensorProductOperatorTag):
ref_mass_inv, stiff_t = expr.args
data = self.to_numpy(ref_mass_inv) @ self.to_numpy(stiff_t)
axis_tags = (TensorProductOperatorAxisTag(),)
return (self.from_numpy(data).copy(
axes=(
pt.Axis(tags=frozenset(axis_tags)),
pt.Axis(tags=frozenset(axis_tags))
)
).tagged(TensorProductOperatorTag())
.tagged(pt.tags.PrefixNamed("diff_op")))
return expr

dag = pt.transform.map_and_copy(dag, thaw_freeze_tp_operators)

# }}}

# {{{ get rid of copies for different views of a cl-array

def eliminate_reshapes_of_data_wrappers(ary):
Expand All @@ -1379,12 +1474,24 @@ def materialize_face_mass_input_and_output(expr):
expr,
"ifj,fej,fej->ei")):
mat, jac, vec = expr.args
return (pt.einsum("ifj,fej,fej->ei",
mat,
jac,
vec.tagged(pt.tags.ImplStored()))
.tagged((pt.tags.ImplStored(),
pt.tags.PrefixNamed("face_mass"))))
if mat.tags_of_type(FaceMassOperatorTag):
return (pt.einsum("ifj,fej,fej->ei",
mat,
jac,
vec.tagged(pt.tags.ImplStored()))
.tagged((pt.tags.ImplStored(),
pt.tags.PrefixNamed("face_mass_result"))))
elif (isinstance(expr, pt.Einsum)
and pt.analysis.is_einsum_similar_to_subscript(
expr,
"ifj,fej->ei")):
mat, vec = expr.args
if mat.tags_of_type(FaceMassOperatorTag):
return (pt.einsum("ifj,fej->ei",
mat,
vec.tagged(pt.tags.ImplStored()))
.tagged((pt.tags.ImplStored(),
pt.tags.PrefixNamed("face_mass_result"))))
else:
return expr

Expand All @@ -1398,17 +1505,41 @@ def materialize_face_mass_input_and_output(expr):
# {{{ materialize inverse mass inputs

def materialize_inverse_mass_inputs(expr):
def is_tp_einsum(expr):
if pt.analysis.is_einsum_similar_to_subscript(
expr, "il,eljk->eijk"):
return True
elif pt.analysis.is_einsum_similar_to_subscript(
expr, "jl,eilk->eijk"):
return True
elif pt.analysis.is_einsum_similar_to_subscript(
expr, "kl,eijl->eijk"):
return True
return False

if (isinstance(expr, pt.Einsum)
and pt.analysis.is_einsum_similar_to_subscript(
expr,
"ei,ij,ej->ei")):
arg1, arg2, arg3 = expr.args
if not arg3.tags_of_type(pt.tags.PrefixNamed):
arg3 = arg3.tagged(pt.tags.PrefixNamed("mass_inv_inp"))
if not arg3.tags_of_type(pt.tags.ImplStored):
arg3 = arg3.tagged(pt.tags.ImplStored())

return expr.copy(args=(arg1, arg2, arg3))
"ij,ej->ei")):
mat, vec = expr.args
if mat.tags_of_type(MassInverseOperatorTag):
if not vec.tags_of_type(pt.tags.PrefixNamed):
vec = vec.tagged(pt.tags.PrefixNamed("input_vec"))
if not vec.tags_of_type(pt.tags.ImplStored):
vec = vec.tagged(pt.tags.ImplStored())

return expr.copy(args=(mat, vec))

elif (isinstance(expr, pt.Einsum) and is_tp_einsum(expr)):
mat, vec = expr.args
if mat.tags_of_type(TensorProductMassInverseOperatorTag):
if not vec.tags_of_type(pt.tags.PrefixNamed):
vec = vec.tagged(pt.tags.PrefixNamed("input_vec_tp"))
if not vec.tags_of_type(pt.tags.ImplStored):
vec = vec.tagged(pt.tags.ImplStored())

return expr.copy(args=(mat, vec))

else:
return expr

Expand Down Expand Up @@ -1638,6 +1769,15 @@ def _untag_impl_stored(expr):

# }}}

final_node_count = get_num_nodes(dag)
with ProcessLogger(logger, "final node count"):
logger.info(
"Final DAG size: %d nodes, started with %d nodes, %s %d nodes",
final_node_count, initial_node_count,
("added" if initial_node_count < final_node_count else
"removed"), abs(initial_node_count - final_node_count)
)

return dag

def transform_loopy_program(self, t_unit):
Expand All @@ -1660,6 +1800,7 @@ def transform_loopy_program(self, t_unit):
# from loopy.transform.instruction import simplify_indices
# t_unit = simplify_indices(t_unit)

knl = t_unit.default_entrypoint

logger.info(f"Transforming kernel '{knl.name}' with {len(knl.instructions)} statements.")

Expand Down Expand Up @@ -1820,7 +1961,8 @@ def transform_loopy_program(self, t_unit):
t_unit = t_unit.with_kernel(knl)
del knl

if False and t_unit.default_entrypoint.tags_of_type(FromArrayContextCompile):
if False and t_unit.default_entrypoint.tags_of_type(
FromArrayContextCompile):
# FIXME: Enable this branch, WIP for now and hence disabled it.
from loopy.match import ObjTagged
import feinsum as fnsm
Expand Down Expand Up @@ -1863,17 +2005,40 @@ def transform_loopy_program(self, t_unit):
knl = t_unit.default_entrypoint
for iel, idofs in sorted(iel_to_idofs.items()):
if idofs:
nunit_dofs = {knl.get_constant_iname_length(idof)
for idof in idofs}
idof, = idofs
if len(idofs) == 1:
nunit_dofs = {
knl.get_constant_iname_length(idof)
for idof in idofs
}
l_one, l_zero = _get_group_size_for_dof_array_loop(
nunit_dofs)

idof, = idofs

knl = lp.split_iname(knl, iel, l_one,
inner_tag="l.1",
outer_tag="g.0")
knl = lp.split_iname(knl, idof, l_zero,
inner_tag="for",
outer_tag="l.0")

l_one_size, l_zero_size = _get_group_size_for_dof_array_loop(
nunit_dofs)
else:
def idof_tp_sort_key(idof):
tag, = knl.inames[idof].tags_of_type(
TensorProductDOFAxisTag
)
return tag.iaxis

inames_to_tags = {iel: "g.0"}
inames_to_tags.update({
idof: f"l.{i}"
for i, idof in enumerate(
sorted(idofs, key=idof_tp_sort_key)[:-1]
)
})

knl = lp.tag_inames(knl, inames_to_tags)

knl = lp.split_iname(knl, iel, l_one_size,
inner_tag="l.1", outer_tag="g.0")
knl = lp.split_iname(knl, idof, l_zero_size,
inner_tag="l.0", outer_tag="unr")
else:
knl = lp.split_iname(knl, iel, 32,
outer_tag="g.0", inner_tag="l.0")
Expand All @@ -1886,5 +2051,4 @@ def transform_loopy_program(self, t_unit):

return t_unit


# vim: foldmethod=marker
Loading