[1/5] autotune: harden cache key + add restore_value (#770)#783
Open
jhinpan wants to merge 3 commits into
Open
[1/5] autotune: harden cache key + add restore_value (#770)#783jhinpan wants to merge 3 commits into
jhinpan wants to merge 3 commits into
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR strengthens FlyDSL’s Python autotuner (python/flydsl/autotune.py) so cached tuned configs are not incorrectly reused across differing compilation/device/env contexts, and adds a restore_value mechanism to make benchmarking correct for in-place kernels. It also introduces GPU-free unit tests to validate serialization, key construction, and restore/reset semantics without requiring torch or compiled bindings.
Changes:
- Harden autotune cache keys by adding stride-pattern normalization, device arch fingerprint, toolchain fingerprint, and cache-invalidating env values.
- Add
restore_valuesnapshot/restore support and deferCompilationContextimport to keep autotuner core import-light. - Add a new
tests/unit/test_autotune.pysuite with GPU-free coverage for key axes, restore/reset behavior, pruning, and disk cache.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| python/flydsl/autotune.py | Extends cache key axes; adds env/toolchain/device fingerprinting, stride normalization, and restore_value snapshot/restore; defers compiler import for testability. |
| tests/unit/test_autotune.py | Adds GPU-free unit tests validating config round-trip, cache-key axes, restore/reset semantics, pruning, and disk cache behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+240
to
+249
| # Dtypes + normalized strides of tensor args for type/layout specialization | ||
| dtype_parts = [] | ||
| stride_parts = [] | ||
| for name, val in sig_args.items(): | ||
| if hasattr(val, "dtype"): | ||
| dtype_parts.append(f"{name}:{val.dtype}") | ||
| if hasattr(val, "shape") and hasattr(val, "stride"): | ||
| stride_parts.append(f"{name}:{_normalize_strides(val)}") | ||
| key_vals.append(tuple(dtype_parts)) | ||
| key_vals.append(tuple(stride_parts)) |
Comment on lines
+325
to
+331
| try: | ||
| return self._do_bench(kernel_call, warmup=self.warmup, rep=self.rep) | ||
| finally: | ||
| # Leave the caller's tensors as the kernel would have left them on a | ||
| # single clean run: restore inputs, then run once more. | ||
| if snapshot: | ||
| self._restore_tensors(snapshot) |
This was referenced Jul 1, 2026
FlyDSL's autotuner exists but nothing uses it, and two gaps block real
adoption. This is the first of a series making it a correct, adopted path.
Cache key (_make_key) previously specialized on shape/dtype only. A config
tuned under one compiler build, GPU arch, or memory layout would be silently
reused under another. Fold in the axes Triton/quack rely on:
- normalized stride pattern ({0,1,other}: broadcast vs contiguous vs strided)
- device arch (get_rocm_arch)
- toolchain fingerprint (reuses jit_function._flydsl_key)
- cache-invalidating env vars (reuses _cache_invalidating_env_values)
The dtype/stride axes are sorted by arg name so a call is keyed identically
regardless of kwarg order (no duplicate tuning / cache files).
restore_value (new) is the correctness soul of autotune: benchmarking runs
the same kernel dozens of times, so an in-place / accumulating kernel (e.g.
fused-add rmsnorm) corrupts its own inputs and picks a config on garbage.
Snapshot the named tensors once and restore before every rep.
reset_to_zero is now also re-applied on the real (non-benchmark) call — both
the post-tune run and cache hits — via a shared _run_config, so an
accumulate-into-zero kernel returns the single-clean-run result instead of
carrying benchmark-rep state. (Was applied only inside the bench loop.)
Also defer the CompilationContext import so the autotuner core stays
importable and unit-testable without the compiled flydsl._mlir bindings.
Adds tests/unit/test_autotune.py: 19 GPU-free tests covering Config
serialization, every cache-key axis (incl. env-fingerprint change and
kwarg-order insensitivity), restore_value/reset_to_zero semantics (incl. the
final-run and cache-hit reset), pruning, and disk-cache round-trip.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
85e4161 to
6f3351f
Compare
58e9931 to
f8e5bc9
Compare
Comment-only cleanup of the PR1 additions: keep the one key fact per helper, drop the Triton/quack background, redundant restatements, and by-example prose. No logic change; 19 unit tests still pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
f8e5bc9 to
bff0d76
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
First of a planned series making FlyDSL's autotuner (
python/flydsl/autotune.py) a correct, adopted tuning path per #770. Today it is well-built but unused, with two gaps that must be closed before any kernel can safely adopt it.What this PR does
Harden the cache key (
_make_key). It previously specialized on shape/dtype only, so a config tuned under one compiler build / GPU arch / memory layout would be silently reused under another. This folds in the axes Triton and quack rely on:{0, 1, other}— broadcast vs contiguous vs strided; exact numbers don't matter, the pattern does)get_rocm_arch)jit_function._flydsl_key()— hashes compiler source, native libs, version)_cache_invalidating_env_values())Add
restore_value— the correctness soul of autotune. Benchmarking runs the same kernel dozens of times; an in-place / accumulating kernel (e.g. fused-add rmsnorm, where output overlaps the residual/input buffers) corrupts its own inputs across reps and picks a config on garbage. We snapshot the named tensors once and restore before every rep. Exposed on@autotunealongside the existingreset_to_zero.Keep the core import-light — defer the
CompilationContextimport so the autotuner core (Config, key, restore) stays importable and unit-testable without the compiledflydsl._mlirbindings.Tests
Adds
tests/unit/test_autotune.py— 16 GPU-free unit tests (no torch, no compiled bindings) coveringConfigserialization, every cache-key axis,restore_value/reset_to_zerosemantics, config pruning, and disk-cache round-trip.ruff + black clean.
Series roadmap (#770)
restore_value+ GPU-free unit testsget_defaultheuristic + exhaustive) + first real adopter (rmsnorm / fused-add rmsnorm)Refs #770. No behavior change for existing code — the autotuner has no current callers.
🤖 Generated with Claude Code