Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
402 changes: 402 additions & 0 deletions src/ninetoothed/_cache.py

Large diffs are not rendered by default.

97 changes: 32 additions & 65 deletions src/ninetoothed/auto_tuner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import hashlib
import json
"""Auto-tuner for ninetoothed kernels.

Migrated to the unified Cache API (ninetoothed._cache.Cache). All timings
are stored as a single JSON per (project, triton-version) directory; the
prior per-func split-file layout is gone -- users with existing caches
should run `rm -rf ~/.ninetoothed/auto_tuning/`.
"""

import os

import triton

from ninetoothed._cache import Cache, project_files_fingerprint
from ninetoothed.aot import _KernelLaunchError
from ninetoothed.generation import CACHE_DIR

Expand All @@ -16,20 +23,23 @@ def __init__(self, funcs, keys):

self._func_to_key = {func: key for func, key in zip(self._funcs, self._keys)}

self._cache_dir = (
_AUTO_TUNING_CACHE_DIR
/ f"{_project_key()}_triton_{triton.__version__.replace('.', '_')}"
)
self._cache_dir.mkdir(parents=True, exist_ok=True)
# Disk layout: <CACHE_DIR>/auto_tuning/<project_key>_triton_<ver>/
# The project_key isolates caches across ninetoothed versions.
subdir = f"{_project_key()}_triton_{triton.__version__.replace('.', '_')}"
disk_dir = CACHE_DIR / "auto_tuning" / subdir

auto_tuner_key = tuple(self._keys)
cache_key = hashlib.sha256(str(auto_tuner_key).encode("utf-8")).hexdigest()
self._cache_path = self._cache_dir / f"{cache_key}.json"
self._cache = Cache(
cache_dir=disk_dir,
suffix=".json",
max_memory=64,
)

if self._cache_path.exists():
self._timings = json.loads(self._cache_path.read_text())
else:
self._timings = {key: {} for key in self._keys}
# The full timings dict is stored under a single sentinel key.
self._disk_key = ("_all_timings_",)
loaded = self._cache.get(self._disk_key, default={})
if not loaded:
loaded = {key: {} for key in self._keys}
self._timings = loaded

self._best_func = {}

Expand All @@ -54,9 +64,7 @@ def _get_timings(self, args, kwargs):
timings = [self._get_timing(func, args, kwargs) for func in self._funcs]

self._timings[arg_key] = timings

self._cache_path.write_text(json.dumps(self._timings))

self._save()
return timings

def _get_timing(self, func, args, kwargs):
Expand All @@ -67,31 +75,18 @@ def _get_timing(self, func, args, kwargs):
if (arg_key := type(self)._make_arg_key(args, kwargs)) in data:
return data[arg_key]

cache_path = self._get_func_cache_path(func)

if cache_path.exists():
data |= json.loads(cache_path.read_text())

if arg_key in data:
return data[arg_key]

try:
timing = triton.testing.do_bench(lambda: func(*args, **kwargs))
except _KernelLaunchError:
timing = float("inf")

data[arg_key] = timing

cache_path.write_text(json.dumps(data))

self._save()
return timing

def _get_func_cache_path(self, func):
func_key = self._func_to_key[func]
cache_key = hashlib.sha256(str(func_key).encode("utf-8")).hexdigest()
cache_path = self._cache_dir / f"{cache_key}.json"

return cache_path
def _save(self):
"""Persist the full timings dict (L1 + L2)."""
self._cache.put(self._disk_key, self._timings)

@staticmethod
def _make_arg_key(args, kwargs):
Expand All @@ -118,35 +113,7 @@ def _make_tensor_key(tensor):
return f"tensor(shape={tuple(tensor.shape)}, dtype={str(tensor.dtype).split('.')[-1]})"


_AUTO_TUNING_CACHE_DIR = CACHE_DIR / "auto_tuning"

_FILE_PATH = os.path.abspath(__file__)

_PARENT_DIR = os.path.dirname(_FILE_PATH)


def _project_key():
consolidated_hash = hashlib.sha256()

for dirpath, dirnames, filenames in os.walk(_PARENT_DIR):
dirnames.sort()
filenames.sort()

for filename in filenames:
file_path = os.path.join(dirpath, filename)

if (
not os.path.isfile(file_path)
or os.path.splitext(file_path)[1] == ".pyc"
):
continue

file_hash = _calculate_file_hash(file_path)
consolidated_hash.update(file_hash.encode("utf-8"))

return consolidated_hash.hexdigest()


