Skip to content

[Tile] Use unpacked vector field for Tile16x16/Tile32x32 register storage#722

Open
hughperkins wants to merge 5 commits into
mainfrom
hp/tiles-use-unpacked-vector
Open

[Tile] Use unpacked vector field for Tile16x16/Tile32x32 register storage#722
hughperkins wants to merge 5 commits into
mainfrom
hp/tiles-use-unpacked-vector

[Tile] Fix two perf/correctness regressions introduced during the unp…

c705326
Select commit
Loading
Failed to load commit list.
Sign in for the full log view
GitHub Actions / Coverage Report succeeded Jun 29, 2026 in 0s

Diff Coverage Report

See details below for per-line coverage annotations.

Details

Coverage Report (c70532649)

Metric Value
Diff coverage (changed lines only) 58%
Overall project coverage 74%

Total: 276 lines, 115 missing, 58% covered

🟢 python/quadrants/__init__.py (100%)
🟢   59          from quadrants.lang.simt._tile import outer  # noqa: I001  # pylint: disable=import-outside-toplevel
🟢 python/quadrants/lang/simt/__init__.py (100%)
      6      from quadrants.lang.simt._tile import Tile16x16Proxy as Tile16x16
      7      from quadrants.lang.simt._tile import Tile32x32Proxy as Tile32x32
