Skip to content
60 changes: 39 additions & 21 deletions pyop2/local_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Union

import coffee
from coffee.visitors import EstimateFlops
import loopy as lp
from loopy.tools import LoopyKeyBuilder
import numpy as np

from pyop2 import version
from pyop2.configuration import configuration
from pyop2.datatypes import ScalarType
from pyop2.exceptions import NameTypeError
from pyop2.types import Access
Expand Down Expand Up @@ -152,28 +152,9 @@ def arguments(self):
for acc, dtype in zip(self.accesses, self.dtypes))

@cached_property
@abc.abstractmethod
def num_flops(self):
Comment thread
connorjward marked this conversation as resolved.
"""Compute the numbers of FLOPs if not already known."""
if self.flop_count is not None:
return self.flop_count

if not configuration["compute_kernel_flops"]:
return 0

if isinstance(self.code, coffee.base.Node):
v = coffee.visitors.EstimateFlops()
return v.visit(self.code)
elif isinstance(self.code, lp.TranslationUnit):
op_map = lp.get_op_map(
self.code.copy(options=lp.Options(ignore_boostable_into=True),
silenced_warnings=['insn_count_subgroups_upper_bound',
'get_x_map_guessing_subgroup_size',
'summing_if_branches_ops']),
subgroup_size='guess')
return op_map.filter_by(name=['add', 'sub', 'mul', 'div'],
dtype=[ScalarType]).eval_and_sum({})
else:
return 0

def __eq__(self, other):
if not isinstance(other, LocalKernel):
Expand Down Expand Up @@ -214,6 +195,13 @@ def dtypes(self):
def dtypes(self, dtypes):
self._dtypes = dtypes

@cached_property
def num_flops(self):
if self.flop_count is not None:
return self.flop_count
else:
return 0


class CoffeeLocalKernel(LocalKernel):
""":class:`LocalKernel` class where `code` has type :class:`coffee.base.Node`."""
Expand All @@ -231,6 +219,14 @@ def dtypes(self):
def dtypes(self, dtypes):
self._dtypes = dtypes

@cached_property
def num_flops(self):
if self.flop_count is not None:
return self.flop_count
else:
v = EstimateFlops()
return v.visit(self.code)


class LoopyLocalKernel(LocalKernel):
""":class:`LocalKernel` class where `code` has type :class:`loopy.LoopKernel`
Expand All @@ -250,3 +246,25 @@ def _loopy_arguments(self):
"""Return the loopy arguments associated with the kernel."""
return tuple(a for a in self.code.callables_table[self.name].subkernel.args
if isinstance(a, lp.ArrayArg))

@cached_property
def num_flops(self):
if self.flop_count is not None:
return self.flop_count
else:
if isinstance(self.code, lp.TranslationUnit):
prog = self.code.with_entrypoints(self.name)
knl = prog.default_entrypoint
warnings = list(knl.silenced_warnings)
warnings.extend(['insn_count_subgroups_upper_bound',
'get_x_map_guessing_subgroup_size',
'summing_if_branches_ops'])
knl = knl.copy(silenced_warnings=warnings,
options=lp.Options(ignore_boostable_into=True))
knl = lp.fix_parameters(knl, layer=1)
prog = prog.with_kernel(knl)
else:
prog = self.code
Comment thread
connorjward marked this conversation as resolved.
Outdated
op_map = lp.get_op_map(prog, subgroup_size=1)
return op_map.filter_by(name=['add', 'sub', 'mul', 'div'],
dtype=[ScalarType]).eval_and_sum({})
3 changes: 2 additions & 1 deletion pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def _compute(self, part):
:arg part: The :class:`SetPartition` to compute over.
"""
with self._compute_event():
PETSc.Log.logFlops(part.size*self.num_flops)
if configuration["compute_kernel_flops"]:
PETSc.Log.logFlops(part.size*self.num_flops)
self.global_kernel(self.comm, part.offset, part.offset+part.size, *self.arglist)

@cached_property
Expand Down