Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions kernels/hgemm_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from flydsl._mlir.dialects import fly, llvm, memref, scf
from flydsl.compiler.kernel_function import CompilationContext
from flydsl.expr import arith, buffer_ops, const_expr, gpu, range_constexpr, rocdl, vector
from flydsl.expr.typing import T
from flydsl.expr.typing import T, as_ir_value
from flydsl.runtime.device import get_rocm_arch
from flydsl.utils.smem_allocator import SMEM_CAPACITY_MAP, SmemAllocator, SmemPtr
from kernels.tensor_shim import GTensor, STensor, _run_compiled, get_dtype_in_kernel
Expand Down Expand Up @@ -298,7 +298,7 @@ def zero_c():
# zero c if current block is the first block
is_t0_cond = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(tid), fx.Index(0))
cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0))
cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False)
cond_ks0_if = scf.IfOp(as_ir_value(cond_ks0), results_=[], has_else=False)
with ir.InsertionPoint(cond_ks0_if.then_block):
zero_vec = vector.broadcast(T.vec(LDG_VEC_SIZE, dtype_), c_zero_d)
for i in range_constexpr(LDG_REG_C_COUNT):
Expand All @@ -310,7 +310,7 @@ def zero_c():
if const_expr(HAS_BIAS):
init_vec = BIAS_.vec_load((n_offset + n_local_idx,), LDG_VEC_SIZE)
cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(m))
cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False)
cond_boundary_if = scf.IfOp(as_ir_value(cond_boundary), results_=[], has_else=False)
with ir.InsertionPoint(cond_boundary_if.then_block):
bytes_offset = C_.linear_offset((row_idx, n_offset + n_local_idx))
bytes_offset_i32 = arith.index_cast(T.i32, bytes_offset)
Expand All @@ -325,7 +325,7 @@ def zero_c():
scf.YieldOp([])
gpu.barrier()
# trigger signal when zeroc is done by the first arrived block
is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False)
is_t0_cond_if = scf.IfOp(as_ir_value(is_t0_cond), results_=[], has_else=False)
with ir.InsertionPoint(is_t0_cond_if.then_block):
signal_ptr = get_llvm_ptr(signal, signal_idx, 4)
llvm.InlineAsmOp(
Expand All @@ -342,7 +342,7 @@ def zero_c():
def split_k_barrier():
# spin-wait until signal triggered
is_t0_cond = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(tid), fx.Index(0))
is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False)
is_t0_cond_if = scf.IfOp(as_ir_value(is_t0_cond), results_=[], has_else=False)
with ir.InsertionPoint(is_t0_cond_if.then_block):
init_cur = arith.constant(0, type=T.i32)
w = scf.WhileOp([T.i32], [init_cur])
Expand All @@ -367,7 +367,7 @@ def split_k_barrier():
rocdl.sched_barrier(0)
gpu.barrier()
# clean semaphore and signal if this is the last block within split-k group
is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False)
is_t0_cond_if = scf.IfOp(as_ir_value(is_t0_cond), results_=[], has_else=False)
with ir.InsertionPoint(is_t0_cond_if.then_block):
semaphore_ptr = get_llvm_ptr(semaphore, signal_idx, 4)
arrive_idx = llvm.AtomicRMWOp(
Expand All @@ -379,7 +379,7 @@ def split_k_barrier():
alignment=4,
).result
cond_ksl = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(arrive_idx), fx.Index(SPLIT_K - 1))
cond_ksl_if = scf.IfOp(cond_ksl, results_=[], has_else=False)
cond_ksl_if = scf.IfOp(as_ir_value(cond_ksl), results_=[], has_else=False)
with ir.InsertionPoint(cond_ksl_if.then_block):
semaphore_[signal_idx] = arith.constant(0, type=T.i32)
signal_[signal_idx] = arith.constant(0, type=T.i32)
Expand Down Expand Up @@ -682,7 +682,7 @@ def hot_loop_scheduler():
m_global_idx = m_offset + m_local_idx
n_global_idx = n_offset + n_local_idx
cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, m_global_idx, fx.Index(m))
cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False)
cond_boundary_if = scf.IfOp(as_ir_value(cond_boundary), results_=[], has_else=False)
with ir.InsertionPoint(cond_boundary_if.then_block):
pk_val = cs_.vec_load((0, m_local_idx, n_local_idx), LDG_VEC_SIZE)
for ksi in range_constexpr(1, BLOCK_K_WARPS):
Expand Down Expand Up @@ -713,7 +713,7 @@ def hot_loop_scheduler():
n_local_idx = fx.Index(global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE)
m_global_idx = m_offset + m_local_idx
cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, m_global_idx, fx.Index(m))
cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False)
cond_boundary_if = scf.IfOp(as_ir_value(cond_boundary), results_=[], has_else=False)
with ir.InsertionPoint(cond_boundary_if.then_block):
vec = cs_.vec_load((0, m_local_idx, n_local_idx), LDG_VEC_SIZE)
for ksi in range_constexpr(1, BLOCK_K_WARPS):
Expand Down
41 changes: 31 additions & 10 deletions python/flydsl/compiler/ast_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import ast
import contextlib
import copy
import difflib
import functools
import inspect
Expand Down Expand Up @@ -394,24 +395,24 @@ def visit_With(self, node: ast.With):
class RewriteBoolOps(Transformer):
@staticmethod
def dsl_and_(lhs, rhs):
if hasattr(lhs, "__fly_and__"):
return lhs.__fly_and__(rhs)
if hasattr(rhs, "__fly_and__"):
return rhs.__fly_and__(lhs)
if hasattr(lhs, "__dsl_and__"):
return lhs.__dsl_and__(rhs)
if hasattr(rhs, "__dsl_and__"):
return rhs.__dsl_and__(lhs)
return lhs and rhs

