diff --git a/.github/workflows/run-compaction.yml b/.github/workflows/run-compaction.yml new file mode 100644 index 000000000..42efe907d --- /dev/null +++ b/.github/workflows/run-compaction.yml @@ -0,0 +1,117 @@ +name: Run Compaction Bench + +on: + workflow_dispatch: + inputs: + dataset: + description: 'Dataset name passed to CompactorBenchmark (-p datasetNames)' + required: false + default: 'ada002-100k' + branches: + description: 'Space-separated list of branches to benchmark' + required: false + default: 'main' + pull_request: + types: [opened, synchronize, ready_for_review] + branches: + - main + paths: + - '**/src/main/java/**' + - 'pom.xml' + - '**/pom.xml' + +jobs: + # Job to generate the matrix configuration + generate-matrix: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: Generate matrix + id: set-matrix + run: | + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + BRANCHES='["main", "${{ github.head_ref }}"]' + elif [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.branches }}" ]]; then + BRANCHES_INPUT="${{ github.event.inputs.branches }}" + BRANCHES="[" + for branch in $BRANCHES_INPUT; do + if [[ "$BRANCHES" != "[" ]]; then + BRANCHES="$BRANCHES, " + fi + BRANCHES="$BRANCHES\"$branch\"" + done + BRANCHES="$BRANCHES]" + else + BRANCHES='["main"]' + fi + + echo "matrix={\"jdk\":[24],\"isa\":[\"isa-avx512f\"],\"branch\":$BRANCHES}" >> $GITHUB_OUTPUT + + test-compaction: + needs: generate-matrix + strategy: + matrix: ${{ fromJSON(needs.generate-matrix.outputs.matrix) }} + runs-on: ${{ matrix.isa }} + steps: + - name: Set up GCC + run: sudo apt install -y gcc + - uses: actions/checkout@v4 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v3 + with: + java-version: ${{ matrix.jdk }} + distribution: temurin + cache: maven + + - name: Checkout branch + uses: actions/checkout@v4 + with: + ref: ${{ matrix.branch }} + fetch-depth: 0 + + - name: Build branch + run: mvn -B -Punix-amd64-profile package --file pom.xml + + - name: Run CompactorBenchmark + id: run-benchmark + run: | + TOTAL_MEM_GB=$(free -g | awk '/^Mem:/ {print $2}') + if [[ -z "$TOTAL_MEM_GB" ]] || [[ "$TOTAL_MEM_GB" -le 0 ]]; then + TOTAL_MEM_GB=16 + fi + HALF_MEM_GB=$((TOTAL_MEM_GB / 2)) + if [[ "$HALF_MEM_GB" -lt 1 ]]; then + HALF_MEM_GB=1 + fi + + DATASET="${{ github.event.inputs.dataset }}" + if [[ -z "$DATASET" ]]; then + DATASET="ada002-100k" + fi + + SAFE_BRANCH=$(echo "${{ matrix.branch }}" | sed 's/[^A-Za-z0-9_-]/_/g') + echo "safe_branch=$SAFE_BRANCH" >> $GITHUB_OUTPUT + + JMH_JAR=$(ls benchmarks-jmh/target/benchmarks-jmh-*.jar | grep -Ev -- '-(javadoc|sources)\.jar$' | head -1) + echo "Using JMH jar: $JMH_JAR" + + java --enable-native-access=ALL-UNNAMED --add-modules=jdk.incubator.vector \ + -Djvector.experimental.enable_native_vectorization=true \ + -Xmx${HALF_MEM_GB}g \ + -cp "$JMH_JAR" \ + io.github.jbellis.jvector.bench.CompactorBenchmark \ + -p workloadMode=PARTITION_AND_COMPACT \ + -p datasetNames=$DATASET \ + -p numPartitions=4 \ + -p splitDistribution=FIBONACCI \ + -p indexPrecision=FUSEDPQ \ + -jvmArgsPrepend "-Xmx${HALF_MEM_GB}g" \ + -wi 0 -i 1 -f 1 + + - name: Upload compaction results + uses: actions/upload-artifact@v4 + with: + name: compaction-results-${{ matrix.isa }}-jdk${{ matrix.jdk }}-${{ steps.run-benchmark.outputs.safe_branch }} + path: target/benchmark-results/compactor-*/compactor-results.jsonl + if-no-files-found: warn diff --git a/.gitignore b/.gitignore index 4b5599f84..70cede6f9 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,10 @@ local/ dataset_ **/local_datasets/** +### Testing Results +**results**.json +**results**.jsonl + ### Bench caches pq_cache/ index_cache/ diff --git a/benchmarks-jmh/pom.xml b/benchmarks-jmh/pom.xml index c82ee2707..05fe36793 100644 --- a/benchmarks-jmh/pom.xml +++ b/benchmarks-jmh/pom.xml @@ -15,6 +15,9 @@ UTF-8 22 1.37 + 2.21.10 + + @@ -53,6 +56,11 @@ log4j-slf4j2-impl 2.24.3 + + software.amazon.awssdk + ec2 + ${awssdk.version} + @@ -94,6 +102,35 @@ + + + org.codehaus.mojo + exec-maven-plugin + + + compactor + + exec + + + false + java + --enable-native-access=ALL-UNNAMED --add-modules=jdk.incubator.vector -Djvector.experimental.enable_native_vectorization=true -cp %classpath io.github.jbellis.jvector.bench.CompactorBenchmark ${args} + + + + analyze + + exec + + + false + java + -cp %classpath io.github.jbellis.jvector.bench.benchtools.EventLogAnalyzer ${args} + + + + - \ No newline at end of file + diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/CompactorBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/CompactorBenchmark.java new file mode 100644 index 000000000..597cfe40e --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/CompactorBenchmark.java @@ -0,0 +1,1036 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.bench.benchtools.BenchmarkParamCounter; +import io.github.jbellis.jvector.disk.ReaderSupplier; +import io.github.jbellis.jvector.disk.ReaderSupplierFactory; +import io.github.jbellis.jvector.example.benchmarks.datasets.DataSet; +import io.github.jbellis.jvector.example.benchmarks.datasets.DataSetInfo; +import io.github.jbellis.jvector.example.benchmarks.datasets.DataSets; +import io.github.jbellis.jvector.example.reporting.GitInfo; +import io.github.jbellis.jvector.example.reporting.JfrRecorder; +import io.github.jbellis.jvector.example.reporting.JsonlWriter; +import io.github.jbellis.jvector.example.reporting.SystemStatsCollector; +import io.github.jbellis.jvector.example.reporting.ThreadAllocTracker; +import io.github.jbellis.jvector.example.util.AccuracyMetrics; +import io.github.jbellis.jvector.example.util.DataSetPartitioner; +import io.github.jbellis.jvector.example.util.storage.CloudStorageLayoutUtil; +import io.github.jbellis.jvector.example.yaml.TestDataPartition; +import io.github.jbellis.jvector.graph.*; +import io.github.jbellis.jvector.graph.disk.AbstractGraphIndexWriter; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexCompactor; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; +import io.github.jbellis.jvector.graph.disk.OnDiskParallelGraphIndexWriter; +import io.github.jbellis.jvector.graph.disk.OrdinalMapper; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.util.FixedBitSet; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.results.format.ResultFormatType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.OverlappingFileLockException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.Instant; +import java.util.*; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.IntFunction; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@Fork(1) +@Warmup(iterations = 0) +@Measurement(iterations = 1) +@Threads(1) +public class CompactorBenchmark { + + // RUN_DIR must be initialized before the Logger so log4j2's File appender + // can resolve ${sys:jvector.internal.runDir} + private static final Path RUN_DIR; + static { + String runDir = System.getProperty("jvector.internal.runDir"); + if (runDir == null) { + runDir = Path.of("target", "benchmark-results", "compactor-" + Instant.now().getEpochSecond()).toString(); + System.setProperty("jvector.internal.runDir", runDir); + } + RUN_DIR = Path.of(runDir); + try { + Files.createDirectories(RUN_DIR); + } catch (IOException e) { + throw new RuntimeException("Failed to create run directory: " + RUN_DIR, e); + } + } + + private static final Logger log = LoggerFactory.getLogger(CompactorBenchmark.class); + + public enum IndexPrecision { + FULLPRECISION, + FUSEDPQ + } + + public enum WorkloadMode { + /** + * Build per-source partitions and stop. (No compaction, no recall.) + */ + PARTITION_ONLY, + + /** + * Assume partitions exist on disk; compact them. + */ + COMPACT_ONLY, + + /** + * Assume partitions exist on disk; compact them, then run recall. + */ + COMPACT_AND_RECALL, + + /** + * Build a single graph for the whole dataset and write it. Then run recall. + */ + BUILD_FROM_SCRATCH, + /** + * (Default) Build partitions, compact them, then run recall. + */ + PARTITION_AND_COMPACT + } + + private static final Path RESULTS_FILE = RUN_DIR.resolve("compactor-results.jsonl"); + private static final Path JFR_DIR = RUN_DIR.resolve("jfrs"); + private static final Path SYSTEM_DIR = RUN_DIR.resolve("system"); + private static final JsonlWriter jsonlWriter = new JsonlWriter(RESULTS_FILE); + + // In the forked JVM, main() passes the computed total via this internal property + private static final int TOTAL_TESTS = Integer.getInteger( + "jvector.internal.totalTests", + BenchmarkParamCounter.computeTotalTests(CompactorBenchmark.class, null) + ); + + private static final AtomicLong LAST_TEST_ID = new AtomicLong(0); + private static final String TEST_ID = generateTestId(); + + /** + * Generates a lexicographically sortable test ID: base36-encoded milliseconds + * followed by 2 base36 suffix chars (starting at "00"). Uses an atomic counter + * so that IDs generated within the same millisecond auto-increment instead of colliding. + */ + static String generateTestId() { + long candidate = System.currentTimeMillis() * 1296; // suffix starts at 00 + long actual = LAST_TEST_ID.updateAndGet(last -> Math.max(candidate, last + 1)); + return Long.toString(actual, 36); + } + + /** + * Returns a Bits instance representing randomly selected live nodes. + * + */ + private static FixedBitSet randomLiveNodes(int size, double liveRate, long seed) { + FixedBitSet live = new FixedBitSet(size); + + if (liveRate >= 1.0) { + live.set(0, size); // all nodes live + return live; + } + + var rnd = new java.util.SplittableRandom(seed); + int liveCount = 0; + + for (int i = 0; i < size; i++) { + if (rnd.nextDouble() < liveRate) { + live.set(i); + liveCount++; + } + } + + // avoid degenerate case (all dead) + if (liveCount == 0 && size > 0) { + live.set(rnd.nextInt(size)); + } + + return live; + } + + private static final Path COUNTER_FILE = RUN_DIR.resolve("completed-count"); + private static final AtomicInteger completedTests = new AtomicInteger(readCompletedCount()); + + /** + * Read the completed test count from a dedicated counter file. + * Each JMH fork is a fresh JVM, so this file provides cross-fork continuity. + * Acquires an exclusive file lock and throws if another process holds it, + * since concurrent benchmark runs against the same RUN_DIR are not supported. + */ + private static int readCompletedCount() { + if (!Files.exists(COUNTER_FILE)) { + return 0; + } + try (var ch = FileChannel.open(COUNTER_FILE, StandardOpenOption.READ)) { + var lock = ch.tryLock(0, Long.MAX_VALUE, true); + if (lock == null) { + throw new IllegalStateException( + "Counter file is locked by another process — concurrent benchmark runs sharing " + + RUN_DIR + " are not supported"); + } + try { + return Integer.parseInt(Files.readString(COUNTER_FILE).trim()); + } finally { + lock.release(); + } + } catch (OverlappingFileLockException e) { + throw new IllegalStateException( + "Counter file is locked by another thread — concurrent benchmark runs sharing " + + RUN_DIR + " are not supported", e); + } catch (IllegalStateException e) { + throw e; + } catch (Exception e) { + // Fall back to 0 for parse errors, etc. + return 0; + } + } + + private static void writeCompletedCount(int count) { + try (var ch = FileChannel.open(COUNTER_FILE, + StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING)) { + var lock = ch.tryLock(); + if (lock == null) { + throw new IllegalStateException( + "Counter file is locked by another process — concurrent benchmark runs sharing " + + RUN_DIR + " are not supported"); + } + try { + ch.write(ByteBuffer.wrap(String.valueOf(count).getBytes(StandardCharsets.UTF_8))); + } finally { + lock.release(); + } + } catch (OverlappingFileLockException e) { + throw new IllegalStateException( + "Counter file is locked by another thread — concurrent benchmark runs sharing " + + RUN_DIR + " are not supported", e); + } catch (IllegalStateException e) { + throw e; + } catch (IOException e) { + log.error("Failed to write completed count", e); + } + } + + private static final AtomicInteger workerCounter = new AtomicInteger(0); + + // ---------- Benchmark state ---------- + private RandomAccessVectorValues ravv; + private List> queryVectors; + private List> baseVectors; + private List> groundTruth; + private DataSet ds; + private VectorSimilarityFunction similarityFunction; + + private final List graphs = new ArrayList<>(); + private final List rss = new ArrayList<>(); + + private Path tempDir; + private List storagePaths; + private List vectorsPerSourceCount; + private String resolvedVectorizationProvider; + + // Paths used during execution + private Path partitionsBaseDir; // where per-source partitions are placed (or found) + private Path compactOutputPath; // where compacted graph is written + private Path scratchOutputPath; // where build-from-scratch graph is written + + // ---------- Params ---------- + @Param({"glove-100-angular"}) + public String datasetNames; + + @Param({"PARTITION_AND_COMPACT"}) + public WorkloadMode workloadMode; + + @Param({"4"}) // Default value, can be overridden via command line + public int numPartitions; + + @Param({"32"}) + public int graphDegree; + + @Param({"100"}) + public int beamWidth; + + /** + * liveNodesRate controls how many nodes are considered "live" per source partition + * when calling compactor.setLiveNodes(...). + * + * - 1.0 => all nodes live (default; behaves like no deletions) + * - 0.8 => ~80% live, ~20% deleted (randomly selected) + */ + @Param({"1.0"}) + public double liveNodesRate; + + @Param({""}) + public String storageDirectories; + + @Param({""}) + public String storageClasses; + + @Param({"UNIFORM"}) + public TestDataPartition.Distribution splitDistribution; + + @Param({"FULLPRECISION"}) + public IndexPrecision indexPrecision; + + @Param({"1"}) + public int parallelWriteThreads; + + @Param({""}) + public String vectorizationProvider; + + @Param({"1.0"}) + public double datasetPortion; + + @Param({"false"}) + public boolean jfrPartitioning; + + @Param({"true"}) + public boolean jfrCompacting; + + @Param({"false"}) + public boolean jfrObjectCount; + + @Param({"true"}) + public boolean sysStatsEnabled; + + @Param({"false"}) + public boolean threadAllocTracking; + + private final JfrRecorder jfrPartitioningRecorder = new JfrRecorder(); + private final JfrRecorder jfrCompactingRecorder = new JfrRecorder(); + private final SystemStatsCollector sysStatsCollector = new SystemStatsCollector(); + private final ThreadAllocTracker threadAllocTracker = new ThreadAllocTracker(); + + private volatile boolean resultPersisted; + + @State(Scope.Thread) + @AuxCounters(AuxCounters.Type.EVENTS) + public static class RecallResult { + public double recall; + } + + private String jfrParamSuffix() { + return String.format("%s-w%s-n%d-d%d-bw%d-%s-%s-pw%d-%s-dp%.2f-live%.2f", + datasetNames, workloadMode, numPartitions, graphDegree, beamWidth, + splitDistribution, indexPrecision, parallelWriteThreads, resolvedVectorizationProvider, datasetPortion, liveNodesRate); + } + + @Setup(Level.Iteration) + public void setup() throws Exception { + try { + resultPersisted = false; + Thread.currentThread().setName("compactor-" + workerCounter.incrementAndGet()); + + if (vectorizationProvider != null && !vectorizationProvider.isBlank()) { + System.setProperty("jvector.vectorization_provider", vectorizationProvider); + } + resolvedVectorizationProvider = VectorizationProvider.getInstance().getClass().getSimpleName(); + + if (sysStatsEnabled) { + String sysStatsFileName = String.format("sysstats-%s.jsonl", jfrParamSuffix()); + try { + sysStatsCollector.start(SYSTEM_DIR, sysStatsFileName); + } catch (Exception e) { + log.warn("Failed to start system stats collection", e); + } + } + + if (threadAllocTracking) { + String threadAllocFileName = String.format("threadalloc-%s.jsonl", jfrParamSuffix()); + try { + threadAllocTracker.start(SYSTEM_DIR, threadAllocFileName); + } catch (Exception e) { + log.warn("Failed to start thread allocation tracking", e); + } + } + + persistStarted(); + + validateParams(); + + int dimension; + + if (workloadMode == WorkloadMode.COMPACT_ONLY) { + ds = null; + queryVectors = null; + groundTruth = null; + ravv = null; + baseVectors = null; + dimension = -1; + + var datasetInfo = DataSets.loadDataSet(datasetNames); + similarityFunction = datasetInfo + .flatMap(DataSetInfo::similarityFunction) + .orElseGet(() -> { + log.warn("Could not determine similarity function for dataset '{}'; defaulting to COSINE", datasetNames); + return VectorSimilarityFunction.COSINE; + }); + + log.info("Skipping dataset load for COMPACT_ONLY mode without recall. Workload: {}, similarityFunction: {}, Live nodes rate: {}", + workloadMode, similarityFunction, liveNodesRate); + } else { + ds = DataSets.loadDataSet(datasetNames) + .orElseThrow(() -> new RuntimeException("Dataset not found: " + datasetNames)) + .getDataSet(); + + if (datasetPortion == 1.0) { + ravv = ds.getBaseRavv(); + baseVectors = ds.getBaseVectors(); + } else { + int totalVectors = ds.getBaseRavv().size(); + int portionedSize = (int) (totalVectors * datasetPortion); + if (portionedSize < Math.max(1, numPartitions)) { + throw new IllegalArgumentException( + "datasetPortion=" + datasetPortion + " yields " + portionedSize + + " vectors, fewer than numPartitions=" + numPartitions); + } + baseVectors = ds.getBaseVectors().subList(0, portionedSize); + ravv = new ListRandomAccessVectorValues(baseVectors, ds.getDimension()); + } + + queryVectors = ds.getQueryVectors(); + groundTruth = ds.getGroundTruth(); + similarityFunction = ds.getSimilarityFunction(); + dimension = ds.getDimension(); + + log.info("Dataset {} loaded with recall data. Base vectors: {} (portion {}), Query vectors: {}, Dim: {}, Similarity: {}, Workload: {}, Live nodes rate: {}", + datasetNames, ravv.size(), datasetPortion, queryVectors.size(), dimension, similarityFunction, workloadMode, liveNodesRate); + } + + // Resolve storagePaths + partitionsDir + storagePaths = resolveStoragePaths(); + partitionsBaseDir = resolvePartitionsBaseDir(storagePaths); + compactOutputPath = resolveCompactOutputPath(partitionsBaseDir); + scratchOutputPath = resolveScratchOutputPath(partitionsBaseDir); + + // Clean stale artifacts only if we're going to rebuild them. + if (workloadMode == WorkloadMode.COMPACT_ONLY || workloadMode == WorkloadMode.COMPACT_AND_RECALL) { + // For compact-only and compact-and-recall, ensure the partition files exist. + verifyPartitionsExist(partitionsBaseDir, numPartitions); + } + + // Partition metadata for remapping (needed for compaction) + if (workloadMode == WorkloadMode.PARTITION_ONLY || workloadMode == WorkloadMode.PARTITION_AND_COMPACT) { + var partitionedData = DataSetPartitioner.partition(baseVectors, numPartitions, splitDistribution); + vectorsPerSourceCount = partitionedData.sizes; + } else { + vectorsPerSourceCount = null; + } + + // Build partitions during setup for SEGMENTS_* (matches original benchmark structure) + if (workloadMode == WorkloadMode.PARTITION_ONLY || workloadMode == WorkloadMode.PARTITION_AND_COMPACT) { + if (jfrPartitioning) { + jfrPartitioningRecorder.start(JFR_DIR, "partitioning-" + jfrParamSuffix() + ".jfr", jfrObjectCount); + } + buildPartitions(ds, baseVectors); + if (jfrPartitioningRecorder.isActive()) { + jfrPartitioningRecorder.stop(); + } + } + + } catch (Exception e) { + persistError(e); + throw e; + } + } + + private void validateParams() { + if (workloadMode == WorkloadMode.BUILD_FROM_SCRATCH) { + log.warn("numPartitions={} ignored in BUILD_FROM_SCRATCH mode", numPartitions); + } + else { + if (numPartitions <= 1) throw new IllegalArgumentException("numPartitions must be larger than one"); + } + if (graphDegree <= 0) throw new IllegalArgumentException("graphDegree must be positive"); + if (beamWidth <= 0) throw new IllegalArgumentException("beamWidth must be positive"); + if (datasetPortion <= 0.0 || datasetPortion > 1.0) { + throw new IllegalArgumentException("datasetPortion must be in (0.0, 1.0]"); + } + if (liveNodesRate <= 0.0 || liveNodesRate > 1.0) { + throw new IllegalArgumentException("liveNodesRate must be in (0.0, 1.0]"); + } + } + + private List resolveStoragePaths() throws IOException { + // Priority: + // 1) storageDirectories (comma-separated) + // 3) temp dir + var paths = new ArrayList(); + + if (storageDirectories != null && !storageDirectories.isBlank()) { + for (String dir : storageDirectories.split(",")) { + Path path = Path.of(dir.trim()); + if (!Files.exists(path)) Files.createDirectories(path); + if (!Files.isDirectory(path) || !Files.isWritable(path)) { + throw new IllegalArgumentException("Path is not a writable directory: " + dir); + } + paths.add(path); + } + } else { + tempDir = Files.createTempDirectory("compact-bench"); + paths.add(tempDir); + } + + // Handle storage class validation + if (storageClasses != null && !storageClasses.isBlank()) { + String[] classes = storageClasses.split(","); + if (classes.length != paths.size()) { + throw new IllegalArgumentException(String.format( + "Mismatch between number of storage classes (%d) and storage directories (%d). They must be pairwise 1:1.", + classes.length, paths.size())); + } + + var actualStorageClasses = CloudStorageLayoutUtil.storageClassByMountPoint(); + for (int i = 0; i < paths.size(); i++) { + Path path = paths.get(i).toAbsolutePath(); + CloudStorageLayoutUtil.StorageClass expected; + try { + expected = CloudStorageLayoutUtil.StorageClass.valueOf(classes[i].trim()); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid StorageClass: " + classes[i], e); + } + + String bestMount = null; + for (String mountPoint : actualStorageClasses.keySet()) { + if (path.toString().startsWith(mountPoint)) { + if (bestMount == null || mountPoint.length() > bestMount.length()) { + bestMount = mountPoint; + } + } + } + + if (bestMount != null) { + CloudStorageLayoutUtil.StorageClass actual = actualStorageClasses.get(bestMount); + if (actual != expected) { + throw new IllegalStateException(String.format( + "Storage class mismatch for path %s: expected %s, found %s (mount: %s)", + path, expected, actual, bestMount)); + } + } else { + log.warn("Could not determine storage class for path {}. Skipping validation.", path); + } + } + } + + return paths; + } + + private Path resolvePartitionsBaseDir(List storagePaths) throws IOException { + Path p = storagePaths.get(0); + Files.createDirectories(p); + return p; + } + + private Path resolveCompactOutputPath(Path baseDir) { + return baseDir.resolve("compact-graph"); + } + + private Path resolveScratchOutputPath(Path baseDir) { + return baseDir.resolve("scratch-graph"); + } + + private void verifyPartitionsExist(Path partitionsDir, int numPartitions) { + for (int i = 0; i < numPartitions; i++) { + Path seg = partitionsDir.resolve("per-source-graph-" + i); + if (!Files.exists(seg)) { + throw new IllegalStateException("Missing partition file for COMPACT_ONLY or COMPACT_AND_RECALL: " + seg.toAbsolutePath()); + } + } + } + + private void buildPartitions(DataSet ds, List> baseVectors) throws Exception { + + var partitionedData = DataSetPartitioner.partition(baseVectors, numPartitions, splitDistribution); + vectorsPerSourceCount = partitionedData.sizes; + + log.info("Building {} partitions into {} (deg={}, bw={}, split={}, splitSizes={}, precision={}, pwThreads={}, vp={})", + numPartitions, partitionsBaseDir.toAbsolutePath(), graphDegree, beamWidth, splitDistribution, vectorsPerSourceCount, + indexPrecision, parallelWriteThreads, resolvedVectorizationProvider); + + int dimension = baseVectors.get(0).length(); + for (int i = 0; i < numPartitions; i++) { + List> vectorsPerSource = partitionedData.vectors.get(i); + + // Round-robin assignment of partition files to storage paths, but still keep canonical base dir name stable. + Path baseDirForThisSegment = storagePaths.get(i % storagePaths.size()); + Path outputPath = baseDirForThisSegment.resolve("per-source-graph-" + i); + if (Files.exists(outputPath)) { + Files.delete(outputPath); + } + + log.info("Building partition {}/{}: vectors={} -> {}", + i + 1, numPartitions, vectorsPerSource.size(), outputPath.toAbsolutePath()); + + var ravvPerSource = new ListRandomAccessVectorValues(vectorsPerSource, dimension); + BuildScoreProvider bspPerSource; + ProductQuantization pq = null; + PQVectors pqVectors = null; + // TODO: should we build partitions by FUSEDPQ? + if (indexPrecision == IndexPrecision.FUSEDPQ) { + boolean centerData = similarityFunction == VectorSimilarityFunction.EUCLIDEAN; + pq = ProductQuantization.compute(ravvPerSource, dimension / 8, 256, centerData); + pqVectors = (PQVectors) pq.encodeAll(ravvPerSource); + bspPerSource = BuildScoreProvider.pqBuildScoreProvider(similarityFunction, pqVectors); + } + else { + bspPerSource = BuildScoreProvider.randomAccessScoreProvider(ravvPerSource, similarityFunction); + } + + var builder = new GraphIndexBuilder(bspPerSource, + dimension, + graphDegree, beamWidth, 1.2f, 1.2f, true); + var graph = builder.build(ravvPerSource); + + AbstractGraphIndexWriter.Builder writerBuilder; + if (parallelWriteThreads > 1) { + writerBuilder = new OnDiskParallelGraphIndexWriter.Builder(graph, outputPath) + .withParallelWorkerThreads(parallelWriteThreads); + } else { + writerBuilder = new OnDiskGraphIndexWriter.Builder(graph, outputPath); + } + + writerBuilder.with(new InlineVectors(dimension)); + + + if (indexPrecision == IndexPrecision.FUSEDPQ) { + writerBuilder.with(new FusedPQ(graph.maxDegree(), pq)); + } + + try (var writer = writerBuilder.build()) { + var suppliers = new EnumMap>(FeatureId.class); + suppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(ravvPerSource.getVector(ordinal))); + + if (indexPrecision == IndexPrecision.FUSEDPQ) { + var view = graph.getView(); + var finalPqVectors = pqVectors; + suppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(view, finalPqVectors, ordinal)); + } + + writer.write(suppliers); + } + } + + log.info("Done building partitions."); + } + + private long compactPartitions() throws Exception { + + // Load partitions (from round-robin storage paths, same naming) + for (int i = 0; i < numPartitions; i++) { + Path baseDir = storagePaths.get(i % storagePaths.size()); + Path segPath = baseDir.resolve("per-source-graph-" + i); + log.info("Loading partition {}/{} from {}", i + 1, numPartitions, segPath.toAbsolutePath()); + rss.add(ReaderSupplierFactory.open(segPath.toAbsolutePath())); + graphs.add(OnDiskGraphIndex.load(rss.get(i))); + } + + // Ensure output dir exists + if (compactOutputPath.getParent() != null) { + Files.createDirectories(compactOutputPath.getParent()); + } + + if (Files.exists(compactOutputPath)) { + Files.delete(compactOutputPath); + } + + log.info("Compacting {} partitions into {}", numPartitions, compactOutputPath.toAbsolutePath()); + + + List remappers = new ArrayList<>(numPartitions); + List liveNodes = new ArrayList<>(numPartitions); + // Remap ordinals: local [0..size-1] -> global increasing in partition order + int globalOrdinal = 0; + for (int n = 0; n < numPartitions; n++) { + int size = graphs.get(n).size(); + var remapper = new OrdinalMapper.OffsetMapper(globalOrdinal, size); + remappers.add(remapper); + liveNodes.add(randomLiveNodes(size, liveNodesRate, n)); + globalOrdinal += size; + } + var compactor = new OnDiskGraphIndexCompactor(graphs, liveNodes, remappers, similarityFunction, null); + + long startNanos = System.nanoTime(); + compactor.compact(compactOutputPath); + return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos); + } + + private long buildFromScratch(List> baseVectors) throws Exception { + if (scratchOutputPath.getParent() != null) { + Files.createDirectories(scratchOutputPath.getParent()); + } + if (Files.exists(scratchOutputPath)) { + Files.delete(scratchOutputPath); + } + + int dimension = baseVectors.get(0).length(); + var full = new ListRandomAccessVectorValues(baseVectors, dimension); + ProductQuantization pq = null; + PQVectors pqVectors = null; + BuildScoreProvider bsp; + if (indexPrecision == IndexPrecision.FUSEDPQ) { + boolean centerData = similarityFunction == VectorSimilarityFunction.EUCLIDEAN; + pq = ProductQuantization.compute(full, dimension / 8, 256, centerData); + pqVectors = (PQVectors) pq.encodeAll(full); + bsp = BuildScoreProvider.pqBuildScoreProvider(similarityFunction, pqVectors); + } + else { + bsp = BuildScoreProvider.randomAccessScoreProvider(full, similarityFunction); + } + + log.info("Building from scratch: vectors={} dim={} sim={} deg={} bw={} precision={} pwThreads={} vp={} -> {}", + full.size(), dimension, similarityFunction, + graphDegree, beamWidth, indexPrecision, parallelWriteThreads, resolvedVectorizationProvider, + scratchOutputPath.toAbsolutePath()); + + var builder = new GraphIndexBuilder(bsp, dimension, graphDegree, beamWidth, 1.2f, 1.2f, true); + var graph = builder.build(full); + + AbstractGraphIndexWriter.Builder writerBuilder = + (parallelWriteThreads > 1) + ? new OnDiskParallelGraphIndexWriter.Builder(graph, scratchOutputPath) + .withParallelWorkerThreads(parallelWriteThreads) + : new OnDiskGraphIndexWriter.Builder(graph, scratchOutputPath); + + writerBuilder.with(new InlineVectors(dimension)); + +// ProductQuantization pq = null; +// PQVectors pqVectors = null; +// if (indexPrecision == IndexPrecision.FUSEDPQ) { +// boolean centerData = similarityFunction == VectorSimilarityFunction.EUCLIDEAN; +// pq = ProductQuantization.compute(full, dimension / 8, 256, centerData); +// pqVectors = (PQVectors) pq.encodeAll(full); +// writerBuilder.with(new FusedPQ(graph.maxDegree(), pq)); +// } + if (indexPrecision == IndexPrecision.FUSEDPQ) { + writerBuilder.with(new FusedPQ(graph.maxDegree(), pq)); + } + + long startNanos = System.nanoTime(); + try (var writer = writerBuilder.build()) { + var suppliers = new EnumMap>(FeatureId.class); + suppliers.put(FeatureId.INLINE_VECTORS, ord -> new InlineVectors.State(full.getVector(ord))); + + if (indexPrecision == IndexPrecision.FUSEDPQ) { + var view = graph.getView(); + var finalPQ = pqVectors; + suppliers.put(FeatureId.FUSED_PQ, ord -> new FusedPQ.State(view, finalPQ, ord)); + } + + writer.write(suppliers); + } + return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos); + } + + @TearDown(Level.Iteration) + public void tearDown() throws IOException, InterruptedException { + if (threadAllocTracker.isActive()) { + threadAllocTracker.stop(); + } + + if (sysStatsCollector.isActive()) { + sysStatsCollector.stop(SYSTEM_DIR); + } + + if (jfrPartitioningRecorder.isActive()) { + jfrPartitioningRecorder.stop(); + } + if (jfrCompactingRecorder.isActive()) { + jfrCompactingRecorder.stop(); + } + + closeLoadedGraphs(); + } + + private void closeLoadedGraphs() { + for (var graph : graphs) { + try { + graph.close(); + } catch (Exception e) { + log.error("Failed to close graph", e); + } + } + graphs.clear(); + + for (var rs : rss) { + try { + rs.close(); + } catch (Exception e) { + log.error("Failed to close ReaderSupplier", e); + } + } + rss.clear(); + } + + @Benchmark + public void run(Blackhole blackhole, RecallResult recallResult) throws Exception { + long durationMs = 0; + double recall = -1; + + try { + if (jfrCompacting) { + try { + jfrCompactingRecorder.start(JFR_DIR, "workload-" + jfrParamSuffix() + ".jfr", jfrObjectCount); + } catch (Exception e) { + log.warn("Failed to start workload JFR recording", e); + } + } + + // Execute workload + switch (workloadMode) { + case PARTITION_ONLY: + break; + + case COMPACT_ONLY: + durationMs = compactPartitions(); + break; + + case COMPACT_AND_RECALL: + durationMs = compactPartitions(); + recall = runRecall(compactOutputPath); + break; + + case BUILD_FROM_SCRATCH: { + durationMs = buildFromScratch(baseVectors); + recall = runRecall(scratchOutputPath); + break; + } + + case PARTITION_AND_COMPACT: + durationMs = compactPartitions(); + recall = runRecall(compactOutputPath); + break; + + default: + throw new IllegalStateException("Unknown workloadMode: " + workloadMode); + } + + recallResult.recall = recall; + persistResult(recall, durationMs); + blackhole.consume(durationMs); + + } catch (Exception e) { + persistError(e); + throw e; + } finally { + if (jfrCompactingRecorder.isActive()) { + jfrCompactingRecorder.stop(); + } + closeLoadedGraphs(); + } + } + + private double runRecall(Path indexPath) throws Exception { + + log.info("Loading and searching index at {}", indexPath.toAbsolutePath()); + try (var rs = ReaderSupplierFactory.open(indexPath)) { + var graph = OnDiskGraphIndex.load(rs); + GraphSearcher searcher = new GraphSearcher(graph); + var view = (ImmutableGraphIndex.ScoringView) searcher.getView(); + searcher.usePruning(false); + List retrieved = new ArrayList<>(queryVectors.size()); + for (int n = 0; n < queryVectors.size(); ++n) { + SearchResult result; + if(indexPrecision == IndexPrecision.FUSEDPQ) { + var asf = view.approximateScoreFunctionFor(queryVectors.get(n), similarityFunction); + var rerank = view.rerankerFor(queryVectors.get(n), similarityFunction); + SearchScoreProvider ssp = new DefaultSearchScoreProvider(asf, rerank); + result = searcher.search(ssp, 10, 10, 0.0f, 0.0f, Bits.ALL); + } + else { + var ssp = DefaultSearchScoreProvider.exact(queryVectors.get(n), similarityFunction, ravv); + result = searcher.search(ssp, 10, 10, 0.0f, 0.0f, Bits.ALL); + } + retrieved.add(result); + } + + double recall = AccuracyMetrics.recallFromSearchResults(groundTruth, retrieved, 10, 10); + log.info("Recall [dataset={}, workloadMode={}, numPartitions={}, graphDegree={}, beamWidth={}, splitDistribution={}, indexPrecision={}, parallelWriteThreads={}, vectorizationProvider={}, datasetPortion={}]: {}", + datasetNames, workloadMode, numPartitions, graphDegree, beamWidth, splitDistribution, indexPrecision, parallelWriteThreads, resolvedVectorizationProvider, datasetPortion, recall); + return recall; + } + } + + // ---------- result persistence ---------- + private LinkedHashMap buildParams() { + var params = new LinkedHashMap(); + params.put("dataset", datasetNames); + params.put("workloadMode", workloadMode.name()); + params.put("numPartitions", numPartitions); + params.put("graphDegree", graphDegree); + params.put("beamWidth", beamWidth); + params.put("storageDirectories", storageDirectories); + params.put("storageClasses", storageClasses); + params.put("splitDistribution", splitDistribution.name()); + params.put("indexPrecision", indexPrecision.name()); + params.put("parallelWriteThreads", parallelWriteThreads); + params.put("vectorizationProvider", resolvedVectorizationProvider); + params.put("datasetPortion", datasetPortion); + params.put("jfrPartitioning", jfrPartitioning); + params.put("jfrCompacting", jfrCompacting); + params.put("jfrObjectCount", jfrObjectCount); + params.put("sysStatsEnabled", sysStatsEnabled); + params.put("threadAllocTracking", threadAllocTracking); + params.put("liveNodesRate", liveNodesRate); + return params; + } + + private LinkedHashMap baseResult(String event) { + var result = new LinkedHashMap(); + result.put("testId", TEST_ID); + result.put("gitHash", GitInfo.getShortHash()); + result.put("timestamp", Instant.now().toString()); + result.put("event", event); + result.put("benchmark", "run"); + result.put("params", buildParams()); + return result; + } + + private void persistStarted() { + var result = baseResult("started"); + result.put("completedTests", completedTests.get()); + result.put("totalTests", TOTAL_TESTS); + jsonlWriter.writeLine(result); + log.info("Starting test {}/{}", completedTests.get() + 1, TOTAL_TESTS); + } + + private void persistResult(double recall, long durationMs) { + if (resultPersisted) return; + resultPersisted = true; + + int completed = completedTests.incrementAndGet(); + writeCompletedCount(completed); + + var result = baseResult("completed"); + var results = new LinkedHashMap(); + results.put("durationMs", durationMs); + + // Only meaningful for recall-enabled workloads; else NaN + results.put("recall", recall); + + if (vectorsPerSourceCount != null) { + results.put("splitSizes", vectorsPerSourceCount.toString()); + } + if (jfrPartitioningRecorder.getFileName() != null) { + results.put("jfrPartitioningFile", jfrPartitioningRecorder.getFileName()); + } + if (jfrCompactingRecorder.getFileName() != null) { + results.put("jfrWorkloadFile", jfrCompactingRecorder.getFileName()); + } + if (sysStatsCollector.getFileName() != null) { + results.put("sysStatsFile", sysStatsCollector.getFileName()); + } + if (threadAllocTracker.getFileName() != null) { + results.put("threadAllocFile", threadAllocTracker.getFileName()); + } + + result.put("results", results); + result.put("completedTests", completed); + result.put("totalTests", TOTAL_TESTS); + + jsonlWriter.writeLine(result); + log.info("Completed test {}/{}", completed, TOTAL_TESTS); + } + + private void persistError(Exception e) { + try { + var result = baseResult("error"); + var results = new LinkedHashMap(); + results.put("errorMessage", e.getMessage() != null ? e.getMessage() : e.getClass().getName()); + result.put("results", results); + result.put("completedTests", completedTests.get()); + result.put("totalTests", TOTAL_TESTS); + jsonlWriter.writeLine(result); + } catch (Exception inner) { + log.error("Failed to persist error event", inner); + } + } + + public static void main(String[] args) throws Exception { + Files.createDirectories(RUN_DIR); + String jmhResultFile = RUN_DIR.resolve("compactor-jmh.json").toString(); + log.info("Benchmark run directory: {}", RUN_DIR.toAbsolutePath()); + log.info("Progressive results will be written to: {}", RESULTS_FILE.toAbsolutePath()); + log.info("JMH results will be written to: {}", Path.of(jmhResultFile).toAbsolutePath()); + + org.openjdk.jmh.runner.options.CommandLineOptions cmdOptions = new org.openjdk.jmh.runner.options.CommandLineOptions(args); + int totalTests = BenchmarkParamCounter.computeTotalTests(CompactorBenchmark.class, cmdOptions); + log.info("Total test combinations: {}", totalTests); + + // Resolve the log4j2 config so the forked JVM picks it up explicitly + var log4j2Config = CompactorBenchmark.class.getClassLoader().getResource("log4j2.xml"); + String log4j2Arg = log4j2Config != null + ? "-Dlog4j2.configurationFile=" + log4j2Config + : "-Dlog4j2.configurationFile=classpath:log4j2.xml"; + + // The forked JVM's stdout is piped through JMH, so System.console() returns null + // and Log4j2 suppresses ANSI. Propagate the parent's TTY detection to the child. + String disableAnsi = System.console() == null ? "true" : "false"; + + // Collect all JVM args for the forked process in one list, + // because jvmArgsAppend() replaces (not appends) on each call. + var jvmArgs = new ArrayList(); + jvmArgs.add("-Djvector.internal.runDir=" + RUN_DIR); + jvmArgs.add("-Djvector.internal.totalTests=" + totalTests); + jvmArgs.add(log4j2Arg); + jvmArgs.add("-Dcompactor.disableAnsi=" + disableAnsi); + + // Pass the vectorization provider if specified in command line options + var vpParam = cmdOptions.getParameter("vectorizationProvider"); + if (vpParam.hasValue()) { + var vpValues = vpParam.get(); + if (!vpValues.isEmpty()) { + jvmArgs.add("-Djvector.vectorization_provider=" + vpValues.iterator().next()); + } + } + + var optBuilder = new org.openjdk.jmh.runner.options.OptionsBuilder(); + optBuilder.include(CompactorBenchmark.class.getSimpleName()) + .parent(cmdOptions) + .forks(1) + .threads(1) + .shouldFailOnError(true) + .jvmArgsAppend(jvmArgs.toArray(new String[0])) + .resultFormat(ResultFormatType.JSON) + .result(jmhResultFile); + + new org.openjdk.jmh.runner.Runner(optBuilder.build()).run(); + } + +} diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/CompactorBenchmark.md b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/CompactorBenchmark.md new file mode 100644 index 000000000..bf8355de0 --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/CompactorBenchmark.md @@ -0,0 +1,139 @@ + + +# CompactorBenchmark + +`CompactorBenchmark` evaluates the **performance, memory usage, and recall quality** of graph index compaction using `OnDiskGraphIndexCompactor`. + +--- + +# 1. Workload Modes + +| Mode | Description | +|------|-------------| +| `PARTITION_AND_COMPACT` | **(default)** Build partitions, compact them, then measure recall — all in one run | +| `PARTITION_ONLY` | Build N partition indexes and exit; no compaction | +| `COMPACT_ONLY` | Compact existing partitions without loading the dataset | +| `BUILD_FROM_SCRATCH` | Build a single index over the full dataset | + +--- + +# 2. Quick Start + +## Default: partition and compact in one run + +The default mode builds partitions and immediately compacts them. Use this when you want a single-command end-to-end result. + +```bash +java -Xmx220g --add-modules jdk.incubator.vector \ + -jar benchmarks-jmh/target/benchmarks-jmh-*.jar CompactorBenchmark \ + -p workloadMode=PARTITION_AND_COMPACT \ + -p datasetNames=glove-100-angular \ + -p numPartitions=4 \ + -p splitDistribution=FIBONACCI \ + -p indexPrecision=FUSEDPQ \ + -wi 0 -i 1 -f 1 +``` + +--- + +# 3. Measuring Peak Heap During Compaction + +The two-step workflow (`PARTITION_ONLY` → `COMPACT_ONLY`) exists to isolate compaction's true memory footprint. In `PARTITION_AND_COMPACT` mode the dataset is still resident in heap during compaction, which inflates the apparent memory cost. `COMPACT_ONLY` skips dataset loading entirely, so the heap limit applies only to the compactor itself. + +This lets you prove that compaction can run on machines with very little RAM — e.g., `-Xmx5g` is sufficient even for large datasets. + +## Step 1: Build partitions + +Run with a large heap since the full dataset must be loaded into memory. + +```bash +java -Xmx220g --add-modules jdk.incubator.vector \ + -jar benchmarks-jmh/target/benchmarks-jmh-*.jar CompactorBenchmark \ + -p workloadMode=PARTITION_ONLY \ + -p datasetNames=glove-100-angular \ + -p numPartitions=4 \ + -p splitDistribution=FIBONACCI \ + -p indexPrecision=FUSEDPQ \ + -wi 0 -i 1 -f 1 +``` + +The partition files are written to disk and reused in the next step. + +## Step 2: Compact only (low-memory run) + +The dataset is **not** loaded in this mode. Use a small `-Xmx` to measure and prove the compactor's true peak heap. + +```bash +java -Xmx5g --add-modules jdk.incubator.vector \ + -jar benchmarks-jmh/target/benchmarks-jmh-*.jar CompactorBenchmark \ + -p workloadMode=COMPACT_ONLY \ + -p datasetNames=glove-100-angular \ + -p numPartitions=4 \ + -p splitDistribution=FIBONACCI \ + -p indexPrecision=FUSEDPQ \ + -wi 0 -i 1 -f 1 +``` + +`durationMs` in the output records only the `compact()` call — not JVM startup or I/O setup. + +--- + +# 4. Key Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `datasetNames` | `glove-100-angular` | Dataset name | +| `workloadMode` | `PARTITION_AND_COMPACT` | Which phase(s) to run | +| `numPartitions` | `4` | Number of source partition indexes | +| `splitDistribution` | — | Data partitioning strategy (`UNIFORM`, `FIBONACCI`, …) | +| `indexPrecision` | — | `FULLPRECISION` (inline vectors only) or `FUSEDPQ` (inline + FusedPQ) | +| `storageDirectories` | *(temp dir)* | Comma-separated list of directories where partition files are written; partitions are distributed round-robin across them. Defaults to a JVM temp directory if unset. | + +--- + +# 5. Index Precision + +`indexPrecision` controls what features are written into each partition index. + +| Value | Written features | +|-------|-----------------| +| `FULLPRECISION` | `INLINE_VECTORS` only | +| `FUSEDPQ` | `INLINE_VECTORS` + `FUSED_PQ` — required for compressed compaction | + +--- + +# 6. Results + +Results are written as JSONL to: + +``` +target/benchmark-results/compactor-/compactor-results.jsonl +``` + +Key fields: + +| Field | Description | +|-------|-------------| +| `durationMs` | Time spent in the measured phase only | +| `recall` | Recall@10 (present when workload mode includes recall, e.g. `PARTITION_AND_COMPACT`) | +| `peakHeapMb` | Peak JVM heap observed during the run | + +--- + +# 7. Memory Footprint + +All datasets in the recall table (see `docs/compaction.md`) can be run under `COMPACT_ONLY` with `-Xmx5g`. Compaction also successfully scales to a dataset with 2560 dimensions and 10M vectors under the same constraint. diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/BenchmarkParamCounter.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/BenchmarkParamCounter.java new file mode 100644 index 000000000..ab3035d9c --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/BenchmarkParamCounter.java @@ -0,0 +1,54 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.bench.benchtools; + +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.runner.options.CommandLineOptions; + +/** + * Counts the total number of {@code @Param} combinations for a JMH benchmark class. + */ +public final class BenchmarkParamCounter { + private BenchmarkParamCounter() {} + + /** + * Computes the total number of benchmark parameter combinations as the cartesian product + * of all {@code @Param} value sets. When {@code cmdOptions} is provided, command-line + * {@code -p} overrides take precedence over the annotation defaults. + * + * @param benchmarkClass the JMH benchmark class to inspect + * @param cmdOptions parsed command-line options, or {@code null} to use annotation defaults only + * @return the total number of parameter combinations + */ + public static int computeTotalTests(Class benchmarkClass, CommandLineOptions cmdOptions) { + int total = 1; + for (var field : benchmarkClass.getDeclaredFields()) { + var paramAnnotation = field.getAnnotation(Param.class); + if (paramAnnotation != null) { + if (cmdOptions != null) { + var cmdOverride = cmdOptions.getParameter(field.getName()); + if (cmdOverride.hasValue() && !cmdOverride.get().isEmpty()) { + total *= cmdOverride.get().size(); + continue; + } + } + total *= paramAnnotation.value().length; + } + } + return total; + } +} diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/EventLogAnalyzer.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/EventLogAnalyzer.java new file mode 100644 index 000000000..896a83af1 --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/EventLogAnalyzer.java @@ -0,0 +1,470 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.bench.benchtools; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; + +/** + * CLI utility that reads a CompactorBenchmark JSONL event log and reports + * max concurrency and the cartesian parameter matrix with testIds and results. + */ +public final class EventLogAnalyzer { + + private static final List PARAM_FIELDS = List.of( + "params.dataset", "params.numSources", "params.graphDegree", "params.beamWidth", + "params.storageDirectories", "params.storageClasses", "params.splitDistribution", + "params.indexPrecision", "params.parallelWriteThreads", "params.vectorizationProvider", + "params.datasetPortion" + ); + + private static final List RESULT_FIELDS = List.of( + "results.recall", "results.durationMs", "results.errorMessage" + ); + + private static final Path RESULTS_DIR = Path.of("target", "benchmark-results"); + private static final String RESULTS_FILENAME = "compactor-results.jsonl"; + + /** + * Finds the most recent compactor-results.jsonl under target/benchmark-results/ + * by selecting the compactor-* directory with the highest name (epoch-second suffix). + */ + private static Path findLatestResultsFile() { + if (!Files.isDirectory(RESULTS_DIR)) { + return null; + } + try (Stream dirs = Files.list(RESULTS_DIR)) { + return dirs.filter(Files::isDirectory) + .filter(d -> d.getFileName().toString().startsWith("compactor-")) + .sorted(Comparator.reverseOrder()) + .map(d -> d.resolve(RESULTS_FILENAME)) + .filter(Files::exists) + .findFirst() + .orElse(null); + } catch (IOException e) { + return null; + } + } + + public static void main(String[] args) throws IOException { + Path inputFile = null; + String startingAt = null, endingAt = null, startingTestId = null, endingTestId = null; + + for (int i = 0; i < args.length; i++) { + switch (args[i]) { + case "--starting-at": startingAt = args[++i]; break; + case "--ending-at": endingAt = args[++i]; break; + case "--starting-testid": startingTestId = args[++i]; break; + case "--ending-testid": endingTestId = args[++i]; break; + case "--help": + case "-h": + printUsage(); + System.exit(0); + break; + default: + if (args[i].startsWith("--")) { + System.err.println("Unknown option: " + args[i]); + System.exit(1); + } + if (inputFile != null) { + System.err.println("Multiple input files specified"); + System.exit(1); + } + inputFile = Path.of(args[i]); + } + } + + if (inputFile == null) { + inputFile = findLatestResultsFile(); + if (inputFile == null) { + System.err.println("No results file found under " + RESULTS_DIR.toAbsolutePath()); + printUsage(); + System.exit(1); + } + } + + System.err.println("Using: " + inputFile.toAbsolutePath()); + + // Read and parse all lines + List> allEvents = new ArrayList<>(); + for (String line : Files.readAllLines(inputFile)) { + line = line.trim(); + if (!line.isEmpty()) { + allEvents.add(parseJsonLine(line)); + } + } + + // Assign synthetic testIds where missing + boolean hasSyntheticIds = false; + String currentSyntheticId = null; + int syntheticCounter = 0; + for (var event : allEvents) { + if (!event.containsKey("testId") || event.get("testId").isEmpty()) { + hasSyntheticIds = true; + if ("started".equals(event.get("event"))) { + currentSyntheticId = String.format("#%03d", ++syntheticCounter); + } + if (currentSyntheticId != null) { + event.put("testId", currentSyntheticId); + } + } + } + + // Group all events by testId + Map>> allByTestId = new LinkedHashMap<>(); + for (var event : allEvents) { + String tid = event.get("testId"); + if (tid != null) { + allByTestId.computeIfAbsent(tid, k -> new ArrayList<>()).add(event); + } + } + + // Determine which testIds are in scope + Set inScopeTestIds = new LinkedHashSet<>(); + for (var entry : allByTestId.entrySet()) { + String tid = entry.getKey(); + + // testId range filter (skip for synthetic IDs) + if (!hasSyntheticIds) { + if (startingTestId != null && tid.compareTo(startingTestId) < 0) continue; + if (endingTestId != null && tid.compareTo(endingTestId) > 0) continue; + } + + // Timestamp range filter: test is in scope if ANY event is within range + if (startingAt != null || endingAt != null) { + boolean anyInRange = false; + for (var event : entry.getValue()) { + String ts = event.get("timestamp"); + if (ts == null) continue; + boolean afterStart = startingAt == null || ts.compareTo(startingAt) >= 0; + boolean beforeEnd = endingAt == null || ts.compareTo(endingAt) <= 0; + if (afterStart && beforeEnd) { + anyInRange = true; + break; + } + } + if (!anyInRange) continue; + } + + inScopeTestIds.add(tid); + } + + if (inScopeTestIds.isEmpty()) { + System.out.println("No matching events found."); + return; + } + + // Collect all events for in-scope testIds + Map>> byTestId = new LinkedHashMap<>(); + for (String tid : inScopeTestIds) { + byTestId.put(tid, allByTestId.get(tid)); + } + + // Compute max concurrency by sweeping start/end intervals + List intervals = new ArrayList<>(); + for (var entry : byTestId.entrySet()) { + String startTs = null, endTs = null; + for (var event : entry.getValue()) { + String ev = event.get("event"); + String ts = event.get("timestamp"); + if (ts == null) continue; + if ("started".equals(ev) && (startTs == null || ts.compareTo(startTs) < 0)) { + startTs = ts; + } + if (("completed".equals(ev) || "error".equals(ev)) + && (endTs == null || ts.compareTo(endTs) > 0)) { + endTs = ts; + } + } + if (startTs != null && endTs != null) { + intervals.add(new String[]{startTs, endTs}); + } + } + + System.out.println("Max concurrency: " + computeMaxConcurrency(intervals)); + System.out.println(); + + // Extract params and results for each test. + // Params come from the "started" event (which includes the "params" sub-object). + // Results come from the "completed" or "error" event (which includes the "results" sub-object). + Map> testParams = new LinkedHashMap<>(); + Map> testResults = new LinkedHashMap<>(); + Map testStatus = new LinkedHashMap<>(); + for (var entry : byTestId.entrySet()) { + String tid = entry.getKey(); + Map params = null; + Map results = null; + String status = "started"; + for (var event : entry.getValue()) { + String ev = event.get("event"); + // Take params from whichever event has them (all events include params now) + if (params == null && event.keySet().stream().anyMatch(k -> k.startsWith("params."))) { + params = event; + } + if ("completed".equals(ev)) { + status = "completed"; + results = event; + } else if ("error".equals(ev)) { + status = "error"; + results = event; + } + } + if (params != null) { + testParams.put(tid, params); + } + testResults.put(tid, results != null ? results : Map.of()); + testStatus.put(tid, status); + } + + // Classify parameters as static (single value) or varying (multiple values) + List varyingParams = new ArrayList<>(); + Map staticParams = new LinkedHashMap<>(); + for (String field : PARAM_FIELDS) { + Set values = new HashSet<>(); + for (var params : testParams.values()) { + values.add(params.getOrDefault(field, "")); + } + if (values.size() > 1) { + varyingParams.add(field); + } else if (values.size() == 1) { + String value = values.iterator().next(); + if (!value.isEmpty()) { + staticParams.put(stripPrefix(field), value); + } + } + } + + // Display static parameters at the top + if (!staticParams.isEmpty()) { + System.out.println("Static parameters:"); + for (var entry : staticParams.entrySet()) { + System.out.println(" " + entry.getKey() + " = " + entry.getValue()); + } + System.out.println(); + } + + if (varyingParams.isEmpty()) { + System.out.println("No varying parameters found."); + } else { + // Display without "params." prefix for readability + List displayNames = new ArrayList<>(); + for (String p : varyingParams) { + displayNames.add(stripPrefix(p)); + } + System.out.println("Varying parameters: " + String.join(", ", displayNames)); + } + System.out.println(); + + // Determine which result fields have any non-empty values + List activeResultFields = new ArrayList<>(); + for (String field : RESULT_FIELDS) { + for (var results : testResults.values()) { + if (!results.getOrDefault(field, "").isEmpty()) { + activeResultFields.add(field); + break; + } + } + } + + // Build and print table sorted by testId + List sortedTestIds = new ArrayList<>(testParams.keySet()); + Collections.sort(sortedTestIds); + + // Columns: testId, varying params (without prefix), status, active result fields (without prefix) + List columns = new ArrayList<>(); + columns.add("testId"); + for (String p : varyingParams) { + columns.add(stripPrefix(p)); + } + columns.add("status"); + for (String r : activeResultFields) { + columns.add(stripPrefix(r)); + } + + List> rows = new ArrayList<>(); + for (String tid : sortedTestIds) { + List row = new ArrayList<>(); + row.add(hasSyntheticIds ? "n/a" : tid); + var params = testParams.get(tid); + for (String field : varyingParams) { + row.add(params != null ? params.getOrDefault(field, "") : ""); + } + row.add(testStatus.getOrDefault(tid, "")); + var results = testResults.getOrDefault(tid, Map.of()); + for (String field : activeResultFields) { + row.add(results.getOrDefault(field, "")); + } + rows.add(row); + } + + printTable(columns, rows); + } + + private static String stripPrefix(String field) { + int dot = field.indexOf('.'); + return dot >= 0 ? field.substring(dot + 1) : field; + } + + private static void printTable(List columns, List> rows) { + int[] widths = new int[columns.size()]; + for (int i = 0; i < columns.size(); i++) { + widths[i] = columns.get(i).length(); + } + for (var row : rows) { + for (int i = 0; i < row.size(); i++) { + widths[i] = Math.max(widths[i], row.get(i).length()); + } + } + + StringBuilder header = new StringBuilder(); + for (int i = 0; i < columns.size(); i++) { + if (i > 0) header.append(" "); + header.append(String.format("%-" + widths[i] + "s", columns.get(i))); + } + System.out.println(header); + + for (var row : rows) { + StringBuilder line = new StringBuilder(); + for (int i = 0; i < row.size(); i++) { + if (i > 0) line.append(" "); + line.append(String.format("%-" + widths[i] + "s", row.get(i))); + } + System.out.println(line); + } + } + + private static void printUsage() { + System.err.println("Usage: EventLogAnalyzer [results.jsonl] [options]"); + System.err.println(" If no file is given, the latest results under target/benchmark-results/ are used."); + System.err.println("Options:"); + System.err.println(" --starting-at Include tests with events at or after this timestamp"); + System.err.println(" --ending-at Include tests with events at or before this timestamp"); + System.err.println(" --starting-testid Include tests with testId >= this value"); + System.err.println(" --ending-testid Include tests with testId <= this value"); + } + + private static int computeMaxConcurrency(List intervals) { + if (intervals.isEmpty()) return 0; + + // Each sweep event is {timestamp, delta} where delta is +1 (start) or -1 (end) + List events = new ArrayList<>(); + for (var interval : intervals) { + events.add(new String[]{interval[0], "+1"}); + events.add(new String[]{interval[1], "-1"}); + } + // Sort by timestamp, then ends before starts at the same timestamp + events.sort((a, b) -> { + int cmp = a[0].compareTo(b[0]); + if (cmp != 0) return cmp; + return a[1].compareTo(b[1]); // "-1" < "+1" lexicographically + }); + + int max = 0, current = 0; + for (var event : events) { + current += Integer.parseInt(event[1]); + max = Math.max(max, current); + } + return max; + } + + /** + * Parses a JSON line into key-value string pairs. + * Nested objects are flattened with dot-prefixed keys (e.g., "params.dataset"). + * Handles quoted strings (with backslash escapes), numbers, booleans, and one level of nesting. + */ + private static Map parseJsonLine(String line) { + Map result = new LinkedHashMap<>(); + parseObject(line, new int[]{0}, "", result); + return result; + } + + private static void parseObject(String line, int[] pos, String prefix, Map result) { + int len = line.length(); + + // Skip to opening brace + while (pos[0] < len && line.charAt(pos[0]) != '{') pos[0]++; + pos[0]++; + + while (pos[0] < len) { + // Skip whitespace and commas + while (pos[0] < len && (line.charAt(pos[0]) == ' ' || line.charAt(pos[0]) == ',' || line.charAt(pos[0]) == '\t')) pos[0]++; + if (pos[0] >= len || line.charAt(pos[0]) == '}') { + pos[0]++; // skip closing brace + break; + } + + // Parse key (quoted string) + if (line.charAt(pos[0]) != '"') break; + pos[0]++; + int keyStart = pos[0]; + while (pos[0] < len && line.charAt(pos[0]) != '"') { + if (line.charAt(pos[0]) == '\\') pos[0]++; + pos[0]++; + } + String key = line.substring(keyStart, pos[0]); + pos[0]++; // skip closing quote + + String fullKey = prefix.isEmpty() ? key : prefix + "." + key; + + // Skip colon and whitespace + while (pos[0] < len && (line.charAt(pos[0]) == ':' || line.charAt(pos[0]) == ' ')) pos[0]++; + + // Parse value + if (pos[0] < len && line.charAt(pos[0]) == '{') { + // Nested object — recurse with dot-prefixed key + parseObject(line, pos, fullKey, result); + } else if (pos[0] < len && line.charAt(pos[0]) == '"') { + // String value + pos[0]++; + StringBuilder sb = new StringBuilder(); + while (pos[0] < len && line.charAt(pos[0]) != '"') { + if (line.charAt(pos[0]) == '\\' && pos[0] + 1 < len) { + pos[0]++; + switch (line.charAt(pos[0])) { + case '"': sb.append('"'); break; + case '\\': sb.append('\\'); break; + case 'n': sb.append('\n'); break; + case 't': sb.append('\t'); break; + default: sb.append(line.charAt(pos[0])); break; + } + } else { + sb.append(line.charAt(pos[0])); + } + pos[0]++; + } + pos[0]++; // skip closing quote + result.put(fullKey, sb.toString()); + } else { + // Number, boolean, or null + int valStart = pos[0]; + while (pos[0] < len && line.charAt(pos[0]) != ',' && line.charAt(pos[0]) != '}' && line.charAt(pos[0]) != ' ') pos[0]++; + result.put(fullKey, line.substring(valStart, pos[0])); + } + } + } +} diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/package-info.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/package-info.java new file mode 100644 index 000000000..1221ea481 --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/benchtools/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Reusable benchmark infrastructure utilities. + *

