Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 103 additions & 4 deletions gimmik/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,122 @@ class HIPMatMul(MatMul):
basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0}

def _kernel_generators(self, dtype, dsize, *, gcn_arch=None, warp_size=64):
max_block_threads = 1024
max_shared = 64*1024

def emit(name, args, meta):
block = meta.get('block', self.basemeta['block'])
shared = meta.get('shared', self.basemeta['shared'])
threads = block[0]*block[1]*block[2]

if threads <= max_block_threads and shared <= max_shared:
yield (name, args, meta)

blkx = self.basemeta['block'][0]

# B loading, C streaming kernel
yield ('cstream', {}, {})
yield from emit('cstream', {'blockx': blkx}, {})

# B streaming, C accumulation kernel
yield ('bstream', {}, {})
yield from emit('bstream', {'blockx': blkx}, {})

# Four-way m-split B streaming, C accumulation kernel
ms, bsz, blkx = 4, 24, 64
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx}
meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize}
yield ('bstream-msplit', args, meta)
yield from emit('bstream-msplit', args, meta)

# Two-way k-split B loading, C streaming kernel
ks, csz, blkx = 2, 24, 64
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize}
yield ('cstream-ksplit', args, meta)
yield from emit('cstream-ksplit', args, meta)

# Only emit tuned variants on architectures they have been validated for.
base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None
if base_arch not in {'gfx90a', 'gfx942'} or warp_size != 64:
return

# Tuned HIP variants
msplits, ksplits = [4, 8], [2, 4]
bsz, csz, blkx = 8, 8, 64
widths = [1]
if self.aligne is not None and self.aligne % 2 == 0:
widths.append(2)

for width in widths:
wargs = ({'dtype': f'{dtype}{width}', 'width': width}
if width > 1 else {})
wmeta = {'width': width} if width > 1 else {}
wpfx = f'w{width}-' if width > 1 else ''

# B loading, C streaming kernel
args = {'blockx': blkx} | wargs
meta = {'block': (blkx, 1, 1),
'desc': f'cstream/{wpfx}x{blkx}'} | wmeta
yield from emit('cstream', args, meta)

# B streaming, C accumulation kernel
meta = {'block': (blkx, 1, 1),
'desc': f'bstream/{wpfx}x{blkx}'} | wmeta
yield from emit('bstream', args, meta)

for ms in msplits:
# m-split B streaming, C accumulation kernel
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs
shared = 2*bsz*blkx*dsize*width
meta = {
'block': (blkx, ms, 1), 'shared': shared,
'desc': f'bstream-msplit/{wpfx}m{ms}-b{bsz}-x{blkx}'
} | wmeta
yield from emit('bstream-msplit', args, meta)

for ks in ksplits:
# k-split B loading, C streaming kernel
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs
shared = (ks - 1)*csz*blkx*dsize*width
meta = {
'block': (blkx, ks, 1), 'shared': shared,
'desc': f'cstream-ksplit/{wpfx}k{ks}-c{csz}-x{blkx}'
} | wmeta
yield from emit('cstream-ksplit', args, meta)

# B loading, C preloading, C streaming kernel
args = {'blockx': blkx} | wargs
meta = {'block': (blkx, 1, 1),
'desc': f'cstream-preload-c/{wpfx}x{blkx}'} | wmeta
yield from emit('cstream-preload-c', args, meta)

# B streaming, C preloading, C accumulation kernel
meta = {'block': (blkx, 1, 1),
'desc': f'bstream-preload-c/{wpfx}x{blkx}'} | wmeta
yield from emit('bstream-preload-c', args, meta)

for ms in msplits:
# m-split B streaming, C preloading, C accumulation kernel
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs
shared = 2*bsz*blkx*dsize*width
meta = {
'block': (blkx, ms, 1), 'shared': shared,
'desc': (
f'bstream-msplit-preload-c/'
f'{wpfx}m{ms}-b{bsz}-x{blkx}'
)
} | wmeta
yield from emit('bstream-msplit-preload-c', args, meta)

for ks in ksplits:
# k-split B loading, C preloading, C streaming kernel
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs
shared = (ks - 1)*csz*blkx*dsize*width
meta = {
'block': (blkx, ks, 1), 'shared': shared,
'desc': (
f'cstream-ksplit-preload-c/'
f'{wpfx}k{ks}-c{csz}-x{blkx}'
)
} | wmeta
yield from emit('cstream-ksplit-preload-c', args, meta)