@staticmethod
def dsl_or_(lhs, rhs):
if hasattr(lhs, "__fly_or__"):
return lhs.__fly_or__(rhs)
if hasattr(rhs, "__fly_or__"):
return rhs.__fly_or__(lhs)
if hasattr(lhs, "__dsl_or__"):
return lhs.__dsl_or__(rhs)
if hasattr(rhs, "__dsl_or__"):
return rhs.__dsl_or__(lhs)
return lhs or rhs

@staticmethod
def dsl_not_(x):
if hasattr(x, "__fly_not__"):
return x.__fly_not__()
if hasattr(x, "__dsl_not__"):
return x.__dsl_not__()
return not x

@classmethod
Expand All @@ -422,6 +423,26 @@ def rewrite_globals(cls):
"dsl_not_": cls.dsl_not_,
}

def visit_Compare(self, node: ast.Compare):
node = self.generic_visit(node)
if len(node.ops) == 1:
return node
# Chained comparison `a < b < c` -> `(a < b) and (b < c)`, then reuse the
# `and` lowering. Middle operands are referenced by two comparisons, matching
# the existing repeated-evaluation convention in visit_BoolOp.
operands = [node.left] + node.comparators
comparisons = []
for i, op in enumerate(node.ops):
comparison = ast.Compare(
left=copy.deepcopy(operands[i]),
ops=[copy.deepcopy(op)],
comparators=[copy.deepcopy(operands[i + 1])],
)
comparisons.append(ast.copy_location(comparison, node))
chained = ast.copy_location(ast.BoolOp(op=ast.And(), values=comparisons), node)
chained = ast.fix_missing_locations(chained)
return self.visit_BoolOp(chained)

def visit_BoolOp(self, node: ast.BoolOp):
node = self.generic_visit(node)