🟢   13      if name in ("Tile16x16", "Tile32x32"):
🟢   14          from quadrants.lang.simt._tile import (  # pylint: disable=import-outside-toplevel
🟢   19          proxy = Tile16x16Proxy if name == "Tile16x16" else Tile32x32Proxy
🟢   20          globals()[name] = proxy
🟢   21          return proxy
🔴 python/quadrants/lang/simt/_tile.py (56%)
      1  # pyright: reportInvalidTypeForm=false
      2  
      3  """
      4  Register-resident NxN tile operations.
      5  
      6  Each tile is an NxN matrix distributed across N threads in a subgroup, one row per thread, with each row stored in N
      7  scalar registers held in an unpacked vector field (``self.r``).  Cross-thread communication uses subgroup shuffles --
      8  no shared memory needed.
      9  
     10  A single factory ``_make_tile_class(N, dtype)`` builds the tile dataclass for both supported tile sizes (N == 16 and
     11  N == 32).  The user-facing entry points are the proxies ``qd.simt.Tile16x16`` and ``qd.simt.Tile32x32``, which defer
     12  dtype resolution to kernel compile time (defaulting to the runtime ``default_fp``).
     13  
     14  The thread's lane index (tid) is obtained internally via ``subgroup.invocation_id()``, so callers never need to pass
     15  it.  See docs/source/user_guide/tile.md for usage documentation.
     16  """
     17  
🟢   18  from typing import TYPE_CHECKING as _TYPE_CHECKING
🟢   19  from typing import Any, NoReturn
     20  
🟢   21  import quadrants as qd
     22  
🟢   23  if _TYPE_CHECKING:
     24  
🔴   25      class _TileProto:  # noqa: E303
     26          """Static type stub so pyright sees TileNxN methods correctly (shared by Tile16x16 and Tile32x32)."""
     27  
🔴   28          SIZE: int
     29  
     30          def __init__(self, *args: Any, **kwargs: Any) -> None: ...  # noqa: E704
     31          @classmethod
     32          def zeros(cls) -> "_TileProto": ...  # noqa: E704
     33          @classmethod
     34          def eye(cls) -> "_TileProto": ...  # noqa: E704
     35          def eye_(self) -> None: ...  # noqa: E704
     36          def cholesky_(self, eps: Any) -> None: ...  # noqa: E704
     37          def solve_triangular_(self, B: "_TileProto", lower: bool = True) -> None: ...  # noqa: E704
     38          def _get_col(self, k: Any) -> Any: ...  # noqa: E704
     39          def _set_col(self, k: Any, val: Any) -> None: ...  # noqa: E704
     40          def _load(self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any) -> None: ...  # noqa: E704
     41          def _store(
     42              self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
     43          ) -> None: ...  # noqa: E704
     44          def _load3d(
     45              self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
     46          ) -> None: ...  # noqa: E704
     47          def _store3d(
     48              self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
     49          ) -> None: ...  # noqa: E704
     50          def _ger_sub(self, a: Any, b: Any) -> None: ...  # noqa: E704
     51          def _trsm(self, L: "_TileProto") -> None: ...  # noqa: E704
     52          def __isub__(self, other: Any) -> "_TileProto": ...  # noqa: E704
     53          def __getitem__(self, key: Any) -> Any: ...  # noqa: E704
     54          def __setitem__(self, key: Any, value: Any) -> None: ...  # noqa: E704
     55  
     56  
🟢   57  class _OuterProduct:
     58      """Deferred outer product proxy for use with augmented assignment on a Tile.
     59  
     60      Created by qd.outer(a, b). Not a quadrants expression -- only valid as the RHS of ``tile -= qd.outer(a, b)``.
     61      """
     62  
🟢   63      _qd_is_deferred = True
     64  
🟢   65      def __init__(self, a: Any, b: Any) -> None:
🟢   66          self.a = a
🟢   67          self.b = b
     68  
🟢   69      def __add__(self, other: Any) -> NoReturn:
🟢   70          raise TypeError("OuterProduct does not support composition; apply each update separately")
     71  
🟢   72      def __radd__(self, other: Any) -> NoReturn:
🔴   73          raise TypeError("OuterProduct does not support composition; apply each update separately")
     74  
     75  
🟢   76  def outer(a: Any, b: Any) -> _OuterProduct:
     77      """Create a deferred outer product for use with Tile augmented assignment.
     78  
     79      Usage::
     80  
     81          t -= qd.outer(a, b)   # equivalent to t._ger_sub(a, b)
     82          t -= qd.outer(v, v)   # symmetric case (a == b)
     83      """
🟢   84      return _OuterProduct(a, b)
     85  
     86  
🟢   87  class _DeferredProxyMixin:
     88      """Raises clear errors if a deferred tile proxy is accidentally used as a value."""
     89  
🟢   90      _proxy_description = "Tile proxy"
     91  
🟢   92      def _misuse(self, op: str = "used") -> NoReturn:
🟢   93          raise TypeError(
     94              f"{self._proxy_description} was {op}, but it is only valid in tile operations (tile[:] = ..., ... = tile, qd.outer(...))"
     95          )
     96  
🟢   97      def __add__(self, other: Any) -> NoReturn:
🟢   98          self._misuse("added")
     99  
🟢  100      def __radd__(self, other: Any) -> NoReturn:
🟢  101          self._misuse("added")
    102  
🟢  103      def __sub__(self, other: Any) -> NoReturn:
🟢  104          self._misuse("subtracted")
    105  
🟢  106      def __mul__(self, other: Any) -> NoReturn:
🟢  107          self._misuse("multiplied")
    108  
🟢  109      def __getitem__(self, key: Any) -> NoReturn:
🟢  110          self._misuse("subscripted")
    111  
🟢  112      def __repr__(self) -> str:
🟢  113          return f"<{self._proxy_description} — not a value; use with tile[:] = ... or qd.outer(...)>"
    114  
    115  
🟢  116  class _TileSliceProxy(_DeferredProxyMixin):
    117      """Deferred 2D/3D array slice for tile load/store.
    118  
    119      Created by subscripting a Field or ndarray with 2D slices, e.g. ``arr[row_start:row_stop, col_start:col_stop]``.
    120      Not a quadrants expression -- only valid as the RHS of a tile assignment (load) or as the LHS target (store).
    121      """
    122  
🟢  123      _qd_is_deferred = True
🟢  124      _proxy_description = "Array slice proxy (arr[r0:r1, c0:c1])"
    125  
🟢  126      def __init__(
    127          self, arr: Any, row_start: Any, row_stop: Any, col_start: Any, col_stop: Any, batch_idx: Any = None
    128      ) -> None:
🟢  129          self.arr = arr
🟢  130          self.row_start = row_start
🟢  131          self.row_stop = row_stop
🟢  132          self.col_start = col_start
🟢  133          self.col_stop = col_stop
🟢  134          self.batch_idx = batch_idx
    135  
🟢  136      def _assign(self, tile: Any) -> None:
    137          """Store path: arr[r:r+n_rows, c:c+n_cols] = tile."""
🟢  138          if self.batch_idx is not None:
🟢  139              tile._store3d(self.arr, self.batch_idx, self.row_start, self.row_stop, self.col_start, self.col_stop)
    140          else:
🟢  141              tile._store(self.arr, self.row_start, self.row_stop, self.col_start, self.col_stop)
    142  
    143  
🟢  144  class _VecSliceProxy(_DeferredProxyMixin):
    145      """Deferred column-vector load from a 2D/3D array.
    146  
    147      Created by ``arr[row_start:row_stop, col]`` or ``arr[batch_idx, row_start:row_stop, col]``.
    148      Each subgroup thread loads one element; out-of-range threads get 0.
    149      Only valid as an argument to ``qd.outer()`` in tile augmented assignment.
    150      """
    151  
🟢  152      _qd_is_deferred = True
🟢  153      _proxy_description = "Vec slice proxy (arr[r0:r1, col])"
    154  
🟢  155      def __init__(self, arr: Any, row_start: Any, row_stop: Any, col: Any, batch_idx: Any = None) -> None:
🟢  156          self.arr = arr
🟢  157          self.row_start = row_start
🟢  158          self.row_stop = row_stop
🟢  159          self.col = col
🟢  160          self.batch_idx = batch_idx
    161  
    162  
🟢  163  class _TileRefProxy:
    164      """Proxy returned by tile[:] for the LHS of a load assignment.
    165  
    166      Enables ``tile[:] = arr[r:r+N, c:n]``.  The ``[:]`` is required to distinguish in-place tile loads from
    167      variable rebinding.
    168      """
    169  
🟢  170      _qd_is_deferred = True
    171  
🟢  172      def __init__(self, tile: Any) -> None:
🟢  173          self.tile = tile
    174  
🟢  175      def _assign(self, value: Any) -> None:
    176          """Load path: tile[:] = arr[r:r+n, c:c+n]. Dispatches to _load or _load3d."""
🟢  177          if isinstance(value, _TileSliceProxy):
🟢  178              if value.batch_idx is not None:
🟢  179                  self.tile._load3d(
    180                      value.arr, value.batch_idx, value.row_start, value.row_stop, value.col_start, value.col_stop
    181                  )
    182              else:
🟢  183                  self.tile._load(value.arr, value.row_start, value.row_stop, value.col_start, value.col_stop)
    184          else:
🔴  185              raise TypeError(f"Tile[:] can only be assigned from an array slice, got {type(value)}")
    186  
    187  
🟢  188  _tile_cache: dict = {}
    189  
    190  
🟢  191  def _make_tile(N: int, dtype=None) -> "type[_TileProto]":
    192      """Create a TileNxN dataclass whose registers use the given scalar dtype (qd.f32 or qd.f64).
    193  
    194      This is an internal factory.  Use ``qd.simt.Tile16x16`` / ``qd.simt.Tile32x32`` (the proxies) instead.
    195      """
🟢  196      if dtype is None:
🔴  197          dtype = qd.f32
🟢  198      key = (N, dtype)
🟢  199      if key in _tile_cache:
🟢  200          return _tile_cache[key]  # pyright: ignore[reportReturnType]
🟢  201      cls = _make_tile_class(N, dtype)
🟢  202      _tile_cache[key] = cls
🟢  203      return cls  # pyright: ignore[reportReturnType]
    204  
    205  
🟢  206  def _make_tile_class(N: int, dtype):
🟢  207      name = f"Tile{N}x{N}"
    208  
🟢  209      class _Tile:
    210          """An NxN tile distributed one row per subgroup thread, with each row held in N scalar registers via an
    211          unpacked vector field.  ``TileNxN()`` creates a zero tile."""
    212  
🟢  213          r: qd.types.vector(N, dtype, unpacked=True)
    214  
🟢  215          @qd.func
🟢  216          def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop):
    217              """Load from a 2D array within [row_start, row_stop) x [col_start, col_stop).
    218  
    219              Each thread loads arr[row_start + tid, col_start:col_stop].  Threads where row_start + tid >= row_stop
    220              skip the load (tile row unchanged).
    221              """