def _process_meta(self, meta):
if self.n is not None:
Expand Down
105 changes: 105 additions & 0 deletions gimmik/kernels/hip/base.mako
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,109 @@ static inline __device__ ${dtype} make_zero()
{ return 0; }
% endif

% if width == 1:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define overloads like the CUDA backend does:

https://github.com/PyFR/GiMMiK/blob/master/gimmik/kernels/cuda/base.mako#L13

Just keeps the code a little cleaner and more consistent.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I updated the HIP base template to use vector operator overloads in the same style as CUDA. I omitted operator+= since HIP already provides it for vector types and defining it here caused an overload ambiguity.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Is the += a runtime issue that should be reported upstream? Seems odd to provide an overloaded for += but nothing else. (If I recall, earlier versions of ROCm provided a full suite of overloads, but that caused issues with CUDA code compatibility where none are provided, so newer versions switch to not providing overloads. Did += sneak through?)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked amd_hip_vector_types.h and found that operator+= is defined as a member function of HIP_vector_type<T, n> (L344), so defining a free function operator+= alongside it causes an overload ambiguity. By contrast, operator+ (L503) and operator* (L535) are defined as template free functions, so our definitions coexist without conflict. The omission of operator+= is therefore intentional rather than an upstream issue to report.

static inline __device__ ${dtype}
gimmik_vmul(${dtype} a, ${dtype} b)
{
return a*b;
}

static inline __device__ ${dtype}
gimmik_vadd(${dtype} a, ${dtype} b)
{
return a + b;
}

static inline __device__ ${dtype}
gimmik_vmadd(${dtype} acc, ${dtype} a, ${dtype} b)
{
// Keep the multiply-add expression visible to the compiler.
return acc + a*b;
}
% elif width == 2:
static inline __device__ ${dtype}
gimmik_vmul(${dtype[:-1]} a, ${dtype} b)
{
return make_${dtype}(a*b.x, a*b.y);
}

static inline __device__ ${dtype}
gimmik_vadd(${dtype} a, ${dtype} b)
{
return make_${dtype}(a.x + b.x, a.y + b.y);
}

static inline __device__ ${dtype}
gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b)
{
// Keep the multiply-add expression visible to the compiler.
return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y);
}
% elif width == 4:
static inline __device__ ${dtype}
gimmik_vmul(${dtype[:-1]} a, ${dtype} b)
{
return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w);
}

static inline __device__ ${dtype}
gimmik_vadd(${dtype} a, ${dtype} b)
{
return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}

static inline __device__ ${dtype}
gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b)
{
// Keep the multiply-add expression visible to the compiler.
return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w);
}
% else:
#error "HIP vector helpers only support width=2 or width=4"
% endif

static inline __device__ void
nt_store_c(${dtype}* p, ${dtype} v)
{
% if dtype.endswith('4'):
__builtin_nontemporal_store(v.x, &p->x);
__builtin_nontemporal_store(v.y, &p->y);
__builtin_nontemporal_store(v.z, &p->z);
__builtin_nontemporal_store(v.w, &p->w);
% elif dtype.endswith('2'):
__builtin_nontemporal_store(v.x, &p->x);
__builtin_nontemporal_store(v.y, &p->y);
% else:
__builtin_nontemporal_store(v, p);
% endif
}

static inline __device__ ${dtype}
nt_load_c(const ${dtype}* p)
{
% if dtype.endswith('4'):
return make_${dtype}(__builtin_nontemporal_load(&p->x),
__builtin_nontemporal_load(&p->y),
__builtin_nontemporal_load(&p->z),
__builtin_nontemporal_load(&p->w));
% elif dtype.endswith('2'):
return make_${dtype}(__builtin_nontemporal_load(&p->x),
__builtin_nontemporal_load(&p->y));
% else:
return __builtin_nontemporal_load(p);
% endif
}

static inline __device__ void
store_c(${dtype}* p, ${dtype} v)
{
nt_store_c(p, v);
}

static inline __device__ ${dtype}
load_c(const ${dtype}* p)
{
return nt_load_c(p);
}

${next.body()}
104 changes: 104 additions & 0 deletions gimmik/kernels/hip/bstream-msplit-preload-c.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
<%inherit file='base'/>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this into the msplit kernel which preload as an option using % if/else as appropriate to switch between the two?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I merged the preload-C path into the existing msplit template behind a preload option, and removed the separate preload-C template.


<%
mx = partition(A, into=msplit, by='rows')
bchunks = chunk(bix, bsz)
%>

