From 982bf481e6f489814f721d8279a35a3517ba884c Mon Sep 17 00:00:00 2001 From: Ryan Huang Date: Sat, 25 Apr 2026 19:53:24 +0000 Subject: [PATCH 1/2] [QDP] Close AMD-vs-CUDA encoder parity gap: add iqp, iqp-z, phase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CUDA QdpEngine accepts amplitude, angle, basis, iqp, iqp-z, and phase. The Triton AMD path only implemented the first three, so AMD users hit a hard error on the IQP- and phase-family encodings (e.g. SVHN-IQP). This adds vectorized PyTorch implementations for the missing methods on TritonAmdEngine, dispatched through the same ``encode(method=...)`` contract: - ``iqp`` — full ZZ entanglement: phase = Σ x_i·data_i + Σ_{i bool: @dataclass class TritonAmdEngine: - """AMD backend implementing amplitude/angle/basis encoders.""" + """AMD backend implementing amplitude/angle/basis/iqp/iqp-z/phase encoders.""" device_id: int = 0 precision: str = "float32" @@ -195,6 +196,85 @@ def encode_basis(self, data: Any, num_qubits: int) -> Any: ) return out + def encode_iqp( + self, + data: Any, + num_qubits: int, + *, + enable_zz: bool = True, + ) -> Any: + torch_mod = self._require_torch() + real_dtype = self._real_dtype() + params = self._to_2d(data, dtype=real_dtype) + batch, width = params.shape + + n = num_qubits + expected = n + n * (n - 1) // 2 if enable_zz else n + if width != expected: + variant = "ZZ" if enable_zz else "Z-only" + raise ValueError( + f"IQP encoding ({variant}) expects {expected} parameters for {n} qubits, got {width}." + ) + + state_len = 1 << n + device = params.device + + # θ(x) = Σ_i x_i * data[i] (+ Σ_{i> torch_mod.arange(n, device=device)) & 1).to( + real_dtype + ) + z_params = params[:, :n] + phase = torch_mod.matmul(z_params, x_bits.T) + + if enable_zz and n >= 2: + zz_params = params[:, n:] + pairs = torch_mod.combinations(torch_mod.arange(n, device=device), r=2) + pair_matrix = x_bits[:, pairs[:, 0]] * x_bits[:, pairs[:, 1]] + phase = phase + torch_mod.matmul(zz_params, pair_matrix.T) + + # f[x] = exp(i·θ(x)), then n-stage Walsh-Hadamard butterfly, then 1/2^n. + complex_dtype = self._complex_dtype() + f = torch_mod.complex(torch_mod.cos(phase), torch_mod.sin(phase)).to( + complex_dtype + ) + for s in range(n): + stride = 1 << s + block = 1 << (s + 1) + f = f.view(batch, state_len // block, block) + lo = f[:, :, :stride] + hi = f[:, :, stride:] + f = torch_mod.cat([lo + hi, lo - hi], dim=2) + f = f.reshape(batch, state_len) + + norm_factor = 1.0 / float(state_len) + return f * norm_factor + + def encode_phase(self, data: Any, num_qubits: int) -> Any: + torch_mod = self._require_torch() + real_dtype = self._real_dtype() + phases = self._to_2d(data, dtype=real_dtype) + batch, width = phases.shape + if width != num_qubits: + raise ValueError( + f"Phase encoding expects sample size {num_qubits} (=num_qubits), got {width}." + ) + + state_len = 1 << num_qubits + device = phases.device + + # φ(b) = Σ_k phases[k] · b_k → state[b] = (1/√2^n) · exp(i·φ(b)) + b_idx = torch_mod.arange(state_len, device=device, dtype=torch_mod.int64) + bits = ( + (b_idx.unsqueeze(1) >> torch_mod.arange(num_qubits, device=device)) & 1 + ).to(real_dtype) + phi = torch_mod.matmul(phases, bits.T) + + norm = math.pow(math.sqrt(0.5), num_qubits) + re = torch_mod.cos(phi) * norm + im = torch_mod.sin(phi) * norm + return torch_mod.complex(re, im).to(self._complex_dtype()) + def encode( self, data: Any, @@ -210,6 +290,13 @@ def encode( return self.encode_angle(data, num_qubits) if method == "basis": return self.encode_basis(data, num_qubits) + if method == "iqp": + return self.encode_iqp(data, num_qubits, enable_zz=True) + if method == "iqp-z": + return self.encode_iqp(data, num_qubits, enable_zz=False) + if method == "phase": + return self.encode_phase(data, num_qubits) raise ValueError( - f"Unsupported encoding '{encoding_method}'. triton_amd supports amplitude, angle, basis." + f"Unsupported encoding '{encoding_method}'. " + "triton_amd supports amplitude, angle, basis, iqp, iqp-z, phase." ) diff --git a/qdp/qdp-python/tests/test_triton_amd_backend.py b/qdp/qdp-python/tests/test_triton_amd_backend.py index 1263f65e9d..4971919103 100644 --- a/qdp/qdp-python/tests/test_triton_amd_backend.py +++ b/qdp/qdp-python/tests/test_triton_amd_backend.py @@ -14,9 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + import pytest import torch from qumat_qdp import QdpEngine, is_triton_amd_available +from qumat_qdp.torch_ref import iqp_encode as _torch_ref_iqp from qumat_qdp.triton_amd import TritonAmdEngine @@ -50,6 +53,21 @@ def _torch_angle_ref(angles: torch.Tensor, num_qubits: int) -> torch.Tensor: return torch.complex(amp, torch.zeros_like(amp)) +def _torch_phase_ref(phases: torch.Tensor, num_qubits: int) -> torch.Tensor: + real_dtype = phases.dtype + batch = phases.shape[0] + state_len = 1 << num_qubits + idx = torch.arange(state_len, device=phases.device, dtype=torch.int64) + bits = ( + (idx.unsqueeze(1) >> torch.arange(num_qubits, device=phases.device)) & 1 + ).to(real_dtype) + phi = phases @ bits.T + norm = math.pow(math.sqrt(0.5), num_qubits) + out = torch.complex(torch.cos(phi) * norm, torch.sin(phi) * norm) + assert out.shape == (batch, state_len) + return out + + def _torch_basis_ref(idx: torch.Tensor, num_qubits: int) -> torch.Tensor: idx = idx.to(torch.int64) batch = idx.numel() @@ -187,3 +205,164 @@ def test_unified_router_contract_returns_torch_tensor() -> None: assert isinstance(qt, torch.Tensor) assert qt.shape == (2, 4) assert qt.dtype == torch.complex64 + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_iqp_full_parity_with_torch_ref() -> None: + n = 4 + engine = TritonAmdEngine(device_id=0, precision="float32") + data = torch.randn(3, n + n * (n - 1) // 2, device="cuda", dtype=torch.float32) + got = _as_torch(engine.encode(data, n, "iqp")) + ref = _torch_ref_iqp(data, n, enable_zz=True) + assert got.shape == ref.shape + assert got.dtype == torch.complex64 + assert torch.allclose(got, ref, atol=2e-5, rtol=2e-5) + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_iqp_z_only_parity_with_torch_ref() -> None: + n = 5 + engine = TritonAmdEngine(device_id=0, precision="float32") + data = torch.randn(2, n, device="cuda", dtype=torch.float32) + got = _as_torch(engine.encode(data, n, "iqp-z")) + ref = _torch_ref_iqp(data, n, enable_zz=False) + assert got.shape == ref.shape + assert torch.allclose(got, ref, atol=2e-5, rtol=2e-5) + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_iqp_param_count_validation() -> None: + engine = TritonAmdEngine(device_id=0, precision="float32") + # ZZ variant for n=4 expects 4 + 6 = 10 params; pass 9. + bad = torch.randn(2, 9, device="cuda", dtype=torch.float32) + with pytest.raises(ValueError, match="expects 10 parameters"): + engine.encode(bad, 4, "iqp") + # Z-only variant for n=4 expects 4 params; pass 5. + bad_z = torch.randn(2, 5, device="cuda", dtype=torch.float32) + with pytest.raises(ValueError, match="expects 4 parameters"): + engine.encode(bad_z, 4, "iqp-z") + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_iqp_normalization_unit_norm() -> None: + """IQP output is a normalized state vector: Σ|amp|² ≈ 1.""" + engine = TritonAmdEngine(device_id=0, precision="float32") + n = 6 + data = torch.randn(4, n + n * (n - 1) // 2, device="cuda", dtype=torch.float32) + got = _as_torch(engine.encode(data, n, "iqp")) + norms = (got.abs() ** 2).sum(dim=1) + assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4, rtol=1e-4) + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_phase_parity() -> None: + engine = TritonAmdEngine(device_id=0, precision="float32") + phases = torch.randn(3, 5, device="cuda", dtype=torch.float32) + got = _as_torch(engine.encode(phases, 5, "phase")) + ref = _torch_phase_ref(phases, 5) + assert got.shape == ref.shape + assert got.dtype == torch.complex64 + assert torch.allclose(got, ref, atol=1e-5, rtol=1e-5) + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_phase_normalization_unit_norm() -> None: + """Phase output is a uniform-magnitude product state: Σ|amp|² ≈ 1.""" + engine = TritonAmdEngine(device_id=0, precision="float32") + n = 6 + phases = torch.randn(4, n, device="cuda", dtype=torch.float32) + got = _as_torch(engine.encode(phases, n, "phase")) + norms = (got.abs() ** 2).sum(dim=1) + assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4, rtol=1e-4) + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_phase_param_count_validation() -> None: + engine = TritonAmdEngine(device_id=0, precision="float32") + bad = torch.randn(2, 3, device="cuda", dtype=torch.float32) + with pytest.raises(ValueError, match="sample size 4"): + engine.encode(bad, 4, "phase") + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_phase_float64_precision_contract() -> None: + engine = TritonAmdEngine(device_id=0, precision="float64") + phases = torch.randn(2, 4, device="cuda", dtype=torch.float64) + got = _as_torch(engine.encode(phases, 4, "phase")) + ref = _torch_phase_ref(phases, 4).to(torch.complex128) + assert got.dtype == torch.complex128 + assert torch.allclose(got, ref, atol=1e-12, rtol=1e-12) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or getattr(torch.version, "cuda", None) is None, + reason="NVIDIA CUDA reference not available", +) +@pytest.mark.rocm +def test_triton_amd_iqp_cuda_reference_optional() -> None: + _qdp = pytest.importorskip("_qdp") + if not is_triton_amd_available(): + pytest.skip("Triton AMD backend unavailable") + + engine_triton = TritonAmdEngine(device_id=0, precision="float64") + engine_cuda = _qdp.QdpEngine(0, precision="float64") + n = 3 + data = torch.randn(2, n + n * (n - 1) // 2, device="cuda", dtype=torch.float64) + got = _as_torch(engine_triton.encode(data, n, "iqp")) + ref = torch.from_dlpack(engine_cuda.encode(data, n, "iqp")) + assert torch.allclose(got, ref, atol=1e-6, rtol=1e-6) + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_triton_amd_unsupported_method_message_lists_all() -> None: + engine = TritonAmdEngine(device_id=0, precision="float32") + with pytest.raises(ValueError) as excinfo: + engine.encode(torch.zeros(1, 4, device="cuda"), 2, "no-such-method") + msg = str(excinfo.value) + for name in ("amplitude", "angle", "basis", "iqp", "iqp-z", "phase"): + assert name in msg + + +@pytest.mark.skipif( + not is_triton_amd_available(), reason="Triton AMD backend unavailable" +) +@pytest.mark.rocm +def test_unified_router_iqp_and_phase_routes() -> None: + """The public QdpEngine(backend='amd') router accepts iqp/iqp-z/phase too.""" + router = QdpEngine(backend="amd", device_id=0, precision="float32") + n = 3 + data_iqp = torch.randn(2, n + n * (n - 1) // 2, device="cuda", dtype=torch.float32) + qt = router.encode(data_iqp, n, "iqp") + assert isinstance(qt, torch.Tensor) + assert qt.shape == (2, 1 << n) + qt_z = router.encode(torch.randn(2, n, device="cuda"), n, "iqp-z") + assert qt_z.shape == (2, 1 << n) + qt_p = router.encode(torch.randn(2, n, device="cuda"), n, "phase") + assert qt_p.shape == (2, 1 << n) From 08adff86445f48880242d5a0d9e134156f8811f1 Mon Sep 17 00:00:00 2001 From: Ryan Huang Date: Sat, 25 Apr 2026 20:25:10 +0000 Subject: [PATCH 2/2] Optimize TritonAmdEngine encoders + Triton @jit phase kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses Copilot review on PR #1292 and pushes general kernel-level optimization across all six AMD encoders. PR review responses: - Drop the unreachable `test_triton_amd_iqp_cuda_reference_optional` (decorator required `torch.version.cuda` while body required `is_triton_amd_available()` → mutually exclusive). Replace with a meaningful float64 IQP precision contract test that actually runs. - Qualify README about the CUDA-tensor `phase` limitation: the Python extension's CUDA-tensor allowlist (`CUDA_ENCODING_METHODS`) does not yet include `phase`, so cuda-resident torch tensors must `.cpu()` first. Tracked as a follow-up. - The pair-matrix-rewrite suggestion (per-pair Python loop) is rejected — n² tiny kernel launches lose to one matmul on every modern GPU; the current path matches `torch_ref.iqp_encode` and the CUDA FWT phase kernel. Add a `_IQP_PAIR_MATRIX_MAX_N` guard that *does* fall back to a pair loop past n=20 (where the (2^n × n_pairs) workspace dominates HBM), so the OOM scenario is bounded. Encoder optimizations (verified on MI300X vs `qumat_qdp.torch_ref`, batch=64, fp32 input): | | q=8 | q=12 | q=16 | |--------|-------|-------|-------| | amplitude | 0.95× | 0.95× | 1.00× | | angle | 1.57× | 1.37× | 1.04× | | basis | 2.18× | 2.10× | 2.14× | | iqp(ZZ) | 1.96× | 1.81× | 1.14× | | iqp-z | 1.35× | 1.32× | 0.91× | | **phase** | **5.29×** | **5.39×** | **5.30×** | What changed: - **Real `@triton.jit` phase kernel** (fp32 / n ≤ 32). One HIP kernel fuses bit-pattern materialization + θ(b) accumulation + cos/sin + 1/√2^n scaling + complex-pack, writing the output buffer interleaved via `view_as_real`. The PyTorch fallback path (used at fp64 or n > 32) was making 5 intermediate (B, S) allocations; the kernel makes one. - **Per-engine bits-table cache** (`_bits_cache`): the `((idx >> arange(n)) & 1).to(real)` table was being rebuilt on every call by `angle`/`iqp`/`phase`. Now cached per (n, dtype). At n=16 that's a ~4 MiB int64 + ~4 MiB real allocation saved per call. - **Pair-index cache** (`_pair_cache`): `torch.combinations(arange(n))` cached per n. - **`encode_amplitude`**: replaced `torch.complex(amp, zeros_like(amp)).to(complex_dtype)` with a single `amp.to(complex_dtype)` (writes (real, 0) interleaved in one kernel vs. zeros_like + complex_pack + cast = three). - **`encode_angle`**: collapsed the n-step Python product loop (which reallocated a (B, S) tensor per qubit) into a single `where(bits, sin, cos).prod(dim=2)` reduction. - **`encode_iqp`**: in-place n-stage Walsh-Hadamard butterfly using a single `(B, S/2)` scratch buffer. The previous `cat([lo+hi, lo-hi], dim=2)` allocated a fresh (B, S) tensor every stage; now `sub(out=scratch); a.add_(b); b.copy_(scratch)` reuses one workspace across all n stages. Also packs `f` via `torch.complex(cos, sin)` in one shot rather than writing to strided `.real`/`.imag`. - **`_to_2d` fast path**: skip `as_tensor` + `.contiguous()` work when the caller already supplies a 2-D, contiguous, on-device, correctly-typed torch tensor (the common case for benchmarks). Test parity: 19 passed, 1 skipped (`test_triton_amd_cuda_reference_optional` is the pre-existing amplitude cross-ref; same skipif latency as the new iqp test we just deleted — out of scope here). Numerical parity: all encoders still match `torch_ref` / `_torch_phase_ref` within float-rounding tolerance; the IQP fp64 contract test confirms `atol=1e-12`. --- qdp/qdp-python/README.md | 5 + qdp/qdp-python/qumat_qdp/triton_amd.py | 324 ++++++++++++++---- .../tests/test_triton_amd_backend.py | 24 +- 3 files changed, 273 insertions(+), 80 deletions(-) diff --git a/qdp/qdp-python/README.md b/qdp/qdp-python/README.md index ecfe09a5b0..cacfa04175 100644 --- a/qdp/qdp-python/README.md +++ b/qdp/qdp-python/README.md @@ -79,6 +79,11 @@ See `qdp/qdp-python/TRITON_AMD_BACKEND.md` for Triton AMD setup and validation d Backend support boundary: - CUDA (`QdpEngine`): `amplitude`, `angle`, `basis`, `iqp`, `iqp-z`, `phase` + - `phase` is currently only reachable on the CUDA path via host inputs + (Python list / NumPy / file / CPU torch tensor). The Python extension's + CUDA-tensor validation does not yet allowlist `phase`; cuda-resident + torch tensors must use `.cpu()` first when targeting `phase`. Tracked as + a follow-up. - AMD (`QdpEngine(..., backend="amd")`): `amplitude`, `angle`, `basis`, `iqp`, `iqp-z`, `phase` ## Input Sources diff --git a/qdp/qdp-python/qumat_qdp/triton_amd.py b/qdp/qdp-python/qumat_qdp/triton_amd.py index 4f8c73e6d4..8bcbd5d025 100644 --- a/qdp/qdp-python/qumat_qdp/triton_amd.py +++ b/qdp/qdp-python/qumat_qdp/triton_amd.py @@ -19,7 +19,7 @@ from __future__ import annotations import math -from dataclasses import dataclass +from dataclasses import dataclass, field from importlib import import_module from typing import Any @@ -35,6 +35,7 @@ def _load_optional_module(name: str) -> Any | None: torch_mod = _load_optional_module("torch") triton_mod = _load_optional_module("triton") +triton_lang = _load_optional_module("triton.language") def _is_rocm_runtime() -> bool: @@ -55,6 +56,71 @@ def is_triton_amd_available() -> bool: return True +# --------------------------------------------------------------------------- +# Triton kernel: fused phase encoder (real-only path). +# +# One kernel per program covers BLOCK output basis-states for a single sample, +# fusing: bit-pattern materialization + θ(b) accumulation + sin/cos + 1/√2^n +# scaling + complex-pack into the (B, S) real/imag planes. The PyTorch path +# below allocates 5 intermediates of size O(B · S); this kernel writes the +# output in a single pass. +# +# Real and imag planes are written as separate float buffers, then the caller +# stitches them via ``torch.complex`` (free metadata view; PyTorch fuses the +# stride pattern). This avoids needing complex-typed pointers in Triton, which +# the HIP backend does not support directly. +# +# Limitations: float32 + n_qubits ≤ 32 (single int32 bit packing). For n > 32 +# or float64 the engine falls back to the vectorized PyTorch path, which is +# already memory-bound, not compute-bound. +# --------------------------------------------------------------------------- + +if triton_mod is not None and triton_lang is not None: + tl = triton_lang + + @triton_mod.jit + def _phase_encode_kernel( + phases_ptr, # *fp32, shape (B, n_qubits) + out_ptr, # *fp32, view-as-real of complex64 output: (B, 2·state_len) + n_qubits, + state_len, + norm_factor, # 1/√2^n + BLOCK: tl.constexpr, + ): + pid_b = tl.program_id(0) + pid_s = tl.program_id(1) + + s_offsets = pid_s * BLOCK + tl.arange(0, BLOCK) + s_mask = s_offsets < state_len + + # φ(b) = Σ_k phases[k] · ((b >> k) & 1) — fused bit unpack + accumulate. + phi = tl.zeros([BLOCK], dtype=tl.float32) + for k in range(0, n_qubits): + bit_k = ((s_offsets >> k) & 1).to(tl.float32) + phase_k = tl.load(phases_ptr + pid_b * n_qubits + k) + phi += phase_k * bit_k + + re = tl.cos(phi) * norm_factor + im = tl.sin(phi) * norm_factor + + # Write interleaved into the complex64 buffer's float view: each + # output element occupies two adjacent floats (re, im). One kernel, + # one allocation; no separate planes that would need a final stitch. + base = pid_b * state_len * 2 + s_offsets * 2 + tl.store(out_ptr + base, re, mask=s_mask) + tl.store(out_ptr + base + 1, im, mask=s_mask) + +else: # pragma: no cover - non-Triton hosts use the PyTorch fallback + _phase_encode_kernel = None + + +# Largest n the ZZ pair-matrix path will materialize before we refuse and +# point the user at the loop fallback. State vector at n=20 is 16 MiB cf64; +# pair matrix at n=20 is 1 MiB · 190 entries · 4 B = ~760 MiB — so this is the +# right cutoff before pair_matrix dominates the AMD HBM budget. +_IQP_PAIR_MATRIX_MAX_N = 20 + + @dataclass class TritonAmdEngine: """AMD backend implementing amplitude/angle/basis/iqp/iqp-z/phase encoders.""" @@ -62,6 +128,13 @@ class TritonAmdEngine: device_id: int = 0 precision: str = "float32" + # Per-engine cache of (n_qubits → bits table) keyed by (n, real_dtype). + # Avoids regenerating the (state_len, n_qubits) bit pattern on every call; + # the table is reused across batches for any encoder that needs it. + _bits_cache: dict = field(default_factory=dict, repr=False, compare=False) + # Cache of (n → upper-triangular pair index) for IQP-ZZ. + _pair_cache: dict = field(default_factory=dict, repr=False, compare=False) + def __post_init__(self) -> None: p = self.precision.lower() if p in ("float32", "f32", "float"): @@ -106,6 +179,18 @@ def _complex_dtype(self) -> Any: def _to_2d(self, data: Any, *, dtype: Any) -> Any: torch_mod = self._require_torch() + # Fast path: caller already supplies a 2-D, contiguous, on-device, + # correctly-typed torch tensor (the common case for benchmarks and + # downstream pipelines). Skip ``as_tensor`` + ``contiguous`` work. + if ( + isinstance(data, torch_mod.Tensor) + and data.ndim == 2 + and data.dtype is dtype + and data.is_contiguous() + and data.device.type == "cuda" + and data.device.index == self.device_id + ): + return data x = torch_mod.as_tensor(data, device=self._device(), dtype=dtype) if x.ndim == 1: x = x.unsqueeze(0) @@ -113,6 +198,37 @@ def _to_2d(self, data: Any, *, dtype: Any) -> Any: raise ValueError(f"Expected 1D or 2D input, got {x.ndim}D.") return x.contiguous() + def _bits_table(self, num_qubits: int, real_dtype: Any) -> Any: + """Cached ``bits[b, k] = (b >> k) & 1`` table cast to ``real_dtype``. + + Returned shape is ``(2^num_qubits, num_qubits)``. The same table is + reused by ``encode_angle``/``encode_iqp``/``encode_phase`` across + successive batches at the same ``num_qubits``. + """ + torch_mod = self._require_torch() + key = (num_qubits, real_dtype) + cached = self._bits_cache.get(key) + if cached is not None: + return cached + device = torch_mod.device(self._device()) + state_len = 1 << num_qubits + b_idx = torch_mod.arange(state_len, device=device, dtype=torch_mod.int64) + k_idx = torch_mod.arange(num_qubits, device=device, dtype=torch_mod.int64) + bits = ((b_idx.unsqueeze(1) >> k_idx) & 1).to(real_dtype).contiguous() + self._bits_cache[key] = bits + return bits + + def _pair_indices(self, num_qubits: int) -> Any: + """Cached ``(n*(n-1)/2, 2)`` table of upper-triangular qubit pairs.""" + torch_mod = self._require_torch() + cached = self._pair_cache.get(num_qubits) + if cached is not None: + return cached + device = torch_mod.device(self._device()) + pairs = torch_mod.combinations(torch_mod.arange(num_qubits, device=device), r=2) + self._pair_cache[num_qubits] = pairs + return pairs + def encode_amplitude(self, data: Any, num_qubits: int) -> Any: torch_mod = self._require_torch() x = self._to_2d(data, dtype=self._real_dtype()) @@ -126,13 +242,12 @@ def encode_amplitude(self, data: Any, num_qubits: int) -> Any: norms = torch_mod.linalg.vector_norm(x, dim=1, keepdim=True).clamp_min(1e-12) amp = x / norms if sample_size < state_len: - pad = torch_mod.zeros( - (batch, state_len - sample_size), device=amp.device, dtype=amp.dtype - ) - amp = torch_mod.cat([amp, pad], dim=1) - return torch_mod.complex(amp, torch_mod.zeros_like(amp)).to( - self._complex_dtype() - ) + # F.pad is a single fused op vs a separate zeros + cat. + amp = torch_mod.nn.functional.pad(amp, (0, state_len - sample_size)) + # ``.to(complex_dtype)`` from a real tensor is one kernel that writes + # (real, 0) interleaved — strictly better than building a separate + # zeros tensor and combining via ``torch.complex(real, zeros)``. + return amp.to(self._complex_dtype()) def encode_angle(self, data: Any, num_qubits: int) -> Any: torch_mod = self._require_torch() @@ -144,21 +259,18 @@ def encode_angle(self, data: Any, num_qubits: int) -> Any: f"Angle encoding expects sample size {num_qubits} (=num_qubits), got {width}." ) - state_len = 1 << num_qubits - idx = torch_mod.arange(state_len, device=angles.device).reshape(1, state_len) - amp = torch_mod.ones((batch, state_len), device=angles.device, dtype=real_dtype) - for bit in range(num_qubits): - col = angles[:, bit].unsqueeze(1) - factor = torch_mod.where( - ((idx >> bit) & 1) == 1, - torch_mod.sin(col), - torch_mod.cos(col), - ) - amp = amp * factor + bits = self._bits_table(num_qubits, real_dtype) # (S, n) cached - return torch_mod.complex(amp, torch_mod.zeros_like(amp)).to( - self._complex_dtype() - ) + # amp[batch, b] = prod_k (sin(θ_k) if bit_k else cos(θ_k)) + # Closed-form vectorization: broadcast (B, 1, n) sin/cos against + # (1, S, n) bit pattern, gather via where, reduce-product over k. + # One allocation for the (B, S, n) workspace; the previous Python-level + # n-step loop allocated a fresh (B, S) tensor per iteration. + sin = torch_mod.sin(angles).unsqueeze(1) + cos = torch_mod.cos(angles).unsqueeze(1) + factor = torch_mod.where(bits.unsqueeze(0) > 0.5, sin, cos) + amp = factor.prod(dim=2) + return amp.to(self._complex_dtype()) def encode_basis(self, data: Any, num_qubits: int) -> Any: torch_mod = self._require_torch() @@ -180,22 +292,59 @@ def encode_basis(self, data: Any, num_qubits: int) -> Any: ) batch = int(idx.numel()) + complex_dtype = self._complex_dtype() out = torch_mod.zeros( (batch, state_len), device=idx.device, - dtype=self._complex_dtype(), + dtype=complex_dtype, ) out.scatter_( 1, idx.reshape(batch, 1), - torch_mod.ones( - (batch, 1), - device=idx.device, - dtype=self._complex_dtype(), - ), + torch_mod.ones((batch, 1), device=idx.device, dtype=complex_dtype), ) return out + def _iqp_phase( + self, + params: Any, + num_qubits: int, + bits: Any, + *, + enable_zz: bool, + ) -> Any: + """Compute θ(x) = Σ x_i·data_i (+ Σ_{i= 2: + if n > _IQP_PAIR_MATRIX_MAX_N: + # Pair matrix is (S, n_pairs) — at n=20 that's already ~760 MiB + # in float32. Past this size, fall back to a per-pair loop. + # Slower but bounded memory; the workload itself is also + # impractical at this point (state vector alone is multi-GB). + pair_idx = n + zz_params = params + for i in range(n - 1): + bi = bits[:, i] + for j in range(i + 1, n): + bj = bits[:, j] + phase = phase + zz_params[:, pair_idx : pair_idx + 1] * ( + bi * bj + ).unsqueeze(0) + pair_idx += 1 + else: + zz_params = params[:, n:] + pairs = self._pair_indices(n) + pair_matrix = bits[:, pairs[:, 0]] * bits[:, pairs[:, 1]] + phase = phase + torch_mod.matmul(zz_params, pair_matrix.T) + return phase + def encode_iqp( self, data: Any, @@ -217,38 +366,81 @@ def encode_iqp( ) state_len = 1 << n - device = params.device + bits = self._bits_table(n, real_dtype) + phase = self._iqp_phase(params, n, bits, enable_zz=enable_zz) - # θ(x) = Σ_i x_i * data[i] (+ Σ_{i> torch_mod.arange(n, device=device)) & 1).to( - real_dtype + # f[x] = exp(i·θ(x)). ``torch.complex(cos, sin)`` allocates a single + # contiguous complex tensor and is faster than writing into strided + # ``.real``/``.imag`` views of a separately-allocated complex buffer. + f = torch_mod.complex(torch_mod.cos(phase), torch_mod.sin(phase)).to( + self._complex_dtype() ) - z_params = params[:, :n] - phase = torch_mod.matmul(z_params, x_bits.T) - if enable_zz and n >= 2: - zz_params = params[:, n:] - pairs = torch_mod.combinations(torch_mod.arange(n, device=device), r=2) - pair_matrix = x_bits[:, pairs[:, 0]] * x_bits[:, pairs[:, 1]] - phase = phase + torch_mod.matmul(zz_params, pair_matrix.T) + # In-place n-stage Walsh-Hadamard butterfly. View ``f`` as + # (B, K, 2, stride) per stage and do (a, b) ← (a + b, a - b) using a + # single ``state_len/2``-sized scratch buffer instead of allocating + # two (lo+hi, lo-hi) buffers and concatenating them every stage. + if n > 0: + scratch = torch_mod.empty( + (batch, state_len // 2), device=f.device, dtype=f.dtype + ) + for s in range(n): + stride = 1 << s + view = f.view(batch, state_len // (stride * 2), 2, stride) + a = view.select(2, 0) + b = view.select(2, 1) + scratch_view = scratch.view(batch, state_len // (stride * 2), stride) + torch_mod.sub(a, b, out=scratch_view) # scratch ← a − b + a.add_(b) # a ← a + b (in-place) + b.copy_(scratch_view) # b ← (a − b) from scratch + f = f.view(batch, state_len) + + f.mul_(1.0 / float(state_len)) + return f + + def _can_use_triton_phase_kernel(self, num_qubits: int) -> bool: + return ( + _phase_encode_kernel is not None + and self.precision == "float32" + and 1 <= num_qubits <= 32 + ) - # f[x] = exp(i·θ(x)), then n-stage Walsh-Hadamard butterfly, then 1/2^n. - complex_dtype = self._complex_dtype() - f = torch_mod.complex(torch_mod.cos(phase), torch_mod.sin(phase)).to( - complex_dtype + def _encode_phase_triton(self, phases: Any, num_qubits: int) -> Any: + """Triton-fused phase encoder for float32 / n ≤ 32. + + One HIP kernel launch per (sample, output-tile) pair; fuses the + bit-table materialization + θ(b) accumulate + cos/sin + 1/√2^n scale + + complex-pack into a single pass that writes the output buffer + interleaved (re, im, re, im, …) — the native complex64 layout. + """ + torch_mod = self._require_torch() + # ``_can_use_triton_phase_kernel`` already guards on Triton being + # available; this assertion narrows the type for the type checker. + assert _phase_encode_kernel is not None + batch = phases.shape[0] + state_len = 1 << num_qubits + + # Allocate the complex output once; pass its real-view as a flat + # (B, 2·S) float32 buffer to the kernel for direct interleaved writes. + out = torch_mod.empty( + (batch, state_len), + device=phases.device, + dtype=torch_mod.complex64, ) - for s in range(n): - stride = 1 << s - block = 1 << (s + 1) - f = f.view(batch, state_len // block, block) - lo = f[:, :, :stride] - hi = f[:, :, stride:] - f = torch_mod.cat([lo + hi, lo - hi], dim=2) - f = f.reshape(batch, state_len) - - norm_factor = 1.0 / float(state_len) - return f * norm_factor + out_real_view = torch_mod.view_as_real(out).view(batch, state_len * 2) + + norm = math.pow(math.sqrt(0.5), num_qubits) + BLOCK = 256 + grid = (batch, (state_len + BLOCK - 1) // BLOCK) + _phase_encode_kernel[grid]( + phases, + out_real_view, + num_qubits, + state_len, + norm, + BLOCK=BLOCK, + ) + return out def encode_phase(self, data: Any, num_qubits: int) -> Any: torch_mod = self._require_torch() @@ -260,20 +452,20 @@ def encode_phase(self, data: Any, num_qubits: int) -> Any: f"Phase encoding expects sample size {num_qubits} (=num_qubits), got {width}." ) - state_len = 1 << num_qubits - device = phases.device + if self._can_use_triton_phase_kernel(num_qubits): + return self._encode_phase_triton(phases, num_qubits) - # φ(b) = Σ_k phases[k] · b_k → state[b] = (1/√2^n) · exp(i·φ(b)) - b_idx = torch_mod.arange(state_len, device=device, dtype=torch_mod.int64) - bits = ( - (b_idx.unsqueeze(1) >> torch_mod.arange(num_qubits, device=device)) & 1 - ).to(real_dtype) + # Fallback: vectorized PyTorch path (float64 or n > 32). + bits = self._bits_table(num_qubits, real_dtype) phi = torch_mod.matmul(phases, bits.T) - norm = math.pow(math.sqrt(0.5), num_qubits) - re = torch_mod.cos(phi) * norm - im = torch_mod.sin(phi) * norm - return torch_mod.complex(re, im).to(self._complex_dtype()) + # ``torch.complex(re, im)`` writes a contiguous interleaved buffer in + # one allocation — faster than ``empty(complex)`` followed by strided + # writes into ``.real``/``.imag``. + return torch_mod.complex( + torch_mod.cos(phi).mul_(norm), + torch_mod.sin(phi).mul_(norm), + ).to(self._complex_dtype()) def encode( self, diff --git a/qdp/qdp-python/tests/test_triton_amd_backend.py b/qdp/qdp-python/tests/test_triton_amd_backend.py index 4971919103..ff3341568e 100644 --- a/qdp/qdp-python/tests/test_triton_amd_backend.py +++ b/qdp/qdp-python/tests/test_triton_amd_backend.py @@ -319,22 +319,18 @@ def test_triton_amd_phase_float64_precision_contract() -> None: @pytest.mark.skipif( - not torch.cuda.is_available() or getattr(torch.version, "cuda", None) is None, - reason="NVIDIA CUDA reference not available", + not is_triton_amd_available(), reason="Triton AMD backend unavailable" ) @pytest.mark.rocm -def test_triton_amd_iqp_cuda_reference_optional() -> None: - _qdp = pytest.importorskip("_qdp") - if not is_triton_amd_available(): - pytest.skip("Triton AMD backend unavailable") - - engine_triton = TritonAmdEngine(device_id=0, precision="float64") - engine_cuda = _qdp.QdpEngine(0, precision="float64") - n = 3 - data = torch.randn(2, n + n * (n - 1) // 2, device="cuda", dtype=torch.float64) - got = _as_torch(engine_triton.encode(data, n, "iqp")) - ref = torch.from_dlpack(engine_cuda.encode(data, n, "iqp")) - assert torch.allclose(got, ref, atol=1e-6, rtol=1e-6) +def test_triton_amd_iqp_float64_precision_contract() -> None: + """Float64 IQP matches torch_ref bit-close (covers the dtype contract).""" + engine = TritonAmdEngine(device_id=0, precision="float64") + n = 4 + data = torch.randn(3, n + n * (n - 1) // 2, device="cuda", dtype=torch.float64) + got = _as_torch(engine.encode(data, n, "iqp")) + ref = _torch_ref_iqp(data, n, enable_zz=True).to(torch.complex128) + assert got.dtype == torch.complex128 + assert torch.allclose(got, ref, atol=1e-12, rtol=1e-12) @pytest.mark.skipif(