Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/flydsl/compiler/jit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .._mlir.passmanager import PassManager
from ..expr.meta import tracing_context
from ..expr.typing import Constexpr, Stream
from ..expr.utils.arith import fastmath as fastmath_ctx
from ..utils import env, log
from .ast_rewriter import ASTRewriter
from .backends import compile_backend_name, get_backend
Expand Down Expand Up @@ -1482,8 +1483,12 @@ def __call__(self, *args, **kwargs):
log().info(f"dsl_args={dsl_args}")
named_args = dict(zip(param_names, dsl_args))
named_args.update(constexpr_values)
fastmath_flag = CompilationContext.get_compile_hints().get("fastmath")
fastmath_scope = (
fastmath_ctx(fastmath_flag) if fastmath_flag is not None else nullcontext()
)
# Bound the call-site boundary at the jit body.
with tracing_context(self.func):
with tracing_context(self.func), fastmath_scope:
if bound_self is not None:
self.func(bound_self, **named_args)
else:
Expand Down
8 changes: 6 additions & 2 deletions python/flydsl/compiler/kernel_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import inspect
import threading
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from .._mlir import ir
from .._mlir.dialects import arith, gpu
from ..expr.meta import capture_user_location, file_location, tracing_context
from ..expr.typing import Constexpr
from ..expr.utils.arith import fastmath as fastmath_ctx
from .ast_rewriter import ASTRewriter
from .diagnostics import install_excepthook, warn_annotation_value_mismatch, warn_invalid_annotations
from .jit_argument import is_type_param_annotation, resolve_signature
Expand Down Expand Up @@ -540,8 +541,11 @@ def _emit_kernel(self, ctx: CompilationContext, args: Tuple, kwargs: Dict, bound
idx += n

dsl_args.update(constexpr_values)

fastmath_flag = CompilationContext.get_compile_hints().get("fastmath")
fastmath_scope = fastmath_ctx(fastmath_flag) if fastmath_flag is not None else nullcontext()
# Bound the call-site boundary at the kernel body.
with tracing_context(self._func):
with tracing_context(self._func), fastmath_scope:
if bound_self is not None:
self._func(bound_self, **dsl_args)
else:
Expand Down
54 changes: 35 additions & 19 deletions python/flydsl/expr/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
__all__ = [
"ArithValue", # Deprecated: will be removed in a future release
"_to_raw", # Deprecated: will be removed in a future release
"FastMathFlags",
"andi",
"constant",
"constant_vector",
"fastmath",
"index", # Deprecated: will be removed in a future release
"index_cast", # Deprecated: will be removed in a future release
"int_to_fp",
"maxnumf",
"minnumf",
"maximumf",
"minimumf",
"shli",
"sitofp",
"trunc_f",
Expand All @@ -35,14 +40,15 @@
]

# Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.)
from .._mlir.dialects import arith as _mlir_arith
from .._mlir.dialects import arith
from .meta import dsl_loc_tracing
from .utils.arith import ( # noqa: F401
ArithValue,
_to_raw,
andi,
constant,
constant_vector,
fastmath,
index,
index_cast,
int_to_fp,
Expand All @@ -53,9 +59,12 @@
unwrap,
xori,
)
from .math import dsl_math_wrap_result
from .typing import as_ir_value


@dsl_loc_tracing
@dsl_math_wrap_result(exemplar="lhs")
def cmpi(predicate, lhs, rhs, **kwargs):
"""Integer comparison accepting DSL numeric types (Int32, ArithValue, etc.).

Expand All @@ -65,12 +74,13 @@ def cmpi(predicate, lhs, rhs, **kwargs):
rhs: Right-hand operand.

Returns:
An ``i1`` comparison result.
A ``Boolean`` (scalar) or ``Vector(Boolean)`` comparison result.
"""
return _mlir_arith.cmpi(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs)
return arith.cmpi(predicate, as_ir_value(lhs), as_ir_value(rhs), **kwargs)


@dsl_loc_tracing
@dsl_math_wrap_result(exemplar="lhs")
def cmpf(predicate, lhs, rhs, **kwargs):
"""Floating-point comparison accepting DSL numeric types.

Expand All @@ -80,24 +90,30 @@ def cmpf(predicate, lhs, rhs, **kwargs):
rhs: Right-hand operand.

Returns:
An ``i1`` comparison result.
A ``Boolean`` (scalar) or ``Vector(Boolean)`` comparison result.
"""
return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs)
return arith.cmpf(predicate, as_ir_value(lhs), as_ir_value(rhs), **kwargs)


