From cd6d43ff21f424391449c9294452967c1756152a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 Sep 2024 18:07:40 -0500 Subject: [PATCH] Fusion actx: cache transform_dag --- meshmode/array_context.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 94c181f0..5b727beb 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -1275,13 +1275,29 @@ def __init__( self.use_axis_tag_inference_fallback = use_axis_tag_inference_fallback self.use_einsum_inference_fallback = use_einsum_inference_fallback - self.transform_loopy_cache = WriteOncePersistentDict("meshmode-fusion_actx_transform_loopy_cache-v1", + self.transform_loopy_cache = \ + WriteOncePersistentDict("meshmode-fusion_actx_transform_loopy_cache-v1", + key_builder=PytatoKeyBuilder(), + safe_sync=False) + self.transform_dag_cache = \ + WriteOncePersistentDict("meshmode-fusion_actx_transform_dag_cache-v1", key_builder=PytatoKeyBuilder(), safe_sync=False) def transform_dag(self, dag): + from pyopencl.array import queue_for_pickling import pytato as pt + orig_dag = pt.transform.map_and_copy(dag, lambda x: x) + + try: + with queue_for_pickling(self.queue, self.allocator): + r = self.transform_dag_cache[dag] + logger.info(f"FusionContractorArrayContext.transform_dag: cache hit") + return r + except KeyError: + logger.debug(f"FusionContractorArrayContext.transform_dag: cache miss") + # {{{ Remove FEMEinsumTags that might have been propagated # TODO: Is this too hacky? @@ -1638,6 +1654,9 @@ def _untag_impl_stored(expr): # }}} + with queue_for_pickling(self.queue, self.allocator): + self.transform_dag_cache.store_if_not_present(orig_dag, dag) + return dag def transform_loopy_program(self, t_unit):