🔴  222              arr_row_stop = arr.shape[0]
🔴  223              if arr_row_stop < row_stop:
🔴  224                  row_stop = arr_row_stop
🔴  225              row = row_start + qd.simt.subgroup.invocation_id()
🔴  226              if row < row_stop:
🔴  227                  arr_col_stop = arr.shape[1]
🔴  228                  if arr_col_stop < col_stop:
🔴  229                      col_stop = arr_col_stop
🔴  230                  for j in qd.static(range(N)):
🔴  231                      if col_start + j < col_stop:
🔴  232                          self.r[j] = arr[row, col_start + j]
    233  
🟢  234          @qd.func
🟢  235          def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop):
    236              """Load from a 3D array within [row_start, row_stop) x [col_start, col_stop).
    237  
    238              Each thread loads arr[batch, row_start+tid, col_start:col_stop].  Threads where row_start + tid >=
    239              row_stop skip the load (tile row unchanged).
    240              """
🔴  241              arr_row_stop = arr.shape[1]
🔴  242              if arr_row_stop < row_stop:
🔴  243                  row_stop = arr_row_stop
🔴  244              row = row_start + qd.simt.subgroup.invocation_id()
🔴  245              if row < row_stop:
🔴  246                  arr_col_stop = arr.shape[2]
🔴  247                  if arr_col_stop < col_stop:
🔴  248                      col_stop = arr_col_stop
🔴  249                  for j in qd.static(range(N)):
🔴  250                      if col_start + j < col_stop:
🔴  251                          self.r[j] = arr[batch, row, col_start + j]
    252  
