/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.ml;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.processor.MapInferenceRequest;
import org.opensearch.neuralsearch.processor.SimilarityInferenceRequest;
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
import org.opensearch.neuralsearch.processor.highlight.SentenceHighlightingRequest;
import org.opensearch.neuralsearch.util.RetryUtil;

public class MLCommonsClientAccessor {
    @Generated
    private static final Logger log = LogManager.getLogger(MLCommonsClientAccessor.class);
    private final MachineLearningNodeClient mlClient;

    public void inferenceSentence(@NonNull String modelId, @NonNull String inputText, @NonNull ActionListener<List<Number>> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.inferenceSentences((TextInferenceRequest)((TextInferenceRequest.TextInferenceRequestBuilder)((TextInferenceRequest.TextInferenceRequestBuilder)TextInferenceRequest.builder().modelId(modelId)).inputTexts(List.of(inputText))).build(), (ActionListener<List<List<Number>>>)ActionListener.wrap(response -> {
            if (response.size() != 1) {
                listener.onFailure((Exception)new IllegalStateException("Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"));
                return;
            }
            listener.onResponse((Object)((List)response.getFirst()));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public void inferenceSentences(@NonNull TextInferenceRequest inferenceRequest, @NonNull ActionListener<List<List<Number>>> listener) {
        Objects.requireNonNull(inferenceRequest, "inferenceRequest is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener);
    }

    public void inferenceSentencesWithMapResult(@NonNull TextInferenceRequest inferenceRequest, @NonNull ActionListener<List<Map<String, ?>>> listener) {
        Objects.requireNonNull(inferenceRequest, "inferenceRequest is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSentencesWithMapResult(inferenceRequest, 0, listener);
    }

    public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull ActionListener<List<Number>> listener) {
        Objects.requireNonNull(inferenceRequest, "inferenceRequest is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener);
    }

    public void inferenceSimilarity(@NonNull SimilarityInferenceRequest inferenceRequest, @NonNull ActionListener<List<Float>> listener) {
        Objects.requireNonNull(inferenceRequest, "inferenceRequest is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSimilarityWithVectorResult(inferenceRequest, 0, listener);
    }

    private void retryableInferenceSentencesWithMapResult(TextInferenceRequest inferenceRequest, int retryTime, ActionListener<List<Map<String, ?>>> listener) {
        MLInput mlInput = this.createMLTextInput(null, inferenceRequest.getInputTexts());
        this.mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
            List<Map<String, ?>> result = this.buildMapResultFromResponse((MLOutput)mlOutput);
            listener.onResponse(result);
        }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSentencesWithMapResult(inferenceRequest, retryTime + 1, listener), listener)));
    }

    private void retryableInferenceSentencesWithVectorResult(TextInferenceRequest inferenceRequest, int retryTime, ActionListener<List<List<Number>>> listener) {
        MLInput mlInput = this.createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts());
        this.mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
            List vector = this.buildVectorFromResponse((MLOutput)mlOutput);
            listener.onResponse(vector);
        }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSentencesWithVectorResult(inferenceRequest, retryTime + 1, listener), listener)));
    }

    private void retryableInferenceSimilarityWithVectorResult(SimilarityInferenceRequest inferenceRequest, int retryTime, ActionListener<List<Float>> listener) {
        MLInput mlInput = this.createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
        this.mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
            List scores = this.buildVectorFromResponse((MLOutput)mlOutput).stream().map(v -> Float.valueOf(((Number)v.getFirst()).floatValue())).collect(Collectors.toList());
            listener.onResponse(scores);
        }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSimilarityWithVectorResult(inferenceRequest, retryTime + 1, listener), listener)));
    }

    private MLInput createMLTextInput(List<String> targetResponseFilters, List<String> inputText) {
        ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
        TextDocsInputDataSet inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
        return new MLInput(FunctionName.TEXT_EMBEDDING, null, (MLInputDataset)inputDataset);
    }

    private MLInput createMLTextPairsInput(String query, List<String> inputText) {
        TextSimilarityInputDataSet inputDataset = new TextSimilarityInputDataSet(query, inputText);
        return new MLInput(FunctionName.TEXT_SIMILARITY, null, (MLInputDataset)inputDataset);
    }

    private <T extends Number> List<List<T>> buildVectorFromResponse(MLOutput mlOutput) {
        ArrayList<List<T>> vector = new ArrayList<List<T>>();
        ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlOutput;
        List tensorOutputList = modelTensorOutput.getMlModelOutputs();
        for (ModelTensors tensors : tensorOutputList) {
            List tensorsList = tensors.getMlModelTensors();
            for (ModelTensor tensor : tensorsList) {
                vector.add(Arrays.stream(tensor.getData()).map(value -> value).collect(Collectors.toList()));
            }
        }
        return vector;
    }

    private List<Map<String, ?>> buildMapResultFromResponse(MLOutput mlOutput) {
        ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlOutput;
        List tensorOutputList = modelTensorOutput.getMlModelOutputs();
        if (CollectionUtils.isEmpty((Collection)tensorOutputList) || CollectionUtils.isEmpty((Collection)((ModelTensors)tensorOutputList.get(0)).getMlModelTensors())) {
            throw new IllegalStateException("Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]");
        }
        ArrayList resultMaps = new ArrayList();
        for (ModelTensors tensors : tensorOutputList) {
            List tensorList = tensors.getMlModelTensors();
            for (ModelTensor tensor : tensorList) {
                resultMaps.add(tensor.getDataAsMap());
            }
        }
        return resultMaps;
    }

    private <T extends Number> List<T> buildSingleVectorFromResponse(MLOutput mlOutput) {
        List<List<T>> vector = this.buildVectorFromResponse(mlOutput);
        return vector.isEmpty() ? new ArrayList() : vector.get(0);
    }

    private void retryableInferenceSentencesWithSingleVectorResult(MapInferenceRequest inferenceRequest, int retryTime, ActionListener<List<Number>> listener) {
        MLInput mlInput = this.createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects());
        this.mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
            List vector = this.buildSingleVectorFromResponse((MLOutput)mlOutput);
            log.debug("Inference Response for input sentence is : {} ", vector);
            listener.onResponse(vector);
        }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, retryTime + 1, listener), listener)));
    }

    private List<Map<String, Object>> processHighlightingOutput(ModelTensorOutput modelTensorOutput) {
        ArrayList<Map<String, Object>> results = new ArrayList<Map<String, Object>>();
        try {
            List tensorOutputList = modelTensorOutput.getMlModelOutputs();
            if (CollectionUtils.isEmpty((Collection)tensorOutputList)) {
                return results;
            }
            for (ModelTensors tensors : tensorOutputList) {
                List tensorsList = tensors.getMlModelTensors();
                if (CollectionUtils.isEmpty((Collection)tensorsList)) {
                    log.warn("No tensors in model output");
                    continue;
                }
                for (ModelTensor tensor : tensorsList) {
                    Map dataMap = tensor.getDataAsMap();
                    if (dataMap == null || dataMap.isEmpty()) continue;
                    Map typedDataMap = dataMap;
                    results.add(typedDataMap);
                }
            }
            if (results.isEmpty()) {
                results.add(Collections.emptyMap());
            }
            return results;
        }
        catch (Exception e) {
            throw new IllegalStateException("Error processing sentence highlighting output", e);
        }
    }

    private MLInput createMLMultimodalInput(List<String> targetResponseFilters, Map<String, String> input) {
        ArrayList<String> inputText = new ArrayList<String>();
        inputText.add(input.get("inputText"));
        if (input.containsKey("inputImage")) {
            inputText.add(input.get("inputImage"));
        }
        ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
        TextDocsInputDataSet inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
        return new MLInput(FunctionName.TEXT_EMBEDDING, null, (MLInputDataset)inputDataset);
    }

    public void getModel(@NonNull String modelId, @NonNull ActionListener<MLModel> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableGetModel(modelId, 0, listener);
    }

    public void getModels(@NonNull Set<String> modelIds, @NonNull Consumer<Map<String, MLModel>> onSuccess, @NonNull Consumer<Exception> onFailure) {
        Objects.requireNonNull(modelIds, "modelIds is marked non-null but is null");
        Objects.requireNonNull(onSuccess, "onSuccess is marked non-null but is null");
        Objects.requireNonNull(onFailure, "onFailure is marked non-null but is null");
        if (modelIds.isEmpty()) {
            try {
                onSuccess.accept(Collections.emptyMap());
            }
            catch (Exception e2) {
                onFailure.accept(e2);
            }
            return;
        }
        ConcurrentHashMap modelMap = new ConcurrentHashMap();
        AtomicInteger counter = new AtomicInteger(modelIds.size());
        AtomicBoolean hasError = new AtomicBoolean(false);
        List<String> errors = Collections.synchronizedList(new ArrayList());
        for (String modelId : modelIds) {
            try {
                this.getModel(modelId, (ActionListener<MLModel>)ActionListener.wrap(model -> {
                    modelMap.put(modelId, model);
                    if (counter.decrementAndGet() == 0) {
                        if (hasError.get()) {
                            onFailure.accept(new RuntimeException(String.join((CharSequence)";", errors)));
                        } else {
                            try {
                                onSuccess.accept(modelMap);
                            }
                            catch (Exception e) {
                                onFailure.accept(e);
                            }
                        }
                    }
                }, e -> this.handleGetModelException(hasError, errors, modelId, (Exception)e, counter, onFailure)));
            }
            catch (Exception e3) {
                this.handleGetModelException(hasError, errors, modelId, e3, counter, onFailure);
            }
        }
    }

    private void handleGetModelException(AtomicBoolean hasError, List<String> errors, String modelId, Exception e, AtomicInteger counter, @NonNull Consumer<Exception> onFailure) {
        Objects.requireNonNull(onFailure, "onFailure is marked non-null but is null");
        hasError.set(true);
        errors.add("Failed to fetch model [" + modelId + "]: " + e.getMessage());
        if (counter.decrementAndGet() == 0) {
            onFailure.accept(new RuntimeException(String.join((CharSequence)";", errors)));
        }
    }

    private void retryableGetModel(@NonNull String modelId, int retryTime, @NonNull ActionListener<MLModel> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.mlClient.getModel(modelId, null, ActionListener.wrap(arg_0 -> listener.onResponse(arg_0), e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableGetModel(modelId, retryTime + 1, listener), listener)));
    }

    private void retryableInferenceSentenceHighlighting(SentenceHighlightingRequest inferenceRequest, int retryTime, ActionListener<List<Map<String, Object>>> listener) {
        try {
            QuestionAnsweringInputDataSet inputDataset = new QuestionAnsweringInputDataSet(inferenceRequest.getQuestion(), inferenceRequest.getContext());
            MLInput mlInput = new MLInput(FunctionName.QUESTION_ANSWERING, null, (MLInputDataset)inputDataset);
            this.mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
                try {
                    List<Map<String, Object>> result = this.processHighlightingOutput((ModelTensorOutput)mlOutput);
                    listener.onResponse(result);
                }
                catch (Exception e) {
                    listener.onFailure(e);
                }
            }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSentenceHighlighting(inferenceRequest, retryTime + 1, listener), listener)));
        }
        catch (Exception e2) {
            listener.onFailure(e2);
        }
    }

    public void inferenceSentenceHighlighting(@NonNull SentenceHighlightingRequest inferenceRequest, @NonNull ActionListener<List<Map<String, Object>>> listener) {
        Objects.requireNonNull(inferenceRequest, "inferenceRequest is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSentenceHighlighting(inferenceRequest, 0, listener);
    }

    @Generated
    public MLCommonsClientAccessor(MachineLearningNodeClient mlClient) {
        this.mlClient = mlClient;
    }
}

