Skip to content

feat(cache): unify cache infrastructure + fix AOT cache collision#161

Open
whjthu wants to merge 6 commits into
masterfrom
cache-unify
Open

feat(cache): unify cache infrastructure + fix AOT cache collision#161
whjthu wants to merge 6 commits into
masterfrom
cache-unify

Conversation

@whjthu

@whjthu whjthu commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

背景

ninetoothed 此前有 4 套独立 cache, 行为不一致, 各自维护:

  • make.py 的 JIT handle 用 in-process dict (无跨进程, 无序列化)
  • auto_tuner.py 的 timings 用 pathlib 直写 ~/.ninetoothed/
  • test_pad._kernel_cache / test_max_pool2d.max_pool2d_kernels 各自手写 dict
  • generation.cache_source 的 Triton .py 文件 cache 用 sha256(source) 当 key

调研下来发现 cache_source 这个 key 算式有 collision bug: 同一 arrangement 在不同 block_size config 下产生一模一样的 Triton source 文本 (block_size 只流到 kernel name 后缀, 不流到 source body), sha256 出同一个 digest, 多个子进程并发 race-write 同一个 .py 文件, 结果 triton.tools.compile 用 loser's --kernel-name 找 winner's def 找不到, 报 AttributeError: module '...' has no attribute 'conv2d_...'

这个 bug 之前没被报告是因为可能是因为 .so 缓存已经在 ~/.ninetoothed/ 里, build._load_cached 命中, 根本不走 aot 路径, collision 不触发。新 clone 或 PR CI 是 cold cache, 才会踩到。

本 PR 包含

新增: 统一 cache 基础设施 (src/ninetoothed/_cache.py)

  • Cache class: L1 (内存 dict, FIFO eviction) + L2 (filesystem, 可选, atomic write)
  • 线程安全 (threading.Lock + tempfile rename)
  • Content-sensitive key (inspect.getsource() + sha256, functools.partial kwargs 参与)
  • hash_function_source() / hash_tensor_signature() / project_files_fingerprint() 等辅助
  • functools.partial 一层 unwrap; fallback path 用 sha256(module.qualname + repr(partial_args)), 不依赖不稳定的 id()

迁移: JIT handle cache (src/ninetoothed/make.py)

  • 新增 _HANDLE_CACHE (memory-only, 256-entry FIFO)
  • 调用点: caller == "torch" path; AOT path 不走 cache
  • 修改 _build_cache_key 让 tensors tuple 容忍非 Tensor 元素 (slices, ints, bools), 走 repr() hash

迁移: auto_tuner cache (src/ninetoothed/auto_tuner.py)

  • 从 pathlib 直写迁移到 Cache(cache_dir=CACHE_DIR/"auto_tuning"/<project_key>triton)
  • 用 _cache.project_files_fingerprint() 隔离不同版本/项目的 cache

修复 AOT cache collision (src/ninetoothed/generation.py)

  • cache_source(source) -> cache_source(source, kernel_name)
  • sha256 digest 把 kernel_name 也混进去, collision 结构性不可能
  • debugging.py 调用点显式传 "debug" 占位符

测试清理

  • test_pad.py 删 _kernel_cache 手动 dict
  • test_max_pool2d.py 删 max_pool2d_kernels 手动 dict, 改写法 2: ceil_mode 作为 tensors tuple 第三元素 (与 test_pad pattern 一致), 不再用 functools.partial
  • test_auto_tuner.py::test_auto_tuner_persists_across_instances 加 skipif(not torch.cuda.is_available()), 跟同文件 test_auto_tuner 守护一致

新增测试

  • tests/test_cache.py: 21 unit tests (Cache class 行为) + 3 integration tests (make() cache 区分 ceil_mode, 含同 shape + 不同 ceil_mode 的关键回归)
  • tests/test_make_cache_key.py: 22 tests (_build_cache_key + hash_function_source 各种 case, 含 partial 区分)
  • tests/test_atomic_write.py: 6 tests (atomic write 行为)

验证

GPU env:

  • pytest tests/test_cache.py: 24/24 pass
  • pytest tests/test_atomic_write.py: 6/6 pass
  • pytest tests/test_make_cache_key.py: 21 pass + 1 xfail
  • pytest tests/test_auto_tuner.py: 5/5 pass
  • pytest tests/test_pad.py: 48/48 pass
  • pytest tests/test_max_pool2d.py: 2/2 pass
  • pytest tests/test_aot.py: 11 pass + 1 skip, 0 fail (本 PR 修前 2 fail)
  • pytest tests/ (全量): 125 pass + 1 xfail + 0 fail
  • ruff check .: All checks passed