🟢  253          @qd.func
🟢  254          def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop):
    255              """Store to a 2D array within [row_start, row_stop) x [col_start, col_stop).
    256  
    257              Each thread stores to arr[row_start + tid, col_start:col_stop].  Threads where row_start + tid >=
    258              row_stop skip the store.
    259              """
🔴  260              arr_row_stop = arr.shape[0]
🔴  261              if arr_row_stop < row_stop:
🔴  262                  row_stop = arr_row_stop
🔴  263              row = row_start + qd.simt.subgroup.invocation_id()
🔴  264              if row < row_stop:
🔴  265                  arr_col_stop = arr.shape[1]
🔴  266                  if arr_col_stop < col_stop:
🔴  267                      col_stop = arr_col_stop
🔴  268                  for j in qd.static(range(N)):
🔴  269                      if col_start + j < col_stop:
🔴  270                          arr[row, col_start + j] = self.r[j]
    271  
🟢  272          @qd.func
🟢  273          def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop):
    274              """Store to a 3D array within [row_start, row_stop) x [col_start, col_stop).
    275  
    276              Each thread stores to arr[batch, row_start+tid, col_start:col_stop].  Threads where row_start + tid >=
    277              row_stop skip the store.
    278              """
🔴  279              arr_row_stop = arr.shape[1]
🔴  280              if arr_row_stop < row_stop:
🔴  281                  row_stop = arr_row_stop
🔴  282              row = row_start + qd.simt.subgroup.invocation_id()
🔴  283              if row < row_stop:
🔴  284                  arr_col_stop = arr.shape[2]
🔴  285                  if arr_col_stop < col_stop:
🔴  286                      col_stop = arr_col_stop
🔴  287                  for j in qd.static(range(N)):
🔴  288                      if col_start + j < col_stop:
🔴  289                          arr[batch, row, col_start + j] = self.r[j]
    290  
🟢  291          @qd.func
🟢  292          def eye_(self):
    293              """Set this tile to the NxN identity matrix.  Each thread sets its diagonal element to 1.0 and all
    294              others to 0.0."""
🔴  295              tid = qd.simt.subgroup.invocation_id()
🔴  296              for j in qd.static(range(N)):
🔴  297                  self.r[j] = 1.0 if tid == j else 0.0
    298  
🟢  299          @qd.func
🟢  300          def _ger_sub(self, a, b):
    301              """General rank-1 subtract in-place: self -= a @ b^T."""
🔴  302              for j in qd.static(range(N)):
🔴  303                  bc = qd.simt.subgroup.shuffle(b, qd.u32(j))
🔴  304                  self.r[j] = self.r[j] - a * bc
    305  
🟢  306          @qd.func
🟢  307          def cholesky_(self, eps):
    308              """In-place NxN Cholesky factorization via subgroup shuffles.
    309  
    310              On return, the lower triangle holds L such that A = L @ L^T.  Diagonal clamped to
    311              sqrt(max(value, eps)) for numerical stability.
    312              """
    313              # ``k`` and ``j`` are wrapped in qd.static so the ``if k > j`` predicate folds at compile time and the
    314              # ``self.r[k]`` / ``self.r[j]`` accesses resolve to a single unpacked-register slot per use (no runtime
    315              # cascade).  The per-lane row-norm used for the diagonal update is carried in ``my_norm_sq``, so each
    316              # diagonal step is O(1) rather than O(k).  The off-diagonal ``dot`` is split into two interleaved partial
    317              # sums (``dot0`` / ``dot1``) so the back-to-back FMA dependency chain is cut in half, exposing more
    318              # instruction-level parallelism.
