Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dpnp/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_acceptance_fn_divide,
_acceptance_fn_negative,
_acceptance_fn_reciprocal,
_acceptance_fn_round,
_acceptance_fn_subtract,
_resolve_weak_types_all_py_ints,
)
Expand Down Expand Up @@ -1723,7 +1724,11 @@
"""

round = UnaryElementwiseFunc(
"round", ti._round_result_type, ti._round, _round_docstring
"round",
ti._round_result_type,
ti._round,
_round_docstring,
acceptance_fn=_acceptance_fn_round,
)
del _round_docstring

Expand Down
18 changes: 14 additions & 4 deletions dpnp/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
return True


def _acceptance_fn_round(arg_dtype, buf_dt, res_dt, sycl_dev):
# for boolean input, prefer floating-point output over integral
if arg_dtype.char == "?" and res_dt.kind in "biu":
Comment thread
vlad-perevezentsev marked this conversation as resolved.
Outdated
return False
return True


def _acceptance_fn_subtract(
arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev
):
Expand Down Expand Up @@ -188,17 +195,19 @@ def _dtype_supported_by_device_impl(


def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn):
_fp16 = sycl_dev.has_aspect_fp16
_fp64 = sycl_dev.has_aspect_fp64

res_dt = query_fn(arg_dtype)
Comment thread
vlad-perevezentsev marked this conversation as resolved.
Outdated
if res_dt:
return None, res_dt
if _dtype_supported_by_device_impl(res_dt, _fp16, _fp64):
return None, res_dt

_fp16 = sycl_dev.has_aspect_fp16
_fp64 = sycl_dev.has_aspect_fp64
all_dts = _all_data_types(_fp16, _fp64)
for buf_dt in all_dts:
if _can_cast(arg_dtype, buf_dt, _fp16, _fp64):
res_dt = query_fn(buf_dt)
if res_dt:
if res_dt and _dtype_supported_by_device_impl(res_dt, _fp16, _fp64):
acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev)
if acceptable:
return buf_dt, res_dt
Expand Down Expand Up @@ -970,6 +979,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"_find_buf_dtype2",
"_to_device_supported_dtype",
"_acceptance_fn_default_unary",
"_acceptance_fn_round",
"_acceptance_fn_reciprocal",
"_acceptance_fn_default_binary",
"_acceptance_fn_divide",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ template <typename T>
struct RoundOutputType
{
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, bool, sycl::half>,
td_ns::TypeMapResultEntry<T, std::uint8_t>,
Comment thread
ndgrigorian marked this conversation as resolved.
td_ns::TypeMapResultEntry<T, std::uint16_t>,
td_ns::TypeMapResultEntry<T, std::uint32_t>,
Expand Down
Loading