Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 111 additions & 4 deletions gimmik/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,130 @@ class HIPMatMul(MatMul):
basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0}

def _kernel_generators(self, dtype, dsize, *, gcn_arch=None, warp_size=64):
max_block_threads = 1024
max_shared = 64*1024

def emit(name, args, meta):
block = meta.get('block', self.basemeta['block'])
shared = meta.get('shared', self.basemeta['shared'])
threads = block[0]*block[1]*block[2]

if threads <= max_block_threads and shared <= max_shared:
yield (name, args, meta)

blkx = self.basemeta['block'][0]

# B loading, C streaming kernel
yield ('cstream', {}, {})
yield from emit('cstream', {'blockx': blkx}, {})

# B streaming, C accumulation kernel
yield ('bstream', {}, {})
yield from emit('bstream', {'blockx': blkx}, {})

# Four-way m-split B streaming, C accumulation kernel
ms, bsz, blkx = 4, 24, 64
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx}
meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize}
yield ('bstream-msplit', args, meta)
yield from emit('bstream-msplit', args, meta)

# Two-way k-split B loading, C streaming kernel
ks, csz, blkx = 2, 24, 64
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize}
yield ('cstream-ksplit', args, meta)
yield from emit('cstream-ksplit', args, meta)

# Tuned HIP variants
msplits, ksplits = [4, 8], [2, 4]
bsz, csz, blkx = 8, 8, 64
width = 2 if self.aligne is not None and self.aligne % 2 == 0 else 1

# B loading, C streaming kernel
args = {'blockx': blkx}
meta = {'block': (blkx, 1, 1), 'desc': f'cstream/x{blkx}'}
yield from emit('cstream', args, meta)

# B streaming, C accumulation kernel
meta = {'block': (blkx, 1, 1), 'desc': f'bstream/x{blkx}'}
yield from emit('bstream', args, meta)

for ms in msplits:
# m-split B streaming, C accumulation kernel
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx}
shared = 2*bsz*blkx*dsize
meta = {'block': (blkx, ms, 1), 'shared': shared,
'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'}
yield from emit('bstream-msplit', args, meta)

for ks in ksplits:
# k-split B loading, C streaming kernel
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
shared = (ks - 1)*csz*blkx*dsize
meta = {'block': (blkx, ks, 1), 'shared': shared,
'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'}
yield from emit('cstream-ksplit', args, meta)

# B loading, C preloading, C streaming kernel
args = {'blockx': blkx}
meta = {'block': (blkx, 1, 1), 'desc': f'cstream-preload-c/x{blkx}'}
yield from emit('cstream-preload-c', args, meta)

# B streaming, C preloading, C accumulation kernel
meta = {'block': (blkx, 1, 1), 'desc': f'bstream-preload-c/x{blkx}'}
yield from emit('bstream-preload-c', args, meta)

if width > 1:
args = {'dtype': f'{dtype}{width}', 'width': width,
'blockx': blkx}
meta = {'block': (blkx, 1, 1), 'width': width,
'desc': f'cstream-width-preload-c/w{width}-x{blkx}'}
yield from emit('cstream-width-preload-c', args, meta)

meta = {'block': (blkx, 1, 1), 'width': width,
'desc': f'bstream-width-preload-c/w{width}-x{blkx}'}
yield from emit('bstream-width-preload-c', args, meta)

for ms in msplits:
# m-split B streaming, C preloading, C accumulation kernel
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx}
shared = 2*bsz*blkx*dsize
meta = {'block': (blkx, ms, 1), 'shared': shared,
'desc': f'bstream-msplit-preload-c/m{ms}-b{bsz}-x{blkx}'}
yield from emit('bstream-msplit-preload-c', args, meta)

if width > 1:
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx,
'dtype': f'{dtype}{width}', 'width': width}
meta = {
'block': (blkx, ms, 1), 'shared': shared*width,
'width': width,
'desc': (
f'bstream-msplit-width-preload-c/w{width}-'
f'm{ms}-b{bsz}-x{blkx}'
)
}
yield from emit('bstream-msplit-width-preload-c', args, meta)

for ks in ksplits:
# k-split B loading, C preloading, C streaming kernel
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
shared = (ks - 1)*csz*blkx*dsize
meta = {
'block': (blkx, ks, 1), 'shared': shared,
'desc': f'cstream-ksplit-preload-c/k{ks}-c{csz}-x{blkx}'
}
yield from emit('cstream-ksplit-preload-c', args, meta)

if width > 1:
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx,
'dtype': f'{dtype}{width}', 'width': width}
meta = {
'block': (blkx, ks, 1), 'shared': shared*width,
'width': width,
'desc': (
f'cstream-ksplit-width-preload-c/w{width}-'
f'k{ks}-c{csz}-x{blkx}'
)
}
yield from emit('cstream-ksplit-width-preload-c', args, meta)

def _process_meta(self, meta):
if self.n is not None:
Expand Down
32 changes: 32 additions & 0 deletions gimmik/kernels/hip/base.mako
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,36 @@ static inline __device__ ${dtype} make_zero()
{ return 0; }
% endif

static inline __device__ void
nt_store_c(${dtype}* p, ${dtype} v)
{
% if dtype.endswith('4'):
__builtin_nontemporal_store(v.x, &p->x);
__builtin_nontemporal_store(v.y, &p->y);
__builtin_nontemporal_store(v.z, &p->z);
__builtin_nontemporal_store(v.w, &p->w);
% elif dtype.endswith('2'):
__builtin_nontemporal_store(v.x, &p->x);
__builtin_nontemporal_store(v.y, &p->y);
% else:
__builtin_nontemporal_store(v, p);
% endif
}

static inline __device__ ${dtype}
nt_load_c(const ${dtype}* p)
{
% if dtype.endswith('4'):
return make_${dtype}(__builtin_nontemporal_load(&p->x),
__builtin_nontemporal_load(&p->y),
__builtin_nontemporal_load(&p->z),
__builtin_nontemporal_load(&p->w));
% elif dtype.endswith('2'):
return make_${dtype}(__builtin_nontemporal_load(&p->x),
__builtin_nontemporal_load(&p->y));
% else:
return __builtin_nontemporal_load(p);
% endif
}

${next.body()}
98 changes: 98 additions & 0 deletions gimmik/kernels/hip/bstream-msplit-preload-c.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
<%inherit file='base'/>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this into the msplit kernel which preload as an option using % if/else as appropriate to switch between the two?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I merged the preload-C path into the existing msplit template behind a preload option, and removed the separate preload-C template.


<%
mx = partition(A, into=msplit, by='rows')
bchunks = chunk(bix, bsz)
%>

__global__ __launch_bounds__(${blockx*msplit}) void
% if n is None:
${kname}(int n,
const ${dtype}* __restrict__ b, int ldb,
${dtype}* __restrict__ c, int ldc)
{
% if width > 1:
n = ((n + ${width} - 1) / ${width}) * ${width};
ldb /= ${width};
ldc /= ${width};
% endif
% else:
${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
const int n = ${-(-n // width)};
const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width};
const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width};
% endif
int i = blockDim.x*blockIdx.x + threadIdx.x;

${dtype} bv, csub[${-(-m // msplit)}];
__shared__ ${dtype} bsub[2][${bsz}][${blockx}];

## Fill the initial shared memory block
% for cid in range(msplit):
if (i < n && threadIdx.y == ${cid})
{
% for kx in bchunks[0]:
% if loop.index % msplit == cid:
bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb];
% endif
% endfor

## Preload C values for active rows owned by this m-split lane
% for j, jx in enumerate(mx[cid]):
% if afix[jx] != -1:
% if beta == 0:
csub[${j}] = make_zero();
% elif beta == 1:
csub[${j}] = nt_load_c(&c[i + ${jx}*ldc]);
% else:
csub[${j}] = ${beta}*nt_load_c(&c[i + ${jx}*ldc]);
% endif
% endif
% endfor
}
% endfor
__syncthreads();

## Iterate over each row-chunk of B
% for bb in range(len(bchunks)):
## Iterate over each row-chunk of C
% for cid, mcx in enumerate(mx):
if (i < n && threadIdx.y == ${cid})
{
## Start filling the next shared memory block
% if not loop.parent.last:
% for kx in bchunks[bb + 1]:
% if loop.index % msplit == cid:
bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb];
% endif
% endfor
% endif
## Accumulate our dot products
% for kx in bchunks[bb]:
bv = bsub[${bb % 2}][${loop.index}][threadIdx.x];
% for j, jx in enumerate(A[mcx, kx]):
% if jx != 0:
csub[${j}] += ${jx}*bv;
% endif
## If we're done with this dot product then store to global
% if kx == alix[mcx[j]]:
nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]);
% endif
% endfor
% endfor
## Handle rows of A which are all zero
% if loop.parent.last:
% for j, jx in enumerate(afix):
% if jx == -1 and j % msplit == cid and beta == 0:
nt_store_c(&c[i + ${j}*ldc], make_zero());
% elif jx == -1 and j % msplit == cid and beta != 1:
nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc]));
% endif
% endfor
% endif
}
% endfor
__syncthreads();
% endfor
}
Loading