-
Notifications
You must be signed in to change notification settings - Fork 16
Add tuned HIP GiMMiK preload-C and width variants with non-temporal loads and stores #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 8 commits
8e23d63
8f4d03e
96671a6
739a82e
0633539
7b59fb0
e9b921a
2aa2577
be1c1db
280e948
c06216d
e014e4d
f6bc308
a3aee45
2c7af9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| <%inherit file='base'/> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we merge this into the msplit kernel which preload as an option using % if/else as appropriate to switch between the two?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I merged the preload-C path into the existing msplit template behind a |
||
|
|
||
| <% | ||
| mx = partition(A, into=msplit, by='rows') | ||
| bchunks = chunk(bix, bsz) | ||
| %> | ||
|
|
||
| __global__ __launch_bounds__(${blockx*msplit}) void | ||
| % if n is None: | ||
| ${kname}(int n, | ||
| const ${dtype}* __restrict__ b, int ldb, | ||
| ${dtype}* __restrict__ c, int ldc) | ||
| { | ||
| % if width > 1: | ||
| n = (n + ${width} - 1) / ${width}; | ||
| ldb /= ${width}; | ||
| ldc /= ${width}; | ||
| % endif | ||
| % else: | ||
| ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) | ||
| { | ||
| const int n = ${-(-n // width)}; | ||
| const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; | ||
| const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; | ||
| % endif | ||
| int i = blockDim.x*blockIdx.x + threadIdx.x; | ||
|
|
||
| ${dtype} bv, csub[${-(-m // msplit)}]; | ||
| __shared__ ${dtype} bsub[2][${bsz}][${blockx}]; | ||
|
|
||
| ## Fill the initial shared memory block | ||
| % for cid in range(msplit): | ||
| if (i < n && threadIdx.y == ${cid}) | ||
| { | ||
| % for kx in bchunks[0]: | ||
| % if loop.index % msplit == cid: | ||
| bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; | ||
| % endif | ||
| % endfor | ||
|
|
||
| % if beta != 0: | ||
| ## Preload C values for active rows owned by this m-split lane | ||
| % for j, jx in enumerate(mx[cid]): | ||
| % if afix[jx] != -1: | ||
| % if beta == 1: | ||
| csub[${j}] = load_c(&c[i + ${jx}*ldc]); | ||
| % else: | ||
| csub[${j}] = gimmik_vmul(${beta}, load_c(&c[i + ${jx}*ldc])); | ||
| % endif | ||
| % endif | ||
| % endfor | ||
| % endif | ||
| } | ||
| % endfor | ||
| __syncthreads(); | ||
|
|
||
| ## Iterate over each row-chunk of B | ||
| % for bb in range(len(bchunks)): | ||
| ## Iterate over each row-chunk of C | ||
| % for cid, mcx in enumerate(mx): | ||
| if (i < n && threadIdx.y == ${cid}) | ||
| { | ||
| ## Start filling the next shared memory block | ||
| % if not loop.parent.last: | ||
| % for kx in bchunks[bb + 1]: | ||
| % if loop.index % msplit == cid: | ||
| bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; | ||
| % endif | ||
| % endfor | ||
| % endif | ||
| ## Accumulate our dot products | ||
| % for kx in bchunks[bb]: | ||
| bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; | ||
| % for j, jx in enumerate(A[mcx, kx]): | ||
| % if beta == 0: | ||
| % if jx != 0 and kx == afix[mcx[j]]: | ||
| csub[${j}] = gimmik_vmul(${jx}, bv); | ||
| % elif jx != 0: | ||
| csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); | ||
| % endif | ||
| % elif jx != 0: | ||
| csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); | ||
| % endif | ||
| ## If we're done with this dot product then store to global | ||
| % if kx == alix[mcx[j]]: | ||
| store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); | ||
| % endif | ||
| % endfor | ||
| % endfor | ||
| ## Handle rows of A which are all zero | ||
| % if loop.parent.last: | ||
| % for j, jx in enumerate(afix): | ||
| % if jx == -1 and j % msplit == cid and beta == 0: | ||
| store_c(&c[i + ${j}*ldc], make_zero()); | ||
| % elif jx == -1 and j % msplit == cid and beta != 1: | ||
| store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); | ||
| % endif | ||
| % endfor | ||
| % endif | ||
| } | ||
| % endfor | ||
| __syncthreads(); | ||
| % endfor | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define overloads like the CUDA backend does:
https://github.com/PyFR/GiMMiK/blob/master/gimmik/kernels/cuda/base.mako#L13
Just keeps the code a little cleaner and more consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I updated the HIP base template to use vector operator overloads in the same style as CUDA. I omitted
operator+=since HIP already provides it for vector types and defining it here caused an overload ambiguity.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Is the
+=a runtime issue that should be reported upstream? Seems odd to provide an overloaded for+=but nothing else. (If I recall, earlier versions of ROCm provided a full suite of overloads, but that caused issues with CUDA code compatibility where none are provided, so newer versions switch to not providing overloads. Did+=sneak through?)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked amd_hip_vector_types.h and found that operator+= is defined as a member function of HIP_vector_type<T, n> (L344), so defining a free function operator+= alongside it causes an overload ambiguity. By contrast, operator+ (L503) and operator* (L535) are defined as template free functions, so our definitions coexist without conflict. The omission of operator+= is therefore intentional rather than an upstream issue to report.