Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
146 commits
Select commit Hold shift + click to select a range
e90929a
Add library for IVF-RaBitQ
jamxia155 Nov 3, 2025
4a5c6dc
Add benchmarking executables for IVF-RaBitQ
jamxia155 Nov 3, 2025
ec88c25
Add --executable-dir option
tfeher Feb 12, 2025
bcc80e2
Add documentation about how to describe new datasets
tfeher Feb 12, 2025
4db552f
fix style
tfeher Feb 12, 2025
fbfe242
update docstring
tfeher Apr 1, 2025
5c53816
Do not prompt for executable-dir
tfeher Jul 9, 2025
285a9dd
Enable IVF-RaBitQ in cuvs_bench python wrapper
jamxia155 Nov 3, 2025
9f5a3d2
Use SPDX for copyright headers
jamxia155 Nov 3, 2025
133f808
Add documentation for 3rd-party dependency
jamxia155 Nov 3, 2025
09075bc
Add FAISS CPU IVF-RaBitQ algorithm
jamxia155 Nov 4, 2025
4884e89
Enable FAISS CPU IVF-RaBitQ in cuvs_bench python wrapper
jamxia155 Nov 5, 2025
cbd2a05
Rename parameter for consistency
jamxia155 Nov 6, 2025
7249455
Fix cuVS build issues with RaBitQ (#4)
Stardust-SJF Nov 12, 2025
1a26a71
Handle host and device data in build
jamxia155 Nov 12, 2025
d49bd0b
Merge remote-tracking branch 'Stardust-SJF_fork/jamxia_cuvs_ivf_rabit…
jamxia155 Nov 13, 2025
794b421
Disable separable compilation for IVF-RaBitQ code
jamxia155 Nov 13, 2025
f1fc50b
Brev benchmark instructions (#1)
tfeher Nov 13, 2025
928945b
Remove outdated instructions
jamxia155 Nov 13, 2025
9a5c0ef
Plumbing for passing raft handle to IVF-RaBitQ
jamxia155 Nov 18, 2025
389c917
Update rotator_gpu class
jamxia155 Nov 18, 2025
114d560
Migrate RotatorGPU class to RAFT
jamxia155 Nov 19, 2025
db8a437
Remove cuBLAS from RotatorGPU class
jamxia155 Nov 19, 2025
b263628
Remove Eigen dependency in `DataQuantizerGPU`
jamxia155 Nov 20, 2025
8404b99
Remove uses of Eigen library
jamxia155 Nov 20, 2025
035e978
Remove dependency `Eigen`
jamxia155 Nov 20, 2025
c73ef60
(WIP) Add tests for IVF-RaBitQ
jamxia155 Nov 21, 2025
6cac52d
Merge remote-tracking branch 'upstream/main' into jamxia_cuvs_ivf_rabitq
jamxia155 Nov 21, 2025
f03ec47
Replace header guards with `#pragma once`
jamxia155 Nov 21, 2025
f0be854
Add namespace
jamxia155 Nov 22, 2025
1c444c7
Add tests for IVF-RaBitQ
jamxia155 Nov 24, 2025
a680ca7
Check errors on CUDA API calls and kernel launches
jamxia155 Nov 24, 2025
596997d
Rename class member cudaStream_t in RotatorGPU
jamxia155 Nov 24, 2025
1cb8299
Migrate IVFGPU to async CUDA calls and launches
jamxia155 Nov 24, 2025
ef476ba
Avoid using default stream
jamxia155 Nov 24, 2025
a844b9b
Use async CUDA calls in InitializerGPU classes
jamxia155 Nov 25, 2025
aa0a195
Use async CUDA calls for device results pool
jamxia155 Nov 25, 2025
b2f8a7b
Use async calls in DataQuantizerGPU class
jamxia155 Nov 25, 2025
e2d5d1a
Use async CUDA calls in BatchedQueryGatherer class
jamxia155 Nov 25, 2025
c698e18
Use async calls in SearcherGPU class
jamxia155 Nov 25, 2025
35e9249
Add class members for resource handle and stream
jamxia155 Nov 26, 2025
c381ab2
Clean up member ownership and access in IVFGPU
jamxia155 Nov 26, 2025
abc657b
Use RAFT containers in RatatorGPU class
jamxia155 Nov 26, 2025
0883c37
Use RAFT containers in InitializerGPU class
jamxia155 Nov 26, 2025
c7ecd1c
Change mdarray index type to int64_t
jamxia155 Nov 27, 2025
55bb4fa
Use RAFT containers in IVFGPU class
jamxia155 Nov 27, 2025
f2d159a
Use RAFT containers in DeviceResultPool struct
jamxia155 Dec 2, 2025
cc28ed8
Use RAFT containers and smart pointers in SearcherGPU class
jamxia155 Dec 3, 2025
384da65
Move IVF-RaBitQ internal headers to cpp/src
jamxia155 Dec 3, 2025
0885531
Revert "Move IVF-RaBitQ internal headers to cpp/src"
jamxia155 Dec 3, 2025
6350521
Synchronize with the updates of IVF-RaBitQ-GPU. (#6)
Stardust-SJF Dec 3, 2025
0421efa
Move IVF-RaBitQ internal headers to cpp/src
jamxia155 Dec 3, 2025
984f2df
Merge remote-tracking branch 'Stardust-SJF_fork/jamxia_cuvs_ivf_rabit…
jamxia155 Dec 3, 2025
5ac201a
Fix a bug
jamxia155 Dec 3, 2025
b10ecae
Remove debug code
jamxia155 Dec 3, 2025
f0a61d6
Remove commented-out code
jamxia155 Dec 4, 2025
f0dc124
Fix memory leaks
jamxia155 Dec 4, 2025
f587114
Initialize elements in padded queries
jamxia155 Dec 4, 2025
3214bca
Add default initializations for class members
jamxia155 Dec 4, 2025
b4970f7
Remove unused declarations
jamxia155 Dec 4, 2025
d128515
Remove unused utils code
jamxia155 Dec 5, 2025
5bf3215
Only create padded queries matrix if needed
jamxia155 Dec 5, 2025
0548b54
Merge remote-tracking branch 'upstream/main' into jamxia_cuvs_ivf_rabitq
jamxia155 Dec 5, 2025
ef94d1f
Remove unused code in memory.hpp
jamxia155 Dec 5, 2025
14c9df5
Remove unused code in InitializerGPU class
jamxia155 Dec 5, 2025
5fc2e00
Remove BatchedQueryGatherer class
jamxia155 Dec 5, 2025
1df4385
Remove unused code from IVFGPU class
jamxia155 Dec 5, 2025
7bff107
Remove unused code in DataQuantizerGPU class
jamxia155 Dec 5, 2025
9ee467e
Remove unused code in SearcherGPU class
jamxia155 Dec 5, 2025
dcd36d3
Remove pool_gpu.cu/.cuh (no longer used)
jamxia155 Dec 5, 2025
322df78
Update API for 1-bit quantization support
jamxia155 Dec 5, 2025
8a3096e
Replace cuBLAS calls with templated RAFT wrapper
jamxia155 Dec 6, 2025
443f1ca
Set max dynamic shared mem size as needed
jamxia155 Dec 8, 2025
fb839c0
Support 1-bit search for all search modes (#7)
Stardust-SJF Dec 10, 2025
6aee332
Update API for 1-bit RaBitQ
jamxia155 Dec 10, 2025
e401067
Split up SearcherGPU impl
jamxia155 Dec 10, 2025
9402e9e
Consolidate kernel parameters
jamxia155 Dec 10, 2025
0d5d678
Remove a comment
jamxia155 Dec 11, 2025
80b9fac
Enable handling of large top-k value (up to 16384).
jamxia155 Dec 11, 2025
a43fda2
Check that topk value is below max supported.
jamxia155 Dec 16, 2025
35732b2
Remove commented code
jamxia155 Dec 16, 2025
078fb06
Merge remote-tracking branch 'upstream/main' into jamxia_cuvs_ivf_rabitq
jamxia155 Dec 18, 2025
eb1efbb
Updates after merge from upstream
jamxia155 Dec 18, 2025
1298494
Remove unnecessary code
jamxia155 Dec 18, 2025
3b4e71e
Compute max_cluster_length in load_transposed
jamxia155 Dec 18, 2025
a129b78
Add condition to updating threshold
jamxia155 Dec 22, 2025
475179a
Do not use block sort for large top-k
jamxia155 Dec 23, 2025
ccbd0ff
Fix a bug in thresholding
jamxia155 Dec 24, 2025
87809cf
Remove thresholding for some search code paths
jamxia155 Dec 24, 2025
444058b
Reduce atomics and shared mem in LUT32 search mode
jamxia155 Dec 26, 2025
144a9e9
Reduce atomics and shared mem in LUT16 search mode
jamxia155 Dec 27, 2025
e3af96e
Reduce atomics and shared mem in QUANT4/8 search
jamxia155 Dec 29, 2025
6848b3a
Remove unnecessary shared memory variables
jamxia155 Dec 29, 2025
07e7a88
Replace device mem allocation with RAFT containers
jamxia155 Jan 1, 2026
f93550b
Replace some raw allocations with RAFT containers
jamxia155 Jan 4, 2026
ada5c0b
Minimize allocation for intermediate output
jamxia155 Jan 6, 2026
c056510
Merge remote-tracking branch 'upstream/main' into jamxia_cuvs_ivf_rabitq
jamxia155 Jan 12, 2026
3453328
Rename library
jamxia155 Jan 13, 2026
5ae864b
Optimize recall calculation for large k
jamxia155 Jan 20, 2026
1c87662
Add build API accepting host dataset
jamxia155 Jan 21, 2026
c5e97e8
Enable subsampling of raw dataset for clustering
jamxia155 Feb 3, 2026
03e93db
Change default kmeans_trainset_fraction for tests
jamxia155 Feb 4, 2026
70068eb
Optimize recall calculation
jamxia155 Feb 4, 2026
7634ae3
Preallocate hashset if possible
jamxia155 Feb 5, 2026
29ecc55
Add API for returning index length
jamxia155 Feb 11, 2026
f6e2a86
Improve error reporting
jamxia155 Feb 11, 2026
cbb7154
Merge remote-tracking branch 'upstream/main' into jamxia_cuvs_ivf_rabitq
jamxia155 Feb 12, 2026
0f0c444
Update member declaration order
jamxia155 Feb 17, 2026
8ade5e9
Merge remote-tracking branch 'upstream/main' into jamxia_cuvs_ivf_rabitq
jamxia155 Feb 17, 2026
dfc0708
Merge remote-tracking branch 'upstream/main' into jamxia_cuvs_ivf_rabitq
jamxia155 Feb 22, 2026
bcb230d
Fix #includes
jamxia155 Feb 22, 2026
107659d
Use safely_launch_kernel_with_smem_size
jamxia155 Feb 22, 2026
f4dd5a8
Fix linalg/gemm include
tfeher Feb 25, 2026
1553d56
Revert "Update member declaration order"
jamxia155 Feb 25, 2026
d42b61f
Revert "Preallocate hashset if possible"
jamxia155 Feb 25, 2026
2fd96a1
Revert "Optimize recall calculation"
jamxia155 Feb 25, 2026
9503329
Revert "Optimize recall calculation for large k"
jamxia155 Feb 25, 2026
3ccf526
Refactor JIT LTO kernel generation (#1812)
KyleFromNVIDIA Feb 23, 2026
3f8b76c
Merge remote-tracking branch 'upstream/main' into cuvs_ivf_rabitq
jamxia155 Mar 3, 2026
fb32dcb
Remove feature branch-specific content
jamxia155 Mar 3, 2026
443aa93
Restore cuVS Bench CMakeLists
jamxia155 Mar 3, 2026
fb26176
Fix style check failures
jamxia155 Mar 3, 2026
994e951
Merge remote-tracking branch 'upstream/main' into cuvs_ivf_rabitq
jamxia155 Mar 4, 2026
5684f7b
Remove oudated algo name
jamxia155 Mar 18, 2026
bbeb058
Minor updates based on review comments
jamxia155 Mar 18, 2026
75b7b19
Merge remote-tracking branch 'upstream/main' into cuvs_ivf_rabitq
jamxia155 Mar 18, 2026
adad8bc
Update index build parameters
jamxia155 Mar 18, 2026
dd6b4ba
Update index build parameter default
jamxia155 Mar 18, 2026
088665b
Remove outdated parameter
jamxia155 Mar 18, 2026
04ae19c
Remove unnecessary overload
jamxia155 Mar 18, 2026
4496d93
Let `build` API return the built index
jamxia155 Mar 19, 2026
bef16e9
Merge remote-tracking branch 'upstream/main' into cuvs_ivf_rabitq
jamxia155 Mar 19, 2026
485eade
Revert change to unrelated file
jamxia155 Mar 19, 2026
9761933
Use public header for kmeans clustering
jamxia155 Mar 19, 2026
09c42a5
Remove unnecessary casting
jamxia155 Mar 19, 2026
8ec06e3
Add streaming construction of IVF-RaBitQ index
jamxia155 Mar 21, 2026
f00624b
Add force_streaming parameter IVF-RaBitQ build
jamxia155 Mar 21, 2026
e8f8984
Expose force_streaming in benchmark configuration parser
jamxia155 Mar 21, 2026
9e76b26
Clarify force_streaming only applies to datasets on host
jamxia155 Mar 21, 2026
3543caa
Add forced streaming test for IVF-RaBitQ
jamxia155 Mar 22, 2026
24e1e3d
Code cleanup
jamxia155 Mar 23, 2026
e3ef24b
Refactor IVF-RaBitQ: Remove batch_flag and improve encapsulation
jamxia155 Mar 23, 2026
52793fc
Consolidate quantizer_gpu implementation into single file
jamxia155 Mar 23, 2026
69299d1
Replace CUDA memory calls with RMM/RAFT primitives
jamxia155 Apr 23, 2026
96d9460
Merge remote-tracking branch 'upstream/main' into cuvs_ivf_rabitq
jamxia155 Apr 23, 2026
e28491b
Fixes after merging in main branch
jamxia155 Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,32 @@ if(NOT BUILD_CPU_ONLY)
)
target_link_libraries(cuvs_cpp_headers INTERFACE raft::raft rmm::rmm)

add_library(
ivf_rabitq STATIC
src/neighbors/ivf_rabitq/gpu_index/ivf_gpu.cu
src/neighbors/ivf_rabitq/gpu_index/initializer_gpu.cu
src/neighbors/ivf_rabitq/gpu_index/quantizer_gpu.cu
src/neighbors/ivf_rabitq/gpu_index/rotator_gpu.cu
src/neighbors/ivf_rabitq/gpu_index/searcher_gpu.cu
src/neighbors/ivf_rabitq/gpu_index/searcher_gpu_shared_mem_opt.cu
src/neighbors/ivf_rabitq/gpu_index/searcher_gpu_quantize_query.cu
src/neighbors/ivf_rabitq/utils/searcher_gpu_utils.cu
)

target_link_libraries(ivf_rabitq PRIVATE OpenMP::OpenMP_CXX CUDA::cudart raft::raft rmm)

target_include_directories(
ivf_rabitq PUBLIC "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>"
)

target_compile_options(
ivf_rabitq
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>: $<$<CONFIG:Debug>:-G;-g> --extended-lambda
--expt-relaxed-constexpr -Xcompiler=-fopenmp > $<$<COMPILE_LANGUAGE:CXX>:-fopenmp>
)

target_compile_definitions(ivf_rabitq PRIVATE HIGH_ACC_FAST_SCAN)

generate_inst_matrix(
cagra_search_inst_files
MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/cagra_search_matrix.json"
Expand Down Expand Up @@ -897,6 +923,7 @@ if(NOT BUILD_CPU_ONLY)
src/neighbors/ivf_pq/detail/ivf_pq_process_and_fill_codes.cu
${ivf_pq_search_inst_files}
${ivf_pq_transform_inst_files}
src/neighbors/ivf_rabitq.cu
src/neighbors/knn_merge_parts.cu
src/neighbors/nn_descent.cu
${nn_descent_inst_files}
Expand Down Expand Up @@ -1040,6 +1067,7 @@ if(NOT BUILD_CPU_ONLY)
$<COMPILE_ONLY:cuco::cuco>
CUDA::nvJitLink
CUDA::nvrtc
ivf_rabitq
)
set_property(TARGET cuvs PROPERTY NO_CUDART_DEP ON)

Expand Down
18 changes: 18 additions & 0 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ option(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT "Include faiss' cpu ivf flat algori
ON
)
option(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm in benchmark" ON)
option(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_RABITQ
"Include faiss' cpu ivf rabitq algorithm in benchmark" OFF
)
option(CUVS_ANN_BENCH_USE_FAISS_CPU_HNSW_FLAT "Include faiss' hnsw algorithm in benchmark" ON)
option(CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT "Include cuVS ivf flat algorithm in benchmark" ON)
option(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ "Include cuVS ivf pq algorithm in benchmark" ON)
Expand All @@ -45,6 +48,7 @@ option(CUVS_ANN_BENCH_SINGLE_EXE
"Make a single executable with benchmark as shared library modules" OFF
)
option(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE "Include cuVS brute force knn in benchmark" ON)
option(CUVS_ANN_BENCH_USE_CUVS_IVF_RABITQ "Include cuVS ivf RaBitQ algorithm in benchmark" ON)

# ##################################################################################################
# * Process options ----------------------------------------------------------
Expand Down Expand Up @@ -244,6 +248,13 @@ if(CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT)
)
endif()

if(CUVS_ANN_BENCH_USE_CUVS_IVF_RABITQ)
ConfigureAnnBench(
NAME CUVS_IVF_RABITQ PATH src/cuvs/cuvs_benchmark.cu src/cuvs/cuvs_ivf_rabitq.cu LINKS cuvs
ivf_rabitq
)
endif()

if(CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE)
ConfigureAnnBench(NAME CUVS_BRUTE_FORCE PATH src/cuvs/cuvs_benchmark.cu LINKS cuvs)
endif()
Expand Down Expand Up @@ -309,6 +320,13 @@ if(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_PQ)
)
endif()

if(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_RABITQ)
ConfigureAnnBench(
NAME FAISS_CPU_IVF_RABITQ PATH src/faiss/faiss_cpu_benchmark.cpp LINKS ${CUVS_FAISS_TARGETS}
cuvs ivf_rabitq
)
endif()

if(CUVS_ANN_BENCH_USE_FAISS_CPU_HNSW_FLAT)
ConfigureAnnBench(
NAME FAISS_CPU_HNSW_FLAT PATH src/faiss/faiss_cpu_benchmark.cpp LINKS ${CUVS_FAISS_TARGETS}
Expand Down
47 changes: 47 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ extern template class cuvs::bench::cuvs_ivf_pq<float, int64_t>;
extern template class cuvs::bench::cuvs_ivf_pq<uint8_t, int64_t>;
extern template class cuvs::bench::cuvs_ivf_pq<int8_t, int64_t>;
#endif
#if defined(CUVS_ANN_BENCH_USE_CUVS_IVF_RABITQ)
#include "cuvs_ivf_rabitq_wrapper.h"
#endif
#ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_RABITQ
extern template class cuvs::bench::cuvs_ivf_rabitq<float, int64_t>;
#endif
#if defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA) || defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB) || \
defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_DISKANN)
#include "cuvs_cagra_wrapper.h"
Expand Down Expand Up @@ -178,6 +184,47 @@ void parse_search_param(const nlohmann::json& conf,
}
#endif

#if defined(CUVS_ANN_BENCH_USE_CUVS_IVF_RABITQ)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename cuvs::bench::cuvs_ivf_rabitq<T, IdxT>::build_param& param)
{
if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); }
if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); }
if (conf.contains("max_points_per_cluster")) {
param.max_train_points_per_cluster = conf.at("max_points_per_cluster");
}
if (conf.contains("bits_per_dim")) { param.bits_per_dim = conf.at("bits_per_dim"); }
if (conf.contains("fast_quantize_flag")) {
param.fast_quantize_flag = conf.at("fast_quantize_flag");
}
if (conf.contains("force_streaming")) { param.force_streaming = conf.at("force_streaming"); }
}

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename cuvs::bench::cuvs_ivf_rabitq<T, IdxT>::search_param& param)
{
if (conf.contains("nprobe")) { param.rabitq_param.n_probes = conf.at("nprobe"); }

if (conf.contains("mode")) {
std::string mode = conf.at("mode");
if (mode == "lut16") {
param.rabitq_param.mode = cuvs::neighbors::ivf_rabitq::search_mode::LUT16;
} else if (mode == "lut32") {
param.rabitq_param.mode = cuvs::neighbors::ivf_rabitq::search_mode::LUT32;
} else if (mode == "quant4") {
param.rabitq_param.mode = cuvs::neighbors::ivf_rabitq::search_mode::QUANT4;
} else if (mode == "quant8") {
param.rabitq_param.mode = cuvs::neighbors::ivf_rabitq::search_mode::QUANT8;
} else {
throw std::runtime_error("mode: '" + mode +
"', should be either 'lut16', 'lut32', 'quant4' or 'quant8'");
}
}
}
#endif