🔴  319              tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴  320              my_norm_sq = qd.cast(0.0, dtype)
🔴  321              for k in qd.static(range(N)):
🔴  322                  diag_val = qd.cast(0.0, dtype)
🔴  323                  if tid == k:
🔴  324                      diag_val = qd.sqrt(qd.max(self.r[k] - my_norm_sq, eps))
🔴  325                      self.r[k] = diag_val
    326  
🔴  327                  diag_k = qd.simt.subgroup.shuffle(diag_val, qd.u32(k))
    328  
🔴  329                  dot0 = qd.cast(0.0, dtype)
🔴  330                  dot1 = qd.cast(0.0, dtype)
🔴  331                  for j in qd.static(range(N)):
🔴  332                      if k > j:
🔴  333                          my_col = self.r[j]
🔴  334                          Lkj = qd.simt.subgroup.shuffle(my_col, qd.u32(k))
🔴  335                          if j % 2 == 0:
🔴  336                              dot0 += Lkj * my_col  # type: ignore[reportOperatorIssue]
    337                          else:
🔴  338                              dot1 += Lkj * my_col  # type: ignore[reportOperatorIssue]
🔴  339                  dot = dot0 + dot1
    340  
🔴  341                  new_val = qd.cast(0.0, dtype)
🔴  342                  if tid > k:  # type: ignore[reportOperatorIssue]
🔴  343                      new_val = (self.r[k] - dot) / diag_k  # type: ignore[reportOperatorIssue]
🔴  344                      self.r[k] = new_val
🔴  345                  if tid > k:  # type: ignore[reportOperatorIssue]
🔴  346                      my_norm_sq += new_val * new_val
    347  
🟢  348          @qd.func
🟢  349          def _get_col(self, k):
    350              """Read register column ``k`` at runtime via a static-unrolled cascade.
    351  
    352              The unpacked vector field rejects runtime indices, so the cascade is emitted explicitly. With ``k`` a
    353              runtime int and ``kk`` a python-int from ``qd.static``, the body of each iteration becomes a guarded
    354              single-slot read; LLVM later selects on ``k`` to pick the matching slot.  Used by ``_trsm`` so the outer
    355              loop can be a runtime ``range(N)`` (LLVM picks the unroll factor) rather than the fully-unrolled
    356              ``qd.static(range(N))`` that spikes register pressure.
    357              """
🔴  358              val = qd.cast(0.0, dtype)
🔴  359              for kk in qd.static(range(N)):
🔴  360                  if k == kk:
🔴  361                      val = self.r[kk]
🔴  362              return val
    363  
🟢  364          @qd.func
🟢  365          def _set_col(self, k, val):
    366              """Write register column ``k`` at runtime via a static-unrolled cascade.  See ``_get_col`` for rationale."""
🔴  367              for kk in qd.static(range(N)):
🔴  368                  if k == kk:
🔴  369                      self.r[kk] = val
    370  
🟢  371          @qd.func
🟢  372          def _trsm(self, L):
    373              """In-place triangular solve: solve self @ L^T = B (original self).
    374  
    375              L is a TileNxN holding the lower-triangular Cholesky factor (from cholesky_).  On return, self holds the
    376              solution X.
    377  
    378              The outer loop uses ``range(N)`` (runtime), not ``qd.static(range(N))``, so LLVM can pick the unroll
    379              factor: fully unrolling the N*N body fully explodes the live set and pushes ~37% more registers into
    380              the kernel, causing measurable perf loss on the blocked Cholesky benchmark (e.g. ~9% slower on
    381              ``misc/demos/cholesky_blocked.py`` for N=92).  The inner ``j`` loop is also ``range(N)`` for the same
    382              reason.  Runtime access into the unpacked-vector field goes through ``_get_col`` / ``_set_col`` which
    383              emit explicit cascades over ``self.r[kk]`` for static ``kk``.
    384              """
🔴  385              for c in range(N):
🔴  386                  dot = qd.cast(0.0, dtype)
🔴  387                  for j in range(N):
🔴  388                      if c > j:
🔴  389                          Lkj = qd.simt.subgroup.shuffle(L._get_col(j), qd.u32(c))
🔴  390                          dot += self._get_col(j) * Lkj  # type: ignore[reportOperatorIssue]
    391  
