Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/user_guide/fastcache.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ Sub-functions called by the kernel are also checked — they must not capture ex

Other named constants (non-enum, non-module) captured from scope will raise a `QuadrantsCompilationError`, except for `UPPERCASE` names which emit a warning instead.

Wrapping a captured global in `qd.static(...)` does **not** exempt it from this check. `qd.static` only controls compile-time evaluation; it does not put the value into the cache key, so a `qd.static`-wrapped global is still flagged — though during the current transition period this emits a warning rather than raising. To use such a constant in a fastcache kernel, pass it as a parameter (template primitive, `@qd.data_oriented` member, or dataclass field) or make it one of the allowed captures above.

### 2. Supported parameter types

Fastcache supports the following parameter types:
Expand Down
22 changes: 16 additions & 6 deletions python/quadrants/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,18 @@ def build_Name(ctx: ASTTransformerFuncContext, node: ast.Name):
if isinstance(node, (ast.stmt, ast.expr)) and isinstance(node.ptr, Expr):
node.ptr.dbg_info = _qd_core.DebugInfo(ctx.get_pos_info(node))
node.ptr.ptr.set_dbg_info(node.ptr.dbg_info)
if ctx.is_pure and node.violates_pure and not ctx.static_scope_status.is_in_static_scope:
if isinstance(node.ptr, (float, int, Field)):
# ``qd.static`` is intentionally NOT a purity escape hatch: a captured module global is still flagged inside
# a static scope, since its value never enters the fastcache key regardless of static wrapping.
if ctx.is_pure and node.violates_pure:
# ``str`` is included alongside the numeric/``Field`` types: a captured string only affects a kernel through
# compile-time ``qd.static`` branches, and its value never enters the fastcache key, so it is cache-unsafe
# in exactly the same way as a captured int/float.
if isinstance(node.ptr, (float, int, str, Field)):
if not _is_quadrants_internal_file(ctx.file):
message = f"[PURE.VIOLATION] WARNING: Accessing global variable {node.id} {type(node.ptr)} {node.violates_pure_reason}"
if node.id.upper() == node.id:
# Transition period: violations inside a ``qd.static`` scope only warn instead of raising, giving
# downstream code time to migrate such constants to kernel params. ``UPPERCASE`` names also warn.
if node.id.upper() == node.id or ctx.is_in_static_scope():
warnings.warn(message)
else:
raise exception.QuadrantsCompilationError(message)
Expand Down Expand Up @@ -782,8 +789,10 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute):
node.violates_pure = node.value.violates_pure
if node.violates_pure:
node.violates_pure_reason = node.value.violates_pure_reason
if ctx.is_pure and node.violates_pure and not ctx.static_scope_status.is_in_static_scope:
if isinstance(node.ptr, (int, float, Field)):
# ``qd.static`` is intentionally NOT a purity escape hatch (see ``build_Name``).
if ctx.is_pure and node.violates_pure:
# ``str`` included for the same reason as in ``build_Name``: a captured string is cache-unsafe.
if isinstance(node.ptr, (int, float, str, Field)):
violation = True
if violation and isinstance(node.ptr, enum.Enum):
violation = False
Expand All @@ -793,7 +802,8 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute):
violation = False
if violation:
message = f"[PURE.VIOLATION] WARNING: Accessing global var {node.attr} from outside function scope within pure kernel {node.value.violates_pure_reason}"
if node.attr.upper() == node.attr:
# Transition period (see ``build_Name``): ``qd.static`` scope downgrades this to a warning.
if node.attr.upper() == node.attr or ctx.is_in_static_scope():
warnings.warn(message)
else:
raise exception.QuadrantsCompilationError(message)
Expand Down
54 changes: 54 additions & 0 deletions tests/python/quadrants/lang/fast_caching/test_pure_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ def k2():
k2()


@test_utils.test()
def test_pure_validation_str():
# A captured ``str`` global is cache-unsafe in the same way as a captured int/float, so it must trigger a purity
# violation. Direct access (not wrapped in ``qd.static``) of a lowercase-named global raises.
s = "hello"

@qd.kernel(pure=True)
def k1():
print(s)