+ * This package provides general-purpose tools for JMH benchmarks including + * {@code @Param} combination counting ({@link io.github.jbellis.jvector.bench.benchtools.BenchmarkParamCounter}) + * and JFR event log analysis ({@link io.github.jbellis.jvector.bench.benchtools.EventLogAnalyzer}). + */ +package io.github.jbellis.jvector.bench.benchtools; diff --git a/benchmarks-jmh/src/main/resources/log4j2.xml b/benchmarks-jmh/src/main/resources/log4j2.xml index 823788261..75bb377fd 100644 --- a/benchmarks-jmh/src/main/resources/log4j2.xml +++ b/benchmarks-jmh/src/main/resources/log4j2.xml @@ -3,13 +3,19 @@ - + + + + + + \ No newline at end of file diff --git a/docs/compaction.md b/docs/compaction.md new file mode 100644 index 000000000..f63697aa9 --- /dev/null +++ b/docs/compaction.md @@ -0,0 +1,199 @@ +# Graph Index Compaction + +`OnDiskGraphIndexCompactor` merges multiple on-disk HNSW graph indexes into a single compacted index. This is useful in write-heavy workloads where data is continuously ingested into small segment indexes that accumulate over time; periodically compacting those segments into one larger index improves search throughput and recall without rebuilding from scratch. + +## Overview + +``` +source[0].index ─┐ +source[1].index ─┤──► OnDiskGraphIndexCompactor ──► compacted.index +source[N].index ─┘ +``` + +Each source is an `OnDiskGraphIndex` with an associated `FixedBitSet` marking which of its nodes are live (not deleted). The compactor merges all live nodes into a single graph, remaps ordinals so the output is contiguously numbered, and optionally retrains the Product Quantization codebook for the combined dataset. + +## Usage + +```java +List sources = List.of(index0, index1, index2); + +// Mark all nodes live (no deletions) +List liveNodes = sources.stream() + .map(s -> { var bs = new FixedBitSet(s.size()); bs.set(0, s.size()); return bs; }) + .collect(toList()); + +// Sequential ordinal remapping: source[s] node i → global offset[s] + i +int offset = 0; +List remappers = new ArrayList<>(); +for (var src : sources) { + remappers.add(new OrdinalMapper.OffsetMapper(offset, src.size())); + offset += src.size(); +} + +var compactor = new OnDiskGraphIndexCompactor( + sources, liveNodes, remappers, + VectorSimilarityFunction.COSINE, + /* executor= */ null // null = create internal ForkJoinPool +); + +compactor.compact(Path.of("compacted.index")); +``` + +### Handling Deleted Nodes + +Deleted nodes are excluded from the output by marking them as `false` in the corresponding `FixedBitSet`. + +```java +// Example: every 5th node is deleted +FixedBitSet live = new FixedBitSet(source.size()); +Map oldToNew = new HashMap<>(); +int newOrd = 0; +for (int i = 0; i < source.size(); i++) { + if (i % 5 != 0) { + live.set(i); + oldToNew.put(i, newOrd++); + } +} +remappers.add(new OrdinalMapper.MapMapper(oldToNew)); +``` + +## Algorithm + +### Ordinal Remapping + +Each source assigns its own local ordinals. The compactor maps them to a new global ordinal space using user-provided `OrdinalMapper`. + + +### PQ Retraining + +If the source indexes use FusedPQ, the compactor retrains the Product Quantization codebook on the combined dataset before writing the output. This is done by `PQRetrainer`, which +performs **balanced proportional sampling** across all sources (up to `ProductQuantization.MAX_PQ_TRAINING_SET_SIZE` vectors total, at least 1000 per source). + + +### Neighbor Selection (per node) + +For each live node at each graph level, the compactor gathers a candidate neighbor pool and then applies diversity selection: + +**1. Gather from same source** (`gatherFromSameSource`)\ +Iterate the node's existing neighbors in its source index. Filter out deleted nodes. Score each with the similarity function. No graph search — neighbors are already precomputed. + +**2. Gather from other sources** (`gatherFromOtherSource`)\ +Run a graph search in every other source index starting from that source's entry point. If FusedPQ is available, approximate PQ scoring is used during the search and top results are rescored exactly. + +- *Level 0*: a full hierarchical graph search is used (`GraphSearcher.search()`), descending from the entry node down to level 0. +- *Level L > 0*: the compactor first descends greedily from the source's entry node through each level above L (one `searchOneLayer` call with topK=1 per level, feeding the result into the next via `setEntryPointsFromPreviousLayer()`), then performs the full beam search at level L. This mirrors standard HNSW construction and gives a much better starting point than jumping directly to level L from the global entry node. + +``` +searchTopK = max(2, ceil(degree / numSources) * 2) +beamWidth = max(degree, searchTopK) * 2 +``` + +**3. Diversity selection** (Vamana-style)\ +Candidates are sorted by score (descending). The compactor selects up to `maxDegree` diverse neighbors using an adaptive alpha: + +``` +for alpha in [1.0, 1.2]: + for each candidate c (highest score first): + if c is already selected: skip + if ∀ selected neighbor j: similarity(c, j) ≤ score(c) × alpha: + select c + if |selected| == maxDegree: stop +``` + +### Hierarchical Levels + +Level 0 (base layer) stores inline vectors, FusedPQ codes, and the neighbor list. Upper levels store only the neighbor list (plus PQ codes at level 1 for cross-level searching). + +Processing is batched per source and run in parallel across sources using a `ForkJoinPool`. A backpressure window keeps at most `taskWindowSize` batches in-flight at once, bounding memory use. + +### Entry Node + +The entry node of the compacted graph is: +1. The original entry node of `sources[0]`, if it is live. +2. Otherwise, the first live node found by scanning all sources in order. + +## Benchmarking + +Use `CompactorBenchmark` (in `benchmarks-jmh`) to measure compaction performance. See `benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/CompactorBenchmark.md` for full instructions. + +### Default: partition and compact in one run + +```bash +java -Xmx220g --add-modules jdk.incubator.vector \ + -jar benchmarks-jmh/target/benchmarks-jmh-*.jar CompactorBenchmark \ + -p workloadMode=PARTITION_AND_COMPACT \ + -p datasetNames= \ + -p numPartitions=4 \ + -p splitDistribution=FIBONACCI \ + -p indexPrecision=FUSEDPQ \ + -wi 0 -i 1 -f 1 +``` + +### Measuring peak heap during compaction + +To measure how little RAM compaction actually needs — without the dataset occupying heap — run the two steps separately. + +**Step 1: build partitions** (dataset in memory, large heap required) + +```bash +java -Xmx220g --add-modules jdk.incubator.vector \ + -jar benchmarks-jmh/target/benchmarks-jmh-*.jar CompactorBenchmark \ + -p workloadMode=PARTITION_ONLY \ + -p datasetNames= \ + -p numPartitions=4 \ + -p splitDistribution=FIBONACCI \ + -p indexPrecision=FUSEDPQ \ + -wi 0 -i 1 -f 1 +``` + +**Step 2: compact only** (dataset not loaded; use a small heap to prove low-memory operation) + +```bash +java -Xmx5g --add-modules jdk.incubator.vector \ + -jar benchmarks-jmh/target/benchmarks-jmh-*.jar CompactorBenchmark \ + -p workloadMode=COMPACT_ONLY \ + -p datasetNames= \ + -p numPartitions=4 \ + -p splitDistribution=FIBONACCI \ + -p indexPrecision=FUSEDPQ \ + -wi 0 -i 1 -f 1 +``` + +`COMPACT_ONLY` skips dataset loading entirely, so `-Xmx5g` is sufficient even for large datasets. This lets you confirm that the compactor itself — not the dataset — is the memory bottleneck. + +Key `workloadMode` values: + +| Mode | Description | +|---|---| +| `PARTITION_AND_COMPACT` | **(default)** Build partitions, compact them, then measure recall | +| `PARTITION_ONLY` | Build N partition indexes and exit; use before `COMPACT_ONLY` | +| `COMPACT_ONLY` | Compact existing partitions without loading the dataset; `durationMs` = `compact()` time | +| `BUILD_FROM_SCRATCH` | Build one index over the full dataset; `durationMs` = `build()` time | + +Results are written as JSONL to `target/benchmark-results/compactor-*/compactor-results.jsonl`. The `durationMs` field records only the target function time (not dataset loading or JVM startup). + +## Recall + + +Recall comparison (results averaged over three runs): + +- Build from scratch: build one index over the full dataset with PQ scoring; search using FusedPQ with FP reranking. +- Compaction: partition the dataset into 4 source indexes (Fibonacci distribution), build each with PQ scoring, then compact into one index; search using FusedPQ with FP reranking. + +| Dataset | Dim | Build from Scratch | Compaction | Delta | +|----------------------|-----:|-------------------:|-----------:|-------:| +| cap-6M | 768 | 0.626 | 0.619 | -0.008 | +| cap-1M | 768 | 0.656 | 0.656 | 0.000 | +| gecko-100k | 768 | 0.690 | 0.701 | +0.011 | +| e5-small-v2-100k | 384 | 0.572 | 0.586 | +0.014 | +| ada002-1M | 1536 | 0.687 | 0.703 | +0.016 | +| e5-base-v2-100k | 768 | 0.676 | 0.692 | +0.016 | +| cohere-english-v3-10M | 1024 | 0.544 | 0.561 | +0.017 | +| e5-large-v2-100k | 1024 | 0.686 | 0.703 | +0.017 | +| ada002-100k | 1536 | 0.751 | 0.769 | +0.018 | +| cohere-english-v3-1M | 1024 | 0.593 | 0.612 | +0.019 | + +# Memory footprint + +All datasets above can be compacted under `COMPACT_ONLY` with `-Xmx5g`. In addition, compaction successfully scales to a dataset with 2560 dimensions and 10M vectors under the same memory constraint. + diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 9e366676c..8135bba25 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -25,10 +25,7 @@ import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; -import io.github.jbellis.jvector.util.Bits; -import io.github.jbellis.jvector.util.ExceptionUtils; -import io.github.jbellis.jvector.util.ExplicitThreadLocal; -import io.github.jbellis.jvector.util.PhysicalCoreExecutor; +import io.github.jbellis.jvector.util.*; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.slf4j.Logger; @@ -57,7 +54,7 @@ * Under most conditions this is not something you need to worry about, but it does mean * that spawning a new Thread per call is not advisable. This includes virtual threads. */ -public class GraphIndexBuilder implements Closeable { +public class GraphIndexBuilder implements Closeable, Accountable { private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class); private final int beamWidth; @@ -848,6 +845,29 @@ public void close() throws IOException { } } + @Override + public long ramBytesUsed() { + int OH = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; + int REF = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + // Shallow size of this object: header + all fields + // Primitive fields: beamWidth(int), dimension(int), neighborOverflow(float), + // alpha(float), addHierarchy(boolean), refineFinalGraph(boolean) + // Reference fields: naturalScratch, concurrentScratch, graph, insertionsInProgress, + // scoreProvider, simdExecutor, parallelExecutor, searchers, rng + long size = OH + 9L * REF + Integer.BYTES * 2 + Float.BYTES * 2 + 2; + + // The graph is the dominant memory consumer + size += graph.ramBytesUsed(); + + // insertionsInProgress: ConcurrentSkipListSet — typically small during measurement, + // but account for object overhead plus per-entry cost + long inProgressEntrySize = OH + 2L * REF + Integer.BYTES + Integer.BYTES; // NodeAtLevel + skip list node + size += OH + REF + (long) insertionsInProgress.size() * inProgressEntrySize; + + return size; + } + private static class ExcludingBits implements Bits { private final int excluded; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 73cc5fbd5..4dd491913 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -51,7 +51,7 @@ public class GraphSearcher implements Closeable { // Scratch data structures that are used in each {@link #searchInternal} call. These can be expensive // to allocate, so they're cleared and reused across calls. private final NodeQueue candidates; - final NodeQueue approximateResults; + public final NodeQueue approximateResults; private final NodeQueue rerankedResults; private final IntHashSet visited; private final NodesUnsorted evictedResults; @@ -307,7 +307,7 @@ public SearchResult search(SearchScoreProvider scoreProvider, return search(scoreProvider, topK, 0.0f, acceptOrds); } - void setEntryPointsFromPreviousLayer() { + public void setEntryPointsFromPreviousLayer() { // push the candidates seen so far back onto the queue for the next layer // at worst we save recomputing the similarity; at best we might connect to a more distant cluster approximateResults.foreach(candidates::push); @@ -316,7 +316,7 @@ void setEntryPointsFromPreviousLayer() { approximateResults.clear(); } - void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bits rawAcceptOrds) { + public void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bits rawAcceptOrds) { // save search parameters for potential later resume initializeScoreProvider(scoreProvider); this.acceptOrds = Bits.intersectionOf(rawAcceptOrds, view.liveNodes()); @@ -384,7 +384,7 @@ private boolean stopSearch(NodeQueue localCandidates, ScoreTracker scoreTracker, // incorrect and is discarded, and there is no reason to pass a rerankFloor parameter to resume(). // // Finally: resume() also drives the use of CachingReranker. - void searchOneLayer(SearchScoreProvider scoreProvider, + public void searchOneLayer(SearchScoreProvider scoreProvider, int rerankK, float threshold, int level, diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CompactWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CompactWriter.java new file mode 100644 index 000000000..5a3091686 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CompactWriter.java @@ -0,0 +1,241 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import io.github.jbellis.jvector.disk.BufferedRandomAccessWriter; +import io.github.jbellis.jvector.disk.RandomAccessWriter; +import io.github.jbellis.jvector.disk.ByteBufferIndexWriter; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.ByteSequence; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; + +import static io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexCompactor.SelectedVecCache; +import static io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexCompactor.WriteResult; + +final class CompactWriter implements AutoCloseable { + + private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + + private static final int FOOTER_MAGIC = 0x4a564244; + private static final int FOOTER_OFFSET_SIZE = Long.BYTES; + private static final int FOOTER_MAGIC_SIZE = Integer.BYTES; + private static final int FOOTER_SIZE = FOOTER_MAGIC_SIZE + FOOTER_OFFSET_SIZE; + + private final RandomAccessWriter writer; + private final int recordSize; + private final long startOffset; + private final int headerSize; + private final Header header; + private final int version; + private final FusedPQ fusedPQFeature; + private final ProductQuantization pq; + private final int baseDegree; + private final int maxOrdinal; + private final ThreadLocal bufferPerThread; + private final ThreadLocal> zeroPQ; + private final boolean fusedPQEnabled; + private final Path outputPath; + private final List configuredLayerInfo; + private final List configuredLayerDegrees; + private final List level1FeatureRecords; + + CompactWriter(Path outputPath, + int maxOrdinal, + int numBaseLayerNodes, + long startOffset, + List layerInfo, + int entryNode, + int dimension, + List layerDegrees, + ProductQuantization pq, + int pqLength, + boolean fusedPQEnabled) + throws IOException { + this.fusedPQEnabled = fusedPQEnabled; + this.version = OnDiskGraphIndex.CURRENT_VERSION; + this.outputPath = outputPath; + this.writer = new BufferedRandomAccessWriter(outputPath); + this.startOffset = startOffset; + this.configuredLayerInfo = new ArrayList<>(layerInfo); + this.configuredLayerDegrees = new ArrayList<>(layerDegrees); + this.baseDegree = layerDegrees.get(0); + this.pq = pq; + this.maxOrdinal = maxOrdinal; + this.level1FeatureRecords = new ArrayList<>(); + + Map featureMap = new LinkedHashMap<>(); + InlineVectors inlineVectorFeature = new InlineVectors(dimension); + featureMap.put(FeatureId.INLINE_VECTORS, inlineVectorFeature); + if (fusedPQEnabled) { + this.fusedPQFeature = new FusedPQ(Collections.max(layerDegrees), pq); + featureMap.put(FeatureId.FUSED_PQ, this.fusedPQFeature); + } else { + this.fusedPQFeature = null; + } + + int rsize = Integer.BYTES + inlineVectorFeature.featureSize() + Integer.BYTES + baseDegree * Integer.BYTES; + if (fusedPQEnabled) { + rsize += fusedPQFeature.featureSize(); + } + this.recordSize = rsize; + + this.configuredLayerInfo.set(0, new CommonHeader.LayerInfo(numBaseLayerNodes, baseDegree)); + var commonHeader = new CommonHeader(this.version, dimension, entryNode, this.configuredLayerInfo, this.maxOrdinal + 1); + this.header = new Header(commonHeader, featureMap); + this.headerSize = header.size(); + + this.bufferPerThread = ThreadLocal.withInitial(() -> { + ByteBuffer buffer = ByteBuffer.allocate(recordSize); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + this.zeroPQ = ThreadLocal.withInitial(() -> { + var vec = vectorTypeSupport.createByteSequence(pqLength > 0 ? pqLength : 1); + vec.zero(); + return vec; + }); + } + + public void writeHeader() throws IOException { + writer.seek(startOffset); + header.write(writer); + assert writer.position() == startOffset + headerSize : String.format("%d != %d", writer.position(), startOffset + headerSize); + writer.flush(); + } + + void writeFooter() throws IOException { + if (fusedPQEnabled && version == 6 && !level1FeatureRecords.isEmpty()) { + for (UpperLayerFeatureRecord record : level1FeatureRecords) { + writer.writeInt(record.ordinal); + vectorTypeSupport.writeByteSequence(writer, record.pqCode); + } + } + long headerOffset = writer.position(); + header.write(writer); + writer.writeLong(headerOffset); + writer.writeInt(FOOTER_MAGIC); + final long expectedPosition = headerOffset + headerSize + FOOTER_SIZE; + assert writer.position() == expectedPosition : String.format("%d != %d", writer.position(), expectedPosition); + } + + public void offsetAfterInline() throws IOException { + long offset = startOffset + headerSize + (long) (maxOrdinal + 1) * recordSize; + writer.seek(offset); + } + + public Path getOutputPath() { + return outputPath; + } + + public void writeUpperLayerNode(int level, int ordinal, int[] neighbors, ByteSequence level1PqCode) throws IOException { + writer.writeInt(ordinal); + writer.writeInt(neighbors.length); + int degree = configuredLayerDegrees.get(level); + int n = 0; + for (; n < neighbors.length; n++) { + writer.writeInt(neighbors[n]); + } + for (; n < degree; n++) { + writer.writeInt(-1); + } + if (fusedPQEnabled && version == 6 && level == 1 && level1PqCode != null) { + level1FeatureRecords.add(new UpperLayerFeatureRecord(ordinal, level1PqCode.copy())); + } + } + + public void close() throws IOException { + final var endOfGraphPosition = writer.position(); + writer.seek(endOfGraphPosition); + writer.flush(); + } + + public WriteResult writeInlineNodeRecord(int ordinal, VectorFloat vec, SelectedVecCache selectedCache, ByteSequence pqCode) throws IOException + { + var bwriter = new ByteBufferIndexWriter(bufferPerThread.get()); + + long fileOffset = startOffset + headerSize + (long) ordinal * recordSize; + bwriter.reset(); + bwriter.writeInt(ordinal); + + for(int i = 0; i < vec.length(); ++i) { + bwriter.writeFloat(vec.get(i)); + } + + // write fused PQ + // since we build a graph in a streaming way, + // we cannot use fusedPQfeature.writeInline + if (fusedPQEnabled) { + int k = 0; + for (; k < selectedCache.size; k++) { + pqCode.zero(); + pq.encodeTo(selectedCache.vecs[k], pqCode); + vectorTypeSupport.writeByteSequence(bwriter, pqCode); + } + for (; k < baseDegree; k++) { + vectorTypeSupport.writeByteSequence(bwriter, zeroPQ.get()); + } + } + + // write neighbors list + bwriter.writeInt(selectedCache.size); + int n = 0; + for (; n < selectedCache.size; n++) { + bwriter.writeInt(selectedCache.nodes[n]); + } + + // pad out to base layer degree + for (; n < baseDegree; n++) { + bwriter.writeInt(-1); + } + + if (bwriter.bytesWritten() != recordSize) { + throw new IllegalStateException( + String.format("Record size mismatch for ordinal %d: expected %d bytes, wrote %d bytes, base degree: %d", + ordinal, recordSize, bwriter.bytesWritten(), baseDegree)); + } + + ByteBuffer dataCopy = bwriter.cloneBuffer(); + + return new WriteResult(ordinal, fileOffset, dataCopy); + } + + static final class UpperLayerFeatureRecord { + final int ordinal; + final ByteSequence pqCode; + + UpperLayerFeatureRecord(int ordinal, ByteSequence pqCode) { + this.ordinal = ordinal; + this.pqCode = pqCode; + } + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java new file mode 100644 index 000000000..9200b2763 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java @@ -0,0 +1,1227 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; +import java.util.*; +import java.util.concurrent.*; +import java.util.stream.IntStream; +import io.github.jbellis.jvector.graph.*; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; +import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.util.*; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import io.github.jbellis.jvector.vector.types.ByteSequence; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static java.lang.Math.*; + +public final class OnDiskGraphIndexCompactor implements Accountable { + private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final Logger log = LoggerFactory.getLogger(OnDiskGraphIndexCompactor.class); + + // Compaction constants + private static final float DIVERSITY_ALPHA_STEP = 0.2f; + private static final int BEAM_WIDTH_MULTIPLIER = 2; + private static final int TARGET_BATCHES_PER_SOURCE = 40; + private static final int TARGET_NODES_PER_BATCH = 128; + private static final int MIN_SEARCH_TOP_K = 2; + private static final int SEARCH_TOP_K_MULTIPLIER = 2; + + private final List sources; + private final List liveNodes; + private final List numLiveNodesPerSource; + private final List remappers; + private final List maxDegrees; + + private final int dimension; + private int maxOrdinal = -1; + private int numTotalNodes = 0; + private boolean ownsExecutor = false; + private final ForkJoinPool executor; + private final int taskWindowSize; + private final VectorSimilarityFunction similarityFunction; + + /** + * Constructs a new OnDiskGraphIndexCompactor to merge multiple graph indexes. + * Initializes thread pool, validates inputs, and prepares metadata for compaction. + */ + public OnDiskGraphIndexCompactor( + List sources, + List liveNodes, + List remappers, + VectorSimilarityFunction similarityFunction, + ForkJoinPool executor) { + checkBeforeCompact(sources, liveNodes, remappers); + + int threads = Runtime.getRuntime().availableProcessors(); + if (executor != null) { + this.executor = executor; + } else { + this.executor = new ForkJoinPool(threads); + this.ownsExecutor = true; + } + this.taskWindowSize = threads; + + this.sources = sources; + this.remappers = remappers; + this.liveNodes = liveNodes; + this.numLiveNodesPerSource = new ArrayList<>(this.sources.size()); + for (int s = 0; s < this.sources.size(); s++) { + int numLiveNodes = this.liveNodes.get(s).cardinality(); + this.numTotalNodes += numLiveNodes; + this.numLiveNodesPerSource.add(numLiveNodes); + } + + maxDegrees = this.sources.stream() + .max(Comparator.comparingInt(s -> s.maxDegrees().size())) + .orElseThrow() + .maxDegrees(); + dimension = this.sources.get(0).getDimension(); + for (var mapper : remappers) { + maxOrdinal = max(mapper.maxOrdinal(), maxOrdinal); + } + this.similarityFunction = similarityFunction; + } + + /** + * Validates that all source indexes have compatible configurations and required features + * before attempting compaction. Ensures consistent dimensions, max degrees, hierarchical + * settings, and feature sets across all sources. + */ + private void checkBeforeCompact( + List sources, + List liveNodes, + List remappers) { + validateInputSizes(sources, liveNodes, remappers); + validateLiveNodesBounds(sources, liveNodes); + validateGraphConfiguration(sources); + validateFeatures(sources); + } + + /** + * Validates that input lists have consistent sizes and are non-null. + */ + private void validateInputSizes(List sources, + List liveNodes, + List remappers) { + if (sources.size() < 2) { + throw new IllegalArgumentException("Must have at least two sources"); + } + Objects.requireNonNull(liveNodes, "liveNodes"); + Objects.requireNonNull(remappers, "remappers"); + + if (sources.size() != liveNodes.size()) { + throw new IllegalArgumentException("sources and liveNodes must have the same size"); + } + if (sources.size() != remappers.size()) { + throw new IllegalArgumentException("sources and remappers must have the same size"); + } + } + + /** + * Validates that liveNodes bitsets match the size of their corresponding sources. + */ + private void validateLiveNodesBounds(List sources, List liveNodes) { + for (int s = 0; s < sources.size(); ++s) { + if (liveNodes.get(s).length() != sources.get(s).size(0)) { + throw new IllegalArgumentException("source " + s + " out of bounds"); + } + } + } + + /** + * Validates that all sources have consistent graph configuration (dimensions, degrees, hierarchy). + */ + private void validateGraphConfiguration(List sources) { + int dimension = sources.get(0).getDimension(); + var refDegrees = sources.stream() + .max(Comparator.comparingInt(s -> s.maxDegrees().size())) + .orElseThrow() + .maxDegrees(); + var addHierarchy = sources.get(0).isHierarchical(); + + for (OnDiskGraphIndex source : sources) { + if (source.getDimension() != dimension) { + throw new IllegalArgumentException("sources must have the same dimension"); + } + int sharedLevels = Math.min(refDegrees.size(), source.maxDegrees().size()); + for (int d = 0; d < sharedLevels; d++) { + if (!Objects.equals(source.maxDegrees().get(d), refDegrees.get(d))) { + throw new IllegalArgumentException("sources must have the same max degrees"); + } + } + if (addHierarchy != source.isHierarchical()) { + throw new IllegalArgumentException("sources must have the same hierarchical setting"); + } + } + } + + /** + * Validates that all sources have compatible features for compaction. + */ + private void validateFeatures(List sources) { + Set refKeys = sources.get(0).getFeatures().keySet(); + boolean sameFeatures = sources.stream() + .skip(1) + .map(s -> s.getFeatures().keySet()) + .allMatch(refKeys::equals); + + if (!sameFeatures) { + throw new IllegalArgumentException("Each source must have the same features"); + } + if (!refKeys.contains(FeatureId.INLINE_VECTORS)) { + throw new IllegalArgumentException("Each source must have the INLINE_VECTORS feature"); + } + } + + /** + * Main compaction entry point. Merges all source indexes into a single output index at the + * specified path, handling PQ retraining if needed, and writing header, all layers, and footer. + */ + public void compact(Path outputPath) throws FileNotFoundException { + boolean fusedPQEnabled = hasFusedPQ(); + boolean compressedPrecision = fusedPQEnabled; + + ProductQuantization pq; + int pqLength; + if (fusedPQEnabled) { + pq = resolvePQFromSources(similarityFunction); + pqLength = pq.compressedVectorSize(); + } else { + pq = null; + pqLength = -1; + } + + List layerInfo = computeLayerInfoFromSources(); + int entryNode = resolveEntryNode(); + + log.info("Writing compacted graph : {} total nodes, maxOrdinal={}, dimension={}, degree={}", + numTotalNodes, maxOrdinal, dimension, maxDegrees.get(0)); + try (CompactWriter writer = new CompactWriter(outputPath, maxOrdinal, numTotalNodes, 0, layerInfo, entryNode, dimension, maxDegrees, pq, pqLength, fusedPQEnabled)) { + writer.writeHeader(); + compactLevels(writer, similarityFunction, fusedPQEnabled, compressedPrecision, pq); + writer.writeFooter(); + log.info("Compaction complete: {}", outputPath); + } catch (IOException | ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } finally { + if (ownsExecutor) executor.shutdown(); + } + } + + /** + * Resolves the entry node for the compacted graph. The chosen node must exist at maxLevel + * (since the on-disk format sets entryNode.level = maxLevel). Prefers the designated entry + * node of any source whose maxLevel equals the global maxLevel; if all such entry nodes + * are deleted, falls back to the first live node at maxLevel across all sources. + */ + private int resolveEntryNode() { + int maxLevel = sources.stream().mapToInt(OnDiskGraphIndex::getMaxLevel).max().orElse(0); + + // The on-disk format sets entryNode.level = layerInfo.size() - 1 (i.e. maxLevel). + // So the chosen node must actually have neighbors written at maxLevel — meaning it + // must exist at maxLevel in its source. Prefer the designated entry node of a + // maxLevel source; fall back to any live node that is at maxLevel. + for (int s = 0; s < sources.size(); s++) { + if (sources.get(s).getMaxLevel() == maxLevel) { + int originalEntry = sources.get(s).getView().entryNode().node; + if (liveNodes.get(s).get(originalEntry)) { + return remappers.get(s).oldToNew(originalEntry); + } + } + } + + // Entry nodes were all deleted: scan for any live node that exists at maxLevel. + for (int s = 0; s < sources.size(); s++) { + if (sources.get(s).getMaxLevel() < maxLevel) continue; + NodesIterator it = sources.get(s).getNodes(maxLevel); + while (it.hasNext()) { + int node = it.next(); + if (liveNodes.get(s).get(node)) { + return remappers.get(s).oldToNew(node); + } + } + } + + throw new IllegalStateException("No live nodes found at maxLevel=" + maxLevel); + } + + /** + * Compacts all hierarchical levels of the graph, processing each level in batches. + * For level 0 (base layer), writes inline vectors and neighbors. For upper layers, + * writes only graph structure and optional PQ codes. + */ + private void compactLevels(CompactWriter writer, + VectorSimilarityFunction similarityFunction, + boolean fusedPQEnabled, + boolean compressedPrecision, + ProductQuantization pq) + throws IOException, ExecutionException, InterruptedException { + + int maxUpperDegree = 0; + for (int level = 1; level < maxDegrees.size(); level++) { + maxUpperDegree = Math.max(maxUpperDegree, maxDegrees.get(level)); + } + + int baseSearchTopK = Math.max(MIN_SEARCH_TOP_K, ((maxDegrees.get(0) + sources.size() - 1) / sources.size()) * SEARCH_TOP_K_MULTIPLIER); + int baseMaxCandidateSize = baseSearchTopK * (sources.size() - 1) + maxDegrees.get(0); + int upperMaxPerSourceTopK = maxUpperDegree == 0 ? 0 : Math.max(MIN_SEARCH_TOP_K, ((maxUpperDegree + sources.size() - 1) / sources.size()) * SEARCH_TOP_K_MULTIPLIER); + int upperMaxCandidateSize = upperMaxPerSourceTopK * sources.size(); + int maxCandidateSize = Math.max(baseMaxCandidateSize, upperMaxCandidateSize); + int scratchDegree = Math.max(maxDegrees.get(0), Math.max(1, maxUpperDegree)); + final ThreadLocal threadLocalScratch = ThreadLocal.withInitial(() -> + new Scratch(maxCandidateSize, scratchDegree, dimension, sources, pq) + ); + + for (int level = 0; level < maxDegrees.size(); level++) { + List batches = buildBatches(level); + int searchTopK = Math.max(MIN_SEARCH_TOP_K, ((maxDegrees.get(level) + sources.size() - 1) / sources.size()) * SEARCH_TOP_K_MULTIPLIER); + int beamWidth = Math.max(maxDegrees.get(level), searchTopK) * BEAM_WIDTH_MULTIPLIER; + + CompactionParams params = new CompactionParams(fusedPQEnabled, compressedPrecision, searchTopK, beamWidth, pq); + + if (level == 0) { + log.info("Compacting level 0 (base layer)"); + + ExecutorCompletionService> ecs = + new ExecutorCompletionService<>(executor); + + java.util.function.Consumer submitOne = (bs) -> { + ecs.submit(() -> { + Scratch scratch = threadLocalScratch.get(); + return computeBaseBatch(writer, bs, scratch, params); + }); + }; + + var wropts = EnumSet.of(StandardOpenOption.WRITE, StandardOpenOption.READ); + try (FileChannel fc = FileChannel.open(writer.getOutputPath(), wropts)) { + + runBatchesWithBackpressure( + batches, + ecs, + submitOne, + (results) -> { + try { + for (WriteResult r : results) { + ByteBuffer b = r.data; + long pos = r.fileOffset; + while (b.hasRemaining()) { + int n = fc.write(b, pos); + pos += n; + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + ); + } + + writer.offsetAfterInline(); + + } else { + final int lvl = level; + log.info("Compacting upper layer {}", level); + + ExecutorCompletionService> ecs = + new ExecutorCompletionService<>(executor); + + java.util.function.Consumer submitOne = (bs) -> { + ecs.submit(() -> { + Scratch scratch = threadLocalScratch.get(); + return computeUpperBatchForLevel(bs, lvl, scratch, params); + }); + }; + + runBatchesWithBackpressure( + batches, + ecs, + submitOne, + (results) -> { + try { + for (UpperLayerWriteResult r : results) { + writer.writeUpperLayerNode( + lvl, + r.ordinal, + r.neighbors, + r.pqCode + ); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + ); + } + } + + Scratch s = threadLocalScratch.get(); + s.close(); + threadLocalScratch.remove(); + } + + /** + * Divides nodes at a given level across all source indexes into processing batches + * for parallel execution. Each batch contains a subset of nodes from one source. + */ + private List buildBatches(int level) { + List batches = new ArrayList<>(); + + for (int s = 0; s < sources.size(); ++s) { + var source = sources.get(s); + if (level > source.getMaxLevel()) continue; + NodesIterator sourceNodes = source.getNodes(level); + int numNodes = sourceNodes.size(); + int[] nodes = new int[numNodes]; + int i = 0; + while (sourceNodes.hasNext()) { + nodes[i++] = sourceNodes.next(); + } + + int numBatches = max(TARGET_BATCHES_PER_SOURCE, (numNodes + TARGET_NODES_PER_BATCH - 1) / TARGET_NODES_PER_BATCH); + if (numBatches > numNodes) numBatches = numNodes; + int batchSize = (numNodes + numBatches - 1) / numBatches; + for (int b = 0; b < numBatches; ++b) { + int start = min(numNodes, batchSize * b); + int end = min(numNodes, batchSize * (b + 1)); + batches.add(new BatchSpec(s, nodes, start, end)); + } + } + + return batches; + } + + /** + * Processes a batch of base layer (level 0) nodes from one source index. For each live node, + * gathers candidates from all sources, applies diversity selection, and creates write results + * containing the full node record data. + */ + private List computeBaseBatch(CompactWriter writer, + BatchSpec bs, + Scratch scratch, + CompactionParams params) throws IOException { + + List out = new ArrayList<>(bs.end - bs.start); + + for (int i = bs.start; i < bs.end; i++) { + int node = bs.nodes[i]; + if (!liveNodes.get(bs.sourceIdx).get(node)) continue; + + out.add(processBaseNode(node, bs.sourceIdx, scratch, writer, params)); + } + + return out; + } + + /** + * Processes a batch of upper layer nodes from one source index. Similar to base layer + * processing but returns only ordinal, neighbors, and optional PQ code (no inline vectors). + */ + private List computeUpperBatchForLevel( + BatchSpec bs, + int level, + Scratch scratch, + CompactionParams params + ) { + List results = + new ArrayList<>(bs.end - bs.start); + + for (int i = bs.start; i < bs.end; i++) { + int node = bs.nodes[i]; + + if (!liveNodes.get(bs.sourceIdx).get(node)) continue; + + results.add(processUpperNode(node, bs.sourceIdx, level, scratch, params)); + } + + return results; + } + + /** + * Processes a single base layer node: retrieves its vector, gathers diverse candidates from + * all sources, selects best neighbors using diversity criteria, remaps ordinals, and returns + * the complete write result for this node. + */ + private WriteResult processBaseNode( + int node, + int sourceIdx, + Scratch scratch, + CompactWriter writer, + CompactionParams params + ) throws IOException { + + var sourceView = (OnDiskGraphIndex.View) scratch.gs[sourceIdx].getView(); + sourceView.getVectorInto(node, scratch.baseVec, 0); + + int candSize = gatherCandidates(node, 0, sourceIdx, scratch, scratch.baseVec, params); + + int[] order = IntStream.range(0, candSize).toArray(); + sortOrderByScoreDesc(order, scratch.candScore, candSize); + + var selected = scratch.selectedCache; + + new CompactVamanaDiversityProvider(similarityFunction, 1.2f) + .retainDiverse( + scratch.candSrc, + scratch.candNode, + scratch.candScore, + order, + candSize, + maxDegrees.get(0), + selected, + scratch.tmpVec, + scratch.gs + ); + + // remap + for (int k = 0; k < selected.size; k++) { + selected.nodes[k] = + remappers.get(selected.sourceIdx[k]) + .oldToNew(selected.nodes[k]); + } + + int newOrdinal = remappers.get(sourceIdx).oldToNew(node); + + return writer.writeInlineNodeRecord( + newOrdinal, + scratch.baseVec, + selected, + scratch.pqCode + ); + } + + /** + * Processes a single upper layer node: similar to base layer processing but only returns + * graph structure (ordinal and neighbors) and optional PQ encoding for level 1. + */ + private UpperLayerWriteResult processUpperNode( + int node, + int sourceIdx, + int level, + Scratch scratch, + CompactionParams params + ) { + var sourceView = (OnDiskGraphIndex.View) scratch.gs[sourceIdx].getView(); + sourceView.getVectorInto(node, scratch.baseVec, 0); + + int candSize = gatherCandidates(node, level, sourceIdx, scratch, scratch.baseVec, params); + + int[] order = IntStream.range(0, candSize).toArray(); + sortOrderByScoreDesc(order, scratch.candScore, candSize); + + var selected = scratch.selectedCache; + + new CompactVamanaDiversityProvider(similarityFunction, 1.2f) + .retainDiverse( + scratch.candSrc, + scratch.candNode, + scratch.candScore, + order, + candSize, + maxDegrees.get(level), + selected, + scratch.tmpVec, + scratch.gs + ); + + // remap + for (int k = 0; k < selected.size; k++) { + selected.nodes[k] = + remappers.get(selected.sourceIdx[k]) + .oldToNew(selected.nodes[k]); + } + + int newOrdinal = remappers.get(sourceIdx).oldToNew(node); + + ByteSequence pqCode = maybeEncodePQ(level, scratch, params); + + return new UpperLayerWriteResult(newOrdinal, selected, pqCode); + } + + /** + * Encodes a vector using Product Quantization if enabled and the level is 1. + * Returns null otherwise. + */ + private ByteSequence maybeEncodePQ(int level, Scratch scratch, CompactionParams params) { + if (!params.fusedPQEnabled || level != 1) { + return null; + } + + scratch.pqCode.zero(); + params.pq.encodeTo(scratch.baseVec, scratch.pqCode); + return scratch.pqCode.copy(); + } + + /** + * Collects neighbor candidates for a node from all source indexes. For the source containing + * the node, uses existing neighbors; for other sources, performs graph search. Returns the + * total number of candidates gathered. + */ + private int gatherCandidates( + int node, + int level, + int sourceIdx, + Scratch scratch, + VectorFloat baseVec, + CompactionParams params + ) { + int candSize = 0; + + for (int ss = 0; ss < sources.size(); ss++) { + var searchView = (OnDiskGraphIndex.View) scratch.gs[ss].getView(); + var indexAlive = liveNodes.get(ss); + + if (ss == sourceIdx) { + candSize = gatherFromSameSource(node, level, ss, searchView, indexAlive, + baseVec, scratch, candSize); + } else { + candSize = gatherFromOtherSource(node, level, ss, searchView, indexAlive, + baseVec, scratch, candSize, params); + } + } + + return candSize; + } + + /** + * Gathers candidates from the same source index that contains the node. + * Simply iterates through existing neighbors. + */ + private int gatherFromSameSource(int node, int level, int sourceIdx, + OnDiskGraphIndex.View searchView, FixedBitSet indexAlive, + VectorFloat baseVec, Scratch scratch, int candSize) { + var it = searchView.getNeighborsIterator(level, node); + while (it.hasNext()) { + int nb = it.nextInt(); + if (!indexAlive.get(nb)) continue; + + searchView.getVectorInto(nb, scratch.tmpVec, 0); + + scratch.candSrc[candSize] = sourceIdx; + scratch.candNode[candSize] = nb; + scratch.candScore[candSize] = similarityFunction.compare(baseVec, scratch.tmpVec); + candSize++; + } + return candSize; + } + + /** + * Gathers candidates from a different source index via graph search. + */ + private int gatherFromOtherSource(int node, int level, int sourceIdx, + OnDiskGraphIndex.View searchView, FixedBitSet indexAlive, + VectorFloat baseVec, Scratch scratch, int candSize, + CompactionParams params) { + SearchScoreProvider ssp = buildCrossSourceScoreProvider( + params.compressedPrecision, + sources.get(sourceIdx), + searchView, + baseVec, + scratch.tmpVec, + similarityFunction + ); + + if (level == 0) { + SearchResult results = scratch.gs[sourceIdx].search( + ssp, params.searchTopK, params.beamWidth, 0f, 0f, indexAlive + ); + + for (var r : results.getNodes()) { + scratch.candSrc[candSize] = sourceIdx; + scratch.candNode[candSize] = r.node; + scratch.candScore[candSize] = + params.fusedPQEnabled + ? rescore(searchView, r.node, baseVec, scratch.tmpVec) + : r.score; + candSize++; + } + } else { + var entry = searchView.entryNode(); + if (level > entry.level) return candSize; + scratch.gs[sourceIdx].initializeInternal(ssp, entry, Bits.ALL); + + // Descend greedily through levels above the target level, so the search at + // `level` starts from the best-known region rather than the global entry node. + // This mirrors how GraphSearcher.searchInternal navigates the hierarchy. + for (int l = entry.level; l > level; l--) { + scratch.gs[sourceIdx].searchOneLayer(ssp, 1, 0f, l, Bits.ALL); + scratch.gs[sourceIdx].setEntryPointsFromPreviousLayer(); + } + + scratch.gs[sourceIdx].searchOneLayer( + ssp, params.searchTopK, 0f, level, indexAlive + ); + + int prev_candSize = candSize; + candSize = appendApproximateResults( + scratch.gs[sourceIdx].approximateResults, + sourceIdx, + scratch, + candSize + ); + + if (params.fusedPQEnabled) { + for (int i = prev_candSize; i < candSize; i++) { + scratch.candScore[i] = rescore( + searchView, + scratch.candNode[i], + baseVec, + scratch.tmpVec + ); + } + } + } + + return candSize; + } + + /** + * Recomputes exact similarity score between the base vector and a node's vector, + * used to refine approximate PQ-based search results. + */ + private float rescore(OnDiskGraphIndex.View view, + int node, + VectorFloat base, + VectorFloat tmp) { + view.getVectorInto(node, tmp, 0); + return similarityFunction.compare(base, tmp); + } + + /** + * Executes batches with controlled concurrency using a sliding window approach. Prevents + * overwhelming memory by limiting the number of in-flight tasks while maintaining high + * throughput via the completion service. + */ + private void runBatchesWithBackpressure( + List batches, + ExecutorCompletionService> ecs, + java.util.function.Consumer submitOne, + java.util.function.Consumer> onComplete + ) throws InterruptedException, ExecutionException { + + final int total = batches.size(); + int nextToSubmit = 0; + int inFlight = 0; + + // initial window + while (inFlight < taskWindowSize && nextToSubmit < total) { + submitOne.accept(batches.get(nextToSubmit++)); + inFlight++; + } + + int completed = 0; + while (completed < total) { + List results = ecs.take().get(); + onComplete.accept(results); + + completed++; + inFlight--; + + if (nextToSubmit < total) { + submitOne.accept(batches.get(nextToSubmit++)); + inFlight++; + } + if (completed % 10 == 0) { + log.info("Compaction I/O progress: {}/{} batches written to disk", completed, total); + } + } + } + + /** + * Appends search results from a NodeQueue to the candidate arrays, returning the updated + * candidate count. + */ + private int appendApproximateResults(NodeQueue queue, + int sourceIdx, + Scratch scratch, + int candSize) { + final int ss = sourceIdx; + final int[] idx = new int[] { candSize }; + + queue.foreach((nb, score) -> { + scratch.candSrc[idx[0]] = ss; + scratch.candNode[idx[0]] = nb; + scratch.candScore[idx[0]] = score; + idx[0]++; + }); + + return idx[0]; + } + + /** + * Computes layer metadata for the compacted graph by counting live nodes at each level + * across all source indexes. + */ + private List computeLayerInfoFromSources() { + int maxLevel = sources.stream().mapToInt(OnDiskGraphIndex::getMaxLevel).max().orElse(0); + List layerInfo = new ArrayList<>(maxLevel + 1); + for (int level = 0; level <= maxLevel; level++) { + int count = 0; + for (int s = 0; s < sources.size(); s++) { + if (level > sources.get(s).getMaxLevel()) continue; + NodesIterator it = sources.get(s).getNodes(level); + FixedBitSet alive = liveNodes.get(s); + while (it.hasNext()) { + int node = it.next(); + if (alive.get(node)) count++; + } + } + layerInfo.add(new CommonHeader.LayerInfo(count, maxDegrees.get(level))); + } + return layerInfo; + } + + /** + * Trains a new Product Quantization codebook using balanced sampling across all source + * indexes. This ensures the PQ is optimized for the combined dataset. + */ + private ProductQuantization resolvePQFromSources(VectorSimilarityFunction similarityFunction) { + PQRetrainer retrainer = new PQRetrainer(sources, liveNodes, dimension); + return retrainer.retrain(similarityFunction); + } + + /** + * Checks if the source indexes have FusedPQ feature enabled. + */ + private boolean hasFusedPQ() { + return sources.get(0).getFeatures().containsKey(FeatureId.FUSED_PQ); + } + + /** + * Creates a score provider for searching across different source indexes. Uses approximate + * PQ-based scoring if compressedPrecision is enabled, otherwise uses exact scoring. + */ + private SearchScoreProvider buildCrossSourceScoreProvider(boolean compressedPrecision, + OnDiskGraphIndex searchSource, + OnDiskGraphIndex.View searchView, + VectorFloat baseVec, + VectorFloat tmpVec, + VectorSimilarityFunction similarityFunction) { + if (compressedPrecision) { + ScoreFunction.ExactScoreFunction reranker = + node2 -> { + searchView.getVectorInto(node2, tmpVec, 0); + return similarityFunction.compare(baseVec, tmpVec); + }; + var asf = ((FusedPQ) searchSource.getFeatures().get(FeatureId.FUSED_PQ)).approximateScoreFunctionFor(baseVec, similarityFunction, searchView, reranker); + + return new DefaultSearchScoreProvider(asf); + } + + var sf = new ScoreFunction.ExactScoreFunction() { + @Override + public float similarityTo(int node2) { + searchView.getVectorInto(node2, tmpVec, 0); + return similarityFunction.compare(baseVec, tmpVec); + } + }; + return new DefaultSearchScoreProvider(sf); + } + + /** + * Estimates the RAM usage of this compactor instance. + * Accounts for data structures used during compaction including bitsets, remappers, + * executor overhead, and per-thread scratch space. + */ + @Override + public long ramBytesUsed() { + int OH = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; + int REF = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + // Shallow size of this object (header + fields) + // Current fields: sources, liveNodes, numLiveNodesPerSource, remappers, maxDegrees, + // dimension(int), maxOrdinal(int), numTotalNodes(int), + // ownsExecutor(boolean), executor, taskWindowSize(int), similarityFunction + long size = OH + 8L * REF + Integer.BYTES * 4 + 1; + + // liveNodes: FixedBitSet per source + for (var entry : liveNodes) { + size += entry.ramBytesUsed(); + } + + // numLiveNodesPerSource: ArrayList of Integers + size += OH + REF + (long) numLiveNodesPerSource.size() * (OH + Integer.BYTES); + + // remappers: each MapMapper holds an oldToNew HashMap and newToOld Int2IntHashMap + // Estimate based on the number of mappings + for (var mapper : remappers) { + // Object overhead + two maps with int key/value pairs + // HashMap entry: ~32 bytes each; Int2IntHashMap: ~16 bytes per entry + if (mapper instanceof OrdinalMapper.MapMapper) { + // rough estimate: the mapper stores two maps over all mapped ordinals + size += OH + (long) (maxOrdinal + 1) * 48; + } + } + + // maxDegrees: small list of integers + size += OH + REF + (long) maxDegrees.size() * (OH + Integer.BYTES); + + // executor: ForkJoinPool overhead (if owned) + // Estimate based on number of threads + int numThreads = ownsExecutor ? Runtime.getRuntime().availableProcessors() : taskWindowSize; + if (ownsExecutor) { + size += OH + REF; + } + + // Scratch space: ThreadLocal instances (one per active thread) + // Each Scratch contains: + // - candSrc, candNode, candScore arrays + // - SelectedVecCache (with its own arrays and vector copies) + // - tmpVec, baseVec (VectorFloat instances) + // - GraphSearcher array (one per source) + // - pqCode ByteSequence + size += estimateScratchSpacePerThread() * numThreads; + + return size; + } + + /** + * Estimates the RAM usage of a single Scratch instance. + */ + private long estimateScratchSpacePerThread() { + int OH = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; + int REF = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + // Calculate maxCandidateSize and maxDegree (same logic as in compactLevels) + int maxUpperDegree = 0; + for (int level = 1; level < maxDegrees.size(); level++) { + maxUpperDegree = Math.max(maxUpperDegree, maxDegrees.get(level)); + } + int baseSearchTopK = Math.max(MIN_SEARCH_TOP_K, ((maxDegrees.get(0) + sources.size() - 1) / sources.size()) * SEARCH_TOP_K_MULTIPLIER); + int baseMaxCandidateSize = baseSearchTopK * (sources.size() - 1) + maxDegrees.get(0); + int upperMaxPerSourceTopK = maxUpperDegree == 0 ? 0 : Math.max(MIN_SEARCH_TOP_K, ((maxUpperDegree + sources.size() - 1) / sources.size()) * SEARCH_TOP_K_MULTIPLIER); + int upperMaxCandidateSize = upperMaxPerSourceTopK * sources.size(); + int maxCandidateSize = Math.max(baseMaxCandidateSize, upperMaxCandidateSize); + int scratchDegree = Math.max(maxDegrees.get(0), Math.max(1, maxUpperDegree)); + + long scratchSize = OH + 6L * REF; + + // candSrc, candNode, candScore arrays + scratchSize += (long) maxCandidateSize * Integer.BYTES; // candSrc + scratchSize += (long) maxCandidateSize * Integer.BYTES; // candNode + scratchSize += (long) maxCandidateSize * Float.BYTES; // candScore + + // SelectedVecCache + scratchSize += OH + 5L * REF + Integer.BYTES; // SelectedVecCache object + scratchSize += (long) scratchDegree * Integer.BYTES; // sourceIdx array + scratchSize += (long) scratchDegree * REF; // views array + scratchSize += (long) scratchDegree * Integer.BYTES; // nodes array + scratchSize += (long) scratchDegree * Float.BYTES; // scores array + scratchSize += (long) scratchDegree * REF; // vecs array + scratchSize += (long) scratchDegree * (OH + dimension * Float.BYTES); // VectorFloat instances + + // tmpVec and baseVec + scratchSize += 2L * (OH + dimension * Float.BYTES); + + // GraphSearcher array (one per source) + scratchSize += (long) sources.size() * REF; + // Each GraphSearcher has internal state - rough estimate + scratchSize += (long) sources.size() * (OH + 10L * REF); + + // pqCode ByteSequence (if PQ enabled) + if (hasFusedPQ()) { + FusedPQ fpq = (FusedPQ) sources.get(0).getFeatures().get(FeatureId.FUSED_PQ); + int subspaceCount = fpq.getPQ().getSubspaceCount(); + scratchSize += OH + subspaceCount; // ByteSequence + } + + return scratchSize; + } + + /** + * Encapsulates common parameters used throughout the compaction process. + */ + private static final class CompactionParams { + final boolean fusedPQEnabled; + final boolean compressedPrecision; + final int searchTopK; + final int beamWidth; + final ProductQuantization pq; + + CompactionParams(boolean fusedPQEnabled, boolean compressedPrecision, + int searchTopK, int beamWidth, ProductQuantization pq) { + this.fusedPQEnabled = fusedPQEnabled; + this.compressedPrecision = compressedPrecision; + this.searchTopK = searchTopK; + this.beamWidth = beamWidth; + this.pq = pq; + } + } + + /** + * Sorts an index array by descending score values using quicksort. + */ + private static void sortOrderByScoreDesc(int[] order, float[] score, int size) { + quicksort(order, score, 0, size - 1); + } + + /** + * Tail-recursive quicksort implementation for sorting by score in descending order. + */ + private static void quicksort(int[] order, float[] score, int lo, int hi) { + while (lo < hi) { + int p = partition(order, score, lo, hi); + // recurse smaller side first (limits stack) + if (p - lo < hi - p) { + quicksort(order, score, lo, p - 1); + lo = p + 1; + } else { + quicksort(order, score, p + 1, hi); + hi = p - 1; + } + } + } + + /** + * Partitions the order array for quicksort using descending score comparison. + */ + private static int partition(int[] order, float[] score, int lo, int hi) { + float pivot = score[order[hi]]; + int i = lo; + for (int j = lo; j < hi; j++) { + if (score[order[j]] > pivot) { // DESC + int t = order[i]; + order[i] = order[j]; + order[j] = t; + i++; + } + } + int t = order[i]; + order[i] = order[hi]; + order[hi] = t; + return i; + } + + static final class WriteResult { + final int newOrdinal; + final long fileOffset; + final ByteBuffer data; + + WriteResult(int newOrdinal, long fileOffset, ByteBuffer data) { + this.newOrdinal = newOrdinal; + this.fileOffset = fileOffset; + this.data = data; + } + }; + + private static final class UpperLayerWriteResult { + final int ordinal; + final int[] neighbors; + final ByteSequence pqCode; + + UpperLayerWriteResult(int ordinal, SelectedVecCache cache, ByteSequence pqCode) { + this.ordinal = ordinal; + this.neighbors = Arrays.copyOf(cache.nodes, cache.size); + this.pqCode = pqCode == null ? null : pqCode.copy(); + } + }; + + + /** + * Thread-local scratch space containing reusable buffers and search state for processing nodes. + */ + private static final class Scratch implements AutoCloseable { + + final int[] candSrc, candNode; + final float[] candScore; + final SelectedVecCache selectedCache; + final VectorFloat tmpVec, baseVec; + final GraphSearcher[] gs; + final ByteSequence pqCode; + + /** + * Constructs scratch space with buffers sized for the maximum expected candidates and degree. + */ + Scratch(int maxCandidateSize, int maxDegree, int dimension, List sources, ProductQuantization pq) { + this.candSrc = new int[maxCandidateSize]; + this.candNode = new int[maxCandidateSize]; + this.candScore = new float[maxCandidateSize]; + this.selectedCache = new SelectedVecCache(maxDegree, dimension); + this.tmpVec = vectorTypeSupport.createFloatVector(dimension); + this.baseVec = vectorTypeSupport.createFloatVector(dimension); + this.pqCode = (pq == null) ? null : vectorTypeSupport.createByteSequence(pq.getSubspaceCount()); + + this.gs = new GraphSearcher[sources.size()]; + for (int i = 0; i < sources.size(); i++) { + gs[i] = new GraphSearcher(sources.get(i)); + gs[i].usePruning(false); + } + } + + /** + * Closes all graph searchers and resets the cache. + */ + @Override + public void close() throws IOException { + for (var s : gs) s.close(); + selectedCache.reset(); + } + } + + /** + * Specification for a batch of nodes to be processed from one source index. + */ + private static final class BatchSpec { + final int sourceIdx; + final int[] nodes; // materialized node ids for this source + final int start; + final int end; + + BatchSpec(int sourceIdx, int[] nodes, int start, int end) { + this.sourceIdx = sourceIdx; + this.nodes = nodes; + this.start = start; + this.end = end; + } + } + + /** + * Provides Vamana-style diversity filtering for neighbor selection during compaction. + */ + private static final class CompactVamanaDiversityProvider { + /** + * the diversity threshold; 1.0 is equivalent to HNSW; Vamana uses 1.2 or more + */ + public final float alpha; + + /** + * used to compute diversity + */ + public final VectorSimilarityFunction vsf; + + /** + * Create a new diversity provider + */ + public CompactVamanaDiversityProvider(VectorSimilarityFunction vsf, float alpha) { + this.vsf = vsf; + this.alpha = alpha; + } + + /** + * Selects diverse neighbors from candidates using gradually increasing alpha threshold. + * Update `selected` with the diverse members of `neighbors`. `neighbors` is not modified + * It assumes that the i-th neighbor with 0 {@literal <=} i {@literal <} diverseBefore is already diverse. + */ + public void retainDiverse(int[] candSrc, int[] candNode, float[] candScore, int[] order, int orderSize, int maxDegree, SelectedVecCache selectedCache, VectorFloat tmp, GraphSearcher[] gs) { + selectedCache.reset(); + if (orderSize == 0) return; + int nSelected = 0; + + // add diverse candidates, gradually increasing alpha to the threshold + // (so that the nearest candidates are prioritized) + float currentAlpha = 1.0f; + while (currentAlpha <= alpha + 1E-6 && nSelected < maxDegree) { + for (int i = 0; i < orderSize && nSelected < maxDegree; i++) { + int ci = order[i]; + int cSrc = candSrc[ci]; + int cNode = candNode[ci]; + float cScore = candScore[ci]; + + OnDiskGraphIndex.View cView = (OnDiskGraphIndex.View) gs[cSrc].getView(); + cView.getVectorInto(cNode, tmp, 0); + if (isDiverse(cView, cNode, tmp, cScore, currentAlpha, selectedCache)) { + selectedCache.add(cSrc, cView, cNode, cScore, tmp); + nSelected++; + } + } + + currentAlpha += DIVERSITY_ALPHA_STEP; + } + } + + /** + * Checks if a candidate is diverse enough by ensuring it's closer to the base node + * than to any already-selected neighbor (scaled by alpha threshold). + */ + private boolean isDiverse(OnDiskGraphIndex.View cView, int cNode, VectorFloat cVec, float cScore, float alpha, SelectedVecCache selectedCache) { + for (int j = 0; j < selectedCache.size; j++) { + if (selectedCache.views[j] == cView && selectedCache.nodes[j] == cNode) { + return false; // already selected; don't add a duplicate + } + if (vsf.compare(cVec, selectedCache.vecs[j]) > cScore * alpha) { + return false; + } + } + return true; + } + + } + + /** + * Cache for storing selected diverse neighbors along with their metadata and vector copies. + */ + static final class SelectedVecCache { + int[] sourceIdx; + OnDiskGraphIndex.View[] views; + int[] nodes; + float[] scores; + VectorFloat[] vecs; + int size; + + /** + * Constructs a cache with the specified capacity and vector dimension. + */ + SelectedVecCache(int capacity, int dimension) { + sourceIdx = new int[capacity]; + views = new OnDiskGraphIndex.View[capacity]; + nodes = new int[capacity]; + scores = new float[capacity]; + vecs = new VectorFloat[capacity]; + for(int c = 0; c < capacity; ++c) { + vecs[c] = vectorTypeSupport.createFloatVector(dimension); + } + size = 0; + } + + /** + * Resets the cache for reuse. + */ + void reset() { + size = 0; + } + + /** + * Adds a selected neighbor to the cache, copying its vector. + */ + void add(int source, OnDiskGraphIndex.View view, int node, float score, VectorFloat vec) { + sourceIdx[size] = source; + views[size] = view; + nodes[size] = node; + scores[size] = score; + vecs[size].copyFrom(vec, 0, 0, vec.length()); + size++; + } + } + +} + diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OrdinalMapper.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OrdinalMapper.java index 526241eff..445255980 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OrdinalMapper.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OrdinalMapper.java @@ -106,4 +106,37 @@ public int newToOld(int newOrdinal) { return newToOld.get(newOrdinal); } } + + /** + * A mapper that applies a fixed offset to ordinals. + * Used for sequential mapping where local ordinal i maps to globalOffset + i. + */ + class OffsetMapper implements OrdinalMapper { + private final int offset; + private final int size; + + public OffsetMapper(int offset, int size) { + this.offset = offset; + this.size = size; + } + + @Override + public int maxOrdinal() { + return offset + size - 1; + } + + @Override + public int oldToNew(int oldOrdinal) { + return oldOrdinal + offset; + } + + @Override + public int newToOld(int newOrdinal) { + int oldOrdinal = newOrdinal - offset; + if (oldOrdinal < 0 || oldOrdinal >= size) { + return OMITTED; + } + return oldOrdinal; + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java new file mode 100644 index 000000000..a0438168e --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java @@ -0,0 +1,233 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.util.DocIdSetIterator; +import io.github.jbellis.jvector.util.FixedBitSet; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Handles Product Quantization retraining for graph index compaction. + * Performs balanced sampling across multiple source indexes and trains + * a new PQ codebook optimized for the combined dataset. + */ +public class PQRetrainer { + private static final Logger log = LoggerFactory.getLogger(PQRetrainer.class); + private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final int MIN_SAMPLES_PER_SOURCE = 1000; + // Number of consecutive nodes to read per chunk before jumping to another location. + // Keeping reads sequential within each chunk lets the OS read-ahead cover them, + // avoiding the random I/O that would happen with per-node random sampling. + private static final int SAMPLE_CHUNK_SIZE = 32; + + private final List sources; + private final List liveNodes; + private final List numLiveNodesPerSource; + private final int dimension; + private final int numTotalNodes; + + public PQRetrainer(List sources, List liveNodes, int dimension) { + this.sources = sources; + this.liveNodes = liveNodes; + this.dimension = dimension; + + this.numLiveNodesPerSource = new ArrayList<>(sources.size()); + int total = 0; + for (int s = 0; s < sources.size(); s++) { + int numLiveNodes = liveNodes.get(s).cardinality(); + total += numLiveNodes; + this.numLiveNodesPerSource.add(numLiveNodes); + } + this.numTotalNodes = total; + } + + /** + * Trains a new Product Quantization codebook using balanced sampling across all source indexes. + * All sampled vectors are read into memory up front, so ProductQuantization.compute() itself + * performs no I/O. + */ + public ProductQuantization retrain(VectorSimilarityFunction similarityFunction) { + log.info("Training PQ using balanced sampling across sources"); + + List samples = sampleBalanced(ProductQuantization.MAX_PQ_TRAINING_SET_SIZE); + + // Sort by (source, node) so extractVectorsSequential reads each source's file + // in ascending order, enabling OS read-ahead instead of random page faults. + samples.sort(Comparator.comparingInt((SampleRef r) -> r.source).thenComparingInt(r -> r.node)); + + log.info("Collected {} training samples", samples.size()); + + // Extract vectors sequentially in sorted (source, node) order so disk reads are + // purely sequential and the OS read-ahead can cover them efficiently. We do this + // here rather than letting ProductQuantization.compute() drive the reads via its + // parallel stream, which would scatter page faults across a potentially very large + // file and cause I/O that scales with dataset size rather than sample count. + List> trainingVectors = extractVectorsSequential(samples); + var ravv = new ListRandomAccessVectorValues(trainingVectors, dimension); + + FusedPQ fpq = (FusedPQ) sources.get(0).getFeatures().get(FeatureId.FUSED_PQ); + ProductQuantization basePQ = fpq.getPQ(); + + boolean center = similarityFunction == VectorSimilarityFunction.EUCLIDEAN; + + return ProductQuantization.compute( + ravv, + basePQ.getSubspaceCount(), + basePQ.getClusterCount(), + center + ); + } + + /** + * Performs balanced sampling across all source indexes to ensure proportional representation. + * Guarantees minimum samples per source while respecting total sample budget. + */ + private List sampleBalanced(int totalSamples) { + // If total live nodes <= totalSamples, return ALL + if (numTotalNodes <= totalSamples) { + List all = new ArrayList<>(numTotalNodes); + + for (int s = 0; s < sources.size(); s++) { + FixedBitSet live = liveNodes.get(s); + + for (int node = live.nextSetBit(0); + node != DocIdSetIterator.NO_MORE_DOCS; + node = live.nextSetBit(node + 1)) { + all.add(new SampleRef(s, node)); + } + } + + return all; + } + + final int MIN_PER_SOURCE = Math.min(MIN_SAMPLES_PER_SOURCE, totalSamples / sources.size()); + + int[] quota = new int[sources.size()]; + int assigned = 0; + + // Proportional allocation + for (int s = 0; s < sources.size(); s++) { + quota[s] = Math.max( + MIN_PER_SOURCE, + (int) ((long) totalSamples * numLiveNodesPerSource.get(s) / numTotalNodes) + ); + assigned += quota[s]; + } + + // Normalize down + while (assigned > totalSamples) { + for (int s = 0; s < sources.size() && assigned > totalSamples; s++) { + if (quota[s] > MIN_PER_SOURCE) { + quota[s]--; + assigned--; + } + } + } + + // Normalize up + while (assigned < totalSamples) { + for (int s = 0; s < sources.size() && assigned < totalSamples; s++) { + quota[s]++; + assigned++; + } + } + + List samples = new ArrayList<>(totalSamples); + ThreadLocalRandom rand = ThreadLocalRandom.current(); + + for (int s = 0; s < sources.size(); s++) { + FixedBitSet live = liveNodes.get(s); + int max = live.length(); + int numChunks = (max + SAMPLE_CHUNK_SIZE - 1) / SAMPLE_CHUNK_SIZE; + + // Build a shuffled chunk order so samples are representative but + // each chunk is read sequentially to minimize page faults. + // Fisher-Yates shuffle + int[] chunkOrder = new int[numChunks]; + for (int i = 0; i < numChunks; i++) chunkOrder[i] = i; + for (int i = numChunks - 1; i > 0; i--) { + int j = rand.nextInt(i + 1); + int tmp = chunkOrder[i]; + chunkOrder[i] = chunkOrder[j]; + chunkOrder[j] = tmp; + } + + int count = 0; + outer: + for (int ci = 0; ci < numChunks; ci++) { + int start = chunkOrder[ci] * SAMPLE_CHUNK_SIZE; + int end = Math.min(max, start + SAMPLE_CHUNK_SIZE); + for (int node = start; node < end; node++) { + if (live.get(node)) { + samples.add(new SampleRef(s, node)); + if (++count >= quota[s]) break outer; + } + } + } + } + + return samples; + } + + /** + * Reads sampled vectors in the order provided. The caller must pre-sort {@code samples} + * by (source, node) so reads within each source are ascending, letting the OS read-ahead + * cover them efficiently. Each source's view is opened once and reused for all its samples. + */ + private List> extractVectorsSequential(List samples) { + OnDiskGraphIndex.View[] views = new OnDiskGraphIndex.View[sources.size()]; + for (int s = 0; s < sources.size(); s++) { + views[s] = (OnDiskGraphIndex.View) sources.get(s).getView(); + } + + List> vectors = new ArrayList<>(samples.size()); + VectorFloat tmp = vectorTypeSupport.createFloatVector(dimension); + for (SampleRef ref : samples) { + views[ref.source].getVectorInto(ref.node, tmp, 0); + vectors.add(tmp.copy()); + } + return vectors; + } + + /** + * Reference to a sampled vector from a specific source index. + */ + private static final class SampleRef { + final int source; + final int node; + + SampleRef(int source, int node) { + this.source = source; + this.node = node; + } + } + +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index 538632da0..42113b242 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -444,7 +444,7 @@ public String toString() { * This is emulative of modern Java records, but keeps to J11 standards. * This class consolidates the layout calculations for PQ data into one place */ - static class PQLayout { + public static class PQLayout { /** * total number of vectors diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorizationProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorizationProvider.java index 1ec46443d..5ab5664d1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorizationProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorizationProvider.java @@ -77,6 +77,28 @@ protected VectorizationProvider() { // visible for tests static VectorizationProvider lookup(boolean testMode) { + String forcedProvider = System.getProperty("jvector.vectorization_provider"); + if (forcedProvider != null) { + switch (forcedProvider.toLowerCase(Locale.ROOT)) { + case "default": + return new DefaultVectorizationProvider(); + case "panama": + try { + return (VectorizationProvider) Class.forName("io.github.jbellis.jvector.vector.PanamaVectorizationProvider").getConstructor().newInstance(); + } catch (Throwable e) { + throw new RuntimeException("Failed to load forced PanamaVectorizationProvider", e); + } + case "native": + try { + return (VectorizationProvider) Class.forName("io.github.jbellis.jvector.vector.NativeVectorizationProvider").getConstructor().newInstance(); + } catch (Throwable e) { + throw new RuntimeException("Failed to load forced NativeVectorizationProvider", e); + } + default: + throw new IllegalArgumentException("Unknown vectorization provider: " + forcedProvider); + } + } + final int runtimeVersion = Runtime.version().feature(); if (runtimeVersion >= 20) { // is locale sane (only buggy in Java 20) diff --git a/jvector-examples/pom.xml b/jvector-examples/pom.xml index 9daf7b8cf..731c1d769 100644 --- a/jvector-examples/pom.xml +++ b/jvector-examples/pom.xml @@ -85,16 +85,16 @@ gson 2.10.1 - + org.slf4j slf4j-api 2.0.9 - ch.qos.logback - logback-classic - 1.4.11 + org.apache.logging.log4j + log4j-slf4j2-impl + 2.24.3 software.amazon.awssdk @@ -112,6 +112,11 @@ aws-crt-client ${awssdk.version} + + software.amazon.awssdk + ec2 + ${awssdk.version} + software.amazon.awssdk s3 diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/GitInfo.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/GitInfo.java new file mode 100644 index 000000000..92979ac5d --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/GitInfo.java @@ -0,0 +1,54 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.example.reporting; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Detects the current git commit hash for tagging benchmark results. + */ +public final class GitInfo { + private static final Logger log = LoggerFactory.getLogger(GitInfo.class); + + private GitInfo() {} + + // Lazy holder pattern — computed once on first access + private static class Holder { + static final String SHORT_HASH; + static { + String hash; + try { + var process = new ProcessBuilder("git", "rev-parse", "HEAD").redirectErrorStream(true).start(); + hash = new String(process.getInputStream().readAllBytes()).trim(); + process.waitFor(); + if (hash.length() >= 8) { + hash = hash.substring(hash.length() - 8); + } + } catch (Exception e) { + log.warn("Could not determine git hash", e); + hash = "unknown"; + } + SHORT_HASH = hash; + } + } + + /** Returns the last 8 characters of {@code git rev-parse HEAD}, or {@code "unknown"} on failure. */ + public static String getShortHash() { + return Holder.SHORT_HASH; + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/JfrRecorder.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/JfrRecorder.java new file mode 100644 index 000000000..ab5c4b581 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/JfrRecorder.java @@ -0,0 +1,105 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.example.reporting; + +import jdk.jfr.Configuration; +import jdk.jfr.Recording; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.text.ParseException; +import java.time.Duration; + +/** + * Manages the lifecycle of a JFR (Java Flight Recorder) recording for benchmarks. + */ +public final class JfrRecorder { + private static final Logger log = LoggerFactory.getLogger(JfrRecorder.class); + + private Recording recording; + private String fileName; + + /** + * Creates the output directory, configures a "profile" recording, starts it, and returns the absolute path. + * + * @param outputDir directory to write the JFR file into + * @param fileName name of the JFR file (e.g. {@code "compactor-foo.jfr"}) + * @return the absolute path of the recording file + * @throws IOException if the directory cannot be created + * @throws ParseException if the JFR "profile" configuration cannot be loaded + */ + public Path start(Path outputDir, String fileName) throws IOException, ParseException { + return start(outputDir, fileName, false); + } + + /** + * Creates the output directory, configures a "profile" recording, starts it, and returns the absolute path. + * + * @param outputDir directory to write the JFR file into + * @param fileName name of the JFR file + * @param objectCount whether to enable periodic 'jdk.ObjectCount' events + * @return the absolute path of the recording file + */ + public Path start(Path outputDir, String fileName, boolean objectCount) throws IOException, ParseException { + Files.createDirectories(outputDir); + Path jfrPath = outputDir.resolve(fileName).toAbsolutePath(); + recording = new Recording(Configuration.getConfiguration("profile")); + recording.setToDisk(true); + recording.setDestination(jfrPath); + + // Enable heap occupancy snapshots and old object sampling + var settings = recording.getSettings(); + if (objectCount) { + settings.put("jdk.ObjectCount#enabled", "true"); + settings.put("jdk.ObjectCount#period", "10s"); // Every 10 seconds + } + settings.put("jdk.OldObjectSample#enabled", "true"); + // Flush to disk every minute so data is available for inspection during long benchmarks + settings.put("flush-interval", Duration.ofMinutes(1).toMillis() + "ms"); + recording.setSettings(settings); + recording.start(); + this.fileName = fileName; + System.out.println("JFR recording started, saving to: " + jfrPath); + log.info("JFR recording started, saving to: {}", jfrPath); + return jfrPath; + } + + /** Stops and closes the recording, logging the saved path. */ + public void stop() { + if (recording != null) { + Path jfrPath = recording.getDestination(); + recording.stop(); + recording.close(); + recording = null; + System.out.println("JFR recording saved to: " + jfrPath); + log.info("JFR recording saved to: {}", jfrPath); + } + } + + /** Returns {@code true} if a recording is currently in progress. */ + public boolean isActive() { + return recording != null; + } + + /** Returns the current file name, or {@code null} if no recording has been started. */ + public String getFileName() { + return fileName; + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/JsonlWriter.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/JsonlWriter.java new file mode 100644 index 000000000..a025dd207 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/JsonlWriter.java @@ -0,0 +1,56 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.example.reporting; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.Map; + +/** + * Append-only JSONL file writer that serializes one map per line using GSON. + */ +public final class JsonlWriter { + private static final Logger log = LoggerFactory.getLogger(JsonlWriter.class); + private static final Gson GSON = new GsonBuilder() + .disableHtmlEscaping() + .serializeNulls() + .create(); // No pretty printing for JSONL + + private final Path outputFile; + + public JsonlWriter(Path outputFile) { + this.outputFile = outputFile; + } + + /** Serializes the map as a single JSON line and appends it to the output file. */ + public void writeLine(Map result) { + String json = GSON.toJson(result) + "\n"; + try { + Files.writeString(outputFile, json, + StandardOpenOption.CREATE, StandardOpenOption.APPEND); + } catch (IOException e) { + log.error("Failed to persist result to {}", outputFile, e); + } + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/SystemStatsCollector.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/SystemStatsCollector.java new file mode 100644 index 000000000..571782510 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/SystemStatsCollector.java @@ -0,0 +1,196 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.example.reporting; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.HashSet; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; + +/** + * Background collector of {@code /proc} system metrics (CPU topology, load, memory, disk I/O). + * Reads /proc files directly in Java and appends JSONL lines to a file every 30 seconds. + */ +public final class SystemStatsCollector { + private static final Logger log = LoggerFactory.getLogger(SystemStatsCollector.class); + private static final Path PROC_CPUINFO = Path.of("/proc/cpuinfo"); + private static final Path PROC_LOADAVG = Path.of("/proc/loadavg"); + private static final Path PROC_MEMINFO = Path.of("/proc/meminfo"); + private static final Path PROC_DISKSTATS = Path.of("/proc/diskstats"); + private static final Pattern DISK_DEVICE_PATTERN = Pattern.compile("sd[a-z]+|nvme[0-9]+n[0-9]+|vd[a-z]+|xvd[a-z]+"); + private static final DateTimeFormatter TS_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss'Z'").withZone(ZoneOffset.UTC); + + private ScheduledExecutorService scheduler; + private BufferedWriter writer; + private String fileName; + private int cpuSockets; + private int cpuCores; + private int cpuThreads; + + public Path start(Path outputDir, String fileName) throws IOException { + if (!Files.exists(PROC_CPUINFO)) { + log.warn("/proc filesystem not available (not Linux?), system stats collection disabled"); + return null; + } + + Files.createDirectories(outputDir); + Path sysStatsPath = outputDir.resolve(fileName).toAbsolutePath(); + this.fileName = fileName; + + parseCpuTopology(); + + this.writer = Files.newBufferedWriter(sysStatsPath, + StandardOpenOption.CREATE, StandardOpenOption.APPEND); + + scheduler = Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "sys-stats-collector"); + t.setDaemon(true); + return t; + }); + scheduler.scheduleAtFixedRate(() -> { + try { + String line = collectSnapshot(); + writer.write(line); + writer.newLine(); + writer.flush(); + } catch (Exception e) { + log.warn("Failed to collect system stats", e); + } + }, 0, 30, TimeUnit.SECONDS); + + log.info("System stats collection started, saving to: {}", sysStatsPath); + return sysStatsPath; + } + + public void stop(Path outputDir) throws InterruptedException { + if (scheduler != null) { + scheduler.shutdown(); + scheduler.awaitTermination(5, TimeUnit.SECONDS); + scheduler = null; + try { + if (writer != null) { + writer.close(); + writer = null; + } + } catch (IOException e) { + log.warn("Failed to close stats writer", e); + } + log.info("System stats collection stopped, saved to: {}", outputDir.resolve(fileName).toAbsolutePath()); + } + } + + public boolean isActive() { + return scheduler != null && !scheduler.isShutdown(); + } + + public String getFileName() { + return fileName; + } + + private void parseCpuTopology() throws IOException { + List lines = Files.readAllLines(PROC_CPUINFO); + int threads = 0; + var physicalIds = new HashSet(); + var coreKeys = new HashSet(); + String currentPhysicalId = "0"; + + for (String line : lines) { + if (line.startsWith("processor")) { + threads++; + } else if (line.startsWith("physical id")) { + currentPhysicalId = line.substring(line.indexOf(':') + 1).trim(); + physicalIds.add(currentPhysicalId); + } else if (line.startsWith("core id")) { + String coreId = line.substring(line.indexOf(':') + 1).trim(); + coreKeys.add(currentPhysicalId + "-" + coreId); + } + } + + this.cpuThreads = threads; + this.cpuSockets = physicalIds.isEmpty() ? 1 : physicalIds.size(); + this.cpuCores = coreKeys.isEmpty() ? cpuThreads : coreKeys.size(); + } + + private String collectSnapshot() throws IOException { + String ts = TS_FORMAT.format(Instant.now()); + + // /proc/loadavg: "0.50 0.35 0.25 2/150 12345" + String loadLine = Files.readString(PROC_LOADAVG).trim(); + String[] loadParts = loadLine.split("\\s+"); + String load1 = loadParts[0]; + String load5 = loadParts[1]; + String load15 = loadParts[2]; + String[] runProcs = loadParts[3].split("/"); + String running = runProcs[0]; + String total = runProcs[1]; + + // /proc/meminfo + long memTotal = 0, memFree = 0, memAvail = 0, buffers = 0, cached = 0, swapTotal = 0, swapFree = 0; + for (String line : Files.readAllLines(PROC_MEMINFO)) { + if (line.startsWith("MemTotal:")) memTotal = parseMemValue(line); + else if (line.startsWith("MemFree:")) memFree = parseMemValue(line); + else if (line.startsWith("MemAvailable:")) memAvail = parseMemValue(line); + else if (line.startsWith("Buffers:")) buffers = parseMemValue(line); + else if (line.startsWith("Cached:")) cached = parseMemValue(line); + else if (line.startsWith("SwapTotal:")) swapTotal = parseMemValue(line); + else if (line.startsWith("SwapFree:")) swapFree = parseMemValue(line); + } + + // /proc/diskstats + StringBuilder disks = new StringBuilder(); + for (String line : Files.readAllLines(PROC_DISKSTATS)) { + String[] f = line.trim().split("\\s+"); + if (f.length < 14) continue; + String dev = f[2]; + if (!DISK_DEVICE_PATTERN.matcher(dev).matches()) continue; + if (disks.length() > 0) disks.append(','); + disks.append(String.format( + "{\"device\":\"%s\",\"readsCompleted\":%s,\"readsMerged\":%s,\"sectorsRead\":%s,\"readTimeMs\":%s," + + "\"writesCompleted\":%s,\"writesMerged\":%s,\"sectorsWritten\":%s,\"writeTimeMs\":%s," + + "\"ioInProgress\":%s,\"ioTimeMs\":%s,\"weightedIoTimeMs\":%s}", + dev, f[3], f[4], f[5], f[6], f[7], f[8], f[9], f[10], f[11], f[12], f[13])); + } + + return String.format( + "{\"timestamp\":\"%s\",\"cpuSockets\":%d,\"cpuCores\":%d,\"cpuThreads\":%d," + + "\"loadAvg1\":%s,\"loadAvg5\":%s,\"loadAvg15\":%s,\"runningProcs\":%s,\"totalProcs\":%s," + + "\"memTotalKB\":%d,\"memFreeKB\":%d,\"memAvailableKB\":%d,\"buffersKB\":%d,\"cachedKB\":%d," + + "\"swapTotalKB\":%d,\"swapFreeKB\":%d,\"diskStats\":[%s]}", + ts, cpuSockets, cpuCores, cpuThreads, + load1, load5, load15, running, total, + memTotal, memFree, memAvail, buffers, cached, swapTotal, swapFree, + disks); + } + + private static long parseMemValue(String line) { + String[] parts = line.split("\\s+"); + return Long.parseLong(parts[1]); + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/ThreadAllocTracker.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/ThreadAllocTracker.java new file mode 100644 index 000000000..207d67045 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/reporting/ThreadAllocTracker.java @@ -0,0 +1,203 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.example.reporting; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadInfo; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + +/** + * Periodically samples per-thread heap allocation via + * {@link com.sun.management.ThreadMXBean#getThreadAllocatedBytes(long[])} + * and writes JSONL output with per-thread deltas and cumulative totals. + * + * Lifecycle mirrors {@link SystemStatsCollector}: {@link #start(Path, String)}, + * {@link #stop()}, {@link #isActive()}, {@link #getFileName()}. + */ +public final class ThreadAllocTracker { + private static final Logger log = LoggerFactory.getLogger(ThreadAllocTracker.class); + + private static final long DEFAULT_INTERVAL_SECONDS = 10; + + private final com.sun.management.ThreadMXBean threadMXBean; + private final long intervalSeconds; + + private volatile Thread samplerThread; + private volatile boolean running; + private String fileName; + + /// Creates a tracker with the default 10-second sampling interval. + public ThreadAllocTracker() { + this(DEFAULT_INTERVAL_SECONDS); + } + + /// Creates a tracker with a custom sampling interval. + /// + /// @param intervalSeconds seconds between each sample + public ThreadAllocTracker(long intervalSeconds) { + this.threadMXBean = (com.sun.management.ThreadMXBean) ManagementFactory.getThreadMXBean(); + this.intervalSeconds = intervalSeconds; + } + + /// Creates the output directory, enables thread allocated memory tracking, + /// and spawns a daemon thread that periodically writes JSONL samples. + /// + /// @param outputDir directory to write the JSONL file into + /// @param fileName name of the output file + /// @return the absolute path of the output file + /// @throws IOException if the directory cannot be created + public Path start(Path outputDir, String fileName) throws IOException { + Files.createDirectories(outputDir); + Path outputPath = outputDir.resolve(fileName).toAbsolutePath(); + this.fileName = fileName; + + threadMXBean.setThreadAllocatedMemoryEnabled(true); + + running = true; + samplerThread = new Thread(() -> sampleLoop(outputPath), "thread-alloc-tracker"); + samplerThread.setDaemon(true); + samplerThread.start(); + + log.info("Thread allocation tracking started, saving to: {}", outputPath); + return outputPath; + } + + /// Stops the sampler thread and writes a final cumulative summary line. + public void stop() { + running = false; + if (samplerThread != null) { + samplerThread.interrupt(); + try { + samplerThread.join(5000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + samplerThread = null; + log.info("Thread allocation tracking stopped, saved to: {}", fileName); + } + } + + /// Returns {@code true} if the sampler thread is currently running. + public boolean isActive() { + return samplerThread != null && running; + } + + /// Returns the current file name, or {@code null} if tracking has not been started. + public String getFileName() { + return fileName; + } + + private void sampleLoop(Path outputPath) { + // Track cumulative allocations per thread (by id) for delta computation + var previousAllocations = new HashMap(); + + try (var writer = Files.newBufferedWriter(outputPath)) { + while (running) { + try { + Thread.sleep(intervalSeconds * 1000); + } catch (InterruptedException e) { + // On interrupt (from stop()), write final summary and exit + break; + } + writeSample(writer, previousAllocations, false); + } + // Write final summary with cumulative totals + writeSample(writer, previousAllocations, true); + } catch (IOException e) { + log.error("Failed to write thread allocation sample", e); + } + } + + private void writeSample(BufferedWriter writer, Map previousAllocations, boolean isSummary) + throws IOException { + long[] threadIds = threadMXBean.getAllThreadIds(); + long[] allocatedBytes = threadMXBean.getThreadAllocatedBytes(threadIds); + ThreadInfo[] threadInfos = threadMXBean.getThreadInfo(threadIds); + + var sb = new StringBuilder(); + sb.append("{\"timestamp\":\"").append(Instant.now().toString()).append('"'); + if (isSummary) { + sb.append(",\"event\":\"summary\""); + } + sb.append(",\"threads\":["); + + long totalAllocated = 0; + long totalDelta = 0; + boolean first = true; + + for (int i = 0; i < threadIds.length; i++) { + if (threadInfos[i] == null || allocatedBytes[i] < 0) { + continue; + } + long id = threadIds[i]; + long allocated = allocatedBytes[i]; + long previous = previousAllocations.getOrDefault(id, 0L); + long delta = allocated - previous; + previousAllocations.put(id, allocated); + + totalAllocated += allocated; + totalDelta += delta; + + if (!first) { + sb.append(','); + } + first = false; + + sb.append("{\"id\":").append(id) + .append(",\"name\":\"").append(escapeJson(threadInfos[i].getThreadName())).append('"') + .append(",\"allocatedBytes\":").append(allocated) + .append(",\"deltaBytes\":").append(delta) + .append('}'); + } + + sb.append("],\"totalAllocatedBytes\":").append(totalAllocated) + .append(",\"totalDeltaBytes\":").append(totalDelta) + .append('}'); + + writer.write(sb.toString()); + writer.newLine(); + writer.flush(); + } + + private static String escapeJson(String value) { + if (value == null) { + return ""; + } + var sb = new StringBuilder(value.length()); + for (int i = 0; i < value.length(); i++) { + char c = value.charAt(i); + switch (c) { + case '"': sb.append("\\\""); break; + case '\\': sb.append("\\\\"); break; + case '\n': sb.append("\\n"); break; + case '\r': sb.append("\\r"); break; + case '\t': sb.append("\\t"); break; + default: sb.append(c); + } + } + return sb.toString(); + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetPartitioner.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetPartitioner.java new file mode 100644 index 000000000..1e6a83f40 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetPartitioner.java @@ -0,0 +1,60 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.example.util; + +import io.github.jbellis.jvector.example.benchmarks.datasets.DataSet; +import io.github.jbellis.jvector.example.yaml.TestDataPartition; +import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.util.ArrayList; +import java.util.List; + +/** + * Utility for partitioning a DataSet into multiple segments based on a distribution. + */ +public final class DataSetPartitioner { + private DataSetPartitioner() {} + + public static final class PartitionedData { + public final List>> vectors; + public final List sizes; + + public PartitionedData(List>> vectors, List sizes) { + this.vectors = vectors; + this.sizes = sizes; + } + } + + public static PartitionedData partition(DataSet ds, int numParts, TestDataPartition.Distribution distribution) { + return partition(ds.getBaseVectors(), numParts, distribution); + } + + public static PartitionedData partition(List> baseVectors, int numParts, TestDataPartition.Distribution distribution) { + List sizes = distribution.computeSplitSizes(baseVectors.size(), numParts); + List>> parts = new ArrayList<>(numParts); + + int runningStart = 0; + for (int size : sizes) { + int start = runningStart; + int end = start + size; + runningStart = end; + parts.add(new ArrayList<>(baseVectors.subList(start, end))); + } + + return new PartitionedData(parts, sizes); + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/CloudStorageLayoutUtil.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/CloudStorageLayoutUtil.java new file mode 100644 index 000000000..9530b8815 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/CloudStorageLayoutUtil.java @@ -0,0 +1,331 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.example.util.storage; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Cloud wrapper that chooses AWS or GCP storage inspection and maps the provider-specific classes + * into cloud-agnostic storage tiers. + */ +public final class CloudStorageLayoutUtil { + private CloudStorageLayoutUtil() { + } + + public enum CloudProvider { + AWS_EC2, + GCP_GCE, + LOCAL_OR_UNKNOWN + } + + public enum StorageClass { + BLOCK_HDD_COLD, + BLOCK_HDD_THROUGHPUT, + BLOCK_HDD_STANDARD, + BLOCK_SSD_BALANCED, + BLOCK_SSD_GENERAL, + BLOCK_SSD_HIGH_IOPS, + LOCAL_SSD, + LOCAL_NVME, + NETWORK_FILESYSTEM, + MEMORY_TMPFS, + PSEUDO_FILESYSTEM, + UNKNOWN + } + + public static final class StorageSnapshot { + private final T cloudSpecificSnapshot; + private final CloudProvider provider; + private final boolean runningInCloud; + private final String instanceId; + private final String instanceTypeOrMachineType; + private final String regionOrZone; + private final Map mountsByMountPoint; + + public StorageSnapshot(T cloudSpecificSnapshot, + CloudProvider provider, + boolean runningInCloud, + String instanceId, + String instanceTypeOrMachineType, + String regionOrZone, + Map mountsByMountPoint) { + this.cloudSpecificSnapshot = cloudSpecificSnapshot; + this.provider = Objects.requireNonNull(provider, "provider"); + this.runningInCloud = runningInCloud; + this.instanceId = instanceId; + this.instanceTypeOrMachineType = instanceTypeOrMachineType; + this.regionOrZone = regionOrZone; + this.mountsByMountPoint = Objects.requireNonNull(mountsByMountPoint, "mountsByMountPoint"); + } + + public T cloudSpecificSnapshot() { + return cloudSpecificSnapshot; + } + + public CloudProvider provider() { + return provider; + } + + public boolean runningInCloud() { + return runningInCloud; + } + + public String instanceId() { + return instanceId; + } + + public String instanceTypeOrMachineType() { + return instanceTypeOrMachineType; + } + + public String regionOrZone() { + return regionOrZone; + } + + public Map mountsByMountPoint() { + return mountsByMountPoint; + } + } + + public static final class MountStorageInfo { + private final String mountPoint; + private final String source; + private final String filesystemType; + private final StorageClass storageClass; + private final String providerSpecificClass; + + public MountStorageInfo(String mountPoint, + String source, + String filesystemType, + StorageClass storageClass, + String providerSpecificClass) { + this.mountPoint = mountPoint; + this.source = source; + this.filesystemType = filesystemType; + this.storageClass = Objects.requireNonNull(storageClass, "storageClass"); + this.providerSpecificClass = providerSpecificClass; + } + + public String mountPoint() { + return mountPoint; + } + + public String source() { + return source; + } + + public String filesystemType() { + return filesystemType; + } + + public StorageClass storageClass() { + return storageClass; + } + + public String providerSpecificClass() { + return providerSpecificClass; + } + } + + public static StorageSnapshot inspectStorage() { + var awsSnapshot = StorageLayoutUtil.inspectStorage(); + if (awsSnapshot.runningOnEc2()) { + return fromAws(awsSnapshot, CloudProvider.AWS_EC2, true); + } + + var gcpSnapshot = GcpStorageLayoutUtil.inspectStorage(); + if (gcpSnapshot.runningOnGcp()) { + return fromGcp(gcpSnapshot); + } + + // Not in a detected cloud environment. Use OS-specific local storage inspection. + var localSnapshot = LocalStorageLayoutUtil.inspectStorage(); + return fromLocal(localSnapshot); + } + + public static Map storageClassByMountPoint() { + var snapshot = inspectStorage(); + var byMountPoint = new LinkedHashMap(snapshot.mountsByMountPoint().size()); + for (var entry : snapshot.mountsByMountPoint().entrySet()) { + byMountPoint.put(entry.getKey(), entry.getValue().storageClass()); + } + return Collections.unmodifiableMap(byMountPoint); + } + + private static StorageSnapshot fromAws(StorageLayoutUtil.StorageSnapshot snapshot, + CloudProvider provider, + boolean runningInCloud) { + var byMountPoint = new LinkedHashMap(snapshot.mountsByMountPoint().size()); + for (var entry : snapshot.mountsByMountPoint().entrySet()) { + var mount = entry.getValue(); + byMountPoint.put( + entry.getKey(), + new MountStorageInfo( + mount.mountPoint(), + mount.source(), + mount.filesystemType(), + mapAwsClass(mount.storageClass()), + mount.storageClass().name() + ) + ); + } + + return new StorageSnapshot<>( + snapshot, + provider, + runningInCloud, + snapshot.instanceId(), + snapshot.instanceType(), + snapshot.region(), + Collections.unmodifiableMap(byMountPoint) + ); + } + + private static StorageSnapshot fromGcp(GcpStorageLayoutUtil.StorageSnapshot snapshot) { + var byMountPoint = new LinkedHashMap(snapshot.mountsByMountPoint().size()); + for (var entry : snapshot.mountsByMountPoint().entrySet()) { + var mount = entry.getValue(); + byMountPoint.put( + entry.getKey(), + new MountStorageInfo( + mount.mountPoint(), + mount.source(), + mount.filesystemType(), + mapGcpClass(mount.storageClass()), + mount.storageClass().name() + ) + ); + } + + return new StorageSnapshot<>( + snapshot, + CloudProvider.GCP_GCE, + true, + snapshot.instanceId(), + snapshot.machineType(), + snapshot.zone(), + Collections.unmodifiableMap(byMountPoint) + ); + } + + private static StorageSnapshot fromLocal(LocalStorageLayoutUtil.StorageSnapshot snapshot) { + var byMountPoint = new LinkedHashMap(snapshot.mountsByMountPoint().size()); + for (var entry : snapshot.mountsByMountPoint().entrySet()) { + var mount = entry.getValue(); + byMountPoint.put( + entry.getKey(), + new MountStorageInfo( + mount.mountPoint(), + mount.source(), + mount.filesystemType(), + mapLocalClass(mount.storageClass()), + mount.storageClass().name() + ) + ); + } + + return new StorageSnapshot<>( + snapshot, + CloudProvider.LOCAL_OR_UNKNOWN, + false, + null, + snapshot.osName(), + snapshot.osName(), + Collections.unmodifiableMap(byMountPoint) + ); + } + + private static StorageClass mapAwsClass(StorageLayoutUtil.StorageClass storageClass) { + switch (storageClass) { + case EBS_COLD_HDD: + return StorageClass.BLOCK_HDD_COLD; + case EBS_THROUGHPUT_HDD: + return StorageClass.BLOCK_HDD_THROUGHPUT; + case EBS_MAGNETIC: + return StorageClass.BLOCK_HDD_STANDARD; + case EBS_GP2: + return StorageClass.BLOCK_SSD_BALANCED; + case EBS_GP3: + return StorageClass.BLOCK_SSD_GENERAL; + case EBS_PROVISIONED_IOPS_SSD: + return StorageClass.BLOCK_SSD_HIGH_IOPS; + case INSTANCE_STORE_SSD: + return StorageClass.LOCAL_SSD; + case INSTANCE_STORE_NVME: + return StorageClass.LOCAL_NVME; + case NETWORK_FILESYSTEM: + return StorageClass.NETWORK_FILESYSTEM; + case MEMORY_TMPFS: + return StorageClass.MEMORY_TMPFS; + case PSEUDO_FILESYSTEM: + return StorageClass.PSEUDO_FILESYSTEM; + case UNKNOWN: + default: + return StorageClass.UNKNOWN; + } + } + + private static StorageClass mapGcpClass(GcpStorageLayoutUtil.StorageClass storageClass) { + switch (storageClass) { + case PD_STANDARD_HDD: + return StorageClass.BLOCK_HDD_STANDARD; + case PD_THROUGHPUT_OPTIMIZED: + return StorageClass.BLOCK_HDD_THROUGHPUT; + case PD_BALANCED_SSD: + return StorageClass.BLOCK_SSD_BALANCED; + case PD_SSD: + return StorageClass.BLOCK_SSD_GENERAL; + case PD_EXTREME_SSD: + return StorageClass.BLOCK_SSD_HIGH_IOPS; + case LOCAL_SSD: + return StorageClass.LOCAL_SSD; + case LOCAL_NVME: + return StorageClass.LOCAL_NVME; + case NETWORK_FILESYSTEM: + return StorageClass.NETWORK_FILESYSTEM; + case MEMORY_TMPFS: + return StorageClass.MEMORY_TMPFS; + case PSEUDO_FILESYSTEM: + return StorageClass.PSEUDO_FILESYSTEM; + case UNKNOWN: + default: + return StorageClass.UNKNOWN; + } + } + + private static StorageClass mapLocalClass(LocalStorageLayoutUtil.StorageClass storageClass) { + switch (storageClass) { + case LOCAL_HDD: + return StorageClass.BLOCK_HDD_STANDARD; + case LOCAL_SSD: + return StorageClass.LOCAL_SSD; + case LOCAL_NVME: + return StorageClass.LOCAL_NVME; + case NETWORK_FILESYSTEM: + return StorageClass.NETWORK_FILESYSTEM; + case MEMORY_TMPFS: + return StorageClass.MEMORY_TMPFS; + case PSEUDO_FILESYSTEM: + return StorageClass.PSEUDO_FILESYSTEM; + case UNKNOWN: + default: + return StorageClass.UNKNOWN; + } + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/GcpStorageLayoutUtil.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/GcpStorageLayoutUtil.java new file mode 100644 index 000000000..a4ecb6d09 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/GcpStorageLayoutUtil.java @@ -0,0 +1,695 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.example.util.storage; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.Stream; + +/** + * Detects GCE runtime context via metadata service and classifies storage for each mounted filesystem. + */ +public final class GcpStorageLayoutUtil { + private static final String GCE_METADATA_HOST_ENV = "GCE_METADATA_HOST"; + private static final String METADATA_HOST_DEFAULT = "metadata.google.internal"; + private static final String METADATA_PREFIX = "/computeMetadata/v1/"; + private static final String METADATA_FLAVOR_HEADER = "Metadata-Flavor"; + private static final String METADATA_FLAVOR_VALUE = "Google"; + private static final Duration METADATA_TIMEOUT = Duration.ofMillis(300); + + private static final Pattern NVME_PARTITION_SUFFIX = Pattern.compile("p\\d+$"); + private static final Pattern GENERIC_PARTITION_SUFFIX = Pattern.compile("\\d+$"); + private static final Set NETWORK_FILESYSTEM_TYPES = Set.of("nfs", "nfs4", "efs", "cifs", "smbfs", "fuse.sshfs"); + + private GcpStorageLayoutUtil() { + } + + public enum StorageClass { + PD_STANDARD_HDD, + PD_THROUGHPUT_OPTIMIZED, + PD_BALANCED_SSD, + PD_SSD, + PD_EXTREME_SSD, + LOCAL_SSD, + LOCAL_NVME, + NETWORK_FILESYSTEM, + MEMORY_TMPFS, + PSEUDO_FILESYSTEM, + UNKNOWN + } + + public static final class StorageSnapshot { + private final boolean runningOnGcp; + private final String instanceId; + private final String machineType; + private final String zone; + private final Map mountsByMountPoint; + + public StorageSnapshot(boolean runningOnGcp, + String instanceId, + String machineType, + String zone, + Map mountsByMountPoint) { + this.runningOnGcp = runningOnGcp; + this.instanceId = instanceId; + this.machineType = machineType; + this.zone = zone; + this.mountsByMountPoint = Objects.requireNonNull(mountsByMountPoint, "mountsByMountPoint"); + } + + public boolean runningOnGcp() { + return runningOnGcp; + } + + public String instanceId() { + return instanceId; + } + + public String machineType() { + return machineType; + } + + public String zone() { + return zone; + } + + public Map mountsByMountPoint() { + return mountsByMountPoint; + } + } + + public static final class MountStorageInfo { + private final String mountPoint; + private final String source; + private final String filesystemType; + private final StorageClass storageClass; + private final String deviceName; + private final String diskKind; + private final String interfaceType; + + public MountStorageInfo(String mountPoint, + String source, + String filesystemType, + StorageClass storageClass, + String deviceName, + String diskKind, + String interfaceType) { + this.mountPoint = mountPoint; + this.source = source; + this.filesystemType = filesystemType; + this.storageClass = Objects.requireNonNull(storageClass, "storageClass"); + this.deviceName = deviceName; + this.diskKind = diskKind; + this.interfaceType = interfaceType; + } + + public String mountPoint() { + return mountPoint; + } + + public String source() { + return source; + } + + public String filesystemType() { + return filesystemType; + } + + public StorageClass storageClass() { + return storageClass; + } + + public String deviceName() { + return deviceName; + } + + public String diskKind() { + return diskKind; + } + + public String interfaceType() { + return interfaceType; + } + } + + public static StorageSnapshot inspectStorage() { + var identity = fetchGcpIdentity(); + var mounts = readMountEntries(); + var diskData = identity.map(GcpStorageLayoutUtil::fetchGcpDiskData).orElse(GcpDiskData.empty()); + + mounts.sort(Comparator.comparing(MountEntry::mountPoint)); + var byMountPoint = new LinkedHashMap(mounts.size()); + for (var mount : mounts) { + var diskResolution = resolveDisk(mount.source(), diskData); + var storageClass = classify(mount, diskResolution); + byMountPoint.put( + mount.mountPoint(), + new MountStorageInfo( + mount.mountPoint(), + mount.source(), + mount.filesystemType(), + storageClass, + diskResolution.deviceName(), + diskResolution.diskKind(), + diskResolution.interfaceType() + ) + ); + } + + return new StorageSnapshot( + identity.isPresent(), + identity.map(GcpIdentity::instanceId).orElse(null), + identity.map(GcpIdentity::machineType).orElse(null), + identity.map(GcpIdentity::zone).orElse(null), + Collections.unmodifiableMap(byMountPoint) + ); + } + + public static Map storageClassByMountPoint() { + var snapshot = inspectStorage(); + var byMountPoint = new LinkedHashMap(snapshot.mountsByMountPoint().size()); + for (var entry : snapshot.mountsByMountPoint().entrySet()) { + byMountPoint.put(entry.getKey(), entry.getValue().storageClass()); + } + return Collections.unmodifiableMap(byMountPoint); + } + + private static Optional fetchGcpIdentity() { + var client = HttpClient.newBuilder() + .connectTimeout(METADATA_TIMEOUT) + .build(); + + var instanceId = readMetadata(client, "instance/id"); + if (instanceId == null || instanceId.isBlank()) { + return Optional.empty(); + } + + var machineType = parseLeafResource(readMetadata(client, "instance/machine-type")); + var zone = parseLeafResource(readMetadata(client, "instance/zone")); + return Optional.of(new GcpIdentity(instanceId.trim(), machineType, zone)); + } + + private static GcpDiskData fetchGcpDiskData(GcpIdentity ignoredIdentity) { + var byDeviceName = fetchDisksByDeviceNameFromMetadata(); + var aliasesByNormalizedDevice = mapGoogleAliasesByNormalizedDevice(); + return new GcpDiskData(byDeviceName, aliasesByNormalizedDevice); + } + + private static Map fetchDisksByDeviceNameFromMetadata() { + var client = HttpClient.newBuilder() + .connectTimeout(METADATA_TIMEOUT) + .build(); + + var indexListing = readMetadata(client, "instance/disks/"); + if (indexListing == null || indexListing.isBlank()) { + return Map.of(); + } + + var byDeviceName = new LinkedHashMap(); + for (var rawLine : indexListing.split("\n")) { + var line = rawLine.trim(); + if (line.isEmpty()) { + continue; + } + var index = line.endsWith("/") ? line.substring(0, line.length() - 1) : line; + var deviceName = readMetadata(client, "instance/disks/" + index + "/device-name"); + if (deviceName == null || deviceName.isBlank()) { + continue; + } + + var diskKind = safeLower(readMetadata(client, "instance/disks/" + index + "/type")); + var interfaceType = safeUpper(readMetadata(client, "instance/disks/" + index + "/interface")); + var diskTypeHint = readMetadata(client, "instance/disks/" + index + "/disk-type"); + byDeviceName.put(deviceName.trim(), new GcpDiskInfo(deviceName.trim(), diskKind, interfaceType, safeLower(diskTypeHint))); + } + return byDeviceName; + } + + private static Map> mapGoogleAliasesByNormalizedDevice() { + var byIdDir = Path.of("/dev/disk/by-id"); + if (!Files.isDirectory(byIdDir)) { + return Map.of(); + } + + var aliasesByDevice = new LinkedHashMap>(); + try (Stream entries = Files.list(byIdDir)) { + entries.filter(Files::isSymbolicLink).forEach(link -> { + var alias = link.getFileName().toString(); + if (!alias.startsWith("google-")) { + return; + } + try { + var target = normalizeDevice(link.toRealPath().toString()); + aliasesByDevice.computeIfAbsent(target, unused -> new ArrayList<>()).add(alias); + } catch (IOException ignored) { + // continue + } + }); + } catch (IOException ignored) { + return Map.of(); + } + + for (var aliases : aliasesByDevice.values()) { + aliases.sort(String::compareTo); + } + return aliasesByDevice; + } + + private static DiskResolution resolveDisk(String mountSource, GcpDiskData diskData) { + if (mountSource == null || !mountSource.startsWith("/dev/")) { + return DiskResolution.empty(); + } + + var normalized = normalizeDevice(mountSource); + var aliases = diskData.aliasesByNormalizedDevice().getOrDefault(normalized, List.of()); + var primaryAlias = aliases.isEmpty() ? null : aliases.get(0); + var inferredDeviceName = primaryAlias == null ? null : stripGooglePrefix(primaryAlias); + GcpDiskInfo info = inferredDeviceName == null ? null : diskData.byDeviceName().get(inferredDeviceName); + + // Try all aliases in case the first one doesn't match a metadata device-name. + if (info == null) { + for (var alias : aliases) { + var candidate = stripGooglePrefix(alias); + if (candidate == null) { + continue; + } + info = diskData.byDeviceName().get(candidate); + if (info != null) { + inferredDeviceName = candidate; + break; + } + } + } + + var rotational = readRotationalFlag(normalized); + if (info == null) { + return new DiskResolution(normalized, inferredDeviceName, null, null, null, rotational); + } + return new DiskResolution( + normalized, + inferredDeviceName, + info.diskKind(), + info.interfaceType(), + info.diskTypeHint(), + rotational + ); + } + + private static StorageClass classify(MountEntry mount, DiskResolution diskResolution) { + var fsType = safeLower(mount.filesystemType()); + var source = mount.source(); + var sourceLower = safeLower(source); + + if ("tmpfs".equals(fsType)) { + return StorageClass.MEMORY_TMPFS; + } + if (NETWORK_FILESYSTEM_TYPES.contains(fsType)) { + return StorageClass.NETWORK_FILESYSTEM; + } + if (isPseudoFileSystem(fsType, sourceLower)) { + return StorageClass.PSEUDO_FILESYSTEM; + } + + if ("scratch".equals(diskResolution.diskKind())) { + if ("NVME".equals(diskResolution.interfaceType()) || sourceLower.contains("nvme")) { + return StorageClass.LOCAL_NVME; + } + return StorageClass.LOCAL_SSD; + } + if ("persistent".equals(diskResolution.diskKind())) { + return classifyPersistentDisk(diskResolution); + } + + // Best-effort fallback based on device name hints and local block characteristics. + var hints = safeLower(diskResolution.deviceName()) + " " + + safeLower(diskResolution.diskTypeHint()) + " " + + sourceLower; + if (hints.contains("local-ssd")) { + return sourceLower.contains("nvme") ? StorageClass.LOCAL_NVME : StorageClass.LOCAL_SSD; + } + if (source != null && source.startsWith("/dev/")) { + if (sourceLower.contains("nvme")) { + return StorageClass.LOCAL_NVME; + } + if (Boolean.TRUE.equals(diskResolution.rotational())) { + return StorageClass.PD_STANDARD_HDD; + } + return StorageClass.LOCAL_SSD; + } + return StorageClass.UNKNOWN; + } + + private static StorageClass classifyPersistentDisk(DiskResolution diskResolution) { + var hints = safeLower(diskResolution.deviceName()) + " " + safeLower(diskResolution.diskTypeHint()); + if (hints.contains("extreme")) { + return StorageClass.PD_EXTREME_SSD; + } + if (hints.contains("throughput")) { + return StorageClass.PD_THROUGHPUT_OPTIMIZED; + } + if (hints.contains("balanced")) { + return StorageClass.PD_BALANCED_SSD; + } + if (hints.contains("pd-ssd") || hints.contains("ssd")) { + return StorageClass.PD_SSD; + } + if (hints.contains("standard")) { + return StorageClass.PD_STANDARD_HDD; + } + + if (Boolean.TRUE.equals(diskResolution.rotational())) { + return StorageClass.PD_STANDARD_HDD; + } + return StorageClass.PD_BALANCED_SSD; + } + + private static String readMetadata(HttpClient client, String relativePath) { + var host = Optional.ofNullable(System.getenv(GCE_METADATA_HOST_ENV)).orElse(METADATA_HOST_DEFAULT); + var uri = URI.create("http://" + host + METADATA_PREFIX + relativePath); + try { + var request = HttpRequest.newBuilder(uri) + .timeout(METADATA_TIMEOUT) + .header(METADATA_FLAVOR_HEADER, METADATA_FLAVOR_VALUE) + .GET() + .build(); + var response = client.send(request, HttpResponse.BodyHandlers.ofString()); + if (response.statusCode() != 200) { + return null; + } + var flavorHeader = response.headers().firstValue(METADATA_FLAVOR_HEADER).orElse(""); + if (!METADATA_FLAVOR_VALUE.equalsIgnoreCase(flavorHeader)) { + return null; + } + return response.body(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } catch (IOException e) { + return null; + } + } + + private static String parseLeafResource(String value) { + if (value == null) { + return null; + } + var trimmed = value.trim(); + if (trimmed.isEmpty()) { + return null; + } + var idx = trimmed.lastIndexOf('/'); + if (idx < 0 || idx == trimmed.length() - 1) { + return trimmed; + } + return trimmed.substring(idx + 1); + } + + private static List readMountEntries() { + var mountsPath = Files.isReadable(Path.of("/proc/self/mounts")) + ? Path.of("/proc/self/mounts") + : Path.of("/proc/mounts"); + + if (!Files.isReadable(mountsPath)) { + return new ArrayList<>(); + } + + var entries = new ArrayList(); + try (Stream lines = Files.lines(mountsPath)) { + lines.forEach(line -> { + var parts = line.split(" "); + if (parts.length < 3) { + return; + } + var source = decodeMountToken(parts[0]); + var mountPoint = decodeMountToken(parts[1]); + var filesystemType = decodeMountToken(parts[2]); + entries.add(new MountEntry(source, mountPoint, filesystemType)); + }); + } catch (IOException ignored) { + return new ArrayList<>(); + } + return entries; + } + + private static Boolean readRotationalFlag(String normalizedDevice) { + if (normalizedDevice == null || !normalizedDevice.startsWith("/dev/")) { + return null; + } + var blockName = normalizedDevice.substring("/dev/".length()); + var rotaPath = Path.of("/sys/class/block", blockName, "queue", "rotational"); + if (!Files.isReadable(rotaPath)) { + return null; + } + try { + var value = Files.readString(rotaPath).trim(); + if ("1".equals(value)) { + return Boolean.TRUE; + } + if ("0".equals(value)) { + return Boolean.FALSE; + } + } catch (IOException ignored) { + return null; + } + return null; + } + + private static boolean isPseudoFileSystem(String fsType, String sourceLower) { + return fsType.equals("proc") + || fsType.equals("sysfs") + || fsType.equals("devpts") + || fsType.equals("devtmpfs") + || fsType.equals("cgroup") + || fsType.equals("cgroup2") + || fsType.equals("autofs") + || fsType.equals("mqueue") + || fsType.equals("tracefs") + || fsType.equals("pstore") + || fsType.equals("securityfs") + || fsType.equals("debugfs") + || fsType.equals("configfs") + || fsType.equals("fusectl") + || fsType.equals("binfmt_misc") + || fsType.equals("rpc_pipefs") + || sourceLower.equals("proc") + || sourceLower.equals("sysfs") + || sourceLower.equals("tmpfs"); + } + + private static String normalizeDevice(String device) { + if (device == null) { + return null; + } + if (!device.startsWith("/dev/")) { + return device; + } + if (device.startsWith("/dev/nvme")) { + return NVME_PARTITION_SUFFIX.matcher(device).replaceAll(""); + } + return GENERIC_PARTITION_SUFFIX.matcher(device).replaceAll(""); + } + + private static String decodeMountToken(String token) { + return token + .replace("\\040", " ") + .replace("\\011", "\t") + .replace("\\012", "\n") + .replace("\\134", "\\"); + } + + private static String stripGooglePrefix(String alias) { + if (alias == null || !alias.startsWith("google-") || alias.length() <= "google-".length()) { + return null; + } + return alias.substring("google-".length()); + } + + private static String safeLower(String value) { + return value == null ? "" : value.toLowerCase(Locale.ROOT); + } + + private static String safeUpper(String value) { + return value == null ? null : value.trim().toUpperCase(Locale.ROOT); + } + + private static final class MountEntry { + private final String source; + private final String mountPoint; + private final String filesystemType; + + private MountEntry(String source, String mountPoint, String filesystemType) { + this.source = source; + this.mountPoint = mountPoint; + this.filesystemType = filesystemType; + } + + private String source() { + return source; + } + + private String mountPoint() { + return mountPoint; + } + + private String filesystemType() { + return filesystemType; + } + } + + private static final class GcpIdentity { + private final String instanceId; + private final String machineType; + private final String zone; + + private GcpIdentity(String instanceId, String machineType, String zone) { + this.instanceId = instanceId; + this.machineType = machineType; + this.zone = zone; + } + + private String instanceId() { + return instanceId; + } + + private String machineType() { + return machineType; + } + + private String zone() { + return zone; + } + } + + private static final class GcpDiskInfo { + private final String deviceName; + private final String diskKind; + private final String interfaceType; + private final String diskTypeHint; + + private GcpDiskInfo(String deviceName, String diskKind, String interfaceType, String diskTypeHint) { + this.deviceName = deviceName; + this.diskKind = diskKind; + this.interfaceType = interfaceType; + this.diskTypeHint = diskTypeHint; + } + + private String deviceName() { + return deviceName; + } + + private String diskKind() { + return diskKind; + } + + private String interfaceType() { + return interfaceType; + } + + private String diskTypeHint() { + return diskTypeHint; + } + } + + private static final class GcpDiskData { + private final Map byDeviceName; + private final Map> aliasesByNormalizedDevice; + + private GcpDiskData(Map byDeviceName, Map> aliasesByNormalizedDevice) { + this.byDeviceName = Objects.requireNonNull(byDeviceName, "byDeviceName"); + this.aliasesByNormalizedDevice = Objects.requireNonNull(aliasesByNormalizedDevice, "aliasesByNormalizedDevice"); + } + + private Map byDeviceName() { + return byDeviceName; + } + + private Map> aliasesByNormalizedDevice() { + return aliasesByNormalizedDevice; + } + + private static GcpDiskData empty() { + return new GcpDiskData(Map.of(), Map.of()); + } + } + + private static final class DiskResolution { + private final String normalizedDevice; + private final String deviceName; + private final String diskKind; + private final String interfaceType; + private final String diskTypeHint; + private final Boolean rotational; + + private DiskResolution(String normalizedDevice, + String deviceName, + String diskKind, + String interfaceType, + String diskTypeHint, + Boolean rotational) { + this.normalizedDevice = normalizedDevice; + this.deviceName = deviceName; + this.diskKind = diskKind; + this.interfaceType = interfaceType; + this.diskTypeHint = diskTypeHint; + this.rotational = rotational; + } + + private static DiskResolution empty() { + return new DiskResolution(null, null, null, null, null, null); + } + + private String normalizedDevice() { + return normalizedDevice; + } + + private String deviceName() { + return deviceName; + } + + private String diskKind() { + return diskKind; + } + + private String interfaceType() { + return interfaceType; + } + + private String diskTypeHint() { + return diskTypeHint; + } + + private Boolean rotational() { + return rotational; + } + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/LocalStorageLayoutUtil.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/LocalStorageLayoutUtil.java new file mode 100644 index 000000000..23dad5a18 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/LocalStorageLayoutUtil.java @@ -0,0 +1,524 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.example.util.storage; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.Stream; + +/** + * Best-effort storage inspection utility for non-cloud environments. + * Supports Linux, macOS, and Windows using local OS signals and common mount metadata. + */ +public final class LocalStorageLayoutUtil { + private static final Pattern LINUX_NVME_PARTITION_SUFFIX = Pattern.compile("p\\d+$"); + private static final Pattern GENERIC_PARTITION_SUFFIX = Pattern.compile("\\d+$"); + private static final Pattern MAC_MOUNT_PATTERN = Pattern.compile("^(.+) on (.+) \\((.+)\\)$"); + private static final Pattern MAC_DISK_SLICE_SUFFIX = Pattern.compile("s\\d+$"); + private static final Set NETWORK_FILESYSTEM_TYPES = + Set.of("nfs", "nfs4", "efs", "cifs", "smbfs", "fuse.sshfs", "afpfs", "webdav", "davfs"); + + private LocalStorageLayoutUtil() { + } + + public enum StorageClass { + LOCAL_HDD, + LOCAL_SSD, + LOCAL_NVME, + NETWORK_FILESYSTEM, + MEMORY_TMPFS, + PSEUDO_FILESYSTEM, + UNKNOWN + } + + public static final class StorageSnapshot { + private final String osName; + private final Map mountsByMountPoint; + + public StorageSnapshot(String osName, Map mountsByMountPoint) { + this.osName = osName; + this.mountsByMountPoint = Objects.requireNonNull(mountsByMountPoint, "mountsByMountPoint"); + } + + public String osName() { + return osName; + } + + public Map mountsByMountPoint() { + return mountsByMountPoint; + } + } + + public static final class MountStorageInfo { + private final String mountPoint; + private final String source; + private final String filesystemType; + private final StorageClass storageClass; + private final String osHint; + + public MountStorageInfo(String mountPoint, + String source, + String filesystemType, + StorageClass storageClass, + String osHint) { + this.mountPoint = mountPoint; + this.source = source; + this.filesystemType = filesystemType; + this.storageClass = Objects.requireNonNull(storageClass, "storageClass"); + this.osHint = osHint; + } + + public String mountPoint() { + return mountPoint; + } + + public String source() { + return source; + } + + public String filesystemType() { + return filesystemType; + } + + public StorageClass storageClass() { + return storageClass; + } + + public String osHint() { + return osHint; + } + } + + public static StorageSnapshot inspectStorage() { + var os = safeLower(System.getProperty("os.name")); + List mounts; + if (isLinux(os)) { + mounts = readLinuxMountEntries(); + } else if (isMac(os)) { + mounts = readMacMountEntries(); + } else if (isWindows(os)) { + mounts = readWindowsMountEntries(); + } else { + mounts = readGenericMountEntries(); + } + + mounts.sort(Comparator.comparing(MountEntry::mountPoint)); + var byMountPoint = new LinkedHashMap(mounts.size()); + for (var mount : mounts) { + StorageClass storageClass; + String osHint; + if (isLinux(os)) { + storageClass = classifyLinux(mount); + osHint = "linux"; + } else if (isMac(os)) { + storageClass = classifyMac(mount); + osHint = "macos"; + } else if (isWindows(os)) { + storageClass = classifyWindows(mount); + osHint = "windows"; + } else { + storageClass = classifyGeneric(mount); + osHint = "generic"; + } + + byMountPoint.put( + mount.mountPoint(), + new MountStorageInfo( + mount.mountPoint(), + mount.source(), + mount.filesystemType(), + storageClass, + osHint + ) + ); + } + + return new StorageSnapshot( + System.getProperty("os.name"), + Collections.unmodifiableMap(byMountPoint) + ); + } + + public static Map storageClassByMountPoint() { + var snapshot = inspectStorage(); + var byMountPoint = new LinkedHashMap(snapshot.mountsByMountPoint().size()); + for (var entry : snapshot.mountsByMountPoint().entrySet()) { + byMountPoint.put(entry.getKey(), entry.getValue().storageClass()); + } + return Collections.unmodifiableMap(byMountPoint); + } + + private static List readLinuxMountEntries() { + var mountsPath = Files.isReadable(Path.of("/proc/self/mounts")) + ? Path.of("/proc/self/mounts") + : Path.of("/proc/mounts"); + if (!Files.isReadable(mountsPath)) { + return new ArrayList<>(); + } + + var entries = new ArrayList(); + try (Stream lines = Files.lines(mountsPath)) { + lines.forEach(line -> { + var parts = line.split(" "); + if (parts.length < 3) { + return; + } + entries.add(new MountEntry( + decodeMountToken(parts[0]), + decodeMountToken(parts[1]), + decodeMountToken(parts[2]) + )); + }); + } catch (IOException ignored) { + return new ArrayList<>(); + } + return entries; + } + + private static List readMacMountEntries() { + var entries = new ArrayList(); + for (String line : runCommandLines("mount")) { + var matcher = MAC_MOUNT_PATTERN.matcher(line); + if (!matcher.matches()) { + continue; + } + var source = matcher.group(1).trim(); + var mountPoint = matcher.group(2).trim(); + var options = matcher.group(3).trim(); + var fsType = options.split(",")[0].trim(); + entries.add(new MountEntry(source, mountPoint, fsType)); + } + if (entries.isEmpty()) { + return readGenericMountEntries(); + } + return entries; + } + + private static List readWindowsMountEntries() { + var entries = new ArrayList(); + var roots = File.listRoots(); + if (roots == null) { + return entries; + } + for (var root : roots) { + if (root == null) { + continue; + } + var path = root.toPath(); + String fsType = "unknown"; + try { + fsType = Files.getFileStore(path).type(); + } catch (IOException ignored) { + // keep default + } + entries.add(new MountEntry(root.getPath(), root.getPath(), fsType)); + } + return entries; + } + + private static List readGenericMountEntries() { + var entries = new ArrayList(); + var roots = File.listRoots(); + if (roots == null) { + return entries; + } + for (var root : roots) { + if (root == null) { + continue; + } + String fsType = "unknown"; + try { + fsType = Files.getFileStore(root.toPath()).type(); + } catch (IOException ignored) { + // keep default + } + entries.add(new MountEntry(root.getPath(), root.getPath(), fsType)); + } + return entries; + } + + private static StorageClass classifyLinux(MountEntry mount) { + var fsType = safeLower(mount.filesystemType()); + var source = mount.source(); + var sourceLower = safeLower(source); + + if ("tmpfs".equals(fsType) || "ramfs".equals(fsType)) { + return StorageClass.MEMORY_TMPFS; + } + if (NETWORK_FILESYSTEM_TYPES.contains(fsType) || sourceLower.startsWith("//")) { + return StorageClass.NETWORK_FILESYSTEM; + } + if (isPseudoFileSystem(fsType, sourceLower)) { + return StorageClass.PSEUDO_FILESYSTEM; + } + + if (source != null && source.startsWith("/dev/")) { + var normalized = normalizeLinuxDevice(sourceLower); + if (normalized.contains("nvme")) { + return StorageClass.LOCAL_NVME; + } + + Boolean rotational = readLinuxRotationalFlag(normalized); + if (Boolean.TRUE.equals(rotational)) { + return StorageClass.LOCAL_HDD; + } + if (Boolean.FALSE.equals(rotational)) { + return StorageClass.LOCAL_SSD; + } + return StorageClass.UNKNOWN; + } + return StorageClass.UNKNOWN; + } + + private static StorageClass classifyMac(MountEntry mount) { + var fsType = safeLower(mount.filesystemType()); + var source = mount.source(); + var sourceLower = safeLower(source); + + if ("devfs".equals(fsType) || "autofs".equals(fsType) || "procfs".equals(fsType)) { + return StorageClass.PSEUDO_FILESYSTEM; + } + if ("tmpfs".equals(fsType) || "ramfs".equals(fsType)) { + return StorageClass.MEMORY_TMPFS; + } + if (NETWORK_FILESYSTEM_TYPES.contains(fsType) || sourceLower.startsWith("//")) { + return StorageClass.NETWORK_FILESYSTEM; + } + + if (source != null && source.startsWith("/dev/")) { + var diskInfo = readMacDiskInfo(source); + if (diskInfo.protocolNvme) { + return StorageClass.LOCAL_NVME; + } + if (diskInfo.solidState != null) { + return diskInfo.solidState ? StorageClass.LOCAL_SSD : StorageClass.LOCAL_HDD; + } + if (sourceLower.contains("nvme")) { + return StorageClass.LOCAL_NVME; + } + return StorageClass.UNKNOWN; + } + return StorageClass.UNKNOWN; + } + + private static StorageClass classifyWindows(MountEntry mount) { + var fsType = safeLower(mount.filesystemType()); + var source = mount.source(); + var sourceLower = safeLower(source); + + if (NETWORK_FILESYSTEM_TYPES.contains(fsType) + || fsType.contains("smb") + || fsType.contains("cifs") + || sourceLower.startsWith("\\\\")) { + return StorageClass.NETWORK_FILESYSTEM; + } + if (fsType.contains("tmp") || fsType.contains("ram")) { + return StorageClass.MEMORY_TMPFS; + } + + // Generic stub: fixed drives are treated as local SSD class when media specifics are unavailable. + if (source != null && source.matches("^[A-Za-z]:\\\\.*")) { + return StorageClass.LOCAL_SSD; + } + return StorageClass.UNKNOWN; + } + + private static StorageClass classifyGeneric(MountEntry mount) { + var fsType = safeLower(mount.filesystemType()); + if ("tmpfs".equals(fsType) || "ramfs".equals(fsType)) { + return StorageClass.MEMORY_TMPFS; + } + if (NETWORK_FILESYSTEM_TYPES.contains(fsType)) { + return StorageClass.NETWORK_FILESYSTEM; + } + return StorageClass.UNKNOWN; + } + + private static boolean isPseudoFileSystem(String fsType, String sourceLower) { + return fsType.equals("proc") + || fsType.equals("sysfs") + || fsType.equals("devpts") + || fsType.equals("devtmpfs") + || fsType.equals("cgroup") + || fsType.equals("cgroup2") + || fsType.equals("autofs") + || fsType.equals("mqueue") + || fsType.equals("tracefs") + || fsType.equals("pstore") + || fsType.equals("securityfs") + || fsType.equals("debugfs") + || fsType.equals("configfs") + || fsType.equals("fusectl") + || fsType.equals("binfmt_misc") + || fsType.equals("rpc_pipefs") + || sourceLower.equals("proc") + || sourceLower.equals("sysfs") + || sourceLower.equals("tmpfs"); + } + + private static String normalizeLinuxDevice(String device) { + if (!device.startsWith("/dev/")) { + return device; + } + if (device.startsWith("/dev/nvme")) { + return LINUX_NVME_PARTITION_SUFFIX.matcher(device).replaceAll(""); + } + return GENERIC_PARTITION_SUFFIX.matcher(device).replaceAll(""); + } + + private static Boolean readLinuxRotationalFlag(String normalizedDevice) { + if (normalizedDevice == null || !normalizedDevice.startsWith("/dev/")) { + return null; + } + var blockName = normalizedDevice.substring("/dev/".length()); + var rotaPath = Path.of("/sys/class/block", blockName, "queue", "rotational"); + if (!Files.isReadable(rotaPath)) { + return null; + } + try { + var value = Files.readString(rotaPath).trim(); + if ("1".equals(value)) { + return Boolean.TRUE; + } + if ("0".equals(value)) { + return Boolean.FALSE; + } + } catch (IOException ignored) { + return null; + } + return null; + } + + private static MacDiskInfo readMacDiskInfo(String sourceDevice) { + var base = sourceDevice; + var slash = sourceDevice.lastIndexOf('/'); + if (slash >= 0 && slash + 1 < sourceDevice.length()) { + var leaf = sourceDevice.substring(slash + 1); + if (leaf.startsWith("disk")) { + leaf = MAC_DISK_SLICE_SUFFIX.matcher(leaf).replaceAll(""); + base = "/dev/" + leaf; + } + } + + Boolean solidState = null; + boolean protocolNvme = false; + for (String line : runCommandLines("diskutil", "info", base)) { + var trimmed = line.trim(); + var lower = safeLower(trimmed); + if (lower.startsWith("solid state:")) { + solidState = lower.endsWith("yes"); + } else if (lower.startsWith("protocol:")) { + protocolNvme = lower.contains("nvme"); + } else if (lower.startsWith("device / media name:") && lower.contains("nvme")) { + protocolNvme = true; + } + } + return new MacDiskInfo(solidState, protocolNvme); + } + + private static List runCommandLines(String... command) { + var lines = new ArrayList(); + var pb = new ProcessBuilder(command); + pb.redirectErrorStream(true); + try { + var process = pb.start(); + try (var reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) { + String line; + while ((line = reader.readLine()) != null) { + lines.add(line); + } + } + process.waitFor(); + } catch (IOException | InterruptedException ignored) { + if (ignored instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + } + return lines; + } + + private static String decodeMountToken(String token) { + return token + .replace("\\040", " ") + .replace("\\011", "\t") + .replace("\\012", "\n") + .replace("\\134", "\\"); + } + + private static boolean isLinux(String osNameLower) { + return osNameLower.contains("linux"); + } + + private static boolean isMac(String osNameLower) { + return osNameLower.contains("mac") || osNameLower.contains("darwin"); + } + + private static boolean isWindows(String osNameLower) { + return osNameLower.contains("win"); + } + + private static String safeLower(String value) { + return value == null ? "" : value.toLowerCase(Locale.ROOT); + } + + private static final class MountEntry { + private final String source; + private final String mountPoint; + private final String filesystemType; + + private MountEntry(String source, String mountPoint, String filesystemType) { + this.source = source; + this.mountPoint = mountPoint; + this.filesystemType = filesystemType; + } + + private String source() { + return source; + } + + private String mountPoint() { + return mountPoint; + } + + private String filesystemType() { + return filesystemType; + } + } + + private static final class MacDiskInfo { + private final Boolean solidState; + private final boolean protocolNvme; + + private MacDiskInfo(Boolean solidState, boolean protocolNvme) { + this.solidState = solidState; + this.protocolNvme = protocolNvme; + } + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/StorageLayoutUtil.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/StorageLayoutUtil.java new file mode 100644 index 000000000..bfbaed234 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/StorageLayoutUtil.java @@ -0,0 +1,584 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.example.util.storage; + +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeInstancesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeVolumesRequest; +import software.amazon.awssdk.services.ec2.model.InstanceBlockDeviceMapping; +import software.amazon.awssdk.services.ec2.model.Volume; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.Stream; + +/** + * Detects EC2 runtime context via IMDSv2 and classifies storage for each mounted filesystem. + */ +public final class StorageLayoutUtil { + private static final String AWS_EC2_METADATA_DISABLED = "AWS_EC2_METADATA_DISABLED"; + private static final URI IMDS_TOKEN_URI = URI.create("http://169.254.169.254/latest/api/token"); + private static final URI IMDS_IDENTITY_URI = URI.create("http://169.254.169.254/latest/dynamic/instance-identity/document"); + private static final Duration IMDS_TIMEOUT = Duration.ofMillis(300); + private static final String IMDS_TOKEN_HEADER = "X-aws-ec2-metadata-token"; + private static final String IMDS_TOKEN_TTL_HEADER = "X-aws-ec2-metadata-token-ttl-seconds"; + + private static final Pattern JSON_FIELD_PATTERN = Pattern.compile("\"([^\"]+)\"\\s*:\\s*\"([^\"]+)\""); + private static final Pattern VOL_ID_PATTERN = Pattern.compile("vol-?[0-9a-fA-F]+"); + private static final Pattern NVME_PARTITION_SUFFIX = Pattern.compile("p\\d+$"); + private static final Pattern GENERIC_PARTITION_SUFFIX = Pattern.compile("\\d+$"); + private static final Set NETWORK_FILESYSTEM_TYPES = Set.of("nfs", "nfs4", "efs", "cifs", "smbfs", "fuse.sshfs"); + + private StorageLayoutUtil() { + } + + public enum StorageClass { + // Slowest EBS tiers + EBS_COLD_HDD, + EBS_THROUGHPUT_HDD, + EBS_MAGNETIC, + + // Faster EBS SSD tiers + EBS_GP2, + EBS_GP3, + EBS_PROVISIONED_IOPS_SSD, + + // Local instance storage + INSTANCE_STORE_SSD, + INSTANCE_STORE_NVME, + + // Non-block storage + NETWORK_FILESYSTEM, + MEMORY_TMPFS, + PSEUDO_FILESYSTEM, + UNKNOWN + } + + public static final class StorageSnapshot { + private final boolean runningOnEc2; + private final String instanceId; + private final String instanceType; + private final String region; + private final Map mountsByMountPoint; + + public StorageSnapshot(boolean runningOnEc2, + String instanceId, + String instanceType, + String region, + Map mountsByMountPoint) { + this.runningOnEc2 = runningOnEc2; + this.instanceId = instanceId; + this.instanceType = instanceType; + this.region = region; + this.mountsByMountPoint = Objects.requireNonNull(mountsByMountPoint, "mountsByMountPoint"); + } + + public boolean runningOnEc2() { + return runningOnEc2; + } + + public String instanceId() { + return instanceId; + } + + public String instanceType() { + return instanceType; + } + + public String region() { + return region; + } + + public Map mountsByMountPoint() { + return mountsByMountPoint; + } + } + + public static final class MountStorageInfo { + private final String mountPoint; + private final String source; + private final String filesystemType; + private final StorageClass storageClass; + private final String volumeId; + private final String volumeType; + + public MountStorageInfo(String mountPoint, + String source, + String filesystemType, + StorageClass storageClass, + String volumeId, + String volumeType) { + this.mountPoint = mountPoint; + this.source = source; + this.filesystemType = filesystemType; + this.storageClass = Objects.requireNonNull(storageClass, "storageClass"); + this.volumeId = volumeId; + this.volumeType = volumeType; + } + + public String mountPoint() { + return mountPoint; + } + + public String source() { + return source; + } + + public String filesystemType() { + return filesystemType; + } + + public StorageClass storageClass() { + return storageClass; + } + + public String volumeId() { + return volumeId; + } + + public String volumeType() { + return volumeType; + } + } + + public static StorageSnapshot inspectStorage() { + var identity = fetchEc2Identity(); + var mounts = readMountEntries(); + var ec2Data = identity.map(StorageLayoutUtil::fetchEc2VolumeData).orElse(Ec2VolumeData.empty()); + + mounts.sort(Comparator.comparing(MountEntry::mountPoint)); + var byMountPoint = new LinkedHashMap(mounts.size()); + for (var mount : mounts) { + var resolvedVolumeId = resolveVolumeId(mount.source(), ec2Data); + var volumeType = resolvedVolumeId == null ? null : ec2Data.volumeTypeById().get(resolvedVolumeId); + var storageClass = classify(mount, resolvedVolumeId, volumeType); + byMountPoint.put( + mount.mountPoint(), + new MountStorageInfo( + mount.mountPoint(), + mount.source(), + mount.filesystemType(), + storageClass, + resolvedVolumeId, + volumeType + ) + ); + } + + return new StorageSnapshot( + identity.isPresent(), + identity.map(Ec2Identity::instanceId).orElse(null), + identity.map(Ec2Identity::instanceType).orElse(null), + identity.map(Ec2Identity::region).orElse(null), + Collections.unmodifiableMap(byMountPoint) + ); + } + + public static Map storageClassByMountPoint() { + var snapshot = inspectStorage(); + var byMountPoint = new LinkedHashMap(snapshot.mountsByMountPoint().size()); + for (var entry : snapshot.mountsByMountPoint().entrySet()) { + byMountPoint.put(entry.getKey(), entry.getValue().storageClass()); + } + return Collections.unmodifiableMap(byMountPoint); + } + + private static Optional fetchEc2Identity() { + var imdsDisabled = System.getenv(AWS_EC2_METADATA_DISABLED); + if (imdsDisabled != null && "true".equalsIgnoreCase(imdsDisabled)) { + return Optional.empty(); + } + + var client = HttpClient.newBuilder() + .connectTimeout(IMDS_TIMEOUT) + .build(); + try { + var tokenRequest = HttpRequest.newBuilder(IMDS_TOKEN_URI) + .timeout(IMDS_TIMEOUT) + .header(IMDS_TOKEN_TTL_HEADER, "60") + .method("PUT", HttpRequest.BodyPublishers.noBody()) + .build(); + var tokenResponse = client.send(tokenRequest, HttpResponse.BodyHandlers.ofString()); + if (tokenResponse.statusCode() != 200) { + return Optional.empty(); + } + + var token = tokenResponse.body(); + if (token == null || token.isBlank()) { + return Optional.empty(); + } + + var identityRequest = HttpRequest.newBuilder(IMDS_IDENTITY_URI) + .timeout(IMDS_TIMEOUT) + .header(IMDS_TOKEN_HEADER, token) + .GET() + .build(); + var identityResponse = client.send(identityRequest, HttpResponse.BodyHandlers.ofString()); + if (identityResponse.statusCode() != 200) { + return Optional.empty(); + } + + return parseIdentity(identityResponse.body()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return Optional.empty(); + } catch (IOException e) { + return Optional.empty(); + } + } + + private static Optional parseIdentity(String json) { + if (json == null || json.isBlank()) { + return Optional.empty(); + } + var values = new LinkedHashMap(); + var matcher = JSON_FIELD_PATTERN.matcher(json); + while (matcher.find()) { + values.put(matcher.group(1), matcher.group(2)); + } + + var instanceId = values.get("instanceId"); + var instanceType = values.get("instanceType"); + var region = values.get("region"); + if (instanceId == null || instanceType == null || region == null) { + return Optional.empty(); + } + return Optional.of(new Ec2Identity(instanceId, instanceType, region)); + } + + private static Ec2VolumeData fetchEc2VolumeData(Ec2Identity identity) { + var deviceNameToVolumeId = new LinkedHashMap(); + var volumeTypeById = new LinkedHashMap(); + var nvmeDeviceToVolumeId = mapNvmeDevicesToVolumeIds(); + + try (var ec2 = Ec2Client.builder().region(Region.of(identity.region())).build()) { + var instanceRequest = DescribeInstancesRequest.builder() + .instanceIds(identity.instanceId()) + .build(); + var instanceResponse = ec2.describeInstances(instanceRequest); + var reservations = instanceResponse.reservations(); + if (reservations != null) { + for (var reservation : reservations) { + for (var instance : reservation.instances()) { + for (InstanceBlockDeviceMapping mapping : instance.blockDeviceMappings()) { + if (mapping.ebs() == null || mapping.ebs().volumeId() == null || mapping.deviceName() == null) { + continue; + } + deviceNameToVolumeId.put(normalizeDevice(mapping.deviceName()), mapping.ebs().volumeId()); + } + } + } + } + + if (!deviceNameToVolumeId.isEmpty()) { + var volumeResponse = ec2.describeVolumes(DescribeVolumesRequest.builder() + .volumeIds(deviceNameToVolumeId.values()) + .build()); + for (Volume volume : volumeResponse.volumes()) { + if (volume.volumeId() != null && volume.volumeType() != null) { + volumeTypeById.put(volume.volumeId(), volume.volumeTypeAsString()); + } + } + } + } catch (RuntimeException ignored) { + // If IAM permissions or service calls fail, we still return mount classifications. + } + + return new Ec2VolumeData(deviceNameToVolumeId, nvmeDeviceToVolumeId, volumeTypeById); + } + + private static List readMountEntries() { + var mountsPath = Files.isReadable(Path.of("/proc/self/mounts")) + ? Path.of("/proc/self/mounts") + : Path.of("/proc/mounts"); + + if (!Files.isReadable(mountsPath)) { + return new ArrayList<>(); + } + + var entries = new ArrayList(); + try (Stream lines = Files.lines(mountsPath)) { + lines.forEach(line -> { + var parts = line.split(" "); + if (parts.length < 3) { + return; + } + var source = decodeMountToken(parts[0]); + var mountPoint = decodeMountToken(parts[1]); + var filesystemType = decodeMountToken(parts[2]); + entries.add(new MountEntry(source, mountPoint, filesystemType)); + }); + } catch (IOException ignored) { + return new ArrayList<>(); + } + return entries; + } + + private static Map mapNvmeDevicesToVolumeIds() { + var byIdDir = Path.of("/dev/disk/by-id"); + if (!Files.isDirectory(byIdDir)) { + return Map.of(); + } + + var mapping = new LinkedHashMap(); + try (Stream entries = Files.list(byIdDir)) { + entries.filter(Files::isSymbolicLink).forEach(link -> { + var name = link.getFileName().toString(); + if (!name.startsWith("nvme-Amazon_Elastic_Block_Store_")) { + return; + } + var volumeId = extractVolumeId(name); + if (volumeId == null) { + return; + } + + try { + var target = normalizeDevice(link.toRealPath().toString()); + mapping.put(target, volumeId); + } catch (IOException ignored) { + // continue + } + }); + } catch (IOException ignored) { + return Map.of(); + } + return mapping; + } + + private static String resolveVolumeId(String mountSource, Ec2VolumeData ec2Data) { + if (mountSource == null || !mountSource.startsWith("/dev/")) { + return null; + } + + var normalized = normalizeDevice(mountSource); + var byNvme = ec2Data.nvmeDeviceToVolumeId().get(normalized); + if (byNvme != null) { + return byNvme; + } + return ec2Data.deviceNameToVolumeId().get(normalized); + } + + private static StorageClass classify(MountEntry mount, String volumeId, String volumeType) { + var fsType = safeLower(mount.filesystemType()); + var source = mount.source(); + var sourceLower = safeLower(source); + + if ("tmpfs".equals(fsType)) { + return StorageClass.MEMORY_TMPFS; + } + if (NETWORK_FILESYSTEM_TYPES.contains(fsType)) { + return StorageClass.NETWORK_FILESYSTEM; + } + if (isPseudoFileSystem(fsType, sourceLower)) { + return StorageClass.PSEUDO_FILESYSTEM; + } + + if (volumeId != null) { + return mapEbsVolumeType(volumeType); + } + + if (source != null && source.startsWith("/dev/")) { + if (sourceLower.contains("nvme")) { + return StorageClass.INSTANCE_STORE_NVME; + } + return StorageClass.INSTANCE_STORE_SSD; + } + return StorageClass.UNKNOWN; + } + + private static StorageClass mapEbsVolumeType(String volumeType) { + if (volumeType == null) { + return StorageClass.EBS_GP3; + } + + switch (safeLower(volumeType)) { + case "sc1": + return StorageClass.EBS_COLD_HDD; + case "st1": + return StorageClass.EBS_THROUGHPUT_HDD; + case "standard": + return StorageClass.EBS_MAGNETIC; + case "io1": + case "io2": + return StorageClass.EBS_PROVISIONED_IOPS_SSD; + case "gp2": + return StorageClass.EBS_GP2; + case "gp3": + return StorageClass.EBS_GP3; + default: + return StorageClass.EBS_GP3; + } + } + + private static boolean isPseudoFileSystem(String fsType, String sourceLower) { + return fsType.equals("proc") + || fsType.equals("sysfs") + || fsType.equals("devpts") + || fsType.equals("devtmpfs") + || fsType.equals("cgroup") + || fsType.equals("cgroup2") + || fsType.equals("autofs") + || fsType.equals("mqueue") + || fsType.equals("tracefs") + || fsType.equals("pstore") + || fsType.equals("securityfs") + || fsType.equals("debugfs") + || fsType.equals("configfs") + || fsType.equals("fusectl") + || fsType.equals("binfmt_misc") + || fsType.equals("rpc_pipefs") + || sourceLower.equals("proc") + || sourceLower.equals("sysfs") + || sourceLower.equals("tmpfs"); + } + + private static String decodeMountToken(String token) { + return token + .replace("\\040", " ") + .replace("\\011", "\t") + .replace("\\012", "\n") + .replace("\\134", "\\"); + } + + private static String extractVolumeId(String value) { + var matcher = VOL_ID_PATTERN.matcher(value); + if (!matcher.find()) { + return null; + } + var raw = matcher.group(); + if (raw.startsWith("vol-")) { + return raw.toLowerCase(Locale.ROOT); + } + return "vol-" + raw.substring(3).toLowerCase(Locale.ROOT); + } + + private static String normalizeDevice(String device) { + if (device == null) { + return null; + } + if (!device.startsWith("/dev/")) { + return device; + } + + if (device.startsWith("/dev/nvme")) { + return NVME_PARTITION_SUFFIX.matcher(device).replaceAll(""); + } + return GENERIC_PARTITION_SUFFIX.matcher(device).replaceAll(""); + } + + private static String safeLower(String value) { + return value == null ? "" : value.toLowerCase(Locale.ROOT); + } + + private static final class MountEntry { + private final String source; + private final String mountPoint; + private final String filesystemType; + + private MountEntry(String source, String mountPoint, String filesystemType) { + this.source = source; + this.mountPoint = mountPoint; + this.filesystemType = filesystemType; + } + + private String source() { + return source; + } + + private String mountPoint() { + return mountPoint; + } + + private String filesystemType() { + return filesystemType; + } + } + + private static final class Ec2Identity { + private final String instanceId; + private final String instanceType; + private final String region; + + private Ec2Identity(String instanceId, String instanceType, String region) { + this.instanceId = instanceId; + this.instanceType = instanceType; + this.region = region; + } + + private String instanceId() { + return instanceId; + } + + private String instanceType() { + return instanceType; + } + + private String region() { + return region; + } + } + + private static final class Ec2VolumeData { + private final Map deviceNameToVolumeId; + private final Map nvmeDeviceToVolumeId; + private final Map volumeTypeById; + + private Ec2VolumeData(Map deviceNameToVolumeId, + Map nvmeDeviceToVolumeId, + Map volumeTypeById) { + Objects.requireNonNull(deviceNameToVolumeId, "deviceNameToVolumeId"); + Objects.requireNonNull(nvmeDeviceToVolumeId, "nvmeDeviceToVolumeId"); + Objects.requireNonNull(volumeTypeById, "volumeTypeById"); + this.deviceNameToVolumeId = deviceNameToVolumeId; + this.nvmeDeviceToVolumeId = nvmeDeviceToVolumeId; + this.volumeTypeById = volumeTypeById; + } + + private Map deviceNameToVolumeId() { + return deviceNameToVolumeId; + } + + private Map nvmeDeviceToVolumeId() { + return nvmeDeviceToVolumeId; + } + + private Map volumeTypeById() { + return volumeTypeById; + } + + private static Ec2VolumeData empty() { + return new Ec2VolumeData(Map.of(), Map.of(), Map.of()); + } + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/package-info.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/package-info.java new file mode 100644 index 000000000..a553c8b23 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/storage/package-info.java @@ -0,0 +1,28 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides utilities for characterizing the underlying storage hardware and layout. + *