🔴  392                  diag_c = qd.simt.subgroup.shuffle(L._get_col(c), qd.u32(c))
🔴  393                  new_val = (self._get_col(c) - dot) / diag_c  # type: ignore[reportOperatorIssue]
🔴  394                  self._set_col(c, new_val)
    395  
🟢  396          def solve_triangular_(self, B: Any, lower: bool = True) -> None:
    397              """Triangular solve: X @ self^T = B, storing result X in B in-place.
    398  
    399              self must be lower-triangular and non-singular (all diagonal elements non-zero).  Passing a singular
    400              matrix causes division by zero, producing inf/NaN without warning.  Only lower=True is supported.
    401              """
🟢  402              if not lower:
🟢  403                  raise TypeError(f"{name}.solve_triangular_: only lower=True is supported")
🟢  404              B._trsm(self)
    405  
🟢  406          @qd.func
🟢  407          def _resolve_vec2d(self, arr: qd.template(), row_start, row_stop, col):
    408              """Load one scalar per thread from a 2D array column, clamped to array bounds."""
🔴  409              tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴  410              arr_row_stop = arr.shape[0]
🔴  411              if arr_row_stop < row_stop:
🔴  412                  row_stop = arr_row_stop
    413              # Use qd.cast, not dtype(0.0): the AST transformer only treats a call as a type construction when
    414              # id(dtype) is in primitive_types.type_ids, but a dtype resolved from a deep-copied default_fp (e.g.
    415              # after qd.init(default_fp=qd.f32)) has a different id and falls through to a raw call, raising
    416              # "Quadrants data types cannot be called outside Quadrants kernels".  qd.cast is identity-independent
    417              # and folds to the same typed constant.
🔴  418              v = qd.cast(0.0, dtype)
🔴  419              if row_start + tid < row_stop:
🔴  420                  v = arr[row_start + tid, col]
🔴  421              return v
    422  
🟢  423          @qd.func
🟢  424          def _resolve_vec3d(self, arr: qd.template(), batch, row_start, row_stop, col):
    425              """Load one scalar per thread from a 3D array column, clamped to array bounds."""
🔴  426              tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴  427              arr_row_stop = arr.shape[1]
🔴  428              if arr_row_stop < row_stop:
🔴  429                  row_stop = arr_row_stop
🔴  430              v = qd.cast(0.0, dtype)  # see _resolve_vec2d for why qd.cast (not dtype(0.0))
🔴  431              if row_start + tid < row_stop:
🔴  432                  v = arr[batch, row_start + tid, col]
🔴  433              return v
    434  
🟢  435          def _resolve_vec_proxy(self, proxy: _VecSliceProxy) -> Any:
    436              """Materialize a _VecSliceProxy into a scalar by dispatching to _resolve_vec2d or _resolve_vec3d."""
🟢  437              if proxy.batch_idx is not None:
🟢  438                  return self._resolve_vec3d(proxy.arr, proxy.batch_idx, proxy.row_start, proxy.row_stop, proxy.col)
🟢  439              return self._resolve_vec2d(proxy.arr, proxy.row_start, proxy.row_stop, proxy.col)
    440  
🟢  441          def _augassign(self, other: Any, op: str) -> None:
    442              """Handle augmented assignment (e.g. tile -= qd.outer(a, b)).
    443  
    444              Resolves _VecSliceProxy arguments and dispatches to _ger_sub.  Only 'Sub' is supported.
    445              """
🟢  446              if isinstance(other, _OuterProduct):
🟢  447                  if op == "Sub":
🟢  448                      a_orig = other.a
🟢  449                      b_orig = other.b
🟢  450                      a = self._resolve_vec_proxy(a_orig) if isinstance(a_orig, _VecSliceProxy) else a_orig
🟢  451                      b = (
    452                          a
    453                          if (b_orig is a_orig)
    454                          else (self._resolve_vec_proxy(b_orig) if isinstance(b_orig, _VecSliceProxy) else b_orig)
    455                      )
🟢  456                      self._ger_sub(a, b)
    457                  else:
🟢  458                      raise TypeError(f"{name}: unsupported augmented assignment op '{op}' with outer product")
    459              else:
🟢  460                  raise TypeError(f"{name}: unsupported augmented assignment with {type(other)}")
    461  