with pytest.raises(qd.QuadrantsCompilationError):
k1()


@test_utils.test()
def test_pure_validation_field():
a = qd.field(qd.i32, (10,))
Expand Down Expand Up @@ -282,3 +296,43 @@ def k1() -> qd.i32:

with pytest.warns(UserWarning, match=r"\[PURE\.VIOLATION\]"):
assert k1() == 32


# Restricted to a single (CPU) arch on purpose: the purity check is a Python-side AST analysis and is entirely
# arch-independent, and running it across multiple archs in one worker lets a fastcache hit from one arch suppress the
# warning on the next, which makes ``pytest.warns`` flaky.
@test_utils.test(arch=qd.cpu)
def test_pure_validation_static_scope_warns():
# Transition period: a captured global accessed inside a ``qd.static`` scope of a pure kernel only warns instead of
# raising, to give downstream code time to migrate such constants to kernel parameters.
assert qd.lang is not None
arch = qd.lang.impl.current_cfg().arch
qd.init(arch=arch, offline_cache=False)

use_alias = True

@qd.kernel(pure=True)
def k1() -> qd.i32:
ret = 0
if qd.static(use_alias):
ret = 1
return ret

with pytest.warns(UserWarning, match=r"\[PURE\.VIOLATION\]"):
assert k1() == 1

class Cfg:
def __init__(self) -> None:
self.flag = True

cfg = Cfg()

@qd.kernel(pure=True)
def k2() -> qd.i32:
ret = 0
if qd.static(cfg.flag):
ret = 1
return ret

with pytest.warns(UserWarning, match=r"\[PURE\.VIOLATION\]"):
assert k2() == 1
8 changes: 4 additions & 4 deletions tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,14 @@ def src_ll_cache_has_return_child(args: list[str]) -> None:

@qd.pure
@qd.kernel
def k1(a: qd.i32, output: qd.types.NDArray[qd.i32, 1]) -> bool:
def k1(a: qd.i32, output: qd.types.NDArray[qd.i32, 1], return_something: qd.Template) -> bool:
output[0] = a
if qd.static(args_obj.return_something):
if qd.static(return_something):
return True

output = qd.ndarray(qd.i32, (10,))
if args_obj.return_something:
assert k1(3, output)
assert k1(3, output, args_obj.return_something)
# Sanity check that the kernel actually ran, and did something.
assert output[0] == 3
assert k1._primal.src_ll_cache_observations.cache_key_generated == args_obj.expect_used_src_ll_cache
Expand All @@ -314,7 +314,7 @@ def k1(a: qd.i32, output: qd.types.NDArray[qd.i32, 1]) -> bool:
with pytest.raises(
qd.QuadrantsSyntaxError, match="Kernel has a return type but does not have a return statement"
):
k1(3, output)
k1(3, output, args_obj.return_something)
print(TEST_RAN)
sys.exit(RET_SUCCESS)

