[Tile] Use unpacked vector field for Tile16x16/Tile32x32 register storage#722
Open
hughperkins wants to merge 5 commits into
Open
[Tile] Use unpacked vector field for Tile16x16/Tile32x32 register storage#722hughperkins wants to merge 5 commits into
hughperkins wants to merge 5 commits into
GitHub Actions / Coverage Report
succeeded
Jun 29, 2026 in 0s
Diff Coverage Report
See details below for per-line coverage annotations.
Details
Coverage Report (c70532649)
| Metric | Value |
|---|---|
| Diff coverage (changed lines only) | 58% |
| Overall project coverage | 74% |
Total: 276 lines, 115 missing, 58% covered
🟢 python/quadrants/__init__.py (100%)
🟢 59 from quadrants.lang.simt._tile import outer # noqa: I001 # pylint: disable=import-outside-toplevel
🟢 python/quadrants/lang/simt/__init__.py (100%)
6 from quadrants.lang.simt._tile import Tile16x16Proxy as Tile16x16
7 from quadrants.lang.simt._tile import Tile32x32Proxy as Tile32x32
🟢 13 if name in ("Tile16x16", "Tile32x32"):
🟢 14 from quadrants.lang.simt._tile import ( # pylint: disable=import-outside-toplevel
🟢 19 proxy = Tile16x16Proxy if name == "Tile16x16" else Tile32x32Proxy
🟢 20 globals()[name] = proxy
🟢 21 return proxy
🔴 python/quadrants/lang/simt/_tile.py (56%)
1 # pyright: reportInvalidTypeForm=false
2
3 """
4 Register-resident NxN tile operations.
5
6 Each tile is an NxN matrix distributed across N threads in a subgroup, one row per thread, with each row stored in N
7 scalar registers held in an unpacked vector field (``self.r``). Cross-thread communication uses subgroup shuffles --
8 no shared memory needed.
9
10 A single factory ``_make_tile_class(N, dtype)`` builds the tile dataclass for both supported tile sizes (N == 16 and
11 N == 32). The user-facing entry points are the proxies ``qd.simt.Tile16x16`` and ``qd.simt.Tile32x32``, which defer
12 dtype resolution to kernel compile time (defaulting to the runtime ``default_fp``).
13
14 The thread's lane index (tid) is obtained internally via ``subgroup.invocation_id()``, so callers never need to pass
15 it. See docs/source/user_guide/tile.md for usage documentation.
16 """
17
🟢 18 from typing import TYPE_CHECKING as _TYPE_CHECKING
🟢 19 from typing import Any, NoReturn
20
🟢 21 import quadrants as qd
22
🟢 23 if _TYPE_CHECKING:
24
🔴 25 class _TileProto: # noqa: E303
26 """Static type stub so pyright sees TileNxN methods correctly (shared by Tile16x16 and Tile32x32)."""
27
🔴 28 SIZE: int
29
30 def __init__(self, *args: Any, **kwargs: Any) -> None: ... # noqa: E704
31 @classmethod
32 def zeros(cls) -> "_TileProto": ... # noqa: E704
33 @classmethod
34 def eye(cls) -> "_TileProto": ... # noqa: E704
35 def eye_(self) -> None: ... # noqa: E704
36 def cholesky_(self, eps: Any) -> None: ... # noqa: E704
37 def solve_triangular_(self, B: "_TileProto", lower: bool = True) -> None: ... # noqa: E704
38 def _get_col(self, k: Any) -> Any: ... # noqa: E704
39 def _set_col(self, k: Any, val: Any) -> None: ... # noqa: E704
40 def _load(self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any) -> None: ... # noqa: E704
41 def _store(
42 self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
43 ) -> None: ... # noqa: E704
44 def _load3d(
45 self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
46 ) -> None: ... # noqa: E704
47 def _store3d(
48 self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
49 ) -> None: ... # noqa: E704
50 def _ger_sub(self, a: Any, b: Any) -> None: ... # noqa: E704
51 def _trsm(self, L: "_TileProto") -> None: ... # noqa: E704
52 def __isub__(self, other: Any) -> "_TileProto": ... # noqa: E704
53 def __getitem__(self, key: Any) -> Any: ... # noqa: E704
54 def __setitem__(self, key: Any, value: Any) -> None: ... # noqa: E704
55
56
🟢 57 class _OuterProduct:
58 """Deferred outer product proxy for use with augmented assignment on a Tile.
59
60 Created by qd.outer(a, b). Not a quadrants expression -- only valid as the RHS of ``tile -= qd.outer(a, b)``.
61 """
62
🟢 63 _qd_is_deferred = True
64
🟢 65 def __init__(self, a: Any, b: Any) -> None:
🟢 66 self.a = a
🟢 67 self.b = b
68
🟢 69 def __add__(self, other: Any) -> NoReturn:
🟢 70 raise TypeError("OuterProduct does not support composition; apply each update separately")
71
🟢 72 def __radd__(self, other: Any) -> NoReturn:
🔴 73 raise TypeError("OuterProduct does not support composition; apply each update separately")
74
75
🟢 76 def outer(a: Any, b: Any) -> _OuterProduct:
77 """Create a deferred outer product for use with Tile augmented assignment.
78
79 Usage::
80
81 t -= qd.outer(a, b) # equivalent to t._ger_sub(a, b)
82 t -= qd.outer(v, v) # symmetric case (a == b)
83 """
🟢 84 return _OuterProduct(a, b)
85
86
🟢 87 class _DeferredProxyMixin:
88 """Raises clear errors if a deferred tile proxy is accidentally used as a value."""
89
🟢 90 _proxy_description = "Tile proxy"
91
🟢 92 def _misuse(self, op: str = "used") -> NoReturn:
🟢 93 raise TypeError(
94 f"{self._proxy_description} was {op}, but it is only valid in tile operations (tile[:] = ..., ... = tile, qd.outer(...))"
95 )
96
🟢 97 def __add__(self, other: Any) -> NoReturn:
🟢 98 self._misuse("added")
99
🟢 100 def __radd__(self, other: Any) -> NoReturn:
🟢 101 self._misuse("added")
102
🟢 103 def __sub__(self, other: Any) -> NoReturn:
🟢 104 self._misuse("subtracted")
105
🟢 106 def __mul__(self, other: Any) -> NoReturn:
🟢 107 self._misuse("multiplied")
108
🟢 109 def __getitem__(self, key: Any) -> NoReturn:
🟢 110 self._misuse("subscripted")
111
🟢 112 def __repr__(self) -> str:
🟢 113 return f"<{self._proxy_description} — not a value; use with tile[:] = ... or qd.outer(...)>"
114
115
🟢 116 class _TileSliceProxy(_DeferredProxyMixin):
117 """Deferred 2D/3D array slice for tile load/store.
118
119 Created by subscripting a Field or ndarray with 2D slices, e.g. ``arr[row_start:row_stop, col_start:col_stop]``.
120 Not a quadrants expression -- only valid as the RHS of a tile assignment (load) or as the LHS target (store).
121 """
122
🟢 123 _qd_is_deferred = True
🟢 124 _proxy_description = "Array slice proxy (arr[r0:r1, c0:c1])"
125
🟢 126 def __init__(
127 self, arr: Any, row_start: Any, row_stop: Any, col_start: Any, col_stop: Any, batch_idx: Any = None
128 ) -> None:
🟢 129 self.arr = arr
🟢 130 self.row_start = row_start
🟢 131 self.row_stop = row_stop
🟢 132 self.col_start = col_start
🟢 133 self.col_stop = col_stop
🟢 134 self.batch_idx = batch_idx
135
🟢 136 def _assign(self, tile: Any) -> None:
137 """Store path: arr[r:r+n_rows, c:c+n_cols] = tile."""
🟢 138 if self.batch_idx is not None:
🟢 139 tile._store3d(self.arr, self.batch_idx, self.row_start, self.row_stop, self.col_start, self.col_stop)
140 else:
🟢 141 tile._store(self.arr, self.row_start, self.row_stop, self.col_start, self.col_stop)
142
143
🟢 144 class _VecSliceProxy(_DeferredProxyMixin):
145 """Deferred column-vector load from a 2D/3D array.
146
147 Created by ``arr[row_start:row_stop, col]`` or ``arr[batch_idx, row_start:row_stop, col]``.
148 Each subgroup thread loads one element; out-of-range threads get 0.
149 Only valid as an argument to ``qd.outer()`` in tile augmented assignment.
150 """
151
🟢 152 _qd_is_deferred = True
🟢 153 _proxy_description = "Vec slice proxy (arr[r0:r1, col])"
154
🟢 155 def __init__(self, arr: Any, row_start: Any, row_stop: Any, col: Any, batch_idx: Any = None) -> None:
🟢 156 self.arr = arr
🟢 157 self.row_start = row_start
🟢 158 self.row_stop = row_stop
🟢 159 self.col = col
🟢 160 self.batch_idx = batch_idx
161
162
🟢 163 class _TileRefProxy:
164 """Proxy returned by tile[:] for the LHS of a load assignment.
165
166 Enables ``tile[:] = arr[r:r+N, c:n]``. The ``[:]`` is required to distinguish in-place tile loads from
167 variable rebinding.
168 """
169
🟢 170 _qd_is_deferred = True
171
🟢 172 def __init__(self, tile: Any) -> None:
🟢 173 self.tile = tile
174
🟢 175 def _assign(self, value: Any) -> None:
176 """Load path: tile[:] = arr[r:r+n, c:c+n]. Dispatches to _load or _load3d."""
🟢 177 if isinstance(value, _TileSliceProxy):
🟢 178 if value.batch_idx is not None:
🟢 179 self.tile._load3d(
180 value.arr, value.batch_idx, value.row_start, value.row_stop, value.col_start, value.col_stop
181 )
182 else:
🟢 183 self.tile._load(value.arr, value.row_start, value.row_stop, value.col_start, value.col_stop)
184 else:
🔴 185 raise TypeError(f"Tile[:] can only be assigned from an array slice, got {type(value)}")
186
187
🟢 188 _tile_cache: dict = {}
189
190
🟢 191 def _make_tile(N: int, dtype=None) -> "type[_TileProto]":
192 """Create a TileNxN dataclass whose registers use the given scalar dtype (qd.f32 or qd.f64).
193
194 This is an internal factory. Use ``qd.simt.Tile16x16`` / ``qd.simt.Tile32x32`` (the proxies) instead.
195 """
🟢 196 if dtype is None:
🔴 197 dtype = qd.f32
🟢 198 key = (N, dtype)
🟢 199 if key in _tile_cache:
🟢 200 return _tile_cache[key] # pyright: ignore[reportReturnType]
🟢 201 cls = _make_tile_class(N, dtype)
🟢 202 _tile_cache[key] = cls
🟢 203 return cls # pyright: ignore[reportReturnType]
204
205
🟢 206 def _make_tile_class(N: int, dtype):
🟢 207 name = f"Tile{N}x{N}"
208
🟢 209 class _Tile:
210 """An NxN tile distributed one row per subgroup thread, with each row held in N scalar registers via an
211 unpacked vector field. ``TileNxN()`` creates a zero tile."""
212
🟢 213 r: qd.types.vector(N, dtype, unpacked=True)
214
🟢 215 @qd.func
🟢 216 def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop):
217 """Load from a 2D array within [row_start, row_stop) x [col_start, col_stop).
218
219 Each thread loads arr[row_start + tid, col_start:col_stop]. Threads where row_start + tid >= row_stop
220 skip the load (tile row unchanged).
221 """
🔴 222 arr_row_stop = arr.shape[0]
🔴 223 if arr_row_stop < row_stop:
🔴 224 row_stop = arr_row_stop
🔴 225 row = row_start + qd.simt.subgroup.invocation_id()
🔴 226 if row < row_stop:
🔴 227 arr_col_stop = arr.shape[1]
🔴 228 if arr_col_stop < col_stop:
🔴 229 col_stop = arr_col_stop
🔴 230 for j in qd.static(range(N)):
🔴 231 if col_start + j < col_stop:
🔴 232 self.r[j] = arr[row, col_start + j]
233
🟢 234 @qd.func
🟢 235 def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop):
236 """Load from a 3D array within [row_start, row_stop) x [col_start, col_stop).
237
238 Each thread loads arr[batch, row_start+tid, col_start:col_stop]. Threads where row_start + tid >=
239 row_stop skip the load (tile row unchanged).
240 """
🔴 241 arr_row_stop = arr.shape[1]
🔴 242 if arr_row_stop < row_stop:
🔴 243 row_stop = arr_row_stop
🔴 244 row = row_start + qd.simt.subgroup.invocation_id()
🔴 245 if row < row_stop:
🔴 246 arr_col_stop = arr.shape[2]
🔴 247 if arr_col_stop < col_stop:
🔴 248 col_stop = arr_col_stop
🔴 249 for j in qd.static(range(N)):
🔴 250 if col_start + j < col_stop:
🔴 251 self.r[j] = arr[batch, row, col_start + j]
252
🟢 253 @qd.func
🟢 254 def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop):
255 """Store to a 2D array within [row_start, row_stop) x [col_start, col_stop).
256
257 Each thread stores to arr[row_start + tid, col_start:col_stop]. Threads where row_start + tid >=
258 row_stop skip the store.
259 """
🔴 260 arr_row_stop = arr.shape[0]
🔴 261 if arr_row_stop < row_stop:
🔴 262 row_stop = arr_row_stop
🔴 263 row = row_start + qd.simt.subgroup.invocation_id()
🔴 264 if row < row_stop:
🔴 265 arr_col_stop = arr.shape[1]
🔴 266 if arr_col_stop < col_stop:
🔴 267 col_stop = arr_col_stop
🔴 268 for j in qd.static(range(N)):
🔴 269 if col_start + j < col_stop:
🔴 270 arr[row, col_start + j] = self.r[j]
271
🟢 272 @qd.func
🟢 273 def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop):
274 """Store to a 3D array within [row_start, row_stop) x [col_start, col_stop).
275
276 Each thread stores to arr[batch, row_start+tid, col_start:col_stop]. Threads where row_start + tid >=
277 row_stop skip the store.
278 """
🔴 279 arr_row_stop = arr.shape[1]
🔴 280 if arr_row_stop < row_stop:
🔴 281 row_stop = arr_row_stop
🔴 282 row = row_start + qd.simt.subgroup.invocation_id()
🔴 283 if row < row_stop:
🔴 284 arr_col_stop = arr.shape[2]
🔴 285 if arr_col_stop < col_stop:
🔴 286 col_stop = arr_col_stop
🔴 287 for j in qd.static(range(N)):
🔴 288 if col_start + j < col_stop:
🔴 289 arr[batch, row, col_start + j] = self.r[j]
290
🟢 291 @qd.func
🟢 292 def eye_(self):
293 """Set this tile to the NxN identity matrix. Each thread sets its diagonal element to 1.0 and all
294 others to 0.0."""
🔴 295 tid = qd.simt.subgroup.invocation_id()
🔴 296 for j in qd.static(range(N)):
🔴 297 self.r[j] = 1.0 if tid == j else 0.0
298
🟢 299 @qd.func
🟢 300 def _ger_sub(self, a, b):
301 """General rank-1 subtract in-place: self -= a @ b^T."""
🔴 302 for j in qd.static(range(N)):
🔴 303 bc = qd.simt.subgroup.shuffle(b, qd.u32(j))
🔴 304 self.r[j] = self.r[j] - a * bc
305
🟢 306 @qd.func
🟢 307 def cholesky_(self, eps):
308 """In-place NxN Cholesky factorization via subgroup shuffles.
309
310 On return, the lower triangle holds L such that A = L @ L^T. Diagonal clamped to
311 sqrt(max(value, eps)) for numerical stability.
312 """
313 # ``k`` and ``j`` are wrapped in qd.static so the ``if k > j`` predicate folds at compile time and the
314 # ``self.r[k]`` / ``self.r[j]`` accesses resolve to a single unpacked-register slot per use (no runtime
315 # cascade). The per-lane row-norm used for the diagonal update is carried in ``my_norm_sq``, so each
316 # diagonal step is O(1) rather than O(k). The off-diagonal ``dot`` is split into two interleaved partial
317 # sums (``dot0`` / ``dot1``) so the back-to-back FMA dependency chain is cut in half, exposing more
318 # instruction-level parallelism.
🔴 319 tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴 320 my_norm_sq = qd.cast(0.0, dtype)
🔴 321 for k in qd.static(range(N)):
🔴 322 diag_val = qd.cast(0.0, dtype)
🔴 323 if tid == k:
🔴 324 diag_val = qd.sqrt(qd.max(self.r[k] - my_norm_sq, eps))
🔴 325 self.r[k] = diag_val
326
🔴 327 diag_k = qd.simt.subgroup.shuffle(diag_val, qd.u32(k))
328
🔴 329 dot0 = qd.cast(0.0, dtype)
🔴 330 dot1 = qd.cast(0.0, dtype)
🔴 331 for j in qd.static(range(N)):
🔴 332 if k > j:
🔴 333 my_col = self.r[j]
🔴 334 Lkj = qd.simt.subgroup.shuffle(my_col, qd.u32(k))
🔴 335 if j % 2 == 0:
🔴 336 dot0 += Lkj * my_col # type: ignore[reportOperatorIssue]
337 else:
🔴 338 dot1 += Lkj * my_col # type: ignore[reportOperatorIssue]
🔴 339 dot = dot0 + dot1
340
🔴 341 new_val = qd.cast(0.0, dtype)
🔴 342 if tid > k: # type: ignore[reportOperatorIssue]
🔴 343 new_val = (self.r[k] - dot) / diag_k # type: ignore[reportOperatorIssue]
🔴 344 self.r[k] = new_val
🔴 345 if tid > k: # type: ignore[reportOperatorIssue]
🔴 346 my_norm_sq += new_val * new_val
347
🟢 348 @qd.func
🟢 349 def _get_col(self, k):
350 """Read register column ``k`` at runtime via a static-unrolled cascade.
351
352 The unpacked vector field rejects runtime indices, so the cascade is emitted explicitly. With ``k`` a
353 runtime int and ``kk`` a python-int from ``qd.static``, the body of each iteration becomes a guarded
354 single-slot read; LLVM later selects on ``k`` to pick the matching slot. Used by ``_trsm`` so the outer
355 loop can be a runtime ``range(N)`` (LLVM picks the unroll factor) rather than the fully-unrolled
356 ``qd.static(range(N))`` that spikes register pressure.
357 """
🔴 358 val = qd.cast(0.0, dtype)
🔴 359 for kk in qd.static(range(N)):
🔴 360 if k == kk:
🔴 361 val = self.r[kk]
🔴 362 return val
363
🟢 364 @qd.func
🟢 365 def _set_col(self, k, val):
366 """Write register column ``k`` at runtime via a static-unrolled cascade. See ``_get_col`` for rationale."""
🔴 367 for kk in qd.static(range(N)):
🔴 368 if k == kk:
🔴 369 self.r[kk] = val
370
🟢 371 @qd.func
🟢 372 def _trsm(self, L):
373 """In-place triangular solve: solve self @ L^T = B (original self).
374
375 L is a TileNxN holding the lower-triangular Cholesky factor (from cholesky_). On return, self holds the
376 solution X.
377
378 The outer loop uses ``range(N)`` (runtime), not ``qd.static(range(N))``, so LLVM can pick the unroll
379 factor: fully unrolling the N*N body fully explodes the live set and pushes ~37% more registers into
380 the kernel, causing measurable perf loss on the blocked Cholesky benchmark (e.g. ~9% slower on
381 ``misc/demos/cholesky_blocked.py`` for N=92). The inner ``j`` loop is also ``range(N)`` for the same
382 reason. Runtime access into the unpacked-vector field goes through ``_get_col`` / ``_set_col`` which
383 emit explicit cascades over ``self.r[kk]`` for static ``kk``.
384 """
🔴 385 for c in range(N):
🔴 386 dot = qd.cast(0.0, dtype)
🔴 387 for j in range(N):
🔴 388 if c > j:
🔴 389 Lkj = qd.simt.subgroup.shuffle(L._get_col(j), qd.u32(c))
🔴 390 dot += self._get_col(j) * Lkj # type: ignore[reportOperatorIssue]
391
🔴 392 diag_c = qd.simt.subgroup.shuffle(L._get_col(c), qd.u32(c))
🔴 393 new_val = (self._get_col(c) - dot) / diag_c # type: ignore[reportOperatorIssue]
🔴 394 self._set_col(c, new_val)
395
🟢 396 def solve_triangular_(self, B: Any, lower: bool = True) -> None:
397 """Triangular solve: X @ self^T = B, storing result X in B in-place.
398
399 self must be lower-triangular and non-singular (all diagonal elements non-zero). Passing a singular
400 matrix causes division by zero, producing inf/NaN without warning. Only lower=True is supported.
401 """
🟢 402 if not lower:
🟢 403 raise TypeError(f"{name}.solve_triangular_: only lower=True is supported")
🟢 404 B._trsm(self)
405
🟢 406 @qd.func
🟢 407 def _resolve_vec2d(self, arr: qd.template(), row_start, row_stop, col):
408 """Load one scalar per thread from a 2D array column, clamped to array bounds."""
🔴 409 tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴 410 arr_row_stop = arr.shape[0]
🔴 411 if arr_row_stop < row_stop:
🔴 412 row_stop = arr_row_stop
413 # Use qd.cast, not dtype(0.0): the AST transformer only treats a call as a type construction when
414 # id(dtype) is in primitive_types.type_ids, but a dtype resolved from a deep-copied default_fp (e.g.
415 # after qd.init(default_fp=qd.f32)) has a different id and falls through to a raw call, raising
416 # "Quadrants data types cannot be called outside Quadrants kernels". qd.cast is identity-independent
417 # and folds to the same typed constant.
🔴 418 v = qd.cast(0.0, dtype)
🔴 419 if row_start + tid < row_stop:
🔴 420 v = arr[row_start + tid, col]
🔴 421 return v
422
🟢 423 @qd.func
🟢 424 def _resolve_vec3d(self, arr: qd.template(), batch, row_start, row_stop, col):
425 """Load one scalar per thread from a 3D array column, clamped to array bounds."""
🔴 426 tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴 427 arr_row_stop = arr.shape[1]
🔴 428 if arr_row_stop < row_stop:
🔴 429 row_stop = arr_row_stop
🔴 430 v = qd.cast(0.0, dtype) # see _resolve_vec2d for why qd.cast (not dtype(0.0))
🔴 431 if row_start + tid < row_stop:
🔴 432 v = arr[batch, row_start + tid, col]
🔴 433 return v
434
🟢 435 def _resolve_vec_proxy(self, proxy: _VecSliceProxy) -> Any:
436 """Materialize a _VecSliceProxy into a scalar by dispatching to _resolve_vec2d or _resolve_vec3d."""
🟢 437 if proxy.batch_idx is not None:
🟢 438 return self._resolve_vec3d(proxy.arr, proxy.batch_idx, proxy.row_start, proxy.row_stop, proxy.col)
🟢 439 return self._resolve_vec2d(proxy.arr, proxy.row_start, proxy.row_stop, proxy.col)
440
🟢 441 def _augassign(self, other: Any, op: str) -> None:
442 """Handle augmented assignment (e.g. tile -= qd.outer(a, b)).
443
444 Resolves _VecSliceProxy arguments and dispatches to _ger_sub. Only 'Sub' is supported.
445 """
🟢 446 if isinstance(other, _OuterProduct):
🟢 447 if op == "Sub":
🟢 448 a_orig = other.a
🟢 449 b_orig = other.b
🟢 450 a = self._resolve_vec_proxy(a_orig) if isinstance(a_orig, _VecSliceProxy) else a_orig
🟢 451 b = (
452 a
453 if (b_orig is a_orig)
454 else (self._resolve_vec_proxy(b_orig) if isinstance(b_orig, _VecSliceProxy) else b_orig)
455 )
🟢 456 self._ger_sub(a, b)
457 else:
🟢 458 raise TypeError(f"{name}: unsupported augmented assignment op '{op}' with outer product")
459 else:
🟢 460 raise TypeError(f"{name}: unsupported augmented assignment with {type(other)}")
461
🟢 462 _Tile.__name__ = f"_{name}"
🟢 463 _Tile.__qualname__ = f"_make_tile_class.<locals>._{name}"
464
465 # StructType.__call__ already defaults missing args to 0, so Tile() produces a zero-initialized tile
466 # without needing default values in the class definition (which @qd.dataclass doesn't support).
🟢 467 result = qd.dataclass(_Tile)
🟢 468 result.SIZE = N # type: ignore[reportAttributeAccessIssue]
🟢 469 result.zeros = result # type: ignore[reportAttributeAccessIssue]
470
🟢 471 @qd.func
🟢 472 def _eye():
🔴 473 t = result()
🔴 474 t.eye_() # type: ignore[reportAttributeAccessIssue]
🔴 475 return t
476
🟢 477 result.eye = _eye # type: ignore[reportAttributeAccessIssue]
🟢 478 return result
479
480
🟢 481 class _TileProxy:
482 """Proxy for dtype-at-point-of-use tile creation.
483
484 Use as ``qd.simt.Tile16x16.zeros(dtype=qd.f32)`` or ``qd.simt.Tile32x32.zeros(dtype=qd.f32)`` inside a kernel.
485 The dtype is resolved at kernel compilation time, defaulting to the compile config's ``default_fp`` if omitted.
486 """
487
🟢 488 def __init__(self, N: int) -> None:
🟢 489 self._N = N
🟢 490 self.SIZE = N
491
🟢 492 def _resolve(self, dtype):
🟢 493 from quadrants.lang import impl # pylint: disable=import-outside-toplevel
🟢 494 from quadrants.lang.exception import ( # pylint: disable=import-outside-toplevel
495 QuadrantsSyntaxError,
496 )
497
🟢 498 arch = impl.current_cfg().arch
🟢 499 if arch in (qd.cpu, qd.x64, getattr(qd, "arm64", None)):
🟢 500 raise QuadrantsSyntaxError(
501 f"Tile{self._N}x{self._N} requires a GPU backend (cuda, metal, vulkan, amdgpu). "
502 f"Current arch is {arch}."
503 )
🟢 504 if dtype is None:
🟢 505 dtype = impl.get_runtime().default_fp
🟢 506 return _make_tile(self._N, dtype)
507
🟢 508 def zeros(self, *, dtype=None):
509 """Zero-initialized tile."""
🟢 510 return self._resolve(dtype)()
511
🟢 512 def eye(self, *, dtype=None):
513 """Identity tile (diagonal = 1, rest = 0)."""
🟢 514 return self._resolve(dtype).eye()
515
516
🟢 517 Tile16x16Proxy = _TileProxy(16)
🟢 518 Tile32x32Proxy = _TileProxy(32)
🟢 python/quadrants/lang/simt/tile_slicing.py (100%)
🟢 9 from quadrants.lang.simt._tile import (
10 _tile_cache,
🟢 20 return any(isinstance(value, t) for t in _tile_cache.values())
🟢 25 return bool(_tile_cache)
🟢 tests/python/test_tile.py (100%)
🟢 12 from quadrants.lang.simt import _tile
🟢 13 from quadrants.lang.simt._tile import (
14 _make_tile,
🟢 30 import functools as _functools # noqa: E402
35 _types.SimpleNamespace(
36 proxy=qd.simt.Tile16x16, make=_functools.partial(_make_tile, 16), size=16, m_size=40, name="tile16"
37 ),
41 _types.SimpleNamespace(
42 proxy=qd.simt.Tile32x32, make=_functools.partial(_make_tile, 32), size=32, m_size=80, name="tile32"
43 ),
431 """_make_tile must return the same object for the same (N, dtype)."""
1550 # The tile-class cache is process-global and keyed by (N, dtype) value, so a previously cached identity-dtype
1551 # class would mask the regression. Drop the entry so the class is rebuilt capturing this non-identity dtype,
1552 # and restore the cache afterwards so we don't leak a deepcopy-keyed entry into other tests.
🟢 1553 cache = _tile._tile_cache
🟢 1554 cache_key = (tdim, nonid_dtype)
🟢 1555 cache.pop(cache_key, None)
🟢 1606 cache.pop(cache_key, None)
Loading