/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchpipelines.questionanswering.generative.client;

import com.google.common.annotations.VisibleForTesting;
import java.util.function.Function;
import lombok.Generated;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;

public class MachineLearningInternalClient {
    private final Client client;

    public ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
        PlainActionFuture actionFuture = PlainActionFuture.newFuture();
        this.predict(modelId, mlInput, (ActionListener<MLOutput>)actionFuture);
        return actionFuture;
    }

    @VisibleForTesting
    void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
        this.validateMLInput(mlInput, true);
        MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().mlInput(mlInput).modelId(modelId).dispatchTask(true).build();
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)predictionRequest, this.getMlPredictionTaskResponseActionListener(listener));
    }

    private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
        ActionListener internalListener = ActionListener.wrap(predictionResponse -> listener.onResponse((Object)predictionResponse.getOutput()), arg_0 -> listener.onFailure(arg_0));
        ActionListener<MLTaskResponse> actionListener = this.wrapActionListener(internalListener, res -> {
            MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse((ActionResponse)res);
            return predictionResponse;
        });
        return actionListener;
    }

    private <T extends ActionResponse> ActionListener<T> wrapActionListener(ActionListener<T> listener, Function<ActionResponse, T> recreate) {
        ActionListener actionListener = ActionListener.wrap(r -> listener.onResponse((Object)((ActionResponse)recreate.apply((ActionResponse)r))), e -> listener.onFailure(e));
        return actionListener;
    }

    private void validateMLInput(MLInput mlInput, boolean requireInput) {
        if (mlInput == null) {
            throw new IllegalArgumentException("ML Input can't be null");
        }
        if (requireInput && mlInput.getInputDataset() == null) {
            throw new IllegalArgumentException("input data set can't be null");
        }
    }

    @Generated
    public MachineLearningInternalClient(Client client) {
        this.client = client;
    }
}

