/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchpipelines.questionanswering.generative;

import com.google.gson.JsonArray;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeSearchResponse;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;

public class GenerativeQAResponseProcessor
extends AbstractProcessor
implements SearchResponseProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(GenerativeQAResponseProcessor.class);
    private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;
    private final String llmModel;
    private final List<String> contextFields;
    private ConversationalMemoryClient memoryClient;
    private Llm llm;
    private final BooleanSupplier featureFlagSupplier;

    protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure, Llm llm, String llmModel, List<String> contextFields, BooleanSupplier supplier) {
        super(tag, description, ignoreFailure);
        this.llmModel = llmModel;
        this.contextFields = contextFields;
        this.llm = llm;
        this.memoryClient = new ConversationalMemoryClient(client);
        this.featureFlagSupplier = supplier;
    }

    public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
        log.info("Entering processResponse.");
        if (!this.featureFlagSupplier.getAsBoolean()) {
            throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG);
        }
        GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
        String llmQuestion = params.getLlmQuestion();
        String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
        String conversationId = params.getConversationId();
        log.info("LLM question: {}, LLM model {}, conversation id: {}", (Object)llmQuestion, (Object)llmModel, (Object)conversationId);
        List<Interaction> chatHistory = conversationId == null ? Collections.emptyList() : this.memoryClient.getInteractions(conversationId, 10);
        List<String> searchResults = this.getSearchResults(response);
        ChatCompletionOutput output = this.llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults));
        String answer = (String)output.getAnswers().get(0);
        String interactionId = null;
        if (conversationId != null) {
            interactionId = this.memoryClient.createInteraction(conversationId, llmQuestion, "Generate a concise and informative answer in less than 100 words for the given question, taking into context: - An enumerated list of search results- A rephrase of the question that was used to generate the search results- The conversation historyCite search results using [${number}] notation.Do not repeat yourself, and NEVER repeat anything in the chat history.If there are any necessary steps or procedures in your answer, enumerate them.", answer, "retrieval_augmented_generation", GenerativeQAResponseProcessor.jsonArrayToString(searchResults));
        }
        return this.insertAnswer(response, answer, interactionId);
    }

    public String getType() {
        return "retrieval_augmented_generation";
    }

    private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) {
        return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters());
    }

    private List<String> getSearchResults(SearchResponse response) {
        ArrayList<String> searchResults = new ArrayList<String>();
        for (SearchHit hit : response.getHits().getHits()) {
            Map docSourceMap = hit.getSourceAsMap();
            for (String contextField : this.contextFields) {
                Object context = docSourceMap.get(contextField);
                if (context == null) {
                    log.error("Context " + contextField + " not found in search hit " + hit);
                    throw new RuntimeException();
                }
                searchResults.add(context.toString());
            }
        }
        return searchResults;
    }

    private static String jsonArrayToString(List<String> listOfStrings) {
        JsonArray array = new JsonArray(listOfStrings.size());
        listOfStrings.forEach(arg_0 -> ((JsonArray)array).add(arg_0));
        return array.toString();
    }

    @Generated
    public void setMemoryClient(ConversationalMemoryClient memoryClient) {
        this.memoryClient = memoryClient;
    }

    @Generated
    public Llm getLlm() {
        return this.llm;
    }

    @Generated
    public void setLlm(Llm llm) {
        this.llm = llm;
    }

    public static final class Factory
    implements Processor.Factory<SearchResponseProcessor> {
        private final Client client;
        private final BooleanSupplier featureFlagSupplier;

        public Factory(Client client, BooleanSupplier supplier) {
            this.client = client;
            this.featureFlagSupplier = supplier;
        }

        public SearchResponseProcessor create(Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, String tag, String description, boolean ignoreFailure, Map<String, Object> config, Processor.PipelineContext pipelineContext) throws Exception {
            if (this.featureFlagSupplier.getAsBoolean()) {
                String modelId = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"model_id");
                String llmModel = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"llm_model");
                List contextFields = ConfigurationUtils.readList((String)"retrieval_augmented_generation", (String)tag, config, (String)"context_field_list");
                if (contextFields.isEmpty()) {
                    throw ConfigurationUtils.newConfigurationException((String)"retrieval_augmented_generation", (String)tag, (String)"context_field_list", (String)"required property can't be empty.");
                }
                log.info("model_id {}, llm_model {}, context_field_list {}", (Object)modelId, (Object)llmModel, (Object)contextFields);
                return new GenerativeQAResponseProcessor(this.client, tag, description, ignoreFailure, ModelLocator.getLlm(modelId, this.client), llmModel, contextFields, this.featureFlagSupplier);
            }
            throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG);
        }
    }
}