@dsl_loc_tracing
def maxnumf(a, b, **kwargs):
"""Floating-point maximum, returning the non-NaN operand when one input is NaN (libm ``fmax``).
@dsl_math_wrap_result
def maximumf(lhs, rhs, *, fastmath=None):
return arith.maximumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath)

Accepts DSL numeric types (Float32, Vector, ...) and preserves the DSL type of ``a`` so the
result can be chained with further DSL operations (e.g. ``.shuffle_xor(...)``).
"""
from .numeric import Numeric
from .typing import Vector

result = _mlir_arith.maxnumf(_to_raw(a), _to_raw(b), **kwargs)
if isinstance(a, Vector):
return Vector(result, a.shape, a.dtype)
if isinstance(a, Numeric):
return Numeric.from_ir_type(result.type)(result)
return result

@dsl_loc_tracing
@dsl_math_wrap_result
def minimumf(lhs, rhs, *, fastmath=None):
return arith.minimumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath)


@dsl_loc_tracing
@dsl_math_wrap_result
def maxnumf(lhs, rhs, *, fastmath=None):
return arith.maxnumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath)


@dsl_loc_tracing
@dsl_math_wrap_result
def minnumf(lhs, rhs, *, fastmath=None):
return arith.minnumf(as_ir_value(lhs), as_ir_value(rhs), fastmath=fastmath)
34 changes: 32 additions & 2 deletions python/flydsl/expr/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
pred = fx.isnan(x)
"""

import inspect
from functools import wraps

from .._mlir import ir
from .._mlir.dialects import math
from .meta import dsl_loc_tracing
from .numeric import Numeric
from .typing import as_ir_value
from .utils.arith import current_fastmath

__all__ = [
"absf",
Expand Down Expand Up @@ -70,12 +72,40 @@
]


def dsl_math_wrap_result(fn):
def dsl_math_wrap_result(fn=None, *, exemplar=None):
"""Wrap raw builder results back into DSL ``Numeric`` / ``Vector`` values.

The DSL type of the result is shaped after an *exemplar* operand:

- ``exemplar=None`` (default): the first positional argument is the
exemplar. This fits the ``x``-first math builders (``exp(x)``, ...).
- ``exemplar="<name>"``: the argument bound to that parameter name is the
exemplar. Use this for builders whose first argument is not the operand,
e.g. ``cmpi(predicate, lhs, rhs)`` -> ``exemplar="lhs"``.
"""
if fn is None:
return lambda f: dsl_math_wrap_result(f, exemplar=exemplar)

sig = inspect.signature(fn)
accepts_fastmath = "fastmath" in sig.parameters

@wraps(fn)
def wrapper(*args, **kwargs):
from .typing import Vector

first = args[0] if args else None
if accepts_fastmath and kwargs.get("fastmath") is None:
ambient = current_fastmath()
if ambient is not None:
kwargs["fastmath"] = ambient

if exemplar is None:
first = args[0] if args else None
else:
try:
bound = sig.bind_partial(*args, **kwargs)
first = bound.arguments.get(exemplar)
except TypeError:
first = kwargs.get(exemplar)
is_vector = isinstance(first, Vector)
is_numeric = isinstance(first, Numeric)

Expand Down
1 change: 1 addition & 0 deletions python/flydsl/expr/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,7 @@ def reduce(self, op, init_val=None, reduction_profile=None, *, fastmath=None):
kind = _resolve_combining_kind(op, is_fp, signed)
et = element_type(self.type)
kwargs = {}
fastmath = resolve_fastmath(fastmath)
if fastmath is not None:
kwargs["fastmath"] = fastmath
Comment thread
sjfeng1999 marked this conversation as resolved.
if init_val is not None:
Expand Down
Loading
Loading