__global__ __launch_bounds__(${blockx*msplit}) void
% if n is None:
${kname}(int n,
const ${dtype}* __restrict__ b, int ldb,
${dtype}* __restrict__ c, int ldc)
{
% if width > 1:
n = (n + ${width} - 1) / ${width};
ldb /= ${width};
ldc /= ${width};
% endif
% else:
${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
const int n = ${-(-n // width)};
const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width};
const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width};
% endif
int i = blockDim.x*blockIdx.x + threadIdx.x;

${dtype} bv, csub[${-(-m // msplit)}];
__shared__ ${dtype} bsub[2][${bsz}][${blockx}];

## Fill the initial shared memory block
% for cid in range(msplit):
if (i < n && threadIdx.y == ${cid})
{
% for kx in bchunks[0]:
% if loop.index % msplit == cid:
bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb];
% endif
% endfor

% if beta != 0:
## Preload C values for active rows owned by this m-split lane
% for j, jx in enumerate(mx[cid]):
% if afix[jx] != -1:
% if beta == 1:
csub[${j}] = load_c(&c[i + ${jx}*ldc]);
% else:
csub[${j}] = gimmik_vmul(${beta}, load_c(&c[i + ${jx}*ldc]));
% endif
% endif
% endfor
% endif
}
% endfor
__syncthreads();

## Iterate over each row-chunk of B
% for bb in range(len(bchunks)):
## Iterate over each row-chunk of C
% for cid, mcx in enumerate(mx):
if (i < n && threadIdx.y == ${cid})
{
## Start filling the next shared memory block
% if not loop.parent.last:
% for kx in bchunks[bb + 1]:
% if loop.index % msplit == cid:
bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb];
% endif
% endfor
% endif
## Accumulate our dot products
% for kx in bchunks[bb]:
bv = bsub[${bb % 2}][${loop.index}][threadIdx.x];
% for j, jx in enumerate(A[mcx, kx]):
% if beta == 0:
% if jx != 0 and kx == afix[mcx[j]]:
csub[${j}] = gimmik_vmul(${jx}, bv);
% elif jx != 0:
csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv);
% endif
% elif jx != 0:
csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv);
% endif
## If we're done with this dot product then store to global
% if kx == alix[mcx[j]]:
store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]);
% endif
% endfor
% endfor
## Handle rows of A which are all zero
% if loop.parent.last:
% for j, jx in enumerate(afix):
% if jx == -1 and j % msplit == cid and beta == 0:
store_c(&c[i + ${j}*ldc], make_zero());
% elif jx == -1 and j % msplit == cid and beta != 1:
store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])));
% endif
% endfor
% endif
}
% endfor
__syncthreads();
% endfor
}
16 changes: 8 additions & 8 deletions gimmik/kernels/hip/bstream-msplit.mako
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ${kname}(int n,
${dtype}* __restrict__ c, int ldc)
{
% if width > 1:
n = ((n + ${width} - 1) / ${width}) * ${width};
n = (n + ${width} - 1) / ${width};
ldb /= ${width};
ldc /= ${width};
% endif
Expand Down Expand Up @@ -60,27 +60,27 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
bv = bsub[${bb % 2}][${loop.index}][threadIdx.x];
% for j, jx in enumerate(A[mcx, kx]):
% if jx != 0 and kx == afix[mcx[j]]:
csub[${j}] = ${jx}*bv;
csub[${j}] = gimmik_vmul(${jx}, bv);
% elif jx != 0:
csub[${j}] += ${jx}*bv;
csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv);
% endif
## If we're done with this dot product then store to global
% if kx == alix[mcx[j]] and beta == 0:
c[i + ${mcx[j]}*ldc] = csub[${j}];
store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]);
% elif kx == alix[mcx[j]] and beta == 1:
c[i + ${mcx[j]}*ldc] += csub[${j}];
store_c(&c[i + ${mcx[j]}*ldc], gimmik_vadd(load_c(&c[i + ${mcx[j]}*ldc]), csub[${j}]));
% elif kx == alix[mcx[j]]:
c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc];
store_c(&c[i + ${mcx[j]}*ldc], gimmik_vadd(csub[${j}], gimmik_vmul(${beta}, load_c(&c[i + ${mcx[j]}*ldc]))));
% endif
% endfor
% endfor
## Handle rows of A which are all zero
% if loop.parent.last:
% for j, jx in enumerate(afix):
% if jx == -1 and j % msplit == cid and beta == 0:
c[i + ${j}*ldc] = make_zero();
store_c(&c[i + ${j}*ldc], make_zero());
% elif jx == -1 and j % msplit == cid and beta != 1:
c[i + ${j}*ldc] *= ${beta};
store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])));
% endif
% endfor
% endif
Expand Down
Loading