-
Notifications
You must be signed in to change notification settings - Fork 81
Emit vector constant for vector index #921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
9874d9c
2bf9c8b
133055a
a1f9606
65766ba
3061cc6
af7d947
141d9d7
7d95c97
144b3df
fcf2487
e261f65
89474ea
1b9334d
337f5e4
a52ea24
420d16f
a624ffd
34e50f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,7 @@ | |
|
|
||
| from loopy.codegen import CodeGenerationState | ||
| from loopy.codegen.result import CodeGenerationResult | ||
| from loopy.kernel import LoopKernel | ||
|
|
||
|
|
||
| # {{{ dtype registry wrappers | ||
|
|
@@ -456,7 +457,8 @@ def get_opencl_callables(): | |
|
|
||
| # {{{ symbol mangler | ||
|
|
||
| def opencl_symbol_mangler(kernel, name): | ||
| def opencl_symbol_mangler(kernel: LoopKernel, | ||
| name: str) -> tuple[NumpyType, str] | None: | ||
| # FIXME: should be more picky about exact names | ||
| if name.startswith("FLT_"): | ||
| return NumpyType(np.dtype(np.float32)), name | ||
|
|
@@ -545,6 +547,21 @@ def wrap_in_typecast(self, actual_type, needed_dtype, s): | |
| from pymbolic.primitives import Comparison | ||
| return Comparison(s, "!=", 0) | ||
|
|
||
| if needed_dtype == actual_type: | ||
| return s | ||
|
|
||
| registry = self.codegen_state.ast_builder.target.get_dtype_registry() | ||
| if self.codegen_state.target.is_vector_dtype(needed_dtype): | ||
| # OpenCL does not let you do explicit vector type casts between vector | ||
| # types. Instead you need to call their function which is of the form | ||
| # <desttype> convert_<desttype><n>(src) where n | ||
| # is the number of elements in the vector which is the same as in src. | ||
| # https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_C.html#explicit-casts | ||
| if self.codegen_state.target.is_vector_dtype(actual_type) or \ | ||
| actual_type.dtype.kind == "b": | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type, |
||
| cast = var("convert_%s" % registry.dtype_to_ctype(needed_dtype)) | ||
| return cast(s) | ||
|
|
||
| return super().wrap_in_typecast(actual_type, needed_dtype, s) | ||
|
|
||
| def map_group_hw_index(self, expr, type_context): | ||
|
|
@@ -553,6 +570,69 @@ def map_group_hw_index(self, expr, type_context): | |
| def map_local_hw_index(self, expr, type_context): | ||
| return var("lid")(expr.axis) | ||
|
|
||
| def map_variable(self, expr, type_context): | ||
|
|
||
| if self.codegen_state.vectorization_info: | ||
| if self.codegen_state.vectorization_info.iname == expr.name: | ||
| # This needs to be converted into a vector literal. | ||
| from loopy.symbolic import Literal | ||
| vector_length = self.codegen_state.vectorization_info.length | ||
| index_type = self.codegen_state.kernel.index_dtype | ||
| vector_type = self.codegen_state.target.vector_dtype(index_type, | ||
| vector_length) | ||
| typecast = self.codegen_state.target.dtype_to_typename(vector_type) | ||
|
nkoskelo marked this conversation as resolved.
Outdated
|
||
| vector_literal = f"(({typecast})" + " (" + \ | ||
| ",".join([f"{i}" for i in range(vector_length)]) + "))" | ||
| return Literal(vector_literal) | ||
| return super().map_variable(expr, type_context) | ||
|
|
||
| def map_if(self, expr, type_context): | ||
| from loopy.types import to_loopy_type | ||
| result_type = self.infer_type(expr) | ||
| conditional_needed_loopy_type = to_loopy_type(np.bool_) | ||
| if self.codegen_state.vectorization_info: | ||
| from loopy.codegen import UnvectorizableError | ||
| from loopy.expression import VectorizabilityChecker | ||
| checker = VectorizabilityChecker(self.codegen_state.kernel, | ||
| self.codegen_state.vectorization_info.iname, | ||
| self.codegen_state.vectorization_info.length) | ||
|
|
||
| try: | ||
| is_vector = checker(expr) | ||
|
|
||
| if is_vector: | ||
| """ | ||
| We could have a vector literal here which may need to be | ||
| converted to an appropriate size. The OpenCL specification states | ||
| that for ( c ? a : b) a, b, and c must have the same | ||
| number of elements and bits and that c must be an integral type. | ||
| https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_C.html#table-builtin-relational | ||
| """ | ||
| index_type = to_loopy_type(self.codegen_state.kernel.index_dtype) | ||
| types = {8: to_loopy_type(np.int64), 4: to_loopy_type(np.int32), | ||
| 2: to_loopy_type(np.int16), 1: to_loopy_type(np.int8)} | ||
| length = self.codegen_state.vectorization_info.length | ||
| if index_type.itemsize != result_type.itemsize and \ | ||
| result_type.itemsize in types.keys(): | ||
|
nkoskelo marked this conversation as resolved.
Outdated
|
||
| # Need to convert index type into result type size. | ||
| # Item size is measured in bytes. | ||
| index_type = types[result_type.itemsize] | ||
| elif index_type.itemsize * length != result_type.itemsize and \ | ||
| (result_type.itemsize // length) in types.keys(): | ||
|
nkoskelo marked this conversation as resolved.
Outdated
|
||
|
|
||
| index_type = types[result_type.itemsize // length] | ||
| vector_type = self.codegen_state.target.vector_dtype(index_type, | ||
| length) | ||
| conditional_needed_loopy_type = to_loopy_type(vector_type) | ||
| except UnvectorizableError: | ||
|
nkoskelo marked this conversation as resolved.
Outdated
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we suppressing exceptions here in the first place?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We know that
nkoskelo marked this conversation as resolved.
Outdated
|
||
| pass | ||
|
|
||
| return type(expr)( | ||
| self.rec(expr.condition, type_context, | ||
| conditional_needed_loopy_type), | ||
| self.rec(expr.then, type_context, result_type), | ||
| self.rec(expr.else_, type_context, result_type), | ||
| ) | ||
| # }}} | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.