-
Notifications
You must be signed in to change notification settings - Fork 184
Multi segment cagra search #2035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c6e880d
390adec
92dbcf9
82dcf71
fad76cb
22a2c8d
69c4771
49e5a14
29751c7
2fffcb8
d284bd0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include <cuvs/core/c_api.h> | ||
| #include <dlpack/dlpack.h> | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /** | ||
| * @brief Select the k smallest values from a flat device array of n candidates. | ||
| * | ||
| * Treats `in_val` as a matrix of shape [1, n] and selects the `k` smallest | ||
| * float values. `out_idx` receives the int64 column positions of the selected | ||
| * values in [0, n), so the caller can recover per-segment identity as: | ||
| * | ||
| * segment_index = out_idx[j] / segment_k | ||
| * position_in_segment = out_idx[j] % segment_k | ||
| * | ||
| * @param[in] res cuvsResources_t handle | ||
| * @param[in] in_val DLManagedTensor* shape [1, n], float32, device memory | ||
| * @param[out] out_val DLManagedTensor* shape [1, k], float32, device memory | ||
| * @param[out] out_idx DLManagedTensor* shape [1, k], int64, device memory | ||
| * @return cuvsError_t | ||
| */ | ||
| cuvsError_t cuvsSelectK(cuvsResources_t res, | ||
| DLManagedTensor* in_val, | ||
| DLManagedTensor* out_val, | ||
| DLManagedTensor* out_idx); | ||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -689,6 +689,54 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
| }); | ||
| } | ||
|
|
||
| extern "C" cuvsError_t cuvsCagraSearchMultiSegment(cuvsResources_t res, | ||
| cuvsCagraSearchParams_t params, | ||
| uint32_t num_segments, | ||
| cuvsCagraIndex_t* indices, | ||
| DLManagedTensor** queries, | ||
| DLManagedTensor** neighbors, | ||
| DLManagedTensor** distances) | ||
| { | ||
| return cuvs::core::translate_exceptions([=] { | ||
| RAFT_EXPECTS(num_segments > 0, "num_segments must be > 0"); | ||
| RAFT_EXPECTS(indices != nullptr && queries != nullptr && neighbors != nullptr && | ||
| distances != nullptr, | ||
| "All pointer arrays must be non-null"); | ||
|
|
||
| auto res_ptr = reinterpret_cast<raft::resources*>(res); | ||
| auto search_params = cuvs::neighbors::cagra::search_params(); | ||
| convert_c_search_params(*params, &search_params); | ||
|
|
||
| // Only float32 is supported for multi-segment search. | ||
| RAFT_EXPECTS( | ||
| indices[0]->dtype.code == kDLFloat && indices[0]->dtype.bits == 32, | ||
| "Multi-segment search only supports float32 indices"); | ||
|
|
||
| using T = float; | ||
| using IdxT = uint32_t; | ||
| using OutIdxT = uint32_t; | ||
| using DistanceT = float; | ||
| using IndexT = cuvs::neighbors::cagra::index<T, IdxT>; | ||
|
|
||
| std::vector<const IndexT*> idx_vec(num_segments); | ||
| std::vector<raft::device_matrix_view<const T, int64_t, raft::row_major>> q_vec(num_segments); | ||
| std::vector<raft::device_matrix_view<OutIdxT, int64_t, raft::row_major>> n_vec(num_segments); | ||
| std::vector<raft::device_matrix_view<DistanceT, int64_t, raft::row_major>> d_vec(num_segments); | ||
|
|
||
| for (uint32_t i = 0; i < num_segments; i++) { | ||
| RAFT_EXPECTS(indices[i] != nullptr && indices[i]->addr != 0, | ||
| "Index at position %u is null or not built", i); | ||
| idx_vec[i] = reinterpret_cast<const IndexT*>(indices[i]->addr); | ||
| q_vec[i] = cuvs::core::from_dlpack<std::remove_reference_t<decltype(q_vec[i])>>(queries[i]); | ||
| n_vec[i] = cuvs::core::from_dlpack<std::remove_reference_t<decltype(n_vec[i])>>(neighbors[i]); | ||
| d_vec[i] = cuvs::core::from_dlpack<std::remove_reference_t<decltype(d_vec[i])>>(distances[i]); | ||
| } | ||
|
Comment on lines
+710
to
+733
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate every segment before the Only 🤖 Prompt for AI Agents |
||
|
|
||
| cuvs::neighbors::cagra::search_multi_segment( | ||
| *res_ptr, search_params, idx_vec, q_vec, n_vec, d_vec); | ||
|
Comment on lines
+726
to
+736
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reject mixed-distance metrics across segments. This API combines raw distances from all segments into one global ranking, so all indices must use the same metric. Right now nothing checks that 🤖 Prompt for AI Agents |
||
| }); | ||
| } | ||
|
|
||
| extern "C" cuvsError_t cuvsCagraMerge(cuvsResources_t res, | ||
| cuvsCagraIndexParams_t params, | ||
| cuvsCagraIndex_t* indices, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #include <cuvs/core/c_api.h> | ||
| #include "../core/exceptions.hpp" | ||
| #include <cuvs/selection/select_k.hpp> | ||
| #include <dlpack/dlpack.h> | ||
|
|
||
| #include <raft/core/device_mdspan.hpp> | ||
| #include <raft/core/resources.hpp> | ||
|
|
||
| extern "C" cuvsError_t cuvsSelectK(cuvsResources_t res, | ||
| DLManagedTensor* in_val, | ||
| DLManagedTensor* out_val, | ||
| DLManagedTensor* out_idx) | ||
| { | ||
| return cuvs::core::translate_exceptions([=] { | ||
| auto* res_ptr = reinterpret_cast<raft::resources*>(res); | ||
|
|
||
| int64_t n = in_val->dl_tensor.shape[1]; | ||
| int64_t k = out_val->dl_tensor.shape[1]; | ||
|
|
||
| auto in_view = raft::make_device_matrix_view<const float, int64_t, raft::row_major>( | ||
| static_cast<const float*>(in_val->dl_tensor.data), 1, n); | ||
|
|
||
| auto out_val_view = raft::make_device_matrix_view<float, int64_t, raft::row_major>( | ||
| static_cast<float*>(out_val->dl_tensor.data), 1, k); | ||
|
|
||
| auto out_idx_view = raft::make_device_matrix_view<int64_t, int64_t, raft::row_major>( | ||
| static_cast<int64_t*>(out_idx->dl_tensor.data), 1, k); | ||
|
|
||
| cuvs::selection::select_k( | ||
| *res_ptr, | ||
| in_view, | ||
| std::nullopt, // implicit positions [0, n) as in_idx | ||
| out_val_view, | ||
| out_idx_view, | ||
| true); // select_min = true (smallest distance = nearest neighbor) | ||
|
Comment on lines
+14
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate the DLPack contract before dereferencing
🤖 Prompt for AI Agents |
||
| }); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 79
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 1665
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 269
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 580
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 2101
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 200
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 89
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 2518
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 39
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 784
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 331
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 779
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 76
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 275
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 665
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 39
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 1728
The async-memory resource owner cannot be
thread_localwhen this API changes the current resource globally.The implementation at
c/src/core/c_api.cpp:188storescuda_async_memory_resourceinthread_local async_mrand passes it tormm::mr::set_current_device_resource(), but the documentation explicitly states this function "will change the memory resource for the whole process" (line 235). This creates a critical lifetime mismatch:set_current_device_resource()is device-scoped (affecting all threads), then when the enabling thread exits, itsthread_local async_mris destroyed while still registered as the current resource, leaving RMM with a dangling pointer.The pool resource avoids this issue by passing temporary rvalues to
set_current_device_resource(), allowing RMM to manage the lifetime. Either makeasync_mrprocess/device-scoped (notthread_local), or narrow the documentation and implementation to clarify thread-local semantics.🤖 Prompt for AI Agents