def _calculate_file_hash(file_path):
with open(file_path, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
"""Fingerprint of the ninetoothed source tree, used to namespace caches
across ninetoothed installation versions."""
return project_files_fingerprint(os.path.dirname(os.path.abspath(__file__)))
2 changes: 1 addition & 1 deletion src/ninetoothed/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _arrangement(*tensors):

application_source = _generate_debug_application_source(tensors, debug_tensors)

source_file = str(cache_source(application_source))
source_file = str(cache_source(application_source, "_debug_"))

module = import_from_path(source_file, source_file)
module_vars = vars(module)
Expand Down
15 changes: 12 additions & 3 deletions src/ninetoothed/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _find_dependencies(func):
["ruff", "format", "-"], input=source, encoding="utf-8"
)

cache_file = cache_source(source)
cache_file = cache_source(source, kernel_name)

self.tensors = self._args
self.kernel_func = self._func_def
Expand Down Expand Up @@ -870,8 +870,17 @@ def visit_Call(self, node):
return node


def cache_source(source):
digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
def cache_source(source, kernel_name):
# Mix kernel_name into the digest so different kernels derived from
# the same source text (e.g. two block_size configs of the same
# arrangement) do not collide on a single .py file. Without this,
# concurrent AOT compilations can race-write the same cache file,
# leaving triton.tools.compile unable to find the named kernel.
hasher = hashlib.sha256()
hasher.update(source.encode("utf-8"))
hasher.update(b"\0")
hasher.update(kernel_name.encode("utf-8"))
digest = hasher.hexdigest()
cache_file = CACHE_DIR / f"{digest}.py"

if not cache_file.exists():
Expand Down
70 changes: 68 additions & 2 deletions src/ninetoothed/make.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,53 @@
"""Public entry point: ninetoothed.make(), with content-sensitive handle cache.

The handle cache (L1, in-process, FIFO) is keyed by a content hash of the
arrangement + application source code, tensor structural signatures, and
compilation parameters. Editing the user-facing functions invalidates the
cache; editing unrelated code does not.
"""

import inspect

from ninetoothed._cache import Cache, hash_function_source, hash_tensor_signature
from ninetoothed.aot import aot
from ninetoothed.jit import jit
from ninetoothed.tensor import Tensor


def _build_cache_key(
arrangement,
application,
tensors,
caller,
kernel_name,
num_warps,
num_stages,
max_num_configs,
):
def _hash_one(t):
# Tensor instances get content-sensitive structural hashing.
if isinstance(t, Tensor):
return hash_tensor_signature(t)
# Non-Tensor elements (slices, ints, lists, etc. used as
# arrangement() kwargs) are hashed via repr() so they
# correctly participate in the cache key.
return ("__raw__", repr(t))

return (
hash_function_source(arrangement),
hash_function_source(application),
tuple(_hash_one(t) for t in tensors),
caller,
kernel_name,
num_warps,
num_stages,
max_num_configs,
)


# Per-process L1 cache for JIT handles. Not shared across processes
# (handles are not serializable). 256-entry FIFO matches prior behavior.
_HANDLE_CACHE = Cache(max_memory=256)


def make(
Expand All @@ -24,27 +70,47 @@ def make(
:param kernel_name: The name for the generated kernel.
:param output_dir: The directory to store the generated files.
:param num_warps: The number of warps to use.
:param num_stages: The number of pipeline stages.
:param num_stages: The number of stages to use.
:param max_num_configs: The maximum number of auto-tuning
configurations to use.
:return: A handle to the compute kernel.
"""

# Cache only the JIT ("torch") path. The AOT path produces on-disk
# build artifacts (.so, .csv, .fingerprint) that are managed by
# build.py's own cache.
if caller == "torch":
key = _build_cache_key(
arrangement,
application,
tensors,
caller,
kernel_name,
num_warps,
num_stages,
max_num_configs,
)
cached = _HANDLE_CACHE.get(key)
if cached is not None:
return cached

params = inspect.signature(application).parameters
types = arrangement(*tensors)
types = types if isinstance(types, tuple) else (types,)
annotations = {param: type for param, type in zip(params, types)}
application.__annotations__ = annotations

if caller == "torch":
return jit(
handle = jit(
application,
caller=caller,
kernel_name=kernel_name,
num_warps=num_warps,
num_stages=num_stages,
max_num_configs=max_num_configs,
)
_HANDLE_CACHE.put(key, handle)
return handle

return aot(
application,
Expand Down
Loading
Loading