Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import static com.nvidia.cuvs.lucene.ThreadLocalCuVSResourcesProvider.assertIsSupported;

import com.nvidia.cuvs.LibraryException;
import com.nvidia.cuvs.spi.CuVSProvider;
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand Down Expand Up @@ -39,6 +40,7 @@ public class CuVS2510GPUVectorsFormat extends KnnVectorsFormat {

static {
try {
CuVSProvider.provider().enableRMMAsyncMemory();
LUCENE_PROVIDER = LuceneProvider.getInstance("99");
FLAT_VECTORS_FORMAT =
LUCENE_PROVIDER.getLuceneFlatVectorsFormatInstance(DefaultFlatVectorScorer.INSTANCE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
new CagraSearchParams.Builder()
.withItopkSize(Math.max(collector.getiTopK(), topK))
.withSearchWidth(collector.getSearchWidth())
.withThreadBlockSize(collector.getThreadBlockSize())
.withAlgo(collector.getSearchAlgo())
.build();
} else {
// Setting itopK as topK because in any case iTopK should be ATLEAST equal to topK
Expand Down Expand Up @@ -600,6 +602,37 @@ private static void checkVersion(int versionMeta, int versionVectorData, IndexIn
}
}

/**
* Maps a local segment ordinal to a Lucene doc ID within this segment.
*
* <p>Used by {@link GPUKnnFloatVectorQuery} after a multi-segment GPU search to convert
* select_k result ordinals to doc IDs before adding {@code docBase}.
*
* @param field the vector field name
* @param ordinal the local ordinal returned by CAGRA
* @return the Lucene doc ID within this segment
* @throws IOException if the vector values cannot be read
*/
public int ordToDoc(String field, int ordinal) throws IOException {
return flatVectorsReader.getFloatVectorValues(field).ordToDoc(ordinal);
}

/**
* Returns the {@link CagraIndex} for the given field, or {@code null} if unavailable
* (e.g., during a merge or when the field is missing).
*
* @param field the vector field name
* @return the CAGRA index, or {@code null}
*/
public CagraIndex getCagraIndexForField(String field) {
if (cuvsIndices == null) return null;
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) return null;
GPUIndex gpuIndex = cuvsIndices.get(info.number);
if (gpuIndex == null) return null;
return gpuIndex.getCagraIndex();
}

/**
* Gets the instance of FieldInfos.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ public HnswIndex hnswIndexFromCagra(HnswIndexParams arg0, CagraIndex arg1) throw
return delegate.hnswIndexFromCagra(arg0, arg1);
}

@Override
public void enableRMMAsyncMemory() {
delegate.enableRMMAsyncMemory();
}

@Override
public void enableRMMManagedPooledMemory(int arg0, int arg1) {
delegate.enableRMMManagedPooledMemory(arg0, arg1);
Expand Down
Loading