diff --git a/ggml/src/ggml-threading.cpp b/ggml/src/ggml-threading.cpp index 25a19eedb90..893aa9329ba 100644 --- a/ggml/src/ggml-threading.cpp +++ b/ggml/src/ggml-threading.cpp @@ -1,5 +1,8 @@ #include "ggml-threading.h" +#include #include +#include +#include std::mutex ggml_critical_section_mutex; @@ -10,3 +13,45 @@ void ggml_critical_section_start() { void ggml_critical_section_end(void) { ggml_critical_section_mutex.unlock(); } + +size_t ggml_quantize_chunk_mt( + enum ggml_type type, + const float * src, + void * dst, + int64_t start, + int64_t nrows, + int64_t n_per_row, + const float * imatrix, + int n_threads) { + if (n_threads <= 1 || nrows <= 1) { + return ggml_quantize_chunk(type, src, dst, start, nrows, n_per_row, imatrix); + } + + const int n_t = std::min((int64_t) n_threads, nrows); + const int64_t chunk = (nrows + n_t - 1) / n_t; + + std::vector results(n_t, 0); + std::vector threads; + threads.reserve(n_t - 1); + + auto worker = [&](int t) { + const int64_t r0 = (int64_t)t * chunk; + const int64_t r1 = std::min(r0 + chunk, nrows); + const int64_t nrows_t = r1 - r0; + results[t] = ggml_quantize_chunk(type, src, dst, start + r0 * n_per_row, nrows_t, n_per_row, imatrix); + }; + + for (int t = 1; t < n_t; ++t) { + threads.emplace_back(worker, t); + } + worker(0); + for (auto & th : threads) { + th.join(); + } + + size_t total = 0; + for (int t = 0; t < n_t; ++t) { + total += results[t]; + } + return total; +} diff --git a/ggml/src/ggml-threading.h b/ggml/src/ggml-threading.h index dec2c8840aa..79c88aef750 100644 --- a/ggml/src/ggml-threading.h +++ b/ggml/src/ggml-threading.h @@ -9,6 +9,29 @@ extern "C" { GGML_API void ggml_critical_section_start(void); GGML_API void ggml_critical_section_end(void); +// Parallel version of ggml_quantize_chunk. +// +// Splits [start, start + nrows * n_per_row) across n_threads worker threads, +// each calling ggml_quantize_chunk on its row range. Threads write to +// non-overlapping regions of dst, so no locking is required. +// +// Falls back to the single-threaded ggml_quantize_chunk when n_threads <= 1 +// or nrows <= 1. The primary motivation is iq4_nl, whose per-block NL search +// makes single-threaded throughput ~95x slower than other 4-bit types; this +// function recovers near-linear scaling with thread count. +// +// imatrix may be NULL for types that do not require it. +// Returns total bytes written (same contract as ggml_quantize_chunk). +GGML_API size_t ggml_quantize_chunk_mt( + enum ggml_type type, + const float * src, + void * dst, + int64_t start, + int64_t nrows, + int64_t n_per_row, + const float * imatrix, + int n_threads); + #ifdef __cplusplus } #endif