Skip to content
63 changes: 43 additions & 20 deletions pyop2/local_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Union

import coffee
from coffee.visitors import EstimateFlops
import loopy as lp
from loopy.tools import LoopyKeyBuilder
import numpy as np
Expand Down Expand Up @@ -152,28 +153,9 @@ def arguments(self):
for acc, dtype in zip(self.accesses, self.dtypes))

@cached_property
@abc.abstractmethod
def num_flops(self):
Comment thread
connorjward marked this conversation as resolved.
"""Compute the numbers of FLOPs if not already known."""
if self.flop_count is not None:
return self.flop_count

if not configuration["compute_kernel_flops"]:
return 0

if isinstance(self.code, coffee.base.Node):
v = coffee.visitors.EstimateFlops()
return v.visit(self.code)
elif isinstance(self.code, lp.TranslationUnit):
op_map = lp.get_op_map(
self.code.copy(options=lp.Options(ignore_boostable_into=True),
silenced_warnings=['insn_count_subgroups_upper_bound',
'get_x_map_guessing_subgroup_size',
'summing_if_branches_ops']),
subgroup_size='guess')
return op_map.filter_by(name=['add', 'sub', 'mul', 'div'],
dtype=[ScalarType]).eval_and_sum({})
else:
return 0

def __eq__(self, other):
if not isinstance(other, LocalKernel):
Expand Down Expand Up @@ -214,6 +196,13 @@ def dtypes(self):
def dtypes(self, dtypes):
self._dtypes = dtypes

@cached_property
def num_flops(self):
if self.flop_count is not None:
return self.flop_count
else:
return 0


class CoffeeLocalKernel(LocalKernel):
""":class:`LocalKernel` class where `code` has type :class:`coffee.base.Node`."""
Expand All @@ -231,6 +220,16 @@ def dtypes(self):
def dtypes(self, dtypes):
self._dtypes = dtypes

@cached_property
def num_flops(self):
if self.flop_count is not None:
return self.flop_count
elif not configuration["compute_kernel_flops"]:
Comment thread
sv2518 marked this conversation as resolved.
Outdated
return 0
else:
v = EstimateFlops()
return v.visit(self.code)


class LoopyLocalKernel(LocalKernel):
""":class:`LocalKernel` class where `code` has type :class:`loopy.LoopKernel`
Expand All @@ -250,3 +249,27 @@ def _loopy_arguments(self):
"""Return the loopy arguments associated with the kernel."""
return tuple(a for a in self.code.callables_table[self.name].subkernel.args
if isinstance(a, lp.ArrayArg))

@cached_property
def num_flops(self):
if self.flop_count is not None:
return self.flop_count
elif not configuration["compute_kernel_flops"]:
return 0
else:
if isinstance(self.code, lp.TranslationUnit):
prog = self.code.with_entrypoints(self.name)
knl = prog.default_entrypoint
warnings = list(knl.silenced_warnings)
warnings.extend(['insn_count_subgroups_upper_bound',
'get_x_map_guessing_subgroup_size',
'summing_if_branches_ops'])
knl = knl.copy(silenced_warnings=warnings,
options=lp.Options(ignore_boostable_into=True))
knl = lp.fix_parameters(knl, layer=1)
prog = prog.with_kernel(knl)
else:
prog = self.code
Comment thread
connorjward marked this conversation as resolved.
Outdated
op_map = lp.get_op_map(prog, subgroup_size=1)
return op_map.filter_by(name=['add', 'sub', 'mul', 'div'],
dtype=[ScalarType]).eval_and_sum({})