Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions src/quant/float16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ namespace ndd {
return (sign << 15) | (exp << 10) | mantissa;
}

// ── FP16 fused multiply-accumulate compatibility shims ──────────────
// vfmlalq_{low,high}_f16 require the FP16FML extension (__ARM_FEATURE_FP16FML)
// which not all AArch64 cores expose. The compat paths use only asimdhp
// (vcvt_f32_f16) + asimd (vfmaq_f32), trading one instruction for three.
#if defined(USE_NEON)
inline float32x4_t vfmlalq_low_f16_compat(float32x4_t acc,
float16x8_t a,
float16x8_t b) {
float32x4_t a_f32 = vcvt_f32_f16(vget_low_f16(a));
float32x4_t b_f32 = vcvt_f32_f16(vget_low_f16(b));
return vfmaq_f32(acc, a_f32, b_f32);
}

inline float32x4_t vfmlalq_high_f16_compat(float32x4_t acc,
float16x8_t a,
float16x8_t b) {
float32x4_t a_f32 = vcvt_f32_f16(vget_high_f16(a));
float32x4_t b_f32 = vcvt_f32_f16(vget_high_f16(b));
return vfmaq_f32(acc, a_f32, b_f32);
}

#if defined(__ARM_FEATURE_FP16FML)
#define FMLAL_LOW(acc, a, b) vfmlalq_low_f16(acc, a, b)
#define FMLAL_HIGH(acc, a, b) vfmlalq_high_f16(acc, a, b)
#else
#define FMLAL_LOW(acc, a, b) vfmlalq_low_f16_compat(acc, a, b)
#define FMLAL_HIGH(acc, a, b) vfmlalq_high_f16_compat(acc, a, b)
#endif
#endif // USE_NEON

#if defined(USE_NEON)
// NEON optimized vector conversion FP16->FP32
inline std::vector<float>
Expand Down Expand Up @@ -419,12 +449,12 @@ namespace ndd {
float16x8_t v2_1 = vld1q_f16(reinterpret_cast<const __fp16*>(pVect2 + i + 8));

float16x8_t diff0 = vsubq_f16(v1_0, v2_0);
sum = vfmlalq_low_f16(sum, diff0, diff0);
sum = vfmlalq_high_f16(sum, diff0, diff0);
sum = FMLAL_LOW(sum, diff0, diff0);
sum = FMLAL_HIGH(sum, diff0, diff0);

float16x8_t diff1 = vsubq_f16(v1_1, v2_1);
sum = vfmlalq_low_f16(sum, diff1, diff1);
sum = vfmlalq_high_f16(sum, diff1, diff1);
sum = FMLAL_LOW(sum, diff1, diff1);
sum = FMLAL_HIGH(sum, diff1, diff1);
}

// Process remaining 8 elements
Expand All @@ -433,8 +463,8 @@ namespace ndd {
float16x8_t v2 = vld1q_f16(reinterpret_cast<const __fp16*>(pVect2 + i));

float16x8_t diff = vsubq_f16(v1, v2);
sum = vfmlalq_low_f16(sum, diff, diff);
sum = vfmlalq_high_f16(sum, diff, diff);
sum = FMLAL_LOW(sum, diff, diff);
sum = FMLAL_HIGH(sum, diff, diff);
}

// Process remaining 4 elements
Expand All @@ -443,7 +473,7 @@ namespace ndd {
float16x4_t v2 = vld1_f16(reinterpret_cast<const __fp16*>(pVect2 + i));
float16x4_t diff = vsub_f16(v1, v2);
float16x8_t diff_q = vcombine_f16(diff, vdup_n_f16(0));
sum = vfmlalq_low_f16(sum, diff_q, diff_q);
sum = FMLAL_LOW(sum, diff_q, diff_q);
}

res = vaddvq_f32(sum);
Expand Down Expand Up @@ -619,19 +649,19 @@ namespace ndd {
float16x8_t v1_1 = vld1q_f16(reinterpret_cast<const __fp16*>(pVect1 + i + 8));
float16x8_t v2_1 = vld1q_f16(reinterpret_cast<const __fp16*>(pVect2 + i + 8));

sum0 = vfmlalq_low_f16(sum0, v1_0, v2_0);
sum0 = vfmlalq_high_f16(sum0, v1_0, v2_0);
sum1 = vfmlalq_low_f16(sum1, v1_1, v2_1);
sum1 = vfmlalq_high_f16(sum1, v1_1, v2_1);
sum0 = FMLAL_LOW(sum0, v1_0, v2_0);
sum0 = FMLAL_HIGH(sum0, v1_0, v2_0);
sum1 = FMLAL_LOW(sum1, v1_1, v2_1);
sum1 = FMLAL_HIGH(sum1, v1_1, v2_1);
}

// Process remaining 8 elements
for(; i + 8 <= qty; i += 8) {
float16x8_t v1 = vld1q_f16(reinterpret_cast<const __fp16*>(pVect1 + i));
float16x8_t v2 = vld1q_f16(reinterpret_cast<const __fp16*>(pVect2 + i));

sum0 = vfmlalq_low_f16(sum0, v1, v2);
sum0 = vfmlalq_high_f16(sum0, v1, v2);
sum0 = FMLAL_LOW(sum0, v1, v2);
sum0 = FMLAL_HIGH(sum0, v1, v2);
}

// Process remaining 4 elements
Expand All @@ -640,7 +670,7 @@ namespace ndd {
float16x4_t v2 = vld1_f16(reinterpret_cast<const __fp16*>(pVect2 + i));
float16x8_t v1_q = vcombine_f16(v1, vdup_n_f16(0));
float16x8_t v2_q = vcombine_f16(v2, vdup_n_f16(0));
sum0 = vfmlalq_low_f16(sum0, v1_q, v2_q);
sum0 = FMLAL_LOW(sum0, v1_q, v2_q);
}

res = vaddvq_f32(vaddq_f32(sum0, sum1));
Expand Down
Loading