diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fa15b0d93..0928e83b73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,6 +99,7 @@ option(USE_AVX2 "Enable AVX2 (FMA, F16C)" OFF) option(USE_SVE2 "Enable SVE2 (INT8/16, FP16)" OFF) option(USE_NEON "Enable NEON (FP16, DotProd)" OFF) option(NDD_INV_IDX_STORE_FLOATS "Store raw float 32 values in sparse index (no quantization)" OFF) +set(NDD_RUNTIME_X86_DISPATCH OFF) # Check if any SIMD option is selected if(NOT USE_AVX512 AND NOT USE_AVX2 AND NOT USE_SVE2 AND NOT USE_NEON) @@ -108,10 +109,9 @@ if(NOT USE_AVX512 AND NOT USE_AVX2 AND NOT USE_SVE2 AND NOT USE_NEON) " -DUSE_SVE2=ON : For ARMv9/SVE2 capable processors (requires SVE2, FP16)\n" " -DUSE_NEON=ON : For standard ARMv8/NEON processors (requires FP16, DotProd)") else() - message(FATAL_ERROR "x86 architecture detected but no SIMD option selected.\n" - "Please specify one of the following flags:\n" - " -DUSE_AVX512=ON : For processors with AVX512F, BW, VNNI, FP16\n" - " -DUSE_AVX2=ON : For processors with AVX2, FMA, F16C") + set(NDD_RUNTIME_X86_DISPATCH ON) + message(STATUS "x86 architecture detected with no explicit SIMD mode; enabling runtime x86 " + "dispatch with an AVX2 baseline and optional AVX512 variants") endif() endif() @@ -253,6 +253,7 @@ message(STATUS "Binary name: ${NDD_BINARY_NAME}") # Add new src/*.cpp files here when they should be compiled into ndd. set(NDD_CORE_SOURCES src/sparse/inverted_index.cpp + src/utils/cpu_compat_check/cpu_runtime_dispatch.cpp src/utils/system_sanity/system_sanity.cpp ) @@ -341,6 +342,12 @@ elseif(USE_NEON) endif() target_compile_definitions(ndd_core PRIVATE USE_NEON) target_compile_definitions(${NDD_BINARY_NAME} PRIVATE USE_NEON) +elseif(NDD_RUNTIME_X86_DISPATCH) + message(STATUS "SIMD: Runtime x86 dispatch enabled (AVX2 baseline + optional AVX512 variants)") + target_compile_options(ndd_core PRIVATE -mavx2 -mfma -mf16c) + target_compile_definitions(ndd_core PRIVATE USE_AVX2 NDD_RUNTIME_X86_DISPATCH NDD_COMPILE_AVX512_VARIANTS) + target_compile_options(${NDD_BINARY_NAME} PRIVATE -mavx2 -mfma -mf16c) + target_compile_definitions(${NDD_BINARY_NAME} PRIVATE USE_AVX2 NDD_RUNTIME_X86_DISPATCH NDD_COMPILE_AVX512_VARIANTS) endif() if(NDD_INV_IDX_STORE_FLOATS) @@ -393,15 +400,20 @@ elseif(USE_SVE2) message(STATUS "SIMD Mode: SVE2") elseif(USE_NEON) message(STATUS "SIMD Mode: NEON") +elseif(NDD_RUNTIME_X86_DISPATCH) + message(STATUS "SIMD Mode: Runtime x86 dispatch") endif() message(STATUS "ASIO include dir: ${ASIO_INCLUDE_DIR}") message(STATUS "LMDB include dir: ${LMDB_INCLUDE_DIR}") message(STATUS "OpenSSL include dir: ${OPENSSL_INCLUDE_DIR}") # Create a symbolic link named 'ndd' pointing to the architecture-specific binary -add_custom_command(TARGET ${NDD_BINARY_NAME} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E create_symlink - $ - ${CMAKE_CURRENT_BINARY_DIR}/ndd - COMMENT "Creating softlink 'ndd' -> ${NDD_BINARY_NAME}" -) +# (skipped when binary is already named 'ndd' to avoid a self-referential symlink) +if(NOT NDD_BINARY_NAME STREQUAL "ndd") + add_custom_command(TARGET ${NDD_BINARY_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E create_symlink + $ + ${CMAKE_CURRENT_BINARY_DIR}/ndd + COMMENT "Creating softlink 'ndd' -> ${NDD_BINARY_NAME}" + ) +endif() diff --git a/src/main.cpp b/src/main.cpp index 4654a54c20..fcf3d6ffd2 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -37,6 +37,7 @@ #include "core/ndd.hpp" #include "auth.hpp" #include "quant/common.hpp" +#include "utils/cpu_compat_check/cpu_runtime_dispatch.hpp" #include "system_sanity/system_sanity.hpp" using ndd::quant::quantLevelToString; @@ -257,9 +258,14 @@ int main(int argc, char** argv) { // Health check endpoint (no auth required) // CROW_ROUTE(app, "/api/v1/health").methods("GET"_method)([](const crow::request& req) { CROW_ROUTE(app, "/api/v1/health").methods("GET"_method)([]() { + crow::json::wvalue::list cpu_flags; + for(const auto& flag : ndd::cpu::get_active_cpu_flags()) { + cpu_flags.emplace_back(flag); + } crow::json::wvalue response( {{"status", "ok"}, - {"timestamp", (std::int64_t)std::chrono::system_clock::now().time_since_epoch().count()}}); + {"timestamp", (std::int64_t)std::chrono::system_clock::now().time_since_epoch().count()}, + {"cpu_flags", cpu_flags}}); PRINT_LOG_TIME(); ndd::printSparseSearchDebugStats(); ndd::printSparseUpdateDebugStats(); diff --git a/src/quant/binary.hpp b/src/quant/binary.hpp index 9ec11c59d3..f735fc8141 100644 --- a/src/quant/binary.hpp +++ b/src/quant/binary.hpp @@ -32,8 +32,9 @@ namespace ndd { return 1.0f; } -#if defined(USE_AVX512) - inline std::vector quantize_avx512(const std::vector& input) { +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512F inline std::vector + quantize_avx512(const std::vector& input) { if(input.empty()) { return std::vector(); } @@ -182,7 +183,12 @@ namespace ndd { // Quantize FP32 vector to Binary (packed bits) inline std::vector quantize(const std::vector& input) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512f()) { + return quantize_avx512(input); + } + return quantize_avx2(input); +#elif defined(USE_AVX512) return quantize_avx512(input); #elif defined(USE_AVX2) return quantize_avx2(input); @@ -213,8 +219,9 @@ namespace ndd { #endif } -#if defined(USE_AVX512) - inline std::vector dequantize_avx512(const uint8_t* buffer, size_t dimension) { +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512F inline std::vector dequantize_avx512(const uint8_t* buffer, + size_t dimension) { std::vector output(dimension); size_t i = 0; @@ -370,7 +377,12 @@ namespace ndd { // Dequantize Binary to FP32 inline std::vector dequantize(const uint8_t* buffer, size_t dimension) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512f()) { + return dequantize_avx512(buffer, dimension); + } + return dequantize_avx2(buffer, dimension); +#elif defined(USE_AVX512) return dequantize_avx512(buffer, dimension); #elif defined(USE_AVX2) return dequantize_avx2(buffer, dimension); @@ -397,7 +409,44 @@ namespace ndd { } // Hamming distance implementation - inline float Hamming(const void* v1, const void* v2, const void* params) { +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512VPOPCNTDQ inline float HammingAVX512(const void* v1, + const void* v2, + const void* params) { + const size_t dim = *static_cast(params); + const uint64_t* p1 = static_cast(v1); + const uint64_t* p2 = static_cast(v2); + + size_t num_uint64 = (dim + 63) / 64; + float dist = 0; + size_t i = 0; + + __m512i acc = _mm512_setzero_si512(); + + for(; i + 8 <= num_uint64; i += 8) { + __m512i d1 = _mm512_loadu_si512((const __m512i*)&p1[i]); + __m512i d2 = _mm512_loadu_si512((const __m512i*)&p2[i]); + __m512i x = _mm512_xor_si512(d1, d2); + __m512i p = _mm512_popcnt_epi64(x); + acc = _mm512_add_epi64(acc, p); + } + + if(i < num_uint64) { + __mmask8 mask = (__mmask8)((1 << (num_uint64 - i)) - 1); + __m512i d1 = _mm512_maskz_loadu_epi64(mask, &p1[i]); + __m512i d2 = _mm512_maskz_loadu_epi64(mask, &p2[i]); + __m512i x = _mm512_xor_si512(d1, d2); + __m512i p = _mm512_popcnt_epi64(x); + acc = _mm512_add_epi64(acc, p); + i = num_uint64; + } + + dist += _mm512_reduce_add_epi64(acc); + return dist; + } +#endif + + inline float HammingBaseline(const void* v1, const void* v2, const void* params) { // params is expected to be a pointer to a struct where the first member is size_t // dim e.g. hnswlib::DistParams const size_t dim = *static_cast(params); @@ -612,6 +661,15 @@ namespace ndd { return dist; } + inline float Hamming(const void* v1, const void* v2, const void* params) { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512vpopcntdq()) { + return HammingAVX512(v1, v2, params); + } +#endif + return HammingBaseline(v1, v2, params); + } + // Wrappers inline float L2Sqr(const void* v1, const void* v2, const void* params) { return Hamming(v1, v2, params); diff --git a/src/quant/common.hpp b/src/quant/common.hpp index 5e9cf4949c..1e6f8d9f94 100644 --- a/src/quant/common.hpp +++ b/src/quant/common.hpp @@ -8,6 +8,8 @@ #include #include +#include "../utils/cpu_compat_check/cpu_runtime_dispatch.hpp" + #if defined(USE_AVX512) || defined(USE_AVX2) # include #endif @@ -191,7 +193,7 @@ namespace ndd { // Forward declarations for SIMD implementations inline float find_abs_max_scalar(const float* data, size_t size); -#if defined(USE_AVX512) +#if NDD_HAS_AVX512_VARIANTS inline float find_abs_max_avx512(const float* data, size_t size); #endif #if defined(USE_AVX2) @@ -206,7 +208,12 @@ namespace ndd { // Find absolute maximum value in a vector (for scaling) inline float find_abs_max(const float* data, size_t size) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512f()) { + return find_abs_max_avx512(data, size); + } + return find_abs_max_avx2(data, size); +#elif defined(USE_AVX512) return find_abs_max_avx512(data, size); #elif defined(USE_SVE2) return find_abs_max_sve(data, size); @@ -228,9 +235,9 @@ namespace ndd { return abs_max; } -#if defined(USE_AVX512) +#if NDD_HAS_AVX512_VARIANTS // AVX512 optimized absolute maximum finding - MAXIMUM register utilization - inline float find_abs_max_avx512(const float* data, size_t size) { + NDD_TARGET_AVX512F inline float find_abs_max_avx512(const float* data, size_t size) { if(size == 0) { return 0.0f; } diff --git a/src/quant/float16.hpp b/src/quant/float16.hpp index 5ff1c6b440..f95325337e 100644 --- a/src/quant/float16.hpp +++ b/src/quant/float16.hpp @@ -176,9 +176,9 @@ namespace ndd { } #endif -#if defined(USE_AVX512) +#if NDD_HAS_AVX512_VARIANTS // AVX512 optimized vector conversion FP16->FP32 - inline std::vector + NDD_TARGET_AVX512FP16 inline std::vector convert_vector_f16_f32_avx512(const std::vector& input) { std::vector output; output.resize(input.size()); @@ -225,7 +225,7 @@ namespace ndd { } // AVX512 optimized vector conversion FP32->FP16 - inline std::vector + NDD_TARGET_AVX512FP16 inline std::vector convert_vector_f32_f16_avx512(const std::vector& input) { std::vector output; output.resize(input.size()); @@ -278,10 +278,28 @@ namespace ndd { return convert_vector_f16_f32_neon(input); #endif -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512fp16()) { + return convert_vector_f16_f32_avx512(input); + } +#elif defined(USE_AVX512) return convert_vector_f16_f32_avx512(input); #endif +#if defined(USE_AVX2) + { + std::vector output(input.size()); + size_t i = 0; + const size_t vec_size = (input.size() / 8) * 8; + for(; i < vec_size; i += 8) { + __m128i v = _mm_loadu_si128(reinterpret_cast(&input[i])); + _mm256_storeu_ps(&output[i], _mm256_cvtph_ps(v)); + } + for(; i < input.size(); i++) output[i] = fp16_to_fp32(input[i]); + return output; + } +#endif + // Fallback scalar implementation std::vector output; output.resize(input.size()); @@ -296,10 +314,29 @@ namespace ndd { return convert_vector_f32_f16_neon(input); #endif -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512fp16()) { + return convert_vector_f32_f16_avx512(input); + } +#elif defined(USE_AVX512) return convert_vector_f32_f16_avx512(input); #endif +#if defined(USE_AVX2) + { + std::vector output(input.size()); + size_t i = 0; + const size_t vec_size = (input.size() / 8) * 8; + for(; i < vec_size; i += 8) { + __m256 v = _mm256_loadu_ps(&input[i]); + __m128i h = _mm256_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&output[i]), h); + } + for(; i < input.size(); i++) output[i] = fp32_to_fp16(input[i]); + return output; + } +#endif + // Fallback scalar implementation std::vector output; output.resize(input.size()); @@ -346,37 +383,65 @@ namespace ndd { float16x4_t out = vcvt_f16_f32(in); vst1_f16(reinterpret_cast<__fp16*>(&output[i]), out); } +#else +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512fp16()) { + const __m512 s512 = _mm512_set1_ps(scale); + const size_t vec_size512 = (input.size() / 64) * 64; + for(; i < vec_size512; i += 64) { + __m512 in0 = _mm512_mul_ps(_mm512_loadu_ps(&input[i]), s512); + __m512 in1 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 16]), s512); + __m512 in2 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 32]), s512); + __m512 in3 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 48]), s512); + + _mm256_storeu_ph(&output[i], + _mm512_cvtps_ph(in0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm256_storeu_ph(&output[i + 16], + _mm512_cvtps_ph(in1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm256_storeu_ph(&output[i + 32], + _mm512_cvtps_ph(in2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm256_storeu_ph(&output[i + 48], + _mm512_cvtps_ph(in3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + } else #elif defined(USE_AVX512) - const __m512 s = _mm512_set1_ps(scale); - size_t vec_size = (input.size() / 64) * 64; // 64 floats per iteration (4x unroll) - for(; i < vec_size; i += 64) { - __m512 in0 = _mm512_mul_ps(_mm512_loadu_ps(&input[i]), s); - __m512 in1 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 16]), s); - __m512 in2 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 32]), s); - __m512 in3 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 48]), s); - - __m256h out0 = - _mm512_cvtps_ph(in0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - __m256h out1 = - _mm512_cvtps_ph(in1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - __m256h out2 = - _mm512_cvtps_ph(in2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - __m256h out3 = - _mm512_cvtps_ph(in3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - - _mm256_storeu_ph(&output[i], out0); - _mm256_storeu_ph(&output[i + 16], out1); - _mm256_storeu_ph(&output[i + 32], out2); - _mm256_storeu_ph(&output[i + 48], out3); + { + const __m512 s512 = _mm512_set1_ps(scale); + const size_t vec_size512 = (input.size() / 64) * 64; + for(; i < vec_size512; i += 64) { + __m512 in0 = _mm512_mul_ps(_mm512_loadu_ps(&input[i]), s512); + __m512 in1 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 16]), s512); + __m512 in2 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 32]), s512); + __m512 in3 = _mm512_mul_ps(_mm512_loadu_ps(&input[i + 48]), s512); + + _mm256_storeu_ph(&output[i], + _mm512_cvtps_ph(in0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm256_storeu_ph(&output[i + 16], + _mm512_cvtps_ph(in1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm256_storeu_ph(&output[i + 32], + _mm512_cvtps_ph(in2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm256_storeu_ph(&output[i + 48], + _mm512_cvtps_ph(in3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + const size_t remaining_vec_size = (input.size() / 16) * 16; + for(; i < remaining_vec_size; i += 16) { + __m512 in = _mm512_mul_ps(_mm512_loadu_ps(&input[i]), s512); + _mm256_storeu_ph(&output[i], + _mm512_cvtps_ph(in, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } } - - size_t remaining_vec_size = (input.size() / 16) * 16; - for(; i < remaining_vec_size; i += 16) { - __m512 in = _mm512_mul_ps(_mm512_loadu_ps(&input[i]), s); - __m256h out = - _mm512_cvtps_ph(in, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - _mm256_storeu_ph(&output[i], out); +#endif +#if defined(USE_AVX2) + { + const __m256 s256 = _mm256_set1_ps(scale); + const size_t vec_size256 = (input.size() / 8) * 8; + for(; i < vec_size256; i += 8) { + __m256 v = _mm256_mul_ps(_mm256_loadu_ps(&input[i]), s256); + __m128i h = _mm256_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&output[i]), h); + } } +#endif #endif for(; i < input.size(); i++) { diff --git a/src/quant/float32.hpp b/src/quant/float32.hpp index a0553380a8..fb4d3861d8 100644 --- a/src/quant/float32.hpp +++ b/src/quant/float32.hpp @@ -17,17 +17,7 @@ namespace hnswlib { static std::vector quantize_to_int8(const void* in, size_t dim) { const float* f_in = static_cast(in); std::vector input(f_in, f_in + dim); -#if defined(USE_SVE2) - return ndd::quant::int8::quantize_vector_fp32_to_int8_buffer_sve(input); -#elif defined(USE_AVX512) - return ndd::quant::int8::quantize_vector_fp32_to_int8_buffer_avx512(input); -#elif defined(USE_AVX2) - return ndd::quant::int8::quantize_vector_fp32_to_int8_buffer_avx2(input); -#elif defined(USE_NEON) - return ndd::quant::int8::quantize_vector_fp32_to_int8_buffer_neon(input); -#else - return ndd::quant::int8::quantize_vector_fp32_to_int8_buffer(input); -#endif + return ndd::quant::int8::quantize_vector_fp32_to_int8_buffer_auto(input); } inline std::vector quantize(const std::vector& input) { @@ -168,8 +158,10 @@ namespace hnswlib { } #endif -#if defined(USE_AVX512) - static float L2SqrAVX512(const void* pVect1, const void* pVect2, size_t qty) { +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512F static float L2SqrAVX512(const void* pVect1, + const void* pVect2, + size_t qty) { const float* vec1 = reinterpret_cast(pVect1); const float* vec2 = reinterpret_cast(pVect2); @@ -198,7 +190,9 @@ namespace hnswlib { return res; } - static float InnerProductAVX512(const void* pVect1, const void* pVect2, size_t qty) { + NDD_TARGET_AVX512F static float InnerProductAVX512(const void* pVect1, + const void* pVect2, + size_t qty) { const float* vec1 = reinterpret_cast(pVect1); const float* vec2 = reinterpret_cast(pVect2); @@ -418,7 +412,12 @@ namespace hnswlib { #endif static float L2Sqr(const void* pVect1, const void* pVect2, size_t qty) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512f()) { + return L2SqrAVX512(pVect1, pVect2, qty); + } + return L2SqrAVX2(pVect1, pVect2, qty); +#elif defined(USE_AVX512) return L2SqrAVX512(pVect1, pVect2, qty); #elif defined(USE_SVE2) return L2SqrSVE(pVect1, pVect2, qty); @@ -432,7 +431,12 @@ namespace hnswlib { } static float InnerProduct(const void* pVect1, const void* pVect2, size_t qty) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512f()) { + return InnerProductAVX512(pVect1, pVect2, qty); + } + return InnerProductAVX2(pVect1, pVect2, qty); +#elif defined(USE_AVX512) return InnerProductAVX512(pVect1, pVect2, qty); #elif defined(USE_SVE2) return InnerProductSVE(pVect1, pVect2, qty); @@ -478,12 +482,96 @@ namespace hnswlib { 128; #endif - static void SimilarityBatchTiled(const void* query, - const void* const* vectors, - size_t count, - const void* params, - float* out, - bool l2_metric) { + static void SimilarityBatchTiledBaseline(const void* query, + const void* const* vectors, + size_t count, + const void* params, + float* out, + bool l2_metric); + +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512F static void SimilarityBatchTiledAVX512(const void* query, + const void* const* vectors, + size_t count, + const void* params, + float* out, + bool l2_metric) { + if(count == 0) { + return; + } + + const DistParams* dist_params = reinterpret_cast(params); + const size_t dim = dist_params->dim; + const float* query_vec = reinterpret_cast(query); + + std::vector dot_acc(count, 0.0f); + std::vector vec_sq_acc; + if(l2_metric) { + vec_sq_acc.assign(count, 0.0f); + } + + float query_sq_acc = 0.0f; + const size_t tile = std::min(dim, static_cast(1024)); + + for(size_t block_start = 0; block_start < dim; block_start += tile) { + const size_t block_len = std::min(tile, dim - block_start); + const float* q_ptr = query_vec + block_start; + + for(size_t d = 0; d < block_len; ++d) { + query_sq_acc += q_ptr[d] * q_ptr[d]; + } + + for(size_t i = 0; i < count; ++i) { + const float* v_ptr = reinterpret_cast(vectors[i]) + block_start; + float dot = dot_acc[i]; + float vec_sq = l2_metric ? vec_sq_acc[i] : 0.0f; + + size_t d = 0; + __m512 dot_vec = _mm512_setzero_ps(); + __m512 sq_vec = _mm512_setzero_ps(); + for(; d + 16 <= block_len; d += 16) { + __m512 qv = _mm512_loadu_ps(q_ptr + d); + __m512 vv = _mm512_loadu_ps(v_ptr + d); + dot_vec = _mm512_fmadd_ps(qv, vv, dot_vec); + if(l2_metric) { + sq_vec = _mm512_fmadd_ps(vv, vv, sq_vec); + } + } + dot += _mm512_reduce_add_ps(dot_vec); + if(l2_metric) { + vec_sq += _mm512_reduce_add_ps(sq_vec); + } + + for(; d < block_len; ++d) { + dot += q_ptr[d] * v_ptr[d]; + if(l2_metric) { + vec_sq += v_ptr[d] * v_ptr[d]; + } + } + + dot_acc[i] = dot; + if(l2_metric) { + vec_sq_acc[i] = vec_sq; + } + } + } + + for(size_t i = 0; i < count; ++i) { + if(l2_metric) { + out[i] = -(query_sq_acc + vec_sq_acc[i] - 2.0f * dot_acc[i]); + } else { + out[i] = dot_acc[i]; + } + } + } +#endif + + static void SimilarityBatchTiledBaseline(const void* query, + const void* const* vectors, + size_t count, + const void* params, + float* out, + bool l2_metric) { if(count == 0) { return; } @@ -615,6 +703,21 @@ namespace hnswlib { } } + static void SimilarityBatchTiled(const void* query, + const void* const* vectors, + size_t count, + const void* params, + float* out, + bool l2_metric) { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512f()) { + SimilarityBatchTiledAVX512(query, vectors, count, params, out, l2_metric); + return; + } +#endif + SimilarityBatchTiledBaseline(query, vectors, count, params, out, l2_metric); + } + static void L2SqrSimBatch(const void* query, const void* const* vectors, size_t count, diff --git a/src/quant/int16.hpp b/src/quant/int16.hpp index a80cf8342f..dd02380aff 100644 --- a/src/quant/int16.hpp +++ b/src/quant/int16.hpp @@ -55,8 +55,8 @@ namespace ndd { return buffer; } -#if defined(USE_AVX512) - inline std::vector +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512BW inline std::vector quantize_vector_fp32_to_int16_buffer_avx512(const std::vector& input) { if(input.empty()) { return std::vector(); @@ -315,7 +315,12 @@ namespace ndd { // Auto-select best quantization implementation for INT16 -> uint8_t buffer inline std::vector quantize_vector_fp32_to_int16_buffer_auto(const std::vector& input) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512bw()) { + return quantize_vector_fp32_to_int16_buffer_avx512(input); + } + return quantize_vector_fp32_to_int16_buffer(input); +#elif defined(USE_AVX512) return quantize_vector_fp32_to_int16_buffer_avx512(input); #elif defined(USE_SVE2) return quantize_vector_fp32_to_int16_buffer_sve(input); @@ -326,10 +331,10 @@ namespace ndd { #endif } -#if defined(USE_AVX512) +#if NDD_HAS_AVX512_VARIANTS // AVX512 optimized dequantization INT16 buffer -> FP32 vector - inline std::vector dequantize_int16_buffer_to_fp32_avx512(const uint8_t* buffer, - size_t dimension) { + NDD_TARGET_AVX512BW inline std::vector + dequantize_int16_buffer_to_fp32_avx512(const uint8_t* buffer, size_t dimension) { std::vector output(dimension); const int16_t* data_ptr = reinterpret_cast(buffer); float scale = extract_scale(buffer, dimension); @@ -468,7 +473,19 @@ namespace ndd { // Auto-select best dequantization implementation for INT16 buffer -> FP32 vector inline std::vector dequantize_int16_buffer_to_fp32(const uint8_t* buffer, size_t dimension) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512bw()) { + return dequantize_int16_buffer_to_fp32_avx512(buffer, dimension); + } + std::vector output(dimension); + const int16_t* data_ptr = reinterpret_cast(buffer); + float scale = extract_scale(buffer, dimension); + + for(size_t i = 0; i < dimension; ++i) { + output[i] = static_cast(data_ptr[i]) * scale; + } + return output; +#elif defined(USE_AVX512) return dequantize_int16_buffer_to_fp32_avx512(buffer, dimension); #elif defined(USE_SVE2) return dequantize_int16_buffer_to_fp32_sve(buffer, dimension); @@ -793,8 +810,94 @@ namespace ndd { return -L2Sqr(pVect1v, pVect2v, qty_ptr); } +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512VNNI static float InnerProductSimAVX512VNNI(const void* pVect1v, + const void* pVect2v, + const void* qty_ptr) { + const int16_t* pVect1 = (const int16_t*)pVect1v; + const int16_t* pVect2 = (const int16_t*)pVect2v; + const auto* params = static_cast(qty_ptr); + size_t qty = params->dim; + + float scale1 = extract_scale((const uint8_t*)pVect1, qty); + float scale2 = extract_scale((const uint8_t*)pVect2, qty); + + int64_t sum = 0; + size_t i = 0; + __m512i sum_vec0 = _mm512_setzero_si512(); + __m512i sum_vec1 = _mm512_setzero_si512(); + + for(; i + 64 <= qty; i += 64) { + __m512i v1_0 = _mm512_loadu_si512((const __m512i*)(pVect1 + i)); + __m512i v2_0 = _mm512_loadu_si512((const __m512i*)(pVect2 + i)); + __m512i v1_1 = _mm512_loadu_si512((const __m512i*)(pVect1 + i + 32)); + __m512i v2_1 = _mm512_loadu_si512((const __m512i*)(pVect2 + i + 32)); + + __m512i prod0 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), v1_0, v2_0); + __m512i prod1 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), v1_1, v2_1); + + __m512i prod0_lo = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(prod0)); + __m512i prod0_hi = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(prod0, 1)); + __m512i prod1_lo = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(prod1)); + __m512i prod1_hi = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(prod1, 1)); + + sum_vec0 = _mm512_add_epi64(sum_vec0, prod0_lo); + sum_vec1 = _mm512_add_epi64(sum_vec1, prod0_hi); + sum_vec0 = _mm512_add_epi64(sum_vec0, prod1_lo); + sum_vec1 = _mm512_add_epi64(sum_vec1, prod1_hi); + } + + for(; i + 32 <= qty; i += 32) { + __m512i v1 = _mm512_loadu_si512((const __m512i*)(pVect1 + i)); + __m512i v2 = _mm512_loadu_si512((const __m512i*)(pVect2 + i)); + + __m512i prod = _mm512_dpwssd_epi32(_mm512_setzero_si512(), v1, v2); + __m512i prod_lo = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(prod)); + __m512i prod_hi = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(prod, 1)); + + sum_vec0 = _mm512_add_epi64(sum_vec0, prod_lo); + sum_vec1 = _mm512_add_epi64(sum_vec1, prod_hi); + } + + sum = _mm512_reduce_add_epi64(sum_vec0) + _mm512_reduce_add_epi64(sum_vec1); + + for(; i + 16 <= qty; i += 16) { + __m256i v1 = _mm256_loadu_si256((const __m256i*)(pVect1 + i)); + __m256i v2 = _mm256_loadu_si256((const __m256i*)(pVect2 + i)); + __m256i prod = _mm256_madd_epi16(v1, v2); + + __m256i prod_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(prod)); + __m256i prod_hi = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(prod, 1)); + + __m128i sum_128 = _mm_add_epi64(_mm256_castsi256_si128(prod_lo), + _mm256_extracti128_si256(prod_lo, 1)); + __m128i high64 = _mm_unpackhi_epi64(sum_128, sum_128); + sum_128 = _mm_add_epi64(sum_128, high64); + sum += _mm_cvtsi128_si64(sum_128); + + sum_128 = _mm_add_epi64(_mm256_castsi256_si128(prod_hi), + _mm256_extracti128_si256(prod_hi, 1)); + high64 = _mm_unpackhi_epi64(sum_128, sum_128); + sum_128 = _mm_add_epi64(sum_128, high64); + sum += _mm_cvtsi128_si64(sum_128); + } + + for(; i < qty; i++) { + sum += static_cast(pVect1[i]) * static_cast(pVect2[i]); + } + + float combined_scale = scale1 * scale2; + return static_cast(sum) * combined_scale; + } +#endif + static float InnerProductSim(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512vnni()) { + return InnerProductSimAVX512VNNI(pVect1v, pVect2v, qty_ptr); + } +#endif const int16_t* pVect1 = (const int16_t*)pVect1v; const int16_t* pVect2 = (const int16_t*)pVect2v; const auto* params = static_cast(qty_ptr); diff --git a/src/quant/int8.hpp b/src/quant/int8.hpp index f0ba9733b7..70562b76a4 100644 --- a/src/quant/int8.hpp +++ b/src/quant/int8.hpp @@ -54,8 +54,8 @@ namespace ndd { return buffer; } -#if defined(USE_AVX512) - inline std::vector +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512BW inline std::vector quantize_vector_fp32_to_int8_buffer_avx512(const std::vector& input) { if(input.empty()) { return std::vector(); @@ -398,7 +398,12 @@ namespace ndd { // Auto-select best quantization implementation for INT8 -> uint8_t buffer inline std::vector quantize_vector_fp32_to_int8_buffer_auto(const std::vector& input) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512bw()) { + return quantize_vector_fp32_to_int8_buffer_avx512(input); + } + return quantize_vector_fp32_to_int8_buffer_avx2(input); +#elif defined(USE_AVX512) return quantize_vector_fp32_to_int8_buffer_avx512(input); #elif defined(USE_SVE2) return quantize_vector_fp32_to_int8_buffer_sve(input); @@ -411,10 +416,10 @@ namespace ndd { #endif } -#if defined(USE_AVX512) +#if NDD_HAS_AVX512_VARIANTS // AVX512 optimized dequantization INT8 buffer -> FP32 vector - inline std::vector dequantize_int8_buffer_to_fp32_avx512(const uint8_t* buffer, - size_t dimension) { + NDD_TARGET_AVX512BW inline std::vector + dequantize_int8_buffer_to_fp32_avx512(const uint8_t* buffer, size_t dimension) { std::vector output(dimension); const int8_t* data_ptr = reinterpret_cast(buffer); float scale = extract_scale(buffer, dimension); @@ -558,7 +563,19 @@ namespace ndd { // Auto-select best dequantization implementation for INT8 buffer -> FP32 vector inline std::vector dequantize_int8_buffer_to_fp32(const uint8_t* buffer, size_t dimension) { -#if defined(USE_AVX512) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512bw()) { + return dequantize_int8_buffer_to_fp32_avx512(buffer, dimension); + } + std::vector output(dimension); + const int8_t* data_ptr = reinterpret_cast(buffer); + float scale = extract_scale(buffer, dimension); + + for(size_t i = 0; i < dimension; ++i) { + output[i] = static_cast(data_ptr[i]) * scale; + } + return output; +#elif defined(USE_AVX512) return dequantize_int8_buffer_to_fp32_avx512(buffer, dimension); #elif defined(USE_SVE2) return dequantize_int8_buffer_to_fp32_sve(buffer, dimension); @@ -762,8 +779,90 @@ namespace ndd { return -L2Sqr(pVect1v, pVect2v, qty_ptr); } +#if NDD_HAS_AVX512_VARIANTS + NDD_TARGET_AVX512VNNI static float InnerProductSimAVX512VNNI(const void* pVect1v, + const void* pVect2v, + const void* qty_ptr) { + const int8_t* pVect1 = (const int8_t*)pVect1v; + const int8_t* pVect2 = (const int8_t*)pVect2v; + const auto* params = static_cast(qty_ptr); + size_t qty = params->dim; + + float scale1 = extract_scale((const uint8_t*)pVect1, qty); + float scale2 = extract_scale((const uint8_t*)pVect2, qty); + + int32_t sum = 0; + size_t i = 0; + __m512i dot_acc0 = _mm512_setzero_si512(); + __m512i dot_acc1 = _mm512_setzero_si512(); + __m512i dot_acc2 = _mm512_setzero_si512(); + __m512i dot_acc3 = _mm512_setzero_si512(); + __m512i sum2_acc0 = _mm512_setzero_si512(); + __m512i sum2_acc1 = _mm512_setzero_si512(); + __m512i sum2_acc2 = _mm512_setzero_si512(); + __m512i sum2_acc3 = _mm512_setzero_si512(); + const __m512i sign_flip = _mm512_set1_epi8(static_cast(0x80)); + const __m512i ones_u8 = _mm512_set1_epi8(static_cast(0x01)); + + for(; i + 256 <= qty; i += 256) { + __m512i v1_0 = _mm512_loadu_si512((const __m512i*)(pVect1 + i)); + __m512i v2_0 = _mm512_loadu_si512((const __m512i*)(pVect2 + i)); + __m512i v1_1 = _mm512_loadu_si512((const __m512i*)(pVect1 + i + 64)); + __m512i v2_1 = _mm512_loadu_si512((const __m512i*)(pVect2 + i + 64)); + __m512i v1_2 = _mm512_loadu_si512((const __m512i*)(pVect1 + i + 128)); + __m512i v2_2 = _mm512_loadu_si512((const __m512i*)(pVect2 + i + 128)); + __m512i v1_3 = _mm512_loadu_si512((const __m512i*)(pVect1 + i + 192)); + __m512i v2_3 = _mm512_loadu_si512((const __m512i*)(pVect2 + i + 192)); + + __m512i v1_u8_0 = _mm512_xor_si512(v1_0, sign_flip); + __m512i v1_u8_1 = _mm512_xor_si512(v1_1, sign_flip); + __m512i v1_u8_2 = _mm512_xor_si512(v1_2, sign_flip); + __m512i v1_u8_3 = _mm512_xor_si512(v1_3, sign_flip); + + dot_acc0 = _mm512_dpbusd_epi32(dot_acc0, v1_u8_0, v2_0); + dot_acc1 = _mm512_dpbusd_epi32(dot_acc1, v1_u8_1, v2_1); + dot_acc2 = _mm512_dpbusd_epi32(dot_acc2, v1_u8_2, v2_2); + dot_acc3 = _mm512_dpbusd_epi32(dot_acc3, v1_u8_3, v2_3); + + sum2_acc0 = _mm512_dpbusd_epi32(sum2_acc0, ones_u8, v2_0); + sum2_acc1 = _mm512_dpbusd_epi32(sum2_acc1, ones_u8, v2_1); + sum2_acc2 = _mm512_dpbusd_epi32(sum2_acc2, ones_u8, v2_2); + sum2_acc3 = _mm512_dpbusd_epi32(sum2_acc3, ones_u8, v2_3); + } + + __m512i dot_acc = _mm512_add_epi32(_mm512_add_epi32(dot_acc0, dot_acc1), + _mm512_add_epi32(dot_acc2, dot_acc3)); + __m512i sum2_acc = _mm512_add_epi32(_mm512_add_epi32(sum2_acc0, sum2_acc1), + _mm512_add_epi32(sum2_acc2, sum2_acc3)); + + for(; i + 64 <= qty; i += 64) { + __m512i v1 = _mm512_loadu_si512((const __m512i*)(pVect1 + i)); + __m512i v2 = _mm512_loadu_si512((const __m512i*)(pVect2 + i)); + + __m512i v1_u8 = _mm512_xor_si512(v1, sign_flip); + dot_acc = _mm512_dpbusd_epi32(dot_acc, v1_u8, v2); + sum2_acc = _mm512_dpbusd_epi32(sum2_acc, ones_u8, v2); + } + + int32_t dot_u = _mm512_reduce_add_epi32(dot_acc); + int32_t sum2 = _mm512_reduce_add_epi32(sum2_acc); + sum = dot_u - 128 * sum2; + + for(; i < qty; i++) { + sum += static_cast(pVect1[i]) * static_cast(pVect2[i]); + } + + return (static_cast(sum) * scale1) * scale2; + } +#endif + static float InnerProductSim(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512vnni()) { + return InnerProductSimAVX512VNNI(pVect1v, pVect2v, qty_ptr); + } +#endif const int8_t* pVect1 = (const int8_t*)pVect1v; const int8_t* pVect2 = (const int8_t*)pVect2v; const auto* params = static_cast(qty_ptr); diff --git a/src/sparse/inverted_index.cpp b/src/sparse/inverted_index.cpp index e0394f8510..43350f8488 100644 --- a/src/sparse/inverted_index.cpp +++ b/src/sparse/inverted_index.cpp @@ -13,6 +13,7 @@ */ #include "inverted_index.hpp" +#include "utils/cpu_compat_check/cpu_runtime_dispatch.hpp" #include #include @@ -793,11 +794,70 @@ namespace ndd { // SIMD helpers // ========================================================================= +#if NDD_HAS_AVX512_VARIANTS + namespace { + NDD_TARGET_AVX512F size_t find_doc_id_simd_avx512(const uint32_t* doc_ids, + size_t size, + size_t start_idx, + uint32_t target) { + size_t idx = start_idx; + const size_t simd_width = 16; + __m512i target_vec = _mm512_set1_epi32((int)target); + + while(idx + simd_width <= size) { + __m512i data_vec = _mm512_loadu_si512(doc_ids + idx); + __mmask16 mask = _mm512_cmpge_epu32_mask(data_vec, target_vec); + + if(mask != 0) { + return idx + __builtin_ctz(mask); + } + idx += simd_width; + } + + while(idx < size && doc_ids[idx] < target) { + ++idx; + } + return idx; + } + + NDD_TARGET_AVX512BW size_t find_next_live_simd_avx512(const uint8_t* values, + size_t size, + size_t start_idx) { + size_t idx = start_idx; + const size_t simd_width = 64; + __m512i zero_vec = _mm512_setzero_si512(); + + while(idx + simd_width <= size) { + __m512i data_vec = _mm512_loadu_si512(values + idx); + __mmask64 mask = _mm512_cmpneq_epu8_mask(data_vec, zero_vec); + + if(mask != 0) { + return idx + __builtin_ctzll(mask); + } + idx += simd_width; + } + + while(idx < size) { + if(values[idx] != 0) { + return idx; + } + ++idx; + } + return idx; + } + } // namespace +#endif + size_t InvertedIndex::findDocIdSIMD(const uint32_t* doc_ids, size_t size, size_t start_idx, uint32_t target) const { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512f()) { + return find_doc_id_simd_avx512(doc_ids, size, start_idx, target); + } +#endif size_t idx = start_idx; #if defined(USE_AVX512) @@ -881,6 +941,11 @@ namespace ndd { size_t size, size_t start_idx) const { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + if(ndd::cpu::use_avx512bw()) { + return find_next_live_simd_avx512(values, size, start_idx); + } +#endif size_t idx = start_idx; #if defined(USE_AVX512) diff --git a/src/utils/cpu_compat_check/check_avx_compat.hpp b/src/utils/cpu_compat_check/check_avx_compat.hpp index 8a8dbd868e..3363549054 100644 --- a/src/utils/cpu_compat_check/check_avx_compat.hpp +++ b/src/utils/cpu_compat_check/check_avx_compat.hpp @@ -13,6 +13,7 @@ static const uint32_t CPUID_SUBLEAF_0 = 0; static const uint32_t EBX_AVX2_BIT = 5; static const uint32_t EBX_AVX512F_BIT = 16; +static const uint32_t EBX_AVX512DQ_BIT = 17; static const uint32_t EBX_AVX512BW_BIT = 30; static const uint32_t ECX_AVX512VNNI_BIT = 11; static const uint32_t ECX_AVX512VPOPCNTDQ_BIT = 14; @@ -24,32 +25,32 @@ static const uint32_t EDX_AVX512FP16_BIT = 23; * Always return false if these functions are called */ -int check_avx2_support(void) { +inline int check_avx2_support(void) { LOG_ERROR(1710, "Unexpected AVX compatibility probe call to " << __func__); return false; } -int check_avx512_support(void) { +inline int check_avx512_support(void) { LOG_ERROR(1711, "Unexpected AVX compatibility probe call to " << __func__); return false; } -int check_avx512_fp16_support(void) { +inline int check_avx512_fp16_support(void) { LOG_ERROR(1712, "Unexpected AVX compatibility probe call to " << __func__); return false; } -int check_avx512_vnni_support(void) { +inline int check_avx512_vnni_support(void) { LOG_ERROR(1713, "Unexpected AVX compatibility probe call to " << __func__); return false; } -int check_avx512_bw_support(void) { +inline int check_avx512_bw_support(void) { LOG_ERROR(1714, "Unexpected AVX compatibility probe call to " << __func__); return false; } -int check_avx512_vpopcntdq_support(void) { +inline int check_avx512_vpopcntdq_support(void) { LOG_ERROR(1715, "Unexpected AVX compatibility probe call to " << __func__); return false; } @@ -172,6 +173,12 @@ static int cpu_has_avx512f(void) { return ((ebx >> EBX_AVX512F_BIT) & 1); } +static int cpu_has_avx512dq(void) { + uint32_t eax, ebx, ecx, edx; + cpuid_ex(CPUID_EXT_FEATURES_LEAF, CPUID_SUBLEAF_0, &eax, &ebx, &ecx, &edx); + return ((ebx >> EBX_AVX512DQ_BIT) & 1); +} + /** * True if CPU has AVX512f and fp16 */ @@ -290,7 +297,7 @@ static void run_one_avx512vpopcntdq_instruction(void) { * ///////////////////////////////////////////////////////////////// */ -int check_avx2_support(void) { +inline int check_avx2_support(void) { int ret = false; if(!cpu_has_avx2()) { @@ -316,7 +323,7 @@ int check_avx2_support(void) { * Returns true if AVX-512 is supported and usable (AVX-512F + OS state). * Should PASS on CPUs with AVX-512 but WITHOUT AVX512_FP16. */ -int check_avx512_support(void) { +inline int check_avx512_support(void) { int ret = false; if(!cpu_has_avx512f()) { @@ -338,7 +345,7 @@ int check_avx512_support(void) { return ret; } -int check_avx512_fp16_support(void) { +inline int check_avx512_fp16_support(void) { int ret = false; if(!is_intel_cpu()) { @@ -365,7 +372,7 @@ int check_avx512_fp16_support(void) { return ret; } -int check_avx512_vnni_support(void) { +inline int check_avx512_vnni_support(void) { int ret = false; if(!cpu_has_avx512f()) { @@ -392,7 +399,7 @@ int check_avx512_vnni_support(void) { return ret; } -int check_avx512_bw_support(void) { +inline int check_avx512_bw_support(void) { int ret = false; if(!cpu_has_avx512f()) { @@ -419,7 +426,7 @@ int check_avx512_bw_support(void) { return ret; } -int check_avx512_vpopcntdq_support(void) { +inline int check_avx512_vpopcntdq_support(void) { int ret = false; if(!cpu_has_avx512f()) { @@ -453,11 +460,11 @@ int check_avx512_vpopcntdq_support(void) { * (P-cores), and AMD Zen 5. */ -bool is_avx2_compatible() { +inline bool is_avx2_compatible() { return check_avx2_support(); } -bool is_avx512_compatible() { +inline bool is_avx512_compatible() { return check_avx2_support() && check_avx512_support() && check_avx512_fp16_support() && check_avx512_vnni_support() && check_avx512_bw_support() && check_avx512_vpopcntdq_support(); diff --git a/src/utils/cpu_compat_check/cpu_runtime_dispatch.cpp b/src/utils/cpu_compat_check/cpu_runtime_dispatch.cpp new file mode 100644 index 0000000000..0f24a1f345 --- /dev/null +++ b/src/utils/cpu_compat_check/cpu_runtime_dispatch.cpp @@ -0,0 +1,251 @@ +#include "cpu_runtime_dispatch.hpp" + +#include +#include + +#include "check_avx_compat.hpp" +#include "../log.hpp" + +namespace { +std::once_flag g_dispatch_init_once; +ndd::cpu::X86SimdCaps g_active_x86_caps; +bool g_dispatch_initialized = false; +bool g_runtime_x86_compatible = false; +} // namespace + +namespace ndd::cpu { + +X86SimdCaps probe_x86_simd_caps() { + X86SimdCaps caps; + +#if defined(__x86_64__) || defined(_M_X64) + caps.os_avx = os_supports_avx(); + caps.os_avx512_state = os_supports_avx512_state(); + caps.avx2 = cpu_has_avx2(); + caps.avx512f = cpu_has_avx512f(); + caps.avx512dq = cpu_has_avx512dq(); + caps.avx512bw = cpu_has_avx512bw(); + caps.avx512vnni = cpu_has_avx512vnni(); + caps.avx512fp16 = cpu_has_avx512f_and_fp16(); + caps.avx512vpopcntdq = cpu_has_avx512vpopcntdq(); +#endif + + return caps; +} + +X86SimdCaps compute_active_x86_simd_caps(const X86SimdCaps& detected) { + X86SimdCaps active; + active.os_avx = detected.os_avx; + active.os_avx512_state = detected.os_avx512_state; + +#if defined(__x86_64__) || defined(_M_X64) + if(!(detected.avx2 && detected.os_avx)) { + return active; + } + + active.avx2 = true; + + if(!(detected.avx512f && detected.avx512dq && detected.os_avx512_state)) { + return active; + } + + active.avx512f = true; + active.avx512dq = true; + active.avx512bw = detected.avx512bw; + active.avx512vnni = detected.avx512vnni; + active.avx512fp16 = detected.avx512fp16; + active.avx512vpopcntdq = detected.avx512vpopcntdq; +#else + (void)detected; +#endif + + return active; +} + +void bind_x86_dispatch(const X86SimdCaps& detected) { + g_active_x86_caps = compute_active_x86_simd_caps(detected); + g_dispatch_initialized = true; +} + +static std::string join_flags(const std::vector& flags) { + std::ostringstream oss; + for(size_t i = 0; i < flags.size(); ++i) { + if(i != 0) { + oss << ", "; + } + oss << flags[i]; + } + return oss.str(); +} + +bool initialize_cpu_dispatch() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + std::call_once(g_dispatch_init_once, []() { + const X86SimdCaps detected = probe_x86_simd_caps(); + bind_x86_dispatch(detected); + + if(!g_active_x86_caps.avx2) { + if(!detected.avx2) { + LOG_ERROR("Runtime x86 dispatch requires AVX2 support on the host CPU"); + } else if(!detected.os_avx) { + LOG_ERROR("Runtime x86 dispatch requires AVX state support from the OS"); + } else { + LOG_ERROR("Runtime x86 dispatch could not enable the AVX2 baseline"); + } + g_runtime_x86_compatible = false; + return; + } + + g_runtime_x86_compatible = true; + LOG_INFO("Runtime x86 dispatch active with CPU flags: " + << join_flags(serialize_active_cpu_flags(g_active_x86_caps))); + + if(detected.avx512f && detected.os_avx512_state) { + std::vector downgraded; + if(!g_active_x86_caps.avx512dq) { + downgraded.push_back("avx512dq"); + } + if(!g_active_x86_caps.avx512bw) { + downgraded.push_back("avx512bw"); + } + if(!g_active_x86_caps.avx512vnni) { + downgraded.push_back("avx512vnni"); + } + if(!g_active_x86_caps.avx512fp16) { + downgraded.push_back("avx512fp16"); + } + if(!g_active_x86_caps.avx512vpopcntdq) { + downgraded.push_back("avx512vpopcntdq"); + } + if(!downgraded.empty()) { + LOG_WARN("Runtime x86 dispatch is falling back to AVX2 for unsupported AVX512 " + "subextensions: " + << join_flags(downgraded)); + } + } else { + LOG_INFO("Runtime x86 dispatch is using the AVX2 baseline only"); + } + }); + + return g_runtime_x86_compatible; +#else + return true; +#endif +} + +const X86SimdCaps& get_active_x86_simd_caps() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + initialize_cpu_dispatch(); +#endif + return g_active_x86_caps; +} + +std::vector serialize_active_cpu_flags(const X86SimdCaps& caps) { + std::vector flags; + + if(caps.avx2) { + flags.push_back("avx2"); + } + if(caps.avx512f) { + flags.push_back("avx512f"); + } + if(caps.avx512dq) { + flags.push_back("avx512dq"); + } + if(caps.avx512bw) { + flags.push_back("avx512bw"); + } + if(caps.avx512vnni) { + flags.push_back("avx512vnni"); + } + if(caps.avx512fp16) { + flags.push_back("avx512fp16"); + } + if(caps.avx512vpopcntdq) { + flags.push_back("avx512vpopcntdq"); + } + + return flags; +} + +std::vector get_default_active_cpu_flags() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + return serialize_active_cpu_flags(get_active_x86_simd_caps()); +#elif defined(USE_AVX512) + return {"avx2", + "avx512f", + "avx512dq", + "avx512bw", + "avx512vnni", + "avx512fp16", + "avx512vpopcntdq"}; +#elif defined(USE_AVX2) + return {"avx2"}; +#elif defined(USE_SVE2) + return {"sve2"}; +#elif defined(USE_NEON) + return {"neon"}; +#else + return {}; +#endif +} + +std::vector get_active_cpu_flags() { return get_default_active_cpu_flags(); } + +bool use_avx2() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + return get_active_x86_simd_caps().avx2; +#else + return false; +#endif +} + +bool use_avx512f() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + return get_active_x86_simd_caps().avx512f; +#else + return false; +#endif +} + +bool use_avx512dq() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + return get_active_x86_simd_caps().avx512dq; +#else + return false; +#endif +} + +bool use_avx512bw() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + return get_active_x86_simd_caps().avx512bw; +#else + return false; +#endif +} + +bool use_avx512vnni() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + return get_active_x86_simd_caps().avx512vnni; +#else + return false; +#endif +} + +bool use_avx512fp16() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + return get_active_x86_simd_caps().avx512fp16; +#else + return false; +#endif +} + +bool use_avx512vpopcntdq() { +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + return get_active_x86_simd_caps().avx512vpopcntdq; +#else + return false; +#endif +} + +} // namespace ndd::cpu diff --git a/src/utils/cpu_compat_check/cpu_runtime_dispatch.hpp b/src/utils/cpu_compat_check/cpu_runtime_dispatch.hpp new file mode 100644 index 0000000000..2af84e03cc --- /dev/null +++ b/src/utils/cpu_compat_check/cpu_runtime_dispatch.hpp @@ -0,0 +1,57 @@ +#pragma once + +#include +#include + +namespace ndd::cpu { + +struct X86SimdCaps { + bool avx2 = false; + bool avx512f = false; + bool avx512dq = false; + bool avx512bw = false; + bool avx512vnni = false; + bool avx512fp16 = false; + bool avx512vpopcntdq = false; + bool os_avx = false; + bool os_avx512_state = false; +}; + +X86SimdCaps probe_x86_simd_caps(); +X86SimdCaps compute_active_x86_simd_caps(const X86SimdCaps& detected); +void bind_x86_dispatch(const X86SimdCaps& detected); +bool initialize_cpu_dispatch(); + +const X86SimdCaps& get_active_x86_simd_caps(); +std::vector get_active_cpu_flags(); +std::vector serialize_active_cpu_flags(const X86SimdCaps& caps); +std::vector get_default_active_cpu_flags(); + +bool use_avx2(); +bool use_avx512f(); +bool use_avx512dq(); +bool use_avx512bw(); +bool use_avx512vnni(); +bool use_avx512fp16(); +bool use_avx512vpopcntdq(); + +} // namespace ndd::cpu + +#if defined(USE_AVX512) || defined(NDD_COMPILE_AVX512_VARIANTS) +# define NDD_HAS_AVX512_VARIANTS 1 +#else +# define NDD_HAS_AVX512_VARIANTS 0 +#endif + +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__GNUC__) || defined(__clang__)) +# define NDD_TARGET_ATTR(features) __attribute__((target(features), noinline)) +#else +# define NDD_TARGET_ATTR(features) +#endif + +#define NDD_TARGET_AVX512F NDD_TARGET_ATTR("avx512f,avx512dq") +#define NDD_TARGET_AVX512BW NDD_TARGET_ATTR("avx512f,avx512dq,avx512bw") +#define NDD_TARGET_AVX512VNNI NDD_TARGET_ATTR("avx512f,avx512dq,avx512vnni") +#define NDD_TARGET_AVX512BW_VNNI NDD_TARGET_ATTR("avx512f,avx512dq,avx512bw,avx512vnni") +#define NDD_TARGET_AVX512FP16 NDD_TARGET_ATTR("avx512f,avx512dq,avx512fp16") +#define NDD_TARGET_AVX512VPOPCNTDQ NDD_TARGET_ATTR("avx512f,avx512dq,avx512vpopcntdq") diff --git a/src/utils/system_sanity/system_sanity.cpp b/src/utils/system_sanity/system_sanity.cpp index 9a1247115f..62cc671bfd 100644 --- a/src/utils/system_sanity/system_sanity.cpp +++ b/src/utils/system_sanity/system_sanity.cpp @@ -19,6 +19,7 @@ #include "utils/settings.hpp" #include "utils/log.hpp" +#include "utils/cpu_compat_check/cpu_runtime_dispatch.hpp" #include "utils/cpu_compat_check/check_avx_compat.hpp" #include "utils/cpu_compat_check/check_arm_compat.hpp" #include "utils/system_sanity/system_sanity.hpp" @@ -26,7 +27,11 @@ static bool is_cpu_compatible() { bool ret = true; -#if defined(USE_AVX2) && (defined(__x86_64__) || defined(_M_X64)) +#if defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) + ret &= ndd::cpu::initialize_cpu_dispatch(); +#endif + +#if defined(USE_AVX2) && !defined(NDD_RUNTIME_X86_DISPATCH) && (defined(__x86_64__) || defined(_M_X64)) ret &= is_avx2_compatible(); #endif //AVX2 checks