Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cbdb757
add localized allocation and deallocation
dsding2 Jun 2, 2025
2fee158
delete commented out code
dsding2 Jun 2, 2025
8ace895
deal with base storage
dsding2 Jun 4, 2025
c4e635c
ruff check fixes
dsding2 Jun 5, 2025
24b1a47
rework to push allocations outside of loops
dsding2 Jun 8, 2025
be78797
add types, fix ruff
dsding2 Jun 9, 2025
461558d
Merge remote-tracking branch 'upstream/main' into opencl_allocation
dsding2 Jun 13, 2025
0bcf4df
Merge branch 'main' into opencl_allocation
dsding2 Jun 17, 2025
0b6abdd
refactor to make more target-generic
dsding2 Jun 17, 2025
4f95a6b
resolve lingering merge issues
dsding2 Jun 17, 2025
bd98636
fix to only allocate global temporaries
dsding2 Jun 17, 2025
47dda68
move temp declarations to ASTBuilder
dsding2 Jun 19, 2025
e494a3b
Merge branch 'main' into opencl_allocation
dsding2 Jun 19, 2025
88c436f
fix typing
dsding2 Jun 19, 2025
dae91e2
fix typing hopefully
dsding2 Jun 23, 2025
1cfe83a
add basic test
dsding2 Jun 23, 2025
3c3bb78
Merge branch 'main' into opencl_allocation
dsding2 Jun 30, 2025
3ef324c
more typing/ruff fixes
dsding2 Jun 30, 2025
f708b66
fix tutorial.rst and add to baseline
dsding2 Jun 30, 2025
a0a8365
Merge branch 'main' into opencl_allocation
inducer Jul 5, 2025
f12ce9f
Merge branch 'main' into opencl_allocation
inducer Jul 10, 2025
452be6b
Merge branch 'main' into opencl_allocation
inducer Jul 10, 2025
3985576
Update loopy/schedule/tools.py
dsding2 Jul 11, 2025
95e119e
Apply suggested test changes
dsding2 Jul 11, 2025
5cbfbf1
implement rename and documentation suggestions
dsding2 Jul 11, 2025
612b238
ruff fixes, revert broken change
dsding2 Jul 12, 2025
4b0d754
Merge branch 'main' into opencl_allocation
inducer Jul 28, 2025
e591ae6
Merge branch 'main' into opencl_allocation
inducer Jul 31, 2025
1290c64
Merge branch 'main' into opencl_allocation
inducer Aug 28, 2025
b24fe99
Improvements
inducer Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -8269,8 +8269,8 @@
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 56,
"endColumn": 17,
"lineCount": 8
"endColumn": 63,
"lineCount": 1
}
},
{
Expand All @@ -8281,6 +8281,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 52,
"endColumn": 59,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
26 changes: 21 additions & 5 deletions loopy/codegen/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,23 @@ def generate_code_for_sched_index(
glob_grid, loc_grid = kernel.get_grid_sizes_for_insn_ids_as_exprs(
get_insn_ids_for_block_at(kernel.linearization, sched_index),
codegen_state.callables_table)
return merge_codegen_results(codegen_state, [
codegen_result,

prefixes, suffixes = (
codegen_state.ast_builder.get_temporary_decl_at_index(
codegen_state, sched_index
)
)
results = [
prefixes,
codegen_result,
codegen_state.ast_builder.get_kernel_call(
codegen_state,
sched_item.kernel_name,
glob_grid, loc_grid)
])
glob_grid, loc_grid),
suffixes
]
results = [r for r in results if r is not None]
return merge_codegen_results(codegen_state, results)
else:
# do not generate host code for non-entrypoint kernels
return codegen_result
Expand Down Expand Up @@ -136,7 +145,14 @@ def generate_code_for_sched_index(
"for '%s', tagged '%s'"
% (sched_item.iname, ", ".join(str(tag) for tag in tags)))

return func(codegen_state, sched_index)
prefixes, suffixes = (
codegen_state.ast_builder.get_temporary_decl_at_index(
codegen_state, sched_index
)
)
results = [prefixes, func(codegen_state, sched_index), suffixes]
results = [r for r in results if r is not None]
return merge_codegen_results(codegen_state, results)

elif isinstance(sched_item, Barrier):
# {{{ emit barrier code
Expand Down
103 changes: 103 additions & 0 deletions loopy/schedule/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,109 @@ def supporting_temporary_names(

return frozenset(result)


def get_temporary_decl_blocks(
Comment thread
inducer marked this conversation as resolved.
Outdated
kernel: LoopKernel
) -> tuple[dict[int, frozenset[str]], dict[int, frozenset[str]]]:
Comment thread
inducer marked this conversation as resolved.
Outdated
from loopy.kernel.data import AddressSpace
from loopy.schedule import CallKernel, EnterLoop

assert kernel.linearization is not None

global_temporaries = frozenset(
tv.name for tv in kernel.temporary_variables.values()
if tv.address_space == AddressSpace.GLOBAL
)

# Collapse into blocks
def get_temporaries_in_bounds(
linearization: Sequence[ScheduleItem],
lower_bound: int,
upper_bound: int
) -> frozenset[str]:
temporaries: frozenset[str] = frozenset()
for sched_index in range(lower_bound, upper_bound+1):
sched_item = linearization[sched_index]
if isinstance(sched_item, CallKernel):
temporaries = (
temporaries_written_in_subkernel(kernel, sched_item.kernel_name)
| temporaries_read_in_subkernel(
kernel, sched_item.kernel_name
)
| (temporaries)
)
return temporaries & global_temporaries

block_boundaries = get_block_boundaries(kernel.linearization)

bounds: dict[int, frozenset[str]] = {}
sched_index = 0
while sched_index < len(kernel.linearization):
sched_item = kernel.linearization[sched_index]
if isinstance(sched_item, EnterLoop) or isinstance(sched_item, CallKernel):
if isinstance(sched_item, CallKernel):
block_end = block_boundaries[sched_index]
accessed_temporaries = (
temporaries_written_in_subkernel(kernel, sched_item.kernel_name)
| temporaries_read_in_subkernel(
kernel, sched_item.kernel_name
)
)
else:
block_end = block_boundaries[sched_index]
accessed_temporaries = get_temporaries_in_bounds(
kernel.linearization, sched_index, block_end
)
bounds[sched_index] = accessed_temporaries
sched_index = block_end + 1
else:
sched_index += 1

def update_seen_storage_vars(
seen_sv: set[str],
new_temp_variables: frozenset[str]
) -> frozenset[str]:
new_storage_variables: set[str] = set()
past_sv = frozenset(seen_sv)
for new_tv_name in new_temp_variables:
new_tv = kernel.temporary_variables[new_tv_name]
if new_tv.base_storage is None:
storage_var = new_tv_name
else:
storage_var = new_tv.base_storage
new_storage_variables.add(storage_var)
seen_sv.add(storage_var)
new_sv = frozenset(new_storage_variables)
return new_sv - past_sv
# forward pass for first accesses
first_accesses: dict[int, frozenset[str]] = {}
seen_storage_variables: set[str] = set()
for sched_index in range(0, len(kernel.linearization)):
if (sched_index not in bounds):
continue
sched_item = kernel.linearization[sched_index]
new_temporary_variables = bounds[sched_index]
new_storage_variables = update_seen_storage_vars(
seen_storage_variables, new_temporary_variables
)

if (len(new_storage_variables) > 0):
first_accesses[sched_index] = new_storage_variables

last_accesses: dict[int, frozenset[str]] = {}
seen_storage_variables.clear()
for sched_index in range(len(kernel.linearization)-1, -1, -1):
if (sched_index not in bounds):
continue
sched_item = kernel.linearization[sched_index]
new_temporary_variables = bounds[sched_index]
new_storage_variables = update_seen_storage_vars(
seen_storage_variables, new_temporary_variables
)

if (len(new_storage_variables) > 0):
Comment thread
dsding2 marked this conversation as resolved.
Outdated
last_accesses[sched_index] = new_storage_variables
return (first_accesses, last_accesses)
# }}}


Expand Down
37 changes: 37 additions & 0 deletions loopy/target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from loopy.codegen import CodeGenerationState, PreambleInfo
from loopy.codegen.result import CodeGenerationResult
from loopy.kernel import LoopKernel
from loopy.kernel.data import TemporaryVariable
from loopy.target.c import DTypeRegistry
from loopy.target.execution import ExecutorBase
from loopy.translation_unit import CallableId, CallablesTable, TranslationUnit
Expand Down Expand Up @@ -251,6 +252,21 @@ def get_temporary_decls(self, codegen_state: CodeGenerationState,
schedule_index: int) -> ASTType:
raise NotImplementedError

def get_temporary_var_declarator(self,
codegen_state: CodeGenerationState,
temp_var: TemporaryVariable) -> ASTType:
raise NotImplementedError()

def get_temporary_var_deallocator(self,
codegen_state: CodeGenerationState,
temp_var: TemporaryVariable) -> ASTType:
raise NotImplementedError()

def get_temporary_decl_at_index(
self, codegen_state: CodeGenerationState,
sched_index: int) -> tuple[ASTType | None, ASTType | None]:
raise NotImplementedError()

def get_kernel_call(self, codegen_state: CodeGenerationState,
subkernel_name: str,
gsize: tuple[Expression, ...],
Expand Down Expand Up @@ -365,6 +381,27 @@ def get_expression_to_code_mapper(self, codegen_state):
def get_kernel_call(self, codegen_state, name, gsize, lsize):
return None

@override
def get_temporary_var_declarator(
self, codegen_state: CodeGenerationState,
temp_var: TemporaryVariable
) -> None:
return None

@override
def get_temporary_var_deallocator(
self, codegen_state: CodeGenerationState,
temp_var: TemporaryVariable
) -> None:
return None

@override
def get_temporary_decl_at_index(
self, codegen_state: CodeGenerationState,
sched_index: int
) -> tuple[None, None]:
return (None, None)

@property
def ast_block_class(self):
return _DummyASTBlock
Expand Down
14 changes: 14 additions & 0 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from cgen import (
Block,
Collection,
Comment,
Const,
Declarator,
Generable,
Expand Down Expand Up @@ -1107,6 +1108,12 @@ def get_temporary_decls(self, codegen_state, schedule_index):

return result

@override
def get_temporary_decl_at_index(
self, codegen_state: CodeGenerationState, sched_index: int
) -> tuple[Generable | None, Generable | None]:
return (None, None)

@property
@override
def ast_block_class(self):
Expand Down Expand Up @@ -1240,6 +1247,7 @@ def arg_to_cgen_declarator(
raise ValueError(f"unexpected type of argument '{passed_name}': "
f"'{type(var_descr)}'")

@override
def get_temporary_var_declarator(self,
codegen_state: CodeGenerationState,
temp_var: TemporaryVariable) -> Declarator:
Expand Down Expand Up @@ -1272,6 +1280,12 @@ def get_temporary_var_declarator(self,
return self.wrap_decl_for_address_space(temp_var_decl,
temp_var.address_space)

@override
def get_temporary_var_deallocator(self,
codegen_state: CodeGenerationState,
temp_var: TemporaryVariable
) -> Generable:
return Comment("Dynamic freeing of temp vars not supported")
# }}}

@override
Expand Down
Loading
Loading