#if defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA) || defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB) || \
defined(CUVS_ANN_BENCH_USE_CUVS_MG) || defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_DISKANN)
template <typename T, typename IdxT>
Expand Down
21 changes: 20 additions & 1 deletion cpp/bench/ann/src/cuvs/cuvs_benchmark.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -91,6 +91,15 @@ auto create_algo(const std::string& algo_name,
a = std::make_unique<cuvs::bench::cuvs_ivf_pq<T, int64_t>>(metric, dim, param);
}
#endif
#ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_RABITQ
if constexpr (std::is_same_v<T, float>) {
if (algo_name == "cuvs_ivf_rabitq") {
typename cuvs::bench::cuvs_ivf_rabitq<T, int64_t>::build_param param;
parse_build_param<T, int64_t>(conf, param);
a = std::make_unique<cuvs::bench::cuvs_ivf_rabitq<T, int64_t>>(metric, dim, param);
}
}
#endif
#ifdef CUVS_ANN_BENCH_USE_CUVS_CAGRA
if (algo_name == "raft_cagra" || algo_name == "cuvs_cagra") {
typename cuvs::bench::cuvs_cagra<T, uint32_t>::build_param param;
Expand Down Expand Up @@ -158,6 +167,16 @@ auto create_search_param(const std::string& algo_name, const nlohmann::json& con
return param;
}
#endif
#ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_RABITQ
if constexpr (std::is_same_v<T, float>) {
if (algo_name == "cuvs_ivf_rabitq") {
auto param =
std::make_unique<typename cuvs::bench::cuvs_ivf_rabitq<T, int64_t>::search_param>();
parse_search_param<T, int64_t>(conf, *param);
return param;
}
}
#endif
#ifdef CUVS_ANN_BENCH_USE_CUVS_CAGRA
if (algo_name == "raft_cagra" || algo_name == "cuvs_cagra") {
auto param = std::make_unique<typename cuvs::bench::cuvs_cagra<T, uint32_t>::search_param>();
Expand Down
9 changes: 9 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_ivf_rabitq.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include "cuvs_ivf_rabitq_wrapper.h"

namespace cuvs::bench {
template class cuvs_ivf_rabitq<float, int64_t>;
} // namespace cuvs::bench
166 changes: 166 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_ivf_rabitq_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include "../common/ann_types.hpp"
#include "cuvs_ann_bench_utils.h"

#include <cuvs/neighbors/ivf_rabitq.hpp>

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/util/cudart_utils.hpp>
#include <rmm/cuda_stream_pool.hpp>

#include <type_traits>

namespace cuvs::bench {

template <typename T, typename IdxT>
class cuvs_ivf_rabitq : public algo<T>, public algo_gpu {
public:
using search_param_base = typename algo<T>::search_param;
using algo<T>::dim_;

struct search_param : public search_param_base {
cuvs::neighbors::ivf_rabitq::search_params rabitq_param;
float refine_ratio = 1.0f;
[[nodiscard]] auto needs_dataset() const -> bool override { return refine_ratio > 1.0f; }
};

using build_param = cuvs::neighbors::ivf_rabitq::index_params;

cuvs_ivf_rabitq(Metric metric, int dim, const build_param& param)
: algo<T>(metric, dim), index_params_(param), dimension_(dim)
{
}

void build(const T* dataset, size_t nrow) final;

void set_search_param(const search_param_base& param, const void* filter_bitset) override;
void set_search_dataset(const T* dataset, size_t nrow) override;

void search(const T* queries,
int batch_size,
int k,
algo_base::index_type* neighbors,
float* distances) const override;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
return handle_.get_sync_stream();
}

// to enable dataset access from GPU memory
[[nodiscard]] auto get_preference() const -> algo_property override
{
algo_property property;
property.dataset_memory_type = MemoryType::kHost;
property.query_memory_type = MemoryType::kDevice;
return property;
}
void save(const std::string& file) const override;
void load(const std::string&) override;
std::unique_ptr<algo<T>> copy() override;

private:
// handle_ must go first to make sure it dies last and all memory allocated in pool
configured_raft_resources handle_{};
build_param index_params_;
cuvs::neighbors::ivf_rabitq::search_params search_params_;
std::shared_ptr<cuvs::neighbors::ivf_rabitq::index<IdxT>> index_;
int dimension_;
float refine_ratio_ = 1.0;
raft::device_matrix_view<const T, IdxT> dataset_;
};

template <typename T, typename IdxT>
void cuvs_ivf_rabitq<T, IdxT>::save(const std::string& file) const
{
cuvs::neighbors::ivf_rabitq::serialize(handle_, file, *index_);
}

template <typename T, typename IdxT>
void cuvs_ivf_rabitq<T, IdxT>::load(const std::string& file)
{
index_ = std::make_shared<cuvs::neighbors::ivf_rabitq::index<IdxT>>(handle_);
cuvs::neighbors::ivf_rabitq::deserialize(handle_, file, index_.get());
}

template <typename T, typename IdxT>
void cuvs_ivf_rabitq<T, IdxT>::build(const T* dataset, size_t nrow)
{
// Create a CUDA stream pool with 1 stream (besides main stream) for kernel/copy overlapping.
size_t n_streams = 1;
raft::resource::set_cuda_stream_pool(handle_, std::make_shared<rmm::cuda_stream_pool>(n_streams));
auto dataset_v = raft::make_device_matrix_view<const T, IdxT>(dataset, IdxT(nrow), dim_);
std::make_shared<cuvs::neighbors::ivf_rabitq::index<IdxT>>(
std::move(cuvs::neighbors::ivf_rabitq::build(handle_, index_params_, dataset_v)))
.swap(index_);
// Note: internally the IVF-RaBitQ build works with simple pointers, and accepts both host and
// device pointer. Therefore, although we provide here a device_mdspan, this works with host
// pointer too.
}

template <typename T, typename IdxT>
std::unique_ptr<algo<T>> cuvs_ivf_rabitq<T, IdxT>::copy()
{
return std::make_unique<cuvs_ivf_rabitq<T, IdxT>>(*this); // use copy constructor
}

template <typename T, typename IdxT>
void cuvs_ivf_rabitq<T, IdxT>::set_search_param(const search_param_base& param, const void*)
{
auto sp = dynamic_cast<const search_param&>(param);
search_params_ = sp.rabitq_param;
refine_ratio_ = sp.refine_ratio;
assert(search_params_.n_probes <= index_params_.n_lists);
}

template <typename T, typename IdxT>
void cuvs_ivf_rabitq<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
{
dataset_ = raft::make_device_matrix_view<const T, IdxT>(dataset, nrow, index_->dim());
}

template <typename T, typename IdxT>
void cuvs_ivf_rabitq<T, IdxT>::search(
const T* queries, int batch_size, int k, algo_base::index_type* neighbors, float* distances) const
{
static_assert(std::is_integral_v<algo_base::index_type>);
static_assert(std::is_integral_v<IdxT>);

IdxT* neighbors_idx;
std::optional<rmm::device_uvector<IdxT>> neighbors_storage{std::nullopt};
if constexpr (sizeof(IdxT) == sizeof(algo_base::index_type)) {
neighbors_idx = reinterpret_cast<IdxT*>(neighbors);
} else {
neighbors_storage.emplace(batch_size * k, raft::resource::get_cuda_stream(handle_));
neighbors_idx = neighbors_storage->data();
}

auto queries_view =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, dimension_);
auto neighbors_view = raft::make_device_matrix_view<IdxT, int64_t>(neighbors_idx, batch_size, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);

cuvs::neighbors::ivf_rabitq::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view);

if constexpr (sizeof(IdxT) != sizeof(algo_base::index_type)) {
raft::linalg::unaryOp(neighbors,
neighbors_idx,
batch_size * k,
raft::cast_op<algo_base::index_type>(),
raft::resource::get_cuda_stream(handle_));
}
}
} // namespace cuvs::bench
Loading
Loading