🟢  462      _Tile.__name__ = f"_{name}"
🟢  463      _Tile.__qualname__ = f"_make_tile_class.<locals>._{name}"
    464  
    465      # StructType.__call__ already defaults missing args to 0, so Tile() produces a zero-initialized tile
    466      # without needing default values in the class definition (which @qd.dataclass doesn't support).
🟢  467      result = qd.dataclass(_Tile)
🟢  468      result.SIZE = N  # type: ignore[reportAttributeAccessIssue]
🟢  469      result.zeros = result  # type: ignore[reportAttributeAccessIssue]
    470  
🟢  471      @qd.func
🟢  472      def _eye():
🔴  473          t = result()
🔴  474          t.eye_()  # type: ignore[reportAttributeAccessIssue]
🔴  475          return t
    476  
🟢  477      result.eye = _eye  # type: ignore[reportAttributeAccessIssue]
🟢  478      return result
    479  
    480  
🟢  481  class _TileProxy:
    482      """Proxy for dtype-at-point-of-use tile creation.
    483  
    484      Use as ``qd.simt.Tile16x16.zeros(dtype=qd.f32)`` or ``qd.simt.Tile32x32.zeros(dtype=qd.f32)`` inside a kernel.
    485      The dtype is resolved at kernel compilation time, defaulting to the compile config's ``default_fp`` if omitted.
    486      """
    487  
🟢  488      def __init__(self, N: int) -> None:
🟢  489          self._N = N
🟢  490          self.SIZE = N
    491  
🟢  492      def _resolve(self, dtype):
🟢  493          from quadrants.lang import impl  # pylint: disable=import-outside-toplevel
🟢  494          from quadrants.lang.exception import (  # pylint: disable=import-outside-toplevel
    495              QuadrantsSyntaxError,
    496          )
    497  
🟢  498          arch = impl.current_cfg().arch
🟢  499          if arch in (qd.cpu, qd.x64, getattr(qd, "arm64", None)):
🟢  500              raise QuadrantsSyntaxError(
    501                  f"Tile{self._N}x{self._N} requires a GPU backend (cuda, metal, vulkan, amdgpu). "
    502                  f"Current arch is {arch}."
    503              )
🟢  504          if dtype is None:
🟢  505              dtype = impl.get_runtime().default_fp
🟢  506          return _make_tile(self._N, dtype)
    507  
🟢  508      def zeros(self, *, dtype=None):
    509          """Zero-initialized tile."""
🟢  510          return self._resolve(dtype)()
    511  
🟢  512      def eye(self, *, dtype=None):
    513          """Identity tile (diagonal = 1, rest = 0)."""
🟢  514          return self._resolve(dtype).eye()
    515  
    516  
🟢  517  Tile16x16Proxy = _TileProxy(16)
🟢  518  Tile32x32Proxy = _TileProxy(32)
🟢 python/quadrants/lang/simt/tile_slicing.py (100%)
🟢    9  from quadrants.lang.simt._tile import (
     10      _tile_cache,
🟢   20      return any(isinstance(value, t) for t in _tile_cache.values())
🟢   25      return bool(_tile_cache)
🟢 tests/python/test_tile.py (100%)
🟢   12  from quadrants.lang.simt import _tile
🟢   13  from quadrants.lang.simt._tile import (
     14      _make_tile,
🟢   30  import functools as _functools  # noqa: E402
     35          _types.SimpleNamespace(
     36              proxy=qd.simt.Tile16x16, make=_functools.partial(_make_tile, 16), size=16, m_size=40, name="tile16"
     37          ),
     41          _types.SimpleNamespace(
     42              proxy=qd.simt.Tile32x32, make=_functools.partial(_make_tile, 32), size=32, m_size=80, name="tile32"
     43          ),
    431      """_make_tile must return the same object for the same (N, dtype)."""
   1550      # The tile-class cache is process-global and keyed by (N, dtype) value, so a previously cached identity-dtype
   1551      # class would mask the regression.  Drop the entry so the class is rebuilt capturing this non-identity dtype,
   1552      # and restore the cache afterwards so we don't leak a deepcopy-keyed entry into other tests.
🟢 1553      cache = _tile._tile_cache
🟢 1554      cache_key = (tdim, nonid_dtype)
🟢 1555      cache.pop(cache_key, None)
🟢 1606          cache.pop(cache_key, None)