/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query.common;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.opensearch.knn.index.query.common.DocAndScoreQuery;
import org.opensearch.knn.index.query.iterators.GroupedNestedDocIdSetIterator;

public class QueryUtils {
    public static QueryUtils INSTANCE = new QueryUtils();

    public Query createDocAndScoreQuery(IndexReader reader, TopDocs topDocs) {
        int len = topDocs.scoreDocs.length;
        Arrays.sort(topDocs.scoreDocs, Comparator.comparingInt(a -> a.doc));
        int[] docs = new int[len];
        float[] scores = new float[len];
        for (int i = 0; i < len; ++i) {
            docs[i] = topDocs.scoreDocs[i].doc;
            scores[i] = topDocs.scoreDocs[i].score;
        }
        int[] segmentStarts = this.findSegmentStarts(reader, docs);
        return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id());
    }

    private int[] findSegmentStarts(IndexReader reader, int[] docs) {
        int[] starts = new int[reader.leaves().size() + 1];
        starts[starts.length - 1] = docs.length;
        if (starts.length == 2) {
            return starts;
        }
        int resultIndex = 0;
        for (int i = 1; i < starts.length - 1; ++i) {
            int upper = ((LeafReaderContext)reader.leaves().get((int)i)).docBase;
            if ((resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper)) < 0) {
                resultIndex = -1 - resultIndex;
            }
            starts[i] = resultIndex;
        }
        return starts;
    }

    public List<Map<Integer, Float>> doSearch(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, Weight weight) throws IOException {
        ArrayList<Callable<Map>> tasks = new ArrayList<Callable<Map>>(leafReaderContexts.size());
        for (LeafReaderContext leafReaderContext : leafReaderContexts) {
            tasks.add(() -> this.searchLeaf(leafReaderContext, weight));
        }
        return indexSearcher.getTaskExecutor().invokeAll(tasks);
    }

    private Map<Integer, Float> searchLeaf(LeafReaderContext ctx, Weight weight) throws IOException {
        HashMap<Integer, Float> leafDocScores = new HashMap<Integer, Float>();
        Scorer scorer = weight.scorer(ctx);
        if (scorer == null) {
            return Collections.emptyMap();
        }
        DocIdSetIterator iterator = scorer.iterator();
        iterator.nextDoc();
        while (iterator.docID() != Integer.MAX_VALUE) {
            leafDocScores.put(scorer.docID(), Float.valueOf(scorer.score()));
            iterator.nextDoc();
        }
        return leafDocScores;
    }

    public DocIdSetIterator getAllSiblings(LeafReaderContext leafReaderContext, Set<Integer> docIds, BitSetProducer parentsFilter, Bits queryFilter) throws IOException {
        if (docIds.isEmpty()) {
            return DocIdSetIterator.empty();
        }
        BitSet parentBitSet = parentsFilter.getBitSet(leafReaderContext);
        return new GroupedNestedDocIdSetIterator(parentBitSet, docIds, queryFilter);
    }

    public Bits createBits(LeafReaderContext leafReaderContext, Weight filterWeight) throws IOException {
        if (filterWeight == null) {
            return new Bits.MatchAllBits(0);
        }
        Scorer scorer = filterWeight.scorer(leafReaderContext);
        if (scorer == null) {
            return new Bits.MatchNoBits(0);
        }
        final Bits liveDocs = leafReaderContext.reader().getLiveDocs();
        int maxDoc = leafReaderContext.reader().maxDoc();
        DocIdSetIterator filteredDocIdsIterator = scorer.iterator();
        if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
            return ((BitSetIterator)filteredDocIdsIterator).getBitSet();
        }
        FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator){

            protected boolean match(int doc) {
                return liveDocs == null || liveDocs.get(doc);
            }
        };
        return BitSet.of((DocIdSetIterator)filterIterator, (int)maxDoc);
    }
}