CPU env (CPU-only, 无 CUDA driver):

  • pytest tests/test_cache.py: 21 pass + 3 skip
  • pytest tests/test_atomic_write.py: 6/6 pass
  • pytest tests/test_make_cache_key.py: 21 pass + 1 xfail
  • pytest tests/test_auto_tuner.py: 4 skip + 1 skip (本 PR 新加)
  • pytest tests/test_pad.py: 48 skip (GPU-only)
  • pytest tests/test_max_pool2d.py: 2 skip (GPU-only)
  • pytest tests/ (全量): 75 pass + 193 skip + 1 xfail + 0 fail

向后兼容 / 迁移注意

  • ~/.ninetoothed/.py 全部失效: 本 PR 修改了 cache_source digest 算法, 旧 .py 文件 (约 1500 个) 不会被新算法命中。一次性重生成成本, 不会再发生。
  • auto_tuning/<project_key>_*/ 旧版本 cache: auto_tuner 已迁移到 Cache class, _disk_key 算法变了, 旧 timings 不会命中。首次跑会重新 benchmark, 之后正常 cache hit。
  • _HANDLE_CACHE 是 memory-only (handle 不可序列化), 不影响跨进程 JIT 行为, 仍然走 aot 路径的 disk cache。

Checklist

  • 现有测试全部通过 (GPU env: 125 pass + 1 xfail + 0 fail, CPU env: 75 pass + 193 skip + 1 xfail + 0 fail)
  • ruff 0 错误
  • 新增 50+ 测试覆盖 cache 行为 (unit + integration)
  • 修复 AOT cache key collision (修前 2 fail -> 修后 0 fail)
  • 不引入对外部依赖 (只用 stdlib + 已有的 hashlib, pathlib, threading)
  • thread-safe (Cache class 走 threading.Lock + atomic rename)

pytest output:

============================= test session starts ==============================
platform linux -- Python 3.12.3, pytest-9.0.3, pluggy-1.6.0
rootdir: /home/haojie/minimax-work/ninetoothed
configfile: pyproject.toml                                                                       
plugins: xdist-3.8.0, anyio-4.13.0, cov-7.1.0
collected 269 items

tests/test_add.py .                                                      [  0%]                                                                                                                   
tests/test_addmm.py ..                                                   [  1%]  
tests/test_aot.py .s..........                                           [  5%]
tests/test_aot_auto_tuning.py ....                                       [  7%]
tests/test_atomic_write.py ......                                        [  9%]
tests/test_attention.py ........                                         [ 12%]
tests/test_auto_tuner.py .....                                           [ 14%]
tests/test_cache.py ........................                             [ 23%]
tests/test_clone.py ....                                                 [ 24%]
tests/test_conv2d.py ....                                                [ 26%]
tests/test_data_ptr.py .                                                 [ 26%]
tests/test_debugging.py .                                                [ 26%]
tests/test_dropout.py .                                                  [ 27%]
tests/test_eval.py ........                                              [ 30%]
tests/test_expand.py .                                                   [ 30%]
tests/test_generation.py ............................................... [ 47%]
.............................                                            [ 58%]
tests/test_getitem.py ..........                                         [ 62%]
tests/test_ipynb.py .                                                    [ 62%]
tests/test_jagged.py ................                                    [ 68%]
tests/test_make_cache_key.py ......x...............                      [ 76%]
tests/test_matmul.py ..                                                  [ 77%]
tests/test_max_pool2d.py ..                                              [ 78%]
tests/test_naming.py .......                                             [ 81%]
tests/test_pad.py ................................................       [ 98%]
tests/test_pow.py .                                                      [ 99%]
tests/test_softmax.py .                                                  [ 99%]
tests/test_unsqueeze.py .                                                [100%]

============ 267 passed, 1 skipped, 1 xfailed in 455.14s (0:07:35) =============

@whjthu whjthu requested a review from voltjia June 8, 2026 17:04
@whjthu

whjthu commented Jun 10, 2026

Copy link
Copy Markdown
Contributor Author

Update

PR 概览

本 PR 相比 master 引入了一套统一的 cache 基础设施,并将 ninetoothed 中已有的多个 cache 路径迁移到统一机制上。整体目标是让 cache 行为更加一致、更安全,并提升重复 JIT/AOT 工作流下的性能和稳定性。

相比 master 的主要修改

1. 新增统一 cache 基础设施

新增 src/ninetoothed/_cache.py,提供:

  • Cache:统一的 L1/L2 cache 抽象
    • L1:进程内 memory cache
    • L2:可选的 filesystem-backed cache
    • FIFO eviction
    • thread-safe 访问
    • atomic disk write
  • 稳定的 cache key 辅助函数:
    • hash_function_source
    • hash_tensor_signature
    • hash_value
    • project_files_fingerprint