+ * This package contains logic to detect and classify storage tiers (e.g., Local SSD, + * Persistent Disk, Network Filesystem) across different environments including + * AWS, GCP, and local development machines. + *

+ * The primary entry point is {@link io.github.jbellis.jvector.example.util.storage.CloudStorageLayoutUtil}, + * which provides a unified view of the system's mount points and their corresponding + * {@link io.github.jbellis.jvector.example.util.storage.CloudStorageLayoutUtil.StorageClass}. + */ +package io.github.jbellis.jvector.example.util.storage; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/TestDataPartition.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/TestDataPartition.java new file mode 100644 index 000000000..b7592f042 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/TestDataPartition.java @@ -0,0 +1,84 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.example.yaml; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Specifically for defining how data is partitioned for testing compaction. + */ +public class TestDataPartition { + public List numSplits; + public List splitDistribution; + + public TestDataPartition() { + this.numSplits = Collections.singletonList(1); + this.splitDistribution = Collections.singletonList(Distribution.UNIFORM); + } + + public TestDataPartition(int numSplits) { + this.numSplits = Collections.singletonList(numSplits); + this.splitDistribution = Collections.singletonList(Distribution.UNIFORM); + } + + public enum Distribution { + UNIFORM, + FIBONACCI, + LOG2N; + + public List computeSplitSizes(int total, int numSplits) { + int[] weights = new int[numSplits]; + switch (this) { + case UNIFORM: + for (int i = 0; i < numSplits; i++) weights[i] = 1; + break; + case FIBONACCI: + int a = 1, b = 2; + weights[0] = 1; + for (int i = 1; i < numSplits; i++) { + weights[i] = b; + int next = a + b; + a = b; + b = next; + } + break; + case LOG2N: + for (int i = 0; i < numSplits; i++) weights[i] = 1 << i; + break; + } + + long weightSum = 0; + for (int w : weights) weightSum += w; + + List sizes = new ArrayList<>(numSplits); + int assigned = 0; + for (int i = 0; i < numSplits; i++) { + int size; + if (i == numSplits - 1) { + size = total - assigned; + } else { + size = (int) (((long) weights[i] * total) / weightSum); + } + sizes.add(size); + assigned += size; + } + return sizes; + } + } +} diff --git a/jvector-examples/src/main/resources/log4j2.xml b/jvector-examples/src/main/resources/log4j2.xml new file mode 100644 index 000000000..83c77bced --- /dev/null +++ b/jvector-examples/src/main/resources/log4j2.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndexCompactor.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndexCompactor.java new file mode 100644 index 000000000..410b96d0e --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndexCompactor.java @@ -0,0 +1,774 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.disk.ReaderSupplier; +import io.github.jbellis.jvector.disk.ReaderSupplierFactory; +import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.example.util.AccuracyMetrics; +import io.github.jbellis.jvector.graph.*; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.util.BoundedLongHeap; +import io.github.jbellis.jvector.util.FixedBitSet; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.concurrent.ForkJoinPool; +import java.util.function.IntFunction; + +import static io.github.jbellis.jvector.TestUtil.createRandomVectors; +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestOnDiskGraphIndexCompactor extends RandomizedTest { + private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + + private ImmutableGraphIndex golden; + private Path testDirectory; + List> allVecs = new ArrayList<>(); + int dimension = 32; + int numVectorsPerGraph = 256; + int numSources = 3; + int numQueries = 20; + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.COSINE; + RandomAccessVectorValues allravv; + private final ForkJoinPool simdExecutor = ForkJoinPool.commonPool(); + private final ForkJoinPool parallelExecutor = ForkJoinPool.commonPool(); + + @Before + public void setup() throws IOException { + testDirectory = Files.createTempDirectory("jvector_test"); + buildFusedPQ(); + buildGoldenPQ(); + } + + /** + * Builds source graphs with FusedPQ feature enabled. + * Uses random vectors with COSINE similarity. + */ + void buildFusedPQ() throws IOException { + for(int i = 0; i < numSources; ++i) { + List> vecs = createRandomVectors(numVectorsPerGraph, dimension); + + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vecs, dimension); + ProductQuantization pq = ProductQuantization.compute(ravv, 8, 256, true, UNWEIGHTED, simdExecutor, parallelExecutor); + PQVectors pqv = (PQVectors) pq.encodeAll(ravv, simdExecutor); + var bsp = BuildScoreProvider.pqBuildScoreProvider(similarityFunction, pqv); + var builder = new GraphIndexBuilder(bsp, dimension, 16, 100, 1.2f, 1.2f, false, true, simdExecutor, parallelExecutor); + var graph = builder.getGraph(); + + var outputPath = testDirectory.resolve("test_graph_" + i); + Map> writeSuppliers = new EnumMap<>(FeatureId.class); + writeSuppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(ravv.getVector(ordinal))); + + var identityMapper = new OrdinalMapper.IdentityMapper(ravv.size() - 1); + var writerBuilder = new OnDiskGraphIndexWriter.Builder(graph, outputPath); + writerBuilder.withMapper(identityMapper); + writerBuilder.with(new InlineVectors(dimension)); + writerBuilder.with(new FusedPQ(graph.maxDegree(), pq)); + var writer = writerBuilder.build(); + + for (var node = 0; node < ravv.size(); node++) { + var stateMap = new EnumMap(FeatureId.class); + stateMap.put(FeatureId.INLINE_VECTORS, writeSuppliers.get(FeatureId.INLINE_VECTORS).apply(node)); + writer.writeInline(node, stateMap); + builder.addGraphNode(node, ravv.getVector(node)); + } + builder.cleanup(); + + writeSuppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(graph.getView(), pqv, ordinal)); + writer.write(writeSuppliers); + allVecs.addAll(vecs); + } + } + + /** + * Builds the golden graph from all vectors combined. + * This represents the ideal case of building from scratch. + */ + void buildGoldenPQ() throws IOException { + allravv = new ListRandomAccessVectorValues(allVecs, dimension); + + ProductQuantization pq = ProductQuantization.compute(allravv, 8, 256, true, UNWEIGHTED, simdExecutor, parallelExecutor); + PQVectors pqv = (PQVectors) pq.encodeAll(allravv, simdExecutor); + var bsp = BuildScoreProvider.pqBuildScoreProvider(similarityFunction, pqv); + var builder = new GraphIndexBuilder(bsp, dimension, 16, 100, 1.2f, 1.2f, false, true, simdExecutor, parallelExecutor); + for (var i = 0; i < allravv.size(); i++) { + builder.addGraphNode(i, allravv.getVector(i)); + } + builder.cleanup(); + golden = builder.getGraph(); + } + List searchFromAll(List> queries, int topK) { + List srs = new ArrayList<>(); + try (GraphSearcher searcher = new GraphSearcher(golden)) { + for(VectorFloat q: queries) { + var row = new ArrayList(); + SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(q, similarityFunction, allravv); + SearchResult sr = searcher.search(ssp, topK, Bits.ALL); + srs.add(sr); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + return srs; + } + List> buildGT(List> queries, int topK) { + List> rows = new ArrayList<>(); + + for(int i = 0; i < queries.size(); ++i) { + NodeQueue expected = new NodeQueue(new BoundedLongHeap(topK), NodeQueue.Order.MIN_HEAP); + for (int j = 0; j < allVecs.size(); j++) { + expected.push(j, similarityFunction.compare(queries.get(i), allVecs.get(j))); + } + + var row = new ArrayList(); + for(int k = 0; k < topK; ++k) { + row.add(expected.pop()); + } + rows.add(row); + } + return rows; + } + + @After + public void tearDown() { + TestUtil.deleteQuietly(testDirectory); + } + + /** + * Builds a small source graph with InlineVectors only (no FusedPQ), using exact scoring. + * Returns the path to the written graph file. + */ + private Path buildSimpleSourceGraph(List> vecs, int dim, VectorSimilarityFunction vsf, String name) throws IOException { + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vecs, dim); + var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, vsf); + var builder = new GraphIndexBuilder(bsp, dim, 4, 20, 1.2f, 1.2f, false, true, simdExecutor, parallelExecutor); + for (int i = 0; i < vecs.size(); i++) { + builder.addGraphNode(i, vecs.get(i)); + } + builder.cleanup(); + var graph = builder.getGraph(); + + var outputPath = testDirectory.resolve(name); + var identityMapper = new OrdinalMapper.IdentityMapper(vecs.size() - 1); + var writerBuilder = new OnDiskGraphIndexWriter.Builder(graph, outputPath); + writerBuilder.withMapper(identityMapper); + writerBuilder.with(new InlineVectors(dim)); + var writer = writerBuilder.build(); + + Map> writeSuppliers = new EnumMap<>(FeatureId.class); + writeSuppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(ravv.getVector(ordinal))); + + for (int node = 0; node < vecs.size(); node++) { + var stateMap = new EnumMap(FeatureId.class); + stateMap.put(FeatureId.INLINE_VECTORS, writeSuppliers.get(FeatureId.INLINE_VECTORS).apply(node)); + writer.writeInline(node, stateMap); + } + writer.write(writeSuppliers); + return outputPath; + } + + /** Creates a vector of the given dimension with value at index {@code hot} set to {@code val}, rest 0. */ + private VectorFloat makeVec(int dim, int hot, float val) { + VectorFloat v = vectorTypeSupport.createFloatVector(dim); + for (int d = 0; d < dim; d++) { + v.set(d, d == hot ? val : 0.0f); + } + return v; + } + + private void assertVecEquals(VectorFloat expected, VectorFloat actual, int ordinal) { + int dim = expected.length(); + assertEquals("dimension mismatch at ordinal " + ordinal, dim, actual.length()); + for (int d = 0; d < dim; d++) { + assertEquals(String.format("vector[%d] dim %d mismatch", ordinal, d), expected.get(d), actual.get(d), 0.0f); + } + } + + /** + * Tests that vectors are stored exactly at the expected global ordinals after compaction. + * Uses two small sources with simple, known float values and identity mapping. + */ + @Test + public void testExactVectorValuesAfterCompaction() throws Exception { + int dim = 4; + int n = 6; // nodes per source + VectorSimilarityFunction vsf = VectorSimilarityFunction.EUCLIDEAN; + + // Source 0: vectors with first dim varying by index + List> vecs0 = new ArrayList<>(); + for (int i = 0; i < n; i++) { + vecs0.add(makeVec(dim, 0, (float)(i + 1))); + } + // Source 1: vectors with second dim varying by index + List> vecs1 = new ArrayList<>(); + for (int i = 0; i < n; i++) { + vecs1.add(makeVec(dim, 1, (float)(i + 10))); + } + + Path path0 = buildSimpleSourceGraph(vecs0, dim, vsf, "simple_src_0"); + Path path1 = buildSimpleSourceGraph(vecs1, dim, vsf, "simple_src_1"); + + ReaderSupplier rs0 = ReaderSupplierFactory.open(path0); + ReaderSupplier rs1 = ReaderSupplierFactory.open(path1); + OnDiskGraphIndex g0 = OnDiskGraphIndex.load(rs0); + OnDiskGraphIndex g1 = OnDiskGraphIndex.load(rs1); + + // Identity remapping: source i -> global ordinals [i*n, (i+1)*n) + Map map0 = new HashMap<>(); + Map map1 = new HashMap<>(); + for (int i = 0; i < n; i++) { + map0.put(i, i); + map1.put(i, n + i); + } + + FixedBitSet live0 = new FixedBitSet(n); + live0.set(0, n); + FixedBitSet live1 = new FixedBitSet(n); + live1.set(0, n); + + var compactor = new OnDiskGraphIndexCompactor( + List.of(g0, g1), + List.of(live0, live1), + List.of(new OrdinalMapper.MapMapper(map0), new OrdinalMapper.MapMapper(map1)), + vsf, null); + + Path outPath = testDirectory.resolve("simple_compact_out"); + compactor.compact(outPath); + + ReaderSupplier rsOut = ReaderSupplierFactory.open(outPath); + OnDiskGraphIndex compacted = OnDiskGraphIndex.load(rsOut); + assertEquals(2 * n, compacted.size(0)); + + var view = compacted.getView(); + VectorFloat buf = vectorTypeSupport.createFloatVector(dim); + + // Source 0 vectors must be at ordinals 0..n-1 + for (int i = 0; i < n; i++) { + view.getVectorInto(i, buf, 0); + assertVecEquals(vecs0.get(i), buf, i); + } + // Source 1 vectors must be at ordinals n..2n-1 + for (int i = 0; i < n; i++) { + view.getVectorInto(n + i, buf, 0); + assertVecEquals(vecs1.get(i), buf, n + i); + } + } + + /** + * Tests that only live vectors appear after compaction, placed at the correct remapped ordinals. + * Deletes every other node from each source and verifies the compacted output exactly. + */ + @Test + public void testExactVectorValuesWithDeletions() throws Exception { + int dim = 4; + int n = 8; // nodes per source + VectorSimilarityFunction vsf = VectorSimilarityFunction.EUCLIDEAN; + + // Source 0: vectors [1,0,0,0] through [8,0,0,0] + List> vecs0 = new ArrayList<>(); + for (int i = 0; i < n; i++) { + vecs0.add(makeVec(dim, 0, (float)(i + 1))); + } + // Source 1: vectors [0,10,0,0] through [0,170,0,0] + List> vecs1 = new ArrayList<>(); + for (int i = 0; i < n; i++) { + vecs1.add(makeVec(dim, 1, (float)((i + 1) * 10))); + } + + Path path0 = buildSimpleSourceGraph(vecs0, dim, vsf, "del_src_0"); + Path path1 = buildSimpleSourceGraph(vecs1, dim, vsf, "del_src_1"); + + ReaderSupplier rs0 = ReaderSupplierFactory.open(path0); + ReaderSupplier rs1 = ReaderSupplierFactory.open(path1); + OnDiskGraphIndex g0 = OnDiskGraphIndex.load(rs0); + OnDiskGraphIndex g1 = OnDiskGraphIndex.load(rs1); + + // Keep only even-indexed nodes (0, 2, 4, 6) in both sources + FixedBitSet live0 = new FixedBitSet(n); + FixedBitSet live1 = new FixedBitSet(n); + Map map0 = new HashMap<>(); + Map map1 = new HashMap<>(); + int globalOrdinal = 0; + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + live0.set(i); + map0.put(i, globalOrdinal++); + } + } + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + live1.set(i); + map1.put(i, globalOrdinal++); + } + } + int expectedTotal = globalOrdinal; + + var compactor = new OnDiskGraphIndexCompactor( + List.of(g0, g1), + List.of(live0, live1), + List.of(new OrdinalMapper.MapMapper(map0), new OrdinalMapper.MapMapper(map1)), + vsf, null); + + Path outPath = testDirectory.resolve("del_compact_out"); + compactor.compact(outPath); + + ReaderSupplier rsOut = ReaderSupplierFactory.open(outPath); + OnDiskGraphIndex compacted = OnDiskGraphIndex.load(rsOut); + assertEquals(expectedTotal, compacted.size(0)); + + var view = compacted.getView(); + VectorFloat buf = vectorTypeSupport.createFloatVector(dim); + + // Verify source 0 live nodes at their mapped ordinals + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + int ord = map0.get(i); + view.getVectorInto(ord, buf, 0); + assertVecEquals(vecs0.get(i), buf, ord); + } + } + // Verify source 1 live nodes at their mapped ordinals + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + int ord = map1.get(i); + view.getVectorInto(ord, buf, 0); + assertVecEquals(vecs1.get(i), buf, ord); + } + } + } + + /** + * Tests that vectors end up at the correct ordinals when a non-sequential remapping is used. + * Source 0 is mapped in reverse order; source 1 is mapped in forward order. + * Verifies exact vector values at every remapped position. + */ + @Test + public void testExactVectorValuesWithCustomRemapping() throws Exception { + int dim = 4; + int n = 6; + VectorSimilarityFunction vsf = VectorSimilarityFunction.EUCLIDEAN; + + List> vecs0 = new ArrayList<>(); + for (int i = 0; i < n; i++) { + vecs0.add(makeVec(dim, 2, (float)(i + 1))); + } + List> vecs1 = new ArrayList<>(); + for (int i = 0; i < n; i++) { + vecs1.add(makeVec(dim, 3, (float)(i + 100))); + } + + Path path0 = buildSimpleSourceGraph(vecs0, dim, vsf, "remap_src_0"); + Path path1 = buildSimpleSourceGraph(vecs1, dim, vsf, "remap_src_1"); + + ReaderSupplier rs0 = ReaderSupplierFactory.open(path0); + ReaderSupplier rs1 = ReaderSupplierFactory.open(path1); + OnDiskGraphIndex g0 = OnDiskGraphIndex.load(rs0); + OnDiskGraphIndex g1 = OnDiskGraphIndex.load(rs1); + + // Source 0: reverse mapping (local 0 -> global n-1, local 1 -> global n-2, ...) + Map map0 = new HashMap<>(); + for (int i = 0; i < n; i++) { + map0.put(i, n - 1 - i); + } + // Source 1: forward mapping (local 0 -> global n, local 1 -> global n+1, ...) + Map map1 = new HashMap<>(); + for (int i = 0; i < n; i++) { + map1.put(i, n + i); + } + + FixedBitSet live0 = new FixedBitSet(n); + live0.set(0, n); + FixedBitSet live1 = new FixedBitSet(n); + live1.set(0, n); + + var compactor = new OnDiskGraphIndexCompactor( + List.of(g0, g1), + List.of(live0, live1), + List.of(new OrdinalMapper.MapMapper(map0), new OrdinalMapper.MapMapper(map1)), + vsf, null); + + Path outPath = testDirectory.resolve("remap_compact_out"); + compactor.compact(outPath); + + ReaderSupplier rsOut = ReaderSupplierFactory.open(outPath); + OnDiskGraphIndex compacted = OnDiskGraphIndex.load(rsOut); + assertEquals(2 * n, compacted.size(0)); + + var view = compacted.getView(); + VectorFloat buf = vectorTypeSupport.createFloatVector(dim); + + for (int i = 0; i < n; i++) { + int ord = map0.get(i); + view.getVectorInto(ord, buf, 0); + assertVecEquals(vecs0.get(i), buf, ord); + } + for (int i = 0; i < n; i++) { + int ord = map1.get(i); + view.getVectorInto(ord, buf, 0); + assertVecEquals(vecs1.get(i), buf, ord); + } + } + + /** + * Tests basic compaction: merging multiple graphs without deletions. + * Verifies that compacted graph recall is comparable to golden graph. + */ + @Test + public void testCompact() throws Exception { + List graphs = new ArrayList<>(); + List rss = new ArrayList<>(); + List liveNodes = new ArrayList<>(); + List remappers = new ArrayList<>(); + + // Load all source graphs + for(int i = 0; i < numSources; ++i) { + var outputPath = testDirectory.resolve("test_graph_" + i); + rss.add(ReaderSupplierFactory.open(outputPath.toAbsolutePath())); + var onDiskGraph = OnDiskGraphIndex.load(rss.get(i)); + graphs.add(onDiskGraph); + } + + // Create identity mapping and all nodes live + int globalOrdinal = 0; + for (int n = 0; n < numSources; n++) { + Map map = new HashMap<>(numVectorsPerGraph); + for (int i = 0; i < numVectorsPerGraph; i++) { + map.put(i, globalOrdinal++); + } + remappers.add(new OrdinalMapper.MapMapper(map)); + + var lives = new FixedBitSet(numVectorsPerGraph); + lives.set(0, numVectorsPerGraph); + liveNodes.add(lives); + } + + var compactor = new OnDiskGraphIndexCompactor(graphs, liveNodes, remappers, similarityFunction, null); + int topK = 10; + + // Select query vectors from the dataset + var outputPath = testDirectory.resolve("test_compact_graph_"); + List> queries = new ArrayList<>(); + for(int i = 0; i < numQueries; ++i) { + queries.add(allVecs.get(randomIntBetween(0, allVecs.size() - 1))); + } + + // Get golden results and ground truth + List goldenResults = searchFromAll(queries, topK); + List> groundTruth = buildGT(queries, topK); + + // Compact and test + compactor.compact(outputPath); + + ReaderSupplier rs = ReaderSupplierFactory.open(outputPath); + var compactGraph = OnDiskGraphIndex.load(rs); + + // Verify basic properties + assertEquals("Compacted graph should have all nodes", numSources * numVectorsPerGraph, compactGraph.size(0)); + + GraphSearcher searcher = new GraphSearcher(compactGraph); + List compactResults = new ArrayList<>(); + for(VectorFloat q: queries) { + SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(q, similarityFunction, allravv); + compactResults.add(searcher.search(ssp, topK, Bits.ALL)); + } + + // Calculate recalls + double goldenRecall = AccuracyMetrics.recallFromSearchResults(groundTruth, goldenResults, topK, topK); + double compactRecall = AccuracyMetrics.recallFromSearchResults(groundTruth, compactResults, topK, topK); + + System.out.printf("Golden (built from scratch) Recall: %.4f%n", goldenRecall); + System.out.printf("Compacted Recall: %.4f%n", compactRecall); + System.out.printf("Recall difference: %.4f%n", Math.abs(goldenRecall - compactRecall)); + + // For random vectors with COSINE, both golden and compact should have similar recall + // The key is that they're comparable to each other, showing compaction preserves graph quality + double recallDifference = Math.abs(goldenRecall - compactRecall); + assertTrue(String.format("Compacted recall (%.4f) should be comparable to golden recall (%.4f), difference: %.4f", + compactRecall, goldenRecall, recallDifference), + recallDifference < 0.2); // Allow up to 20% difference for random vectors + + // Verify both are reasonable (not completely broken) + assertTrue(String.format("Golden recall should be at least 0.2, got %.4f", goldenRecall), + goldenRecall >= 0.2); + assertTrue(String.format("Compacted recall should be at least 0.2, got %.4f", compactRecall), + compactRecall >= 0.2); + + searcher.close(); + } + + /** + * Tests compaction with deleted nodes. + * Verifies that deleted nodes are properly excluded from the compacted graph. + */ + @Test + public void testCompactWithDeletions() throws Exception { + List graphs = new ArrayList<>(); + List rss = new ArrayList<>(); + List liveNodes = new ArrayList<>(); + List remappers = new ArrayList<>(); + + for(int i = 0; i < numSources; ++i) { + var outputPath = testDirectory.resolve("test_graph_" + i); + rss.add(ReaderSupplierFactory.open(outputPath.toAbsolutePath())); + var onDiskGraph = OnDiskGraphIndex.load(rss.get(i)); + graphs.add(onDiskGraph); + } + + // Mark some nodes as deleted (not live) + int globalOrdinal = 0; + int totalLiveNodes = 0; + Set deletedGlobalOrdinals = new HashSet<>(); + + for (int n = 0; n < numSources; n++) { + Map map = new HashMap<>(); + var lives = new FixedBitSet(numVectorsPerGraph); + + // Delete every 5th node + for (int i = 0; i < numVectorsPerGraph; i++) { + int originalGlobalOrdinal = n * numVectorsPerGraph + i; + if (i % 5 != 0) { + lives.set(i); + map.put(i, globalOrdinal++); + totalLiveNodes++; + } else { + deletedGlobalOrdinals.add(originalGlobalOrdinal); + } + } + + remappers.add(new OrdinalMapper.MapMapper(map)); + liveNodes.add(lives); + } + + var compactor = new OnDiskGraphIndexCompactor(graphs, liveNodes, remappers, similarityFunction, null); + var outputPath = testDirectory.resolve("test_compact_with_deletions"); + + compactor.compact(outputPath); + + ReaderSupplier rs = ReaderSupplierFactory.open(outputPath); + var compactGraph = OnDiskGraphIndex.load(rs); + + // Verify the compacted graph has the correct size (excluding deleted nodes) + assertEquals("Compacted graph size should equal live nodes", totalLiveNodes, compactGraph.size(0)); + + // Verify search functionality still works + GraphSearcher searcher = new GraphSearcher(compactGraph); + var query = allVecs.get(randomIntBetween(0, allVecs.size() - 1)); + SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(query, similarityFunction, allravv); + SearchResult result = searcher.search(ssp, 10, Bits.ALL); + + // Verify we get results and they're all valid + assertTrue("Should return some results", result.getNodes().length > 0); + + searcher.close(); + } + + /** + * Tests compaction with custom ordinal mappings. + * Verifies that vectors are correctly placed at their mapped ordinals. + */ + @Test + public void testOrdinalMapping() throws Exception { + List graphs = new ArrayList<>(); + List rss = new ArrayList<>(); + List liveNodes = new ArrayList<>(); + List remappers = new ArrayList<>(); + + for(int i = 0; i < numSources; ++i) { + var outputPath = testDirectory.resolve("test_graph_" + i); + rss.add(ReaderSupplierFactory.open(outputPath.toAbsolutePath())); + var onDiskGraph = OnDiskGraphIndex.load(rss.get(i)); + graphs.add(onDiskGraph); + } + + // Create custom ordinal mappings (non-sequential) + int globalOrdinal = 0; + List> mappingList = new ArrayList<>(); + + for (int n = 0; n < numSources; n++) { + Map map = new HashMap<>(); + // Use a custom mapping: reverse order for even sources, normal order for odd + if (n % 2 == 0) { + for (int i = 0; i < numVectorsPerGraph; i++) { + int newOrdinal = globalOrdinal + (numVectorsPerGraph - 1 - i); + map.put(i, newOrdinal); + } + globalOrdinal += numVectorsPerGraph; + } else { + for (int i = 0; i < numVectorsPerGraph; i++) { + map.put(i, globalOrdinal++); + } + } + mappingList.add(map); + remappers.add(new OrdinalMapper.MapMapper(map)); + + var lives = new FixedBitSet(numVectorsPerGraph); + lives.set(0, numVectorsPerGraph); + liveNodes.add(lives); + } + + var compactor = new OnDiskGraphIndexCompactor(graphs, liveNodes, remappers, similarityFunction, null); + var outputPath = testDirectory.resolve("test_compact_with_ordinal_mapping"); + + compactor.compact(outputPath); + + ReaderSupplier rs = ReaderSupplierFactory.open(outputPath); + var compactGraph = OnDiskGraphIndex.load(rs); + + // Verify the graph was created with correct ordinal mapping + assertEquals("Compacted graph should have all nodes", numSources * numVectorsPerGraph, compactGraph.size(0)); + + // Verify that the vectors are correctly mapped in the compacted graph + var compactView = compactGraph.getView(); + + // Check a few vectors to ensure they're at the correct ordinals + for (int sourceIdx = 0; sourceIdx < numSources; sourceIdx++) { + Map mapping = mappingList.get(sourceIdx); + // Check first, middle, and last nodes + int[] testIndices = {0, numVectorsPerGraph / 2, numVectorsPerGraph - 1}; + + for (int localIdx : testIndices) { + int expectedGlobalOrdinal = mapping.get(localIdx); + int originalVectorIdx = sourceIdx * numVectorsPerGraph + localIdx; + + VectorFloat originalVec = allVecs.get(originalVectorIdx); + VectorFloat compactVec = vectorTypeSupport.createFloatVector(dimension); + compactView.getVectorInto(expectedGlobalOrdinal, compactVec, 0); + + // Verify the vectors match (use similarity for normalized vectors) + float similarity = similarityFunction.compare(originalVec, compactVec); + assertTrue(String.format("Vector at ordinal %d should match (similarity=%.4f)", + expectedGlobalOrdinal, similarity), + similarity > 0.9999f); + } + } + } + + /** + * Tests compaction with both deletions and custom ordinal mappings combined. + * Verifies that both features work correctly together. + */ + @Test + public void testDeletionsAndOrdinalMapping() throws Exception { + List graphs = new ArrayList<>(); + List rss = new ArrayList<>(); + List liveNodes = new ArrayList<>(); + List remappers = new ArrayList<>(); + + for(int i = 0; i < numSources; ++i) { + var outputPath = testDirectory.resolve("test_graph_" + i); + rss.add(ReaderSupplierFactory.open(outputPath.toAbsolutePath())); + var onDiskGraph = OnDiskGraphIndex.load(rss.get(i)); + graphs.add(onDiskGraph); + } + + // Combine deletions with custom ordinal mapping + int globalOrdinal = 0; + int totalLiveNodes = 0; + List> mappingList = new ArrayList<>(); + + for (int n = 0; n < numSources; n++) { + Map map = new HashMap<>(); + var lives = new FixedBitSet(numVectorsPerGraph); + + // Delete every 4th node + for (int i = 0; i < numVectorsPerGraph; i++) { + if (i % 4 != 0) { + lives.set(i); + map.put(i, globalOrdinal++); + totalLiveNodes++; + } + } + + mappingList.add(map); + remappers.add(new OrdinalMapper.MapMapper(map)); + liveNodes.add(lives); + } + + var compactor = new OnDiskGraphIndexCompactor(graphs, liveNodes, remappers, similarityFunction, null); + var outputPath = testDirectory.resolve("test_compact_deletions_and_mapping"); + + compactor.compact(outputPath); + + ReaderSupplier rs = ReaderSupplierFactory.open(outputPath); + var compactGraph = OnDiskGraphIndex.load(rs); + + // Verify correct size + assertEquals("Compacted graph should only contain live nodes", totalLiveNodes, compactGraph.size(0)); + + // Verify a sample of vectors are at correct ordinals + var compactView = compactGraph.getView(); + int samplesVerified = 0; + for (int sourceIdx = 0; sourceIdx < numSources; sourceIdx++) { + Map mapping = mappingList.get(sourceIdx); + + // Check a few live nodes per source + for (int localIdx = 1; localIdx < numVectorsPerGraph && samplesVerified < 20; localIdx++) { + if (localIdx % 4 == 0) continue; // Skip deleted nodes + + int expectedGlobalOrdinal = mapping.get(localIdx); + int originalVectorIdx = sourceIdx * numVectorsPerGraph + localIdx; + + VectorFloat originalVec = allVecs.get(originalVectorIdx); + VectorFloat compactVec = vectorTypeSupport.createFloatVector(dimension); + compactView.getVectorInto(expectedGlobalOrdinal, compactVec, 0); + + // Verify the vectors match using similarity + float similarity = similarityFunction.compare(originalVec, compactVec); + assertTrue(String.format("Vector at ordinal %d should match (similarity=%.4f)", + expectedGlobalOrdinal, similarity), + similarity > 0.9999f); + samplesVerified++; + } + } + + // Verify search functionality + GraphSearcher searcher = new GraphSearcher(compactGraph); + var query = allVecs.get(randomIntBetween(0, allVecs.size() - 1)); + SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(query, similarityFunction, allravv); + SearchResult result = searcher.search(ssp, 10, Bits.ALL); + + assertTrue("Search should return results", result.getNodes().length > 0); + + searcher.close(); + } +} diff --git a/rat-excludes.txt b/rat-excludes.txt index 436c97822..ccb3b13e5 100644 --- a/rat-excludes.txt +++ b/rat-excludes.txt @@ -7,6 +7,7 @@ CONTRIBUTIONS.md package.json .github/workflows/tag-release.yml .github/workflows/run-bench.yml +.github/workflows/run-compaction.yml .mvn/wrapper/maven-wrapper.properties .mvn/jvm.config README.md @@ -26,9 +27,9 @@ scripts/test_node_setup.sh scripts/jmh_results_formatter.py yaml-configs/**/*.yaml yaml-configs/**/*.yml +yaml-configs/**/.catalog-cache/** src/main/resources/logback.xml docs/**/*.md yaml-configs/**/*.md local_datasets/** **/datasets/** -