Expand Down
8 changes: 7 additions & 1 deletion python/flydsl/compiler/jit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .._mlir.passmanager import PassManager
from ..expr.meta import tracing_context
from ..expr.typing import Constexpr, Stream
from ..expr.utils.arith import fastmath as fastmath_ctx
from ..utils import env, log
from .ast_rewriter import ASTRewriter
from .backends import compile_backend_name, get_backend
Expand All @@ -40,6 +41,7 @@
CompilationContext,
KernelFunction,
create_gpu_module,
effective_fastmath_hint,
func_def_location,
get_gpu_module_body,
)
Expand Down Expand Up @@ -1482,8 +1484,12 @@ def __call__(self, *args, **kwargs):
log().info(f"dsl_args={dsl_args}")
named_args = dict(zip(param_names, dsl_args))
named_args.update(constexpr_values)
fastmath_flag = effective_fastmath_hint(CompilationContext.get_compile_hints())
fastmath_scope = (
fastmath_ctx(fastmath_flag) if fastmath_flag is not None else nullcontext()
)
# Bound the call-site boundary at the jit body.
with tracing_context(self.func):
with tracing_context(self.func), fastmath_scope:
if bound_self is not None:
self.func(bound_self, **named_args)
else:
Expand Down
17 changes: 15 additions & 2 deletions python/flydsl/compiler/kernel_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import inspect
import threading
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from .._mlir import ir
from .._mlir.dialects import arith, gpu
from ..expr.meta import capture_user_location, file_location, tracing_context
from ..expr.typing import Constexpr
from ..expr.utils.arith import fastmath as fastmath_ctx
from .ast_rewriter import ASTRewriter
from .diagnostics import install_excepthook, warn_annotation_value_mismatch, warn_invalid_annotations
from .jit_argument import is_type_param_annotation, resolve_signature
Expand Down Expand Up @@ -243,6 +244,15 @@ def next_kernel_id(self) -> int:
return kid


def effective_fastmath_hint(hints: dict):
"""Resolve the ambient fastmath hint for traced JIT/kernel bodies."""
if "fastmath" in hints:
return hints["fastmath"]
if hints.get("fast_fp_math"):
return "fast"
return None


# =============================================================================
# Kernel Launcher
# =============================================================================
Expand Down Expand Up @@ -540,8 +550,11 @@ def _emit_kernel(self, ctx: CompilationContext, args: Tuple, kwargs: Dict, bound
idx += n

dsl_args.update(constexpr_values)

fastmath_flag = effective_fastmath_hint(CompilationContext.get_compile_hints())
fastmath_scope = fastmath_ctx(fastmath_flag) if fastmath_flag is not None else nullcontext()
# Bound the call-site boundary at the kernel body.
with tracing_context(self._func):
with tracing_context(self._func), fastmath_scope:
if bound_self is not None:
self._func(bound_self, **dsl_args)
else:
Expand Down
54 changes: 35 additions & 19 deletions python/flydsl/expr/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
__all__ = [
"ArithValue", # Deprecated: will be removed in a future release
"_to_raw", # Deprecated: will be removed in a future release
"FastMathFlags",
"andi",
"constant",
"constant_vector",
"fastmath",
"index", # Deprecated: will be removed in a future release
"index_cast", # Deprecated: will be removed in a future release
"int_to_fp",
"maxnumf",
"minnumf",
"maximumf",
"minimumf",
"shli",
"sitofp",
"trunc_f",
Expand All @@ -35,14 +40,15 @@
]

# Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.)
from .._mlir.dialects import arith as _mlir_arith
from .._mlir.dialects import arith
from .meta import dsl_loc_tracing
from .utils.arith import ( # noqa: F401
ArithValue,
_to_raw,
andi,
constant,
constant_vector,
fastmath,
index,
index_cast,
int_to_fp,
Expand All @@ -53,9 +59,12 @@
unwrap,
xori,
)
from .math import dsl_math_wrap_result
from .typing import as_ir_value


@dsl_loc_tracing
@dsl_math_wrap_result(exemplar="lhs")
def cmpi(predicate, lhs, rhs, **kwargs):
"""Integer comparison accepting DSL numeric types (Int32, ArithValue, etc.).

Expand All @@ -65,12 +74,13 @@ def cmpi(predicate, lhs, rhs, **kwargs):
rhs: Right-hand operand.

Returns:
An ``i1`` comparison result.
A ``Boolean`` (scalar) or ``Vector(Boolean)`` comparison result.
"""
return _mlir_arith.cmpi(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs)
return arith.cmpi(predicate, as_ir_value(lhs), as_ir_value(rhs), **kwargs)


@dsl_loc_tracing
@dsl_math_wrap_result(exemplar="lhs")
def cmpf(predicate, lhs, rhs, **kwargs):
"""Floating-point comparison accepting DSL numeric types.

Expand All @@ -80,24 +90,30 @@ def cmpf(predicate, lhs, rhs, **kwargs):
rhs: Right-hand operand.

Returns:
An ``i1`` comparison result.
A ``Boolean`` (scalar) or ``Vector(Boolean)`` comparison result.
"""
return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs)
return arith.cmpf(predicate, as_ir_value(lhs), as_ir_value(rhs), **kwargs)


