From 6fe63eb8f4092b1908336965de39e1350eb9d800 Mon Sep 17 00:00:00 2001 From: Harish-endee Date: Tue, 7 Apr 2026 11:08:37 +0530 Subject: [PATCH] Fix FP16 NEON build on AArch64 CPUs without FP16FML support Some AArch64 CPUs don't support the FP16FML extension, which causes builds to fail due to missing vfmlalq_low_f16 and vfmlalq_high_f16 intrinsics. This adds compatibility fallbacks that use universally available NEON instructions (vcvt_f32_f16 + vfmaq_f32) instead, with automatic compile-time dispatch so CPUs with FP16FML still use the native single-instruction path. --- src/quant/float16.hpp | 58 ++++++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/src/quant/float16.hpp b/src/quant/float16.hpp index 5ff1c6b440..ea2614cf4b 100644 --- a/src/quant/float16.hpp +++ b/src/quant/float16.hpp @@ -80,6 +80,36 @@ namespace ndd { return (sign << 15) | (exp << 10) | mantissa; } +// ── FP16 fused multiply-accumulate compatibility shims ────────────── +// vfmlalq_{low,high}_f16 require the FP16FML extension (__ARM_FEATURE_FP16FML) +// which not all AArch64 cores expose. The compat paths use only asimdhp +// (vcvt_f32_f16) + asimd (vfmaq_f32), trading one instruction for three. +#if defined(USE_NEON) + inline float32x4_t vfmlalq_low_f16_compat(float32x4_t acc, + float16x8_t a, + float16x8_t b) { + float32x4_t a_f32 = vcvt_f32_f16(vget_low_f16(a)); + float32x4_t b_f32 = vcvt_f32_f16(vget_low_f16(b)); + return vfmaq_f32(acc, a_f32, b_f32); + } + + inline float32x4_t vfmlalq_high_f16_compat(float32x4_t acc, + float16x8_t a, + float16x8_t b) { + float32x4_t a_f32 = vcvt_f32_f16(vget_high_f16(a)); + float32x4_t b_f32 = vcvt_f32_f16(vget_high_f16(b)); + return vfmaq_f32(acc, a_f32, b_f32); + } + +#if defined(__ARM_FEATURE_FP16FML) +#define FMLAL_LOW(acc, a, b) vfmlalq_low_f16(acc, a, b) +#define FMLAL_HIGH(acc, a, b) vfmlalq_high_f16(acc, a, b) +#else +#define FMLAL_LOW(acc, a, b) vfmlalq_low_f16_compat(acc, a, b) +#define FMLAL_HIGH(acc, a, b) vfmlalq_high_f16_compat(acc, a, b) +#endif +#endif // USE_NEON + #if defined(USE_NEON) // NEON optimized vector conversion FP16->FP32 inline std::vector @@ -419,12 +449,12 @@ namespace ndd { float16x8_t v2_1 = vld1q_f16(reinterpret_cast(pVect2 + i + 8)); float16x8_t diff0 = vsubq_f16(v1_0, v2_0); - sum = vfmlalq_low_f16(sum, diff0, diff0); - sum = vfmlalq_high_f16(sum, diff0, diff0); + sum = FMLAL_LOW(sum, diff0, diff0); + sum = FMLAL_HIGH(sum, diff0, diff0); float16x8_t diff1 = vsubq_f16(v1_1, v2_1); - sum = vfmlalq_low_f16(sum, diff1, diff1); - sum = vfmlalq_high_f16(sum, diff1, diff1); + sum = FMLAL_LOW(sum, diff1, diff1); + sum = FMLAL_HIGH(sum, diff1, diff1); } // Process remaining 8 elements @@ -433,8 +463,8 @@ namespace ndd { float16x8_t v2 = vld1q_f16(reinterpret_cast(pVect2 + i)); float16x8_t diff = vsubq_f16(v1, v2); - sum = vfmlalq_low_f16(sum, diff, diff); - sum = vfmlalq_high_f16(sum, diff, diff); + sum = FMLAL_LOW(sum, diff, diff); + sum = FMLAL_HIGH(sum, diff, diff); } // Process remaining 4 elements @@ -443,7 +473,7 @@ namespace ndd { float16x4_t v2 = vld1_f16(reinterpret_cast(pVect2 + i)); float16x4_t diff = vsub_f16(v1, v2); float16x8_t diff_q = vcombine_f16(diff, vdup_n_f16(0)); - sum = vfmlalq_low_f16(sum, diff_q, diff_q); + sum = FMLAL_LOW(sum, diff_q, diff_q); } res = vaddvq_f32(sum); @@ -619,10 +649,10 @@ namespace ndd { float16x8_t v1_1 = vld1q_f16(reinterpret_cast(pVect1 + i + 8)); float16x8_t v2_1 = vld1q_f16(reinterpret_cast(pVect2 + i + 8)); - sum0 = vfmlalq_low_f16(sum0, v1_0, v2_0); - sum0 = vfmlalq_high_f16(sum0, v1_0, v2_0); - sum1 = vfmlalq_low_f16(sum1, v1_1, v2_1); - sum1 = vfmlalq_high_f16(sum1, v1_1, v2_1); + sum0 = FMLAL_LOW(sum0, v1_0, v2_0); + sum0 = FMLAL_HIGH(sum0, v1_0, v2_0); + sum1 = FMLAL_LOW(sum1, v1_1, v2_1); + sum1 = FMLAL_HIGH(sum1, v1_1, v2_1); } // Process remaining 8 elements @@ -630,8 +660,8 @@ namespace ndd { float16x8_t v1 = vld1q_f16(reinterpret_cast(pVect1 + i)); float16x8_t v2 = vld1q_f16(reinterpret_cast(pVect2 + i)); - sum0 = vfmlalq_low_f16(sum0, v1, v2); - sum0 = vfmlalq_high_f16(sum0, v1, v2); + sum0 = FMLAL_LOW(sum0, v1, v2); + sum0 = FMLAL_HIGH(sum0, v1, v2); } // Process remaining 4 elements @@ -640,7 +670,7 @@ namespace ndd { float16x4_t v2 = vld1_f16(reinterpret_cast(pVect2 + i)); float16x8_t v1_q = vcombine_f16(v1, vdup_n_f16(0)); float16x8_t v2_q = vcombine_f16(v2, vdup_n_f16(0)); - sum0 = vfmlalq_low_f16(sum0, v1_q, v2_q); + sum0 = FMLAL_LOW(sum0, v1_q, v2_q); } res = vaddvq_f32(vaddq_f32(sum0, sum1));