Expand Down
33 changes: 20 additions & 13 deletions tests/python/test_tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,18 @@ def test_zeros(TILE, make_tile, tdim, m_size, tensor_type, qd_dtype, use_zeros_a
Ann = _ann(tensor_type, qd_dtype, 2)

@qd.kernel(fastcache=True)
def k1(dst_arr: Ann, N: qd.Template):
def k1(dst_arr: Ann, N: qd.Template, use_alias: qd.Template):
qd.loop_config(block_dim=N)
tile_size = N
for _ in range(tile_size):
if qd.static(use_zeros_alias):
if qd.static(use_alias):
t = Tile.zeros()
t._store(dst_arr, 0, tile_size, 0, tile_size)
else:
t = Tile()
t._store(dst_arr, 0, tile_size, 0, tile_size)

k1(dst, tdim)
k1(dst, tdim, use_zeros_alias)
np.testing.assert_allclose(dst.to_numpy(), np.zeros((tdim, tdim), dtype=np_dtype))


Expand All @@ -117,7 +117,7 @@ def test_eye(TILE, make_tile, tdim, m_size, tensor_type, qd_dtype, inplace):
Ann = _ann(tensor_type, qd_dtype, 2)

@qd.kernel(fastcache=True)
def k1(src_arr: Ann, dst_arr: Ann, N: qd.Template):
def k1(src_arr: Ann, dst_arr: Ann, N: qd.Template, inplace: qd.Template):
qd.loop_config(block_dim=N)
tile_size = N
for _ in range(tile_size):
Expand All @@ -132,7 +132,7 @@ def k1(src_arr: Ann, dst_arr: Ann, N: qd.Template):

data = np.arange(tdim * tdim, dtype=np_dtype).reshape(tdim, tdim) + 100.0
src.from_numpy(data)
k1(src, dst, tdim)
k1(src, dst, tdim, inplace)
np.testing.assert_allclose(dst.to_numpy(), np.eye(tdim, dtype=np_dtype))


Expand Down Expand Up @@ -907,7 +907,7 @@ def test_load_slice_errors(TILE, make_tile, tdim, m_size, bad_slice, match):
dst = qd.ndarray(qd.f32, (tdim, tdim))

@qd.kernel(fastcache=True)
def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template):
def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template, bad_slice: qd.Template):
qd.loop_config(block_dim=N)
tile_size = N
for _ in range(tile_size):
Expand All @@ -923,7 +923,7 @@ def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Tem
d[0:tile_size, 0:tile_size] = t

with pytest.raises(QuadrantsSyntaxError, match=match):
k1(src, dst, tdim)
k1(src, dst, tdim, bad_slice)


@pytest.mark.parametrize(
Expand All @@ -942,7 +942,7 @@ def test_store_slice_errors(TILE, make_tile, tdim, m_size, bad_slice, match):
dst = qd.ndarray(qd.f32, (tdim, tdim))

@qd.kernel(fastcache=True)
def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template):
def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template, bad_slice: qd.Template):
qd.loop_config(block_dim=N)
tile_size = N
for _ in range(tile_size):
Expand All @@ -958,7 +958,7 @@ def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Tem
d[0:, 0:tile_size] = t

with pytest.raises(QuadrantsSyntaxError, match=match):
k1(src, dst, tdim)
k1(src, dst, tdim, bad_slice)


@test_utils.test(arch=qd.gpu)
Expand Down Expand Up @@ -1104,7 +1104,7 @@ def test_vec_slice_errors(TILE, make_tile, tdim, m_size, bad_slice):
dst = qd.ndarray(qd.f32, (tdim, tdim))

@qd.kernel(fastcache=True)
def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template):
def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template, bad_slice: qd.Template):
qd.loop_config(block_dim=N)
tile_size = N
for _ in range(tile_size):
Expand All @@ -1117,7 +1117,7 @@ def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Tem
d[0:tile_size, 0:tile_size] = t

with pytest.raises(QuadrantsSyntaxError, match="both start and stop"):
k1(src, dst, tdim)
k1(src, dst, tdim, bad_slice)


# =============================================================================
Expand Down Expand Up @@ -1328,7 +1328,14 @@ def test_shared_array_partial_cols(TILE, make_tile, tdim, m_size, partial_store,
dst = qd.field(dtype=qd.f32, shape=(tdim, tdim))

@qd.kernel(fastcache=True)
def k1(src_f: qd.Template, dst_f: qd.Template, NCOLS: qd.i32, N: qd.Template):
def k1(
src_f: qd.Template,
dst_f: qd.Template,
NCOLS: qd.i32,
N: qd.Template,
partial_store: qd.Template,
partial_load: qd.Template,
):
qd.loop_config(block_dim=N)
tile_size = N
for _ in range(tile_size):
Expand Down Expand Up @@ -1356,7 +1363,7 @@ def k1(src_f: qd.Template, dst_f: qd.Template, NCOLS: qd.i32, N: qd.Template):

data = np.arange(tdim * tdim, dtype=np.float32).reshape(tdim, tdim) + 1.0
src.from_numpy(data)
k1(src, dst, NCOLS, tdim)
k1(src, dst, NCOLS, tdim, partial_store, partial_load)
result = dst.to_numpy()
np.testing.assert_allclose(result[:, :NCOLS], data[:, :NCOLS])
if partial_load:
Expand Down
Loading