@dsl_loc_tracing
def maxnumf(a, b, **kwargs):
"""Floating-point maximum, returning the non-NaN operand when one input is NaN (libm ``fmax``).
@dsl_math_wrap_result
def maximumf(lhs, rhs, *, fastmath=None):
return arith.maximumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath)

Accepts DSL numeric types (Float32, Vector, ...) and preserves the DSL type of ``a`` so the
result can be chained with further DSL operations (e.g. ``.shuffle_xor(...)``).
"""
from .numeric import Numeric
from .typing import Vector

result = _mlir_arith.maxnumf(_to_raw(a), _to_raw(b), **kwargs)
if isinstance(a, Vector):
return Vector(result, a.shape, a.dtype)
if isinstance(a, Numeric):
return Numeric.from_ir_type(result.type)(result)
return result

@dsl_loc_tracing
@dsl_math_wrap_result
def minimumf(lhs, rhs, *, fastmath=None):
return arith.minimumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath)


@dsl_loc_tracing
@dsl_math_wrap_result
def maxnumf(lhs, rhs, *, fastmath=None):
return arith.maxnumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath)


@dsl_loc_tracing
@dsl_math_wrap_result
def minnumf(lhs, rhs, *, fastmath=None):
return arith.minnumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath)
34 changes: 32 additions & 2 deletions python/flydsl/expr/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
pred = fx.isnan(x)
"""

import inspect
from functools import wraps

from .._mlir import ir
from .._mlir.dialects import math
from .meta import dsl_loc_tracing
from .numeric import Numeric
from .typing import as_ir_value
from .utils.arith import current_fastmath

__all__ = [
"absf",
Expand Down Expand Up @@ -70,12 +72,40 @@
]


def dsl_math_wrap_result(fn):
def dsl_math_wrap_result(fn=None, *, exemplar=None):
"""Wrap raw builder results back into DSL ``Numeric`` / ``Vector`` values.

The DSL type of the result is shaped after an *exemplar* operand:

- ``exemplar=None`` (default): the first positional argument is the
exemplar. This fits the ``x``-first math builders (``exp(x)``, ...).
- ``exemplar="<name>"``: the argument bound to that parameter name is the
exemplar. Use this for builders whose first argument is not the operand,
e.g. ``cmpi(predicate, lhs, rhs)`` -> ``exemplar="lhs"``.
"""
if fn is None:
return lambda f: dsl_math_wrap_result(f, exemplar=exemplar)

sig = inspect.signature(fn)
accepts_fastmath = "fastmath" in sig.parameters

@wraps(fn)
def wrapper(*args, **kwargs):
from .typing import Vector

first = args[0] if args else None
if accepts_fastmath and kwargs.get("fastmath") is None:
ambient = current_fastmath()
if ambient is not None:
kwargs["fastmath"] = ambient

if exemplar is None:
first = args[0] if args else None
else:
try:
bound = sig.bind_partial(*args, **kwargs)
first = bound.arguments.get(exemplar)
except TypeError:
first = kwargs.get(exemplar)
is_vector = isinstance(first, Vector)
is_numeric = isinstance(first, Numeric)

Expand Down
Loading
Loading