/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.backward_codecs.lucene90;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import org.apache.lucene.backward_codecs.lucene90.Lucene90BoundsChecker;
import org.apache.lucene.backward_codecs.lucene90.Lucene90NeighborArray;
import org.apache.lucene.backward_codecs.lucene90.Lucene90OnHeapHnswGraph;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.NeighborQueue;

public final class Lucene90HnswGraphBuilder {
    private static final long DEFAULT_RAND_SEED = 42L;
    public static final String HNSW_COMPONENT = "HNSW";
    public static long randSeed = 42L;
    private final int maxConn;
    private final int beamWidth;
    private final Lucene90NeighborArray scratch;
    private final VectorSimilarityFunction similarityFunction;
    private final RandomAccessVectorValues vectorValues;
    private final SplittableRandom random;
    private final Lucene90BoundsChecker bound;
    final Lucene90OnHeapHnswGraph hnsw;
    private InfoStream infoStream = InfoStream.getDefault();
    private RandomAccessVectorValues buildVectors;

    public Lucene90HnswGraphBuilder(RandomAccessVectorValuesProducer vectors, VectorSimilarityFunction similarityFunction, int maxConn, int beamWidth, long seed) throws IOException {
        this.vectorValues = vectors.randomAccess();
        this.buildVectors = vectors.randomAccess();
        this.similarityFunction = Objects.requireNonNull(similarityFunction);
        if (maxConn <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.maxConn = maxConn;
        this.beamWidth = beamWidth;
        this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
        this.bound = Lucene90BoundsChecker.create(false);
        this.random = new SplittableRandom(seed);
        this.scratch = new Lucene90NeighborArray(Math.max(beamWidth, maxConn + 1));
    }

    public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
        long start;
        if (vectors == this.vectorValues) {
            throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
        }
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
        }
        long t = start = System.nanoTime();
        for (int node = 1; node < vectors.size(); ++node) {
            this.addGraphNode(vectors.vectorValue(node));
            if (node % 10000 != 0 || !this.infoStream.isEnabled(HNSW_COMPONENT)) continue;
            long now = System.nanoTime();
            this.infoStream.message(HNSW_COMPONENT, String.format(Locale.ROOT, "built %d in %d/%d ms", node, (now - t) / 1000000L, (now - start) / 1000000L));
            t = now;
        }
        return this.hnsw;
    }

    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    void addGraphNode(float[] value) throws IOException {
        NeighborQueue candidates = Lucene90OnHeapHnswGraph.search(value, this.beamWidth, this.beamWidth, this.vectorValues, this.similarityFunction, this.hnsw, null, Integer.MAX_VALUE, this.random);
        int node = this.hnsw.addNode();
        this.addDiverseNeighbors(node, candidates);
    }

    private void addDiverseNeighbors(int node, NeighborQueue candidates) throws IOException {
        Lucene90NeighborArray neighbors = this.hnsw.getNeighbors(node);
        assert (neighbors.size() == 0);
        this.popToScratch(candidates);
        this.selectDiverse(neighbors, this.scratch);
        int size = neighbors.size();
        for (int i = 0; i < size; ++i) {
            int nbr = neighbors.node()[i];
            Lucene90NeighborArray nbrNbr = this.hnsw.getNeighbors(nbr);
            nbrNbr.add(node, neighbors.score()[i]);
            if (nbrNbr.size() <= this.maxConn) continue;
            this.diversityUpdate(nbrNbr);
        }
    }

    private void selectDiverse(Lucene90NeighborArray neighbors, Lucene90NeighborArray candidates) throws IOException {
        for (int i = candidates.size() - 1; neighbors.size() < this.maxConn && i >= 0; --i) {
            int cNode = candidates.node()[i];
            float cScore = candidates.score()[i];
            assert (cNode < this.hnsw.size());
            if (!this.diversityCheck(this.vectorValues.vectorValue(cNode), cScore, neighbors, this.buildVectors)) continue;
            neighbors.add(cNode, cScore);
        }
    }

    private void popToScratch(NeighborQueue candidates) {
        this.scratch.clear();
        int candidateCount = candidates.size();
        for (int i = 0; i < candidateCount; ++i) {
            float score = candidates.topScore();
            this.scratch.add(candidates.pop(), score);
        }
    }

    private boolean diversityCheck(float[] candidate, float score, Lucene90NeighborArray neighbors, RandomAccessVectorValues vectorValues) throws IOException {
        this.bound.set(score);
        for (int i = 0; i < neighbors.size(); ++i) {
            float neighborSimilarity = this.similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node()[i]));
            if (this.bound.check(neighborSimilarity)) continue;
            return false;
        }
        return true;
    }

    private void diversityUpdate(Lucene90NeighborArray neighbors) throws IOException {
        assert (neighbors.size() == this.maxConn + 1);
        int replacePoint = this.findNonDiverse(neighbors);
        if (replacePoint == -1) {
            this.bound.set(neighbors.score()[0]);
            if (this.bound.check(neighbors.score()[this.maxConn])) {
                neighbors.removeLast();
                return;
            }
            replacePoint = 0;
        }
        neighbors.node()[replacePoint] = neighbors.node()[this.maxConn];
        neighbors.score()[replacePoint] = neighbors.score()[this.maxConn];
        neighbors.removeLast();
    }

    private int findNonDiverse(Lucene90NeighborArray neighbors) throws IOException {
        for (int i = neighbors.size() - 1; i >= 0; --i) {
            int neighborId = neighbors.node()[i];
            this.bound.set(neighbors.score()[i]);
            float[] neighborVector = this.vectorValues.vectorValue(neighborId);
            for (int j = this.maxConn; j > i; --j) {
                float neighborSimilarity = this.similarityFunction.compare(neighborVector, this.buildVectors.vectorValue(neighbors.node()[j]));
                if (this.bound.check(neighborSimilarity)) continue;
                return i;
            }
        }
        return -1;
    }
}