2. 为 ninetoothed.make 增加 JIT handle cache

ninetoothed.make(..., caller="torch") 现在会使用进程内 L1 cache 缓存 JIT handle。

cache key 覆盖:

  • arrangement function
  • application function
  • tensor structural signature
  • tensors 中的非 Tensor 参数
  • caller
  • kernel name
  • num_warps / num_stages / max_num_configs 等编译参数

因此,重复构建同一个 JIT kernel 时,可以直接复用已有 handle,避免重复走 symbolic arrangement 和 JIT handle 构建流程。

3. 改进 cache key 正确性

相比原有实现,本 PR 让 cache key 更完整、更稳健。

function hash 现在会考虑:

  • nested functools.partial
  • bound positional / keyword arguments
  • closure values
  • 被引用的 helper function
  • source code 不可用时的 fallback hash

tensor signature 现在会考虑:

  • tensor ndim
  • tensor shape
  • dtype
  • jagged dimension
  • padding / other value
  • constexpr flag
  • constexpr value

这可以减少错误 miss,也可以避免不安全的错误 hit。

4. 迁移 auto-tuning cache

auto_tuner.py 从原来的手写 filesystem cache 迁移到统一 Cache 机制。

auto-tuning cache 现在会结合 project fingerprint 和 Triton version 做隔离,降低跨项目或旧版本 cache 被错误复用的风险。

5. 修复 AOT source cache collision

generation.cache_source 现在会把 kernel_name 纳入 cache digest。

这修复了一个 AOT cold-cache 场景下的 collision 问题:不同 kernel 可能生成相同 source text,但需要不同 kernel symbol。此前在 fresh clone 或 CI cold cache 中,可能会加载到错误的 cached .py 文件,导致 Triton compile 阶段找不到对应 kernel symbol。

6. 移除测试中的重复手写 cache

部分测试中原有的手写 cache 已移除或改为走统一机制,包括:

  • test_pad.py
  • test_max_pool2d.py

这样测试路径和真实用户路径更加一致。

7. 增加回归测试

本 PR 新增和更新了多组测试,覆盖:

  • memory-only cache
  • disk-backed cache
  • FIFO eviction
  • L2 到 L1 的 promotion
  • thread-safe put/get
  • process-safe atomic writes
  • cache key 稳定性
  • nested partial hash
  • closure / global helper hash
  • tensor shape / dtype / constexpr 对 cache key 的影响
  • make() 对非 Tensor 参数的 cache key 区分
  • AOT cache collision 回归

新增能力

合入本 PR 后,ninetoothed 将拥有一套可复用的统一 cache 层。

这套 cache 能力包括:

  • 统一的 memory / disk cache 行为
  • content-sensitive cache key
  • 更安全的并发 disk write
  • 可复用的 function / tensor / project hash 工具
  • 更快的重复 JIT handle 构建
  • 更可靠的 AOT source cache
  • 为后续更多模块迁移到统一 cache 提供基础设施

性能对比

Benchmark 在 GPU 服务器上运行:

  • base:master
  • head:本 PR 分支
  • Python env:/home/haojie/codex-work/vevn
  • CUDA nvcc/usr/local/cuda/bin/nvcc

测试内容是重复调用同一个 ninetoothed.make(..., caller="torch"),衡量 JIT handle cache 能避免多少重复构建开销。

每个 case 运行 3 组,每组 100 次 make() 调用。下表统计首次调用之后 repeated make() 的中位耗时。

Case master 本 PR 提升
pad / slice-copy style kernel 7.087 ms 0.174 ms 约 40.7x
matmul-style kernel 45.628 ms 0.373 ms 约 122.2x

平均 repeated make() 耗时:

Case master mean 本 PR mean 提升
pad / slice-copy style kernel 7.153 ms 0.192 ms 约 37.3x
matmul-style kernel 48.438 ms 0.385 ms 约 125.8x

JIT handle 复用情况:

Case master handle 复用 本 PR handle 复用
pad / slice-copy style kernel 0 / 99 99 / 99
matmul-style kernel 0 / 99 99 / 99

性能结果解读

本 PR 的主要性能收益来自重复 JIT handle 构建路径,而不是已经构建完成后的 GPU kernel runtime。

这类收益会体现在:

  • 测试
  • benchmark
  • 交互式开发
  • 重复模型 / 算子初始化
  • 多次创建相同 shape / config 算子的场景

对于简单 kernel,repeated make() 开销从约 7 ms 降到约 0.17 ms

对于 matmul 这类更复杂的 symbolic arrangement,repeated make() 开销从约 45 ms 降到约 0.37 ms

验证情况

本 PR 已在服务器环境中按 workflow 等价命令完成验证。

Pytest

275 passed, 1 skipped in 1038.83s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant