Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion gimmik/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gimmik.hip import HIPMatMul
from gimmik.metal import MetalMatMul
from gimmik.opencl import OpenCLMatMul
from gimmik.ptx import PTXMatMul


def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm',
Expand All @@ -22,7 +23,8 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm',
'cuda': CUDAMatMul,
'ispc': ISPCMatMul,
'hip': HIPMatMul,
'opencl': OpenCLMatMul
'opencl': OpenCLMatMul,
'ptx': PTXMatMul
}

mm = platmap[platform](alpha*mat, beta, None, n, ldb, ldc)
Expand Down
3 changes: 2 additions & 1 deletion gimmik/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def _render_kernel(self, dtype, tplname, tplargs):
src = tpl.render(**tplargs)

# At single precision suffix all floating point constants by 'f'
if dtype == 'float':
# (PTX doesn't use an 'f' suffix for FP literals)
if dtype == 'float' and self.platform != 'ptx':

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have an attr like _needs_fp32_suffix = True|False to avoid the PTX check.

src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?',
r'\g<0>f', src)

Expand Down
4 changes: 4 additions & 0 deletions gimmik/kernels/ptx/base.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.version 8.6
.target sm_${cc[0]}${cc[1]}${'a' if cc[0] >= 9 else ''}
.address_size 64
${next.body()}
261 changes: 261 additions & 0 deletions gimmik/kernels/ptx/bstream-msplit.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
<%inherit file='base'/>

<%
mx = partition(A, into=msplit, by='rows')
bchunks = chunk(bix_list, bsz)
m_per_group = max(len(mcx) for mcx in mx)
bsub_bytes = 2 * bsz * blockx * dwidth_i
def bsub_off(buf, idx):
return (buf * bsz + idx) * blockx * dwidth_i
use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8)
Comment thread
WillTrojak marked this conversation as resolved.
Outdated
%>

% if n is None:
.visible .entry ${kname}(.param .u32 _n,
.param .u64 _b,
.param .u32 _ldb,
.param .u64 _c,
.param .u32 _ldc)
{
.reg .u32 ldb, ldc;
ld.param.u32 ldb, [_ldb];
ld.param.u32 ldc, [_ldc];
% else:
.visible .entry ${kname}(.param .u64 _b,
.param .u64 _c)
{
% endif
.reg .u32 n, id, tid_x, tid_y;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure we throw higher up if n is too big.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking here

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't handle n being too large in any of the other backends.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PyFR/GiMMiK/blob/master/gimmik/kernels/cuda/cstream.mako#L20 in the embedded case we do (argument case doesn't but that is not currently used for CUDA).

.reg .u64 b, c, b_base, c_base, bsub_thread;
% if use_cpasync:
.reg .u32 bsub_sm_thread;
% endif
.reg .${pftype} bv, csub<${m_per_group}>;
.reg .pred p1, p_skip;
.shared .align 8 .b8 _bsub[${bsub_bytes}];

% if n is None:
ld.param.u32 n, [_n];
% else:
mov.u32 n, ${n};
% endif
ld.param.u64 b, [_b];
ld.param.u64 c, [_c];

{
.reg .u32 _ctaid_x;
mov.u32 _ctaid_x, %ctaid.x;
mov.u32 tid_x, %tid.x;
mov.u32 tid_y, %tid.y;
mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x;
}

setp.ge.u32 p1, id, n;
@p1 bra $L_EXIT;

cvta.to.global.u64 b, b;
cvta.to.global.u64 c, c;

{
.reg .u64 _id64;
cvt.u64.u32 _id64, id;
mad.lo.u64 b_base, _id64, ${dwidth_i}, b;
mad.lo.u64 c_base, _id64, ${dwidth_i}, c;
}

{
.reg .u64 _tx_off;
mul.wide.u32 _tx_off, tid_x, ${dwidth_i};
mov.u64 bsub_thread, _bsub;
add.u64 bsub_thread, bsub_thread, _tx_off;
}
% if use_cpasync:
{
.reg .u64 _sm64;
cvta.to.shared.u64 _sm64, bsub_thread;
cvt.u32.u64 bsub_sm_thread, _sm64;
}
% endif

% for cid, mcx in enumerate(mx):
## cid = ${cid}, rows ${mcx}
setp.ne.u32 p_skip, tid_y, ${cid};
@p_skip bra $L_END_CID_${cid};

% if use_cpasync:
## Async fill of chunk 0
% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]:
% if n is None:
{
.reg .u32 _boff;
.reg .u64 _bptr;
mul.lo.u32 _boff, ldb, ${kx};
mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base;
cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i};
}
% else:
cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i};
% endif
% endfor
cp.async.commit_group;
cp.async.wait_all;
bar.sync 0;
% else:
## Sync fill of chunk 0
% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]:
{
.reg .${pftype} _bv;
% if n is None:
.reg .u32 _boff;
.reg .u64 _bptr;
mul.lo.u32 _boff, ldb, ${kx};
mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base;
ld.weak.global.cg.${pftype} _bv, [_bptr];
% else:
ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}];
% endif
st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv;
}
% endfor
bar.sync 0;
% endif

## Main loop over B-chunks (double-buffered)
% for bb in range(len(bchunks)):
<%
buf_cur = bb % 2
buf_next = (bb + 1) % 2
%>
% if not loop.last:
% for idx, kx in [(i, k) for i, k in enumerate(bchunks[bb + 1]) if i % msplit == cid]:
% if use_cpasync:
% if n is None:
{
.reg .u32 _boff;
.reg .u64 _bptr;
mul.lo.u32 _boff, ldb, ${kx};
mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base;
cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i};
}
% else:
cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i};
% endif
% else:
{
.reg .${pftype} _bv;
% if n is None:
.reg .u32 _boff;
.reg .u64 _bptr;
mul.lo.u32 _boff, ldb, ${kx};
mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base;
ld.weak.global.cg.${pftype} _bv, [_bptr];
% else:
ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}];
% endif
st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv;
}
% endif
% endfor
% if use_cpasync:
cp.async.commit_group;
% endif
% endif

% for idx, kx in enumerate(bchunks[bb]):
ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}];
% for j, row_j in enumerate(mcx):
<% jx = A[row_j, kx] %>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See if NumPy can be used in the for loop A[mcx, kx]

% if jx != 0 and kx == afix[row_j]:
mul.${pftype} csub${j}, bv, ${jx};
% elif jx != 0:
fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j};
% endif
% if kx == alix[row_j]:
% if beta_zero:
% if n is None:
{
.reg .u32 _coff;
.reg .u64 _cptr;
mul.lo.u32 _coff, ldc, ${row_j};
mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base;
st.weak.global.cg.${pftype} [_cptr], csub${j};
}
% else:
st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j};
% endif
% else:
{
.reg .${pftype} _ctmp;
% if n is None:
.reg .u32 _coff;
.reg .u64 _cptr;
mul.lo.u32 _coff, ldc, ${row_j};
mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base;
ld.weak.global.cg.${pftype} _ctmp, [_cptr];
fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j};
st.weak.global.${pftype} [_cptr], _ctmp;
% else:
ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}];
fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j};
st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp;
% endif
}
% endif
% endif
% endfor
% endfor
% if use_cpasync:
% if not loop.last:
cp.async.wait_all;
% endif
% endif
bar.sync 0;
% endfor
## End of Main loop over B-chunks

## Handle zero rows in this cid's group
% if has_zero_rows:
% for row_j in mcx:
% if afix[row_j] == -1:
% if beta_zero:
{
.reg .${pftype} _tmp;
mov.${pftype} _tmp, ${fzero};
% if n is None:
.reg .u32 _coff;
.reg .u64 _cptr;
mul.lo.u32 _coff, ldc, ${row_j};
mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base;
st.weak.global.cg.${pftype} [_cptr], _tmp;
% else:
st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp;
% endif
}
% elif beta != 1:
{
.reg .${pftype} _tmp;
% if n is None:
.reg .u32 _coff;
.reg .u64 _cptr;
mul.lo.u32 _coff, ldc, ${row_j};
mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base;
ld.weak.global.cg.${pftype} _tmp, [_cptr];
mul.${pftype} _tmp, _tmp, ${float(beta)};
st.weak.global.${pftype} [_cptr], _tmp;
% else:
ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}];
mul.${pftype} _tmp, _tmp, ${float(beta)};
st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp;
% endif
}
% endif
% endif
% endfor
% endif

$L_END_CID_${cid}:
% endfor

$L_EXIT:
ret;
}
Loading