/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.ltr;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.Semaphore;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.solr.ltr.DocInfo;
import org.apache.solr.ltr.FeatureLogger;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.request.SolrQueryRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LTRScoringQuery
extends Query
implements Accountable {
    private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
    private static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(LTRScoringQuery.class);
    private final LTRScoringModel ltrScoringModel;
    private final boolean extractAllFeatures;
    private final LTRThreadModule ltrThreadMgr;
    private final Semaphore querySemaphore;
    private FeatureLogger fl;
    private final Map<String, String[]> efi;
    private Query originalQuery;
    private SolrQueryRequest request;

    public LTRScoringQuery(LTRScoringModel ltrScoringModel) {
        this(ltrScoringModel, Collections.emptyMap(), false, null);
    }

    public LTRScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) {
        this(ltrScoringModel, Collections.emptyMap(), extractAllFeatures, null);
    }

    public LTRScoringQuery(LTRScoringModel ltrScoringModel, Map<String, String[]> externalFeatureInfo, boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) {
        this.ltrScoringModel = ltrScoringModel;
        this.efi = externalFeatureInfo;
        this.extractAllFeatures = extractAllFeatures;
        this.ltrThreadMgr = ltrThreadMgr;
        this.querySemaphore = this.ltrThreadMgr != null ? this.ltrThreadMgr.createQuerySemaphore() : null;
    }

    public LTRScoringModel getScoringModel() {
        return this.ltrScoringModel;
    }

    public void setFeatureLogger(FeatureLogger fl) {
        this.fl = fl;
    }

    public FeatureLogger getFeatureLogger() {
        return this.fl;
    }

    public void setOriginalQuery(Query originalQuery) {
        this.originalQuery = originalQuery;
    }

    public Query getOriginalQuery() {
        return this.originalQuery;
    }

    public Map<String, String[]> getExternalFeatureInfo() {
        return this.efi;
    }

    public void setRequest(SolrQueryRequest request) {
        this.request = request;
    }

    public SolrQueryRequest getRequest() {
        return this.request;
    }

    public int hashCode() {
        int prime = 31;
        int result = this.classHash();
        result = 31 * result + (this.ltrScoringModel == null ? 0 : this.ltrScoringModel.hashCode());
        result = 31 * result + (this.originalQuery == null ? 0 : this.originalQuery.hashCode());
        if (this.efi == null) {
            result = 31 * result + 0;
        } else {
            for (Map.Entry<String, String[]> entry : this.efi.entrySet()) {
                String key = entry.getKey();
                Object[] values = entry.getValue();
                result = 31 * result + key.hashCode();
                result = 31 * result + Arrays.hashCode(values);
            }
        }
        result = 31 * result + this.toString().hashCode();
        return result;
    }

    public boolean equals(Object o) {
        return this.sameClassAs(o) && this.equalsTo((LTRScoringQuery)((Object)((Object)((Object)this)).getClass().cast(o)));
    }

    public void visit(QueryVisitor visitor) {
        visitor.visitLeaf((Query)this);
    }

    private boolean equalsTo(LTRScoringQuery other) {
        if (this.ltrScoringModel == null ? other.ltrScoringModel != null : !this.ltrScoringModel.equals(other.ltrScoringModel)) {
            return false;
        }
        if (this.originalQuery == null ? other.originalQuery != null : !this.originalQuery.equals((Object)other.originalQuery)) {
            return false;
        }
        if (this.efi == null) {
            if (other.efi != null) {
                return false;
            }
        } else {
            if (other.efi == null || this.efi.size() != other.efi.size()) {
                return false;
            }
            for (Map.Entry<String, String[]> entry : this.efi.entrySet()) {
                String key = entry.getKey();
                Object[] otherValues = other.efi.get(key);
                if (otherValues != null && Arrays.equals(otherValues, entry.getValue())) continue;
                return false;
            }
        }
        return true;
    }

    public ModelWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        List<Feature> modelFeatures = this.ltrScoringModel.getFeatures();
        Collection<Feature> allFeatures = this.ltrScoringModel.getAllFeatures();
        int modelFeatSize = modelFeatures.size();
        Collection<Feature> features = null;
        features = this.extractAllFeatures ? allFeatures : modelFeatures;
        Feature.FeatureWeight[] extractedFeatureWeights = new Feature.FeatureWeight[features.size()];
        Feature.FeatureWeight[] modelFeaturesWeights = new Feature.FeatureWeight[modelFeatSize];
        ArrayList<Feature.FeatureWeight> featureWeights = new ArrayList<Feature.FeatureWeight>(features.size());
        if (this.querySemaphore == null) {
            this.createWeights(searcher, scoreMode.needsScores(), featureWeights, features);
        } else {
            this.createWeightsParallel(searcher, scoreMode.needsScores(), featureWeights, features);
        }
        int i = 0;
        int j = 0;
        if (this.extractAllFeatures) {
            for (Feature.FeatureWeight fw : featureWeights) {
                extractedFeatureWeights[i++] = fw;
            }
            for (Feature f : modelFeatures) {
                modelFeaturesWeights[j++] = extractedFeatureWeights[f.getIndex()];
            }
        } else {
            for (Feature.FeatureWeight fw : featureWeights) {
                extractedFeatureWeights[i++] = fw;
                modelFeaturesWeights[j++] = fw;
            }
        }
        return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size());
    }

    private void createWeights(IndexSearcher searcher, boolean needsScores, List<Feature.FeatureWeight> featureWeights, Collection<Feature> features) throws IOException {
        SolrQueryRequest req = this.getRequest();
        for (Feature f : features) {
            try {
                Feature.FeatureWeight fw = f.createWeight(searcher, needsScores, req, this.originalQuery, this.efi);
                featureWeights.add(fw);
            }
            catch (Exception e) {
                throw new RuntimeException("Exception from createWeight for " + f.toString() + " " + e.getMessage(), e);
            }
        }
    }

    private void createWeightsParallel(IndexSearcher searcher, boolean needsScores, List<Feature.FeatureWeight> featureWeights, Collection<Feature> features) throws RuntimeException {
        SolrQueryRequest req = this.getRequest();
        ArrayList<FutureTask<Feature.FeatureWeight>> futures = new ArrayList<FutureTask<Feature.FeatureWeight>>(features.size());
        try {
            for (Feature feature : features) {
                CreateWeightCallable callable = new CreateWeightCallable(feature, searcher, needsScores, req);
                FutureTask<Feature.FeatureWeight> runnableFuture = new FutureTask<Feature.FeatureWeight>(callable);
                this.querySemaphore.acquire();
                this.ltrThreadMgr.acquireLTRSemaphore();
                this.ltrThreadMgr.execute(runnableFuture);
                futures.add(runnableFuture);
            }
            for (Future future : futures) {
                featureWeights.add((Feature.FeatureWeight)((Object)future.get()));
            }
        }
        catch (Exception e) {
            log.info("Error while creating weights in LTR: InterruptedException", (Throwable)e);
            throw new RuntimeException("Error while creating weights in LTR: " + e.getMessage(), e);
        }
    }

    public String toString(String field) {
        return field;
    }

    public long ramBytesUsed() {
        return BASE_RAM_BYTES + RamUsageEstimator.sizeOfObject(this.efi) + RamUsageEstimator.sizeOfObject((Object)this.ltrScoringModel) + RamUsageEstimator.sizeOfObject((Object)this.originalQuery, (long)1024L);
    }

    public class ModelWeight
    extends Weight {
        private final Feature.FeatureWeight[] modelFeatureWeights;
        private final float[] modelFeatureValuesNormalized;
        private final Feature.FeatureWeight[] extractedFeatureWeights;
        private final FeatureInfo[] featuresInfo;

        public ModelWeight(Feature.FeatureWeight[] modelFeatureWeights, Feature.FeatureWeight[] extractedFeatureWeights, int allFeaturesSize) {
            super((Query)LTRScoringQuery.this);
            this.extractedFeatureWeights = extractedFeatureWeights;
            this.modelFeatureWeights = modelFeatureWeights;
            this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length];
            this.featuresInfo = new FeatureInfo[allFeaturesSize];
            this.setFeaturesInfo();
        }

        private void setFeaturesInfo() {
            for (int i = 0; i < this.extractedFeatureWeights.length; ++i) {
                String featName = this.extractedFeatureWeights[i].getName();
                int featId = this.extractedFeatureWeights[i].getIndex();
                float value = this.extractedFeatureWeights[i].getDefaultValue();
                this.featuresInfo[featId] = new FeatureInfo(featName, value, false);
            }
        }

        public FeatureInfo[] getFeaturesInfo() {
            return this.featuresInfo;
        }

        Feature.FeatureWeight[] getModelFeatureWeights() {
            return this.modelFeatureWeights;
        }

        float[] getModelFeatureValuesNormalized() {
            return this.modelFeatureValuesNormalized;
        }

        Feature.FeatureWeight[] getExtractedFeatureWeights() {
            return this.extractedFeatureWeights;
        }

        private float makeNormalizedFeaturesAndScore() {
            int pos = 0;
            for (Feature.FeatureWeight feature : this.modelFeatureWeights) {
                int featureId = feature.getIndex();
                FeatureInfo fInfo = this.featuresInfo[featureId];
                this.modelFeatureValuesNormalized[pos] = fInfo.isUsed() ? fInfo.getValue() : feature.getDefaultValue();
                ++pos;
            }
            LTRScoringQuery.this.ltrScoringModel.normalizeFeaturesInPlace(this.modelFeatureValuesNormalized);
            return LTRScoringQuery.this.ltrScoringModel.score(this.modelFeatureValuesNormalized);
        }

        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            Explanation[] explanations = new Explanation[this.featuresInfo.length];
            for (Feature.FeatureWeight feature : this.extractedFeatureWeights) {
                explanations[feature.getIndex()] = feature.explain(context, doc);
            }
            ArrayList<Explanation> featureExplanations = new ArrayList<Explanation>();
            for (int idx = 0; idx < this.modelFeatureWeights.length; ++idx) {
                Feature.FeatureWeight f = this.modelFeatureWeights[idx];
                Explanation e = LTRScoringQuery.this.ltrScoringModel.getNormalizerExplanation(explanations[f.getIndex()], idx);
                featureExplanations.add(e);
            }
            ModelScorer bs = this.scorer(context);
            bs.iterator().advance(doc);
            float finalScore = bs.score();
            return LTRScoringQuery.this.ltrScoringModel.explain(context, doc, finalScore, featureExplanations);
        }

        public void extractTerms(Set<Term> terms) {
            for (Feature.FeatureWeight feature : this.extractedFeatureWeights) {
                feature.extractTerms(terms);
            }
        }

        protected void reset() {
            for (int i = 0; i < this.extractedFeatureWeights.length; ++i) {
                int featId = this.extractedFeatureWeights[i].getIndex();
                float value = this.extractedFeatureWeights[i].getDefaultValue();
                this.featuresInfo[featId].setValue(value);
                this.featuresInfo[featId].setUsed(false);
            }
        }

        public ModelScorer scorer(LeafReaderContext context) throws IOException {
            ArrayList<Feature.FeatureWeight.FeatureScorer> featureScorers = new ArrayList<Feature.FeatureWeight.FeatureScorer>(this.extractedFeatureWeights.length);
            for (Feature.FeatureWeight featureWeight : this.extractedFeatureWeights) {
                Feature.FeatureWeight.FeatureScorer scorer = featureWeight.scorer(context);
                if (scorer == null) continue;
                featureScorers.add(scorer);
            }
            ModelScorer mscorer = new ModelScorer(this, featureScorers);
            return mscorer;
        }

        public boolean isCacheable(LeafReaderContext ctx) {
            return false;
        }

        public class ModelScorer
        extends Scorer {
            private final DocInfo docInfo;
            private final Scorer featureTraversalScorer;

            public DocInfo getDocInfo() {
                return this.docInfo;
            }

            public ModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
                super(weight);
                this.docInfo = new DocInfo();
                for (Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) {
                    subSocer.setDocInfo(this.docInfo);
                }
                this.featureTraversalScorer = featureScorers.size() <= 1 ? new DenseModelScorer(weight, featureScorers) : new SparseModelScorer(weight, featureScorers);
            }

            public Collection<Scorable.ChildScorable> getChildren() throws IOException {
                return this.featureTraversalScorer.getChildren();
            }

            public int docID() {
                return this.featureTraversalScorer.docID();
            }

            public float score() throws IOException {
                return this.featureTraversalScorer.score();
            }

            public float getMaxScore(int upTo) throws IOException {
                return Float.POSITIVE_INFINITY;
            }

            public DocIdSetIterator iterator() {
                return this.featureTraversalScorer.iterator();
            }

            private class DenseModelScorer
            extends Scorer {
                private int activeDoc;
                private int targetDoc;
                private int freq;
                private final List<Feature.FeatureWeight.FeatureScorer> featureScorers;

                private DenseModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
                    super(weight);
                    this.activeDoc = -1;
                    this.targetDoc = -1;
                    this.freq = -1;
                    this.featureScorers = featureScorers;
                }

                public int docID() {
                    return this.targetDoc;
                }

                public float score() throws IOException {
                    ModelWeight.this.reset();
                    this.freq = 0;
                    if (this.targetDoc == this.activeDoc) {
                        for (Scorer scorer : this.featureScorers) {
                            if (scorer.docID() != this.activeDoc) continue;
                            ++this.freq;
                            Feature.FeatureWeight scFW = (Feature.FeatureWeight)scorer.getWeight();
                            int featureId = scFW.getIndex();
                            ModelWeight.this.featuresInfo[featureId].setValue(scorer.score());
                            ModelWeight.this.featuresInfo[featureId].setUsed(true);
                        }
                    }
                    return ModelWeight.this.makeNormalizedFeaturesAndScore();
                }

                public float getMaxScore(int upTo) throws IOException {
                    return Float.POSITIVE_INFINITY;
                }

                public final Collection<Scorable.ChildScorable> getChildren() {
                    ArrayList<Scorable.ChildScorable> children = new ArrayList<Scorable.ChildScorable>();
                    for (Scorer scorer : this.featureScorers) {
                        children.add(new Scorable.ChildScorable((Scorable)scorer, "SHOULD"));
                    }
                    return children;
                }

                public DocIdSetIterator iterator() {
                    return new DenseIterator();
                }

                private class DenseIterator
                extends DocIdSetIterator {
                    private DenseIterator() {
                    }

                    public int docID() {
                        return DenseModelScorer.this.targetDoc;
                    }

                    public int nextDoc() throws IOException {
                        if (DenseModelScorer.this.activeDoc <= DenseModelScorer.this.targetDoc) {
                            DenseModelScorer.this.activeDoc = Integer.MAX_VALUE;
                            for (Scorer scorer : DenseModelScorer.this.featureScorers) {
                                if (scorer.docID() == Integer.MAX_VALUE) continue;
                                DenseModelScorer.this.activeDoc = Math.min(DenseModelScorer.this.activeDoc, scorer.iterator().nextDoc());
                            }
                        }
                        return ++DenseModelScorer.this.targetDoc;
                    }

                    public int advance(int target) throws IOException {
                        if (DenseModelScorer.this.activeDoc < target) {
                            DenseModelScorer.this.activeDoc = Integer.MAX_VALUE;
                            for (Scorer scorer : DenseModelScorer.this.featureScorers) {
                                if (scorer.docID() == Integer.MAX_VALUE) continue;
                                DenseModelScorer.this.activeDoc = Math.min(DenseModelScorer.this.activeDoc, scorer.iterator().advance(target));
                            }
                        }
                        DenseModelScorer.this.targetDoc = target;
                        return target;
                    }

                    public long cost() {
                        long sum = 0L;
                        for (int i = 0; i < DenseModelScorer.this.featureScorers.size(); ++i) {
                            sum += ((Feature.FeatureWeight.FeatureScorer)((Object)DenseModelScorer.this.featureScorers.get(i))).iterator().cost();
                        }
                        return sum;
                    }
                }
            }

            private class SparseModelScorer
            extends Scorer {
                private final DisiPriorityQueue subScorers;
                private final ScoringQuerySparseIterator itr;
                private int targetDoc;
                private int activeDoc;

                private SparseModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
                    super(weight);
                    this.targetDoc = -1;
                    this.activeDoc = -1;
                    if (featureScorers.size() <= 1) {
                        throw new IllegalArgumentException("There must be at least 2 subScorers");
                    }
                    this.subScorers = new DisiPriorityQueue(featureScorers.size());
                    for (Scorer scorer : featureScorers) {
                        DisiWrapper w = new DisiWrapper(scorer);
                        this.subScorers.add(w);
                    }
                    this.itr = new ScoringQuerySparseIterator(this.subScorers);
                }

                public int docID() {
                    return this.itr.docID();
                }

                public float score() throws IOException {
                    DisiWrapper topList = this.subScorers.topList();
                    ModelWeight.this.reset();
                    if (this.activeDoc == this.targetDoc) {
                        DisiWrapper w = topList;
                        while (w != null) {
                            Scorer subScorer = w.scorer;
                            Feature.FeatureWeight scFW = (Feature.FeatureWeight)subScorer.getWeight();
                            int featureId = scFW.getIndex();
                            ModelWeight.this.featuresInfo[featureId].setValue(subScorer.score());
                            ModelWeight.this.featuresInfo[featureId].setUsed(true);
                            w = w.next;
                        }
                    }
                    return ModelWeight.this.makeNormalizedFeaturesAndScore();
                }

                public float getMaxScore(int upTo) throws IOException {
                    return Float.POSITIVE_INFINITY;
                }

                public DocIdSetIterator iterator() {
                    return this.itr;
                }

                public final Collection<Scorable.ChildScorable> getChildren() {
                    ArrayList<Scorable.ChildScorable> children = new ArrayList<Scorable.ChildScorable>();
                    for (DisiWrapper scorer : this.subScorers) {
                        children.add(new Scorable.ChildScorable((Scorable)scorer.scorer, "SHOULD"));
                    }
                    return children;
                }

                private class ScoringQuerySparseIterator
                extends DisjunctionDISIApproximation {
                    public ScoringQuerySparseIterator(DisiPriorityQueue subIterators) {
                        super(subIterators);
                    }

                    public final int nextDoc() throws IOException {
                        if (SparseModelScorer.this.activeDoc == SparseModelScorer.this.targetDoc) {
                            SparseModelScorer.this.activeDoc = super.nextDoc();
                        } else if (SparseModelScorer.this.activeDoc < SparseModelScorer.this.targetDoc) {
                            SparseModelScorer.this.activeDoc = super.advance(SparseModelScorer.this.targetDoc + 1);
                        }
                        return ++SparseModelScorer.this.targetDoc;
                    }

                    public final int advance(int target) throws IOException {
                        if (SparseModelScorer.this.activeDoc < target) {
                            SparseModelScorer.this.activeDoc = super.advance(target);
                        }
                        SparseModelScorer.this.targetDoc = target;
                        return SparseModelScorer.this.targetDoc;
                    }
                }
            }
        }
    }

    public static class FeatureInfo {
        private final String name;
        private float value;
        private boolean used;

        FeatureInfo(String n, float v, boolean u) {
            this.name = n;
            this.value = v;
            this.used = u;
        }

        public void setValue(float value) {
            this.value = value;
        }

        public String getName() {
            return this.name;
        }

        public float getValue() {
            return this.value;
        }

        public boolean isUsed() {
            return this.used;
        }

        public void setUsed(boolean used) {
            this.used = used;
        }
    }

    private class CreateWeightCallable
    implements Callable<Feature.FeatureWeight> {
        private final Feature f;
        private final IndexSearcher searcher;
        private final boolean needsScores;
        private final SolrQueryRequest req;

        public CreateWeightCallable(Feature f, IndexSearcher searcher, boolean needsScores, SolrQueryRequest req) {
            this.f = f;
            this.searcher = searcher;
            this.needsScores = needsScores;
            this.req = req;
        }

        @Override
        public Feature.FeatureWeight call() throws Exception {
            try {
                Feature.FeatureWeight fw;
                Feature.FeatureWeight featureWeight = fw = this.f.createWeight(this.searcher, this.needsScores, this.req, LTRScoringQuery.this.originalQuery, LTRScoringQuery.this.efi);
                return featureWeight;
            }
            catch (Exception e) {
                throw new RuntimeException("Exception from createWeight for " + this.f.toString() + " " + e.getMessage(), e);
            }
            finally {
                LTRScoringQuery.this.querySemaphore.release();
                LTRScoringQuery.this.ltrThreadMgr.releaseLTRSemaphore();
            }
        }
    }
}

