diff --git a/src/quant/float16.hpp b/src/quant/float16.hpp index 5ff1c6b440..ea2614cf4b 100644 --- a/src/quant/float16.hpp +++ b/src/quant/float16.hpp @@ -80,6 +80,36 @@ namespace ndd { return (sign << 15) | (exp << 10) | mantissa; } +// ── FP16 fused multiply-accumulate compatibility shims ────────────── +// vfmlalq_{low,high}_f16 require the FP16FML extension (__ARM_FEATURE_FP16FML) +// which not all AArch64 cores expose. The compat paths use only asimdhp +// (vcvt_f32_f16) + asimd (vfmaq_f32), trading one instruction for three. +#if defined(USE_NEON) + inline float32x4_t vfmlalq_low_f16_compat(float32x4_t acc, + float16x8_t a, + float16x8_t b) { + float32x4_t a_f32 = vcvt_f32_f16(vget_low_f16(a)); + float32x4_t b_f32 = vcvt_f32_f16(vget_low_f16(b)); + return vfmaq_f32(acc, a_f32, b_f32); + } + + inline float32x4_t vfmlalq_high_f16_compat(float32x4_t acc, + float16x8_t a, + float16x8_t b) { + float32x4_t a_f32 = vcvt_f32_f16(vget_high_f16(a)); + float32x4_t b_f32 = vcvt_f32_f16(vget_high_f16(b)); + return vfmaq_f32(acc, a_f32, b_f32); + } + +#if defined(__ARM_FEATURE_FP16FML) +#define FMLAL_LOW(acc, a, b) vfmlalq_low_f16(acc, a, b) +#define FMLAL_HIGH(acc, a, b) vfmlalq_high_f16(acc, a, b) +#else +#define FMLAL_LOW(acc, a, b) vfmlalq_low_f16_compat(acc, a, b) +#define FMLAL_HIGH(acc, a, b) vfmlalq_high_f16_compat(acc, a, b) +#endif +#endif // USE_NEON + #if defined(USE_NEON) // NEON optimized vector conversion FP16->FP32 inline std::vector @@ -419,12 +449,12 @@ namespace ndd { float16x8_t v2_1 = vld1q_f16(reinterpret_cast(pVect2 + i + 8)); float16x8_t diff0 = vsubq_f16(v1_0, v2_0); - sum = vfmlalq_low_f16(sum, diff0, diff0); - sum = vfmlalq_high_f16(sum, diff0, diff0); + sum = FMLAL_LOW(sum, diff0, diff0); + sum = FMLAL_HIGH(sum, diff0, diff0); float16x8_t diff1 = vsubq_f16(v1_1, v2_1); - sum = vfmlalq_low_f16(sum, diff1, diff1); - sum = vfmlalq_high_f16(sum, diff1, diff1); + sum = FMLAL_LOW(sum, diff1, diff1); + sum = FMLAL_HIGH(sum, diff1, diff1); } // Process remaining 8 elements @@ -433,8 +463,8 @@ namespace ndd { float16x8_t v2 = vld1q_f16(reinterpret_cast(pVect2 + i)); float16x8_t diff = vsubq_f16(v1, v2); - sum = vfmlalq_low_f16(sum, diff, diff); - sum = vfmlalq_high_f16(sum, diff, diff); + sum = FMLAL_LOW(sum, diff, diff); + sum = FMLAL_HIGH(sum, diff, diff); } // Process remaining 4 elements @@ -443,7 +473,7 @@ namespace ndd { float16x4_t v2 = vld1_f16(reinterpret_cast(pVect2 + i)); float16x4_t diff = vsub_f16(v1, v2); float16x8_t diff_q = vcombine_f16(diff, vdup_n_f16(0)); - sum = vfmlalq_low_f16(sum, diff_q, diff_q); + sum = FMLAL_LOW(sum, diff_q, diff_q); } res = vaddvq_f32(sum); @@ -619,10 +649,10 @@ namespace ndd { float16x8_t v1_1 = vld1q_f16(reinterpret_cast(pVect1 + i + 8)); float16x8_t v2_1 = vld1q_f16(reinterpret_cast(pVect2 + i + 8)); - sum0 = vfmlalq_low_f16(sum0, v1_0, v2_0); - sum0 = vfmlalq_high_f16(sum0, v1_0, v2_0); - sum1 = vfmlalq_low_f16(sum1, v1_1, v2_1); - sum1 = vfmlalq_high_f16(sum1, v1_1, v2_1); + sum0 = FMLAL_LOW(sum0, v1_0, v2_0); + sum0 = FMLAL_HIGH(sum0, v1_0, v2_0); + sum1 = FMLAL_LOW(sum1, v1_1, v2_1); + sum1 = FMLAL_HIGH(sum1, v1_1, v2_1); } // Process remaining 8 elements @@ -630,8 +660,8 @@ namespace ndd { float16x8_t v1 = vld1q_f16(reinterpret_cast(pVect1 + i)); float16x8_t v2 = vld1q_f16(reinterpret_cast(pVect2 + i)); - sum0 = vfmlalq_low_f16(sum0, v1, v2); - sum0 = vfmlalq_high_f16(sum0, v1, v2); + sum0 = FMLAL_LOW(sum0, v1, v2); + sum0 = FMLAL_HIGH(sum0, v1, v2); } // Process remaining 4 elements @@ -640,7 +670,7 @@ namespace ndd { float16x4_t v2 = vld1_f16(reinterpret_cast(pVect2 + i)); float16x8_t v1_q = vcombine_f16(v1, vdup_n_f16(0)); float16x8_t v2_q = vcombine_f16(v2, vdup_n_f16(0)); - sum0 = vfmlalq_low_f16(sum0, v1_q, v2_q); + sum0 = FMLAL_LOW(sum0, v1_q, v2_q); } res = vaddvq_f32(vaddq_f32(sum0, sum1));