/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.engine.utils.ZipUtils;

public abstract class DLModel
implements Predictable {
    @Generated
    private static final Logger log = LogManager.getLogger(DLModel.class);
    public static final String MODEL_ZIP_FILE = "model_zip_file";
    public static final String MODEL_HELPER = "model_helper";
    public static final String ML_ENGINE = "ml_engine";
    protected ModelHelper modelHelper;
    protected MLEngine mlEngine;
    protected String modelId;
    protected Predictor<Input, Output>[] predictors;
    protected ZooModel[] models;
    protected Device[] devices;
    protected AtomicInteger nextDevice = new AtomicInteger(0);

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        throw new IllegalArgumentException("model not deployed");
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        if (this.modelHelper == null || this.modelId == null) {
            throw new IllegalArgumentException("model not deployed");
        }
        try {
            return (MLOutput)AccessController.doPrivileged(() -> {
                Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());
                if (!this.isModelReady()) {
                    throw new MLException("model not deployed.");
                }
                return this.predict(this.modelId, mlInput);
            });
        }
        catch (Throwable e) {
            String errorMsg = "Failed to inference " + mlInput.getAlgorithm() + " model: " + this.modelId;
            log.error(errorMsg, e);
            throw new MLException(errorMsg, e);
        }
    }

    protected Predictor<Input, Output> getPredictor() {
        int currentDevice = this.nextDevice.getAndIncrement();
        if (currentDevice > this.devices.length - 1) {
            this.nextDevice.set((currentDevice %= this.devices.length) + 1);
        }
        return this.predictors[currentDevice];
    }

    public abstract ModelTensorOutput predict(String var1, MLInput var2) throws TranslateException;

    @Override
    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        String engine;
        switch (model.getModelFormat()) {
            case TORCH_SCRIPT: {
                engine = "PyTorch";
                break;
            }
            case ONNX: {
                engine = "OnnxRuntime";
                break;
            }
            default: {
                throw new IllegalArgumentException("unsupported engine");
            }
        }
        File modelZipFile = (File)params.get(MODEL_ZIP_FILE);
        this.modelHelper = (ModelHelper)params.get(MODEL_HELPER);
        this.mlEngine = (MLEngine)params.get(ML_ENGINE);
        if (modelZipFile == null) {
            throw new IllegalArgumentException("model file is null");
        }
        if (this.modelHelper == null) {
            throw new IllegalArgumentException("model helper is null");
        }
        if (this.mlEngine == null) {
            throw new IllegalArgumentException("ML engine is null");
        }
        this.modelId = model.getModelId();
        if (this.modelId == null) {
            throw new IllegalArgumentException("model id is null");
        }
        if (!FunctionName.isDLModel((FunctionName)model.getAlgorithm())) {
            throw new IllegalArgumentException("wrong function name");
        }
        this.loadModel(modelZipFile, this.modelId, model.getName(), model.getVersion(), model.getModelConfig(), engine);
    }

    @Override
    public void close() {
        if (this.modelHelper != null && this.modelId != null) {
            this.modelHelper.deleteFileCache(this.modelId);
            if (this.predictors != null) {
                this.closePredictors(this.predictors);
                this.predictors = null;
            }
            if (this.models != null) {
                this.closeModels(this.models);
                this.models = null;
            }
        }
    }

    @Override
    public boolean isModelReady() {
        return this.predictors != null && this.modelHelper != null && this.modelId != null;
    }

    public abstract Translator<Input, Output> getTranslator(String var1, MLModelConfig var2);

    public abstract TranslatorFactory getTranslatorFactory(String var1, MLModelConfig var2);

    public Map<String, Object> getArguments(MLModelConfig modelConfig) {
        return null;
    }

    public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
    }

    protected void doLoadModel(List<Predictor<Input, Output>> predictorList, List<ZooModel<Input, Output>> modelList, String engine, Path modelPath, MLModelConfig modelConfig) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        this.devices = Engine.getEngine((String)engine).getDevices();
        for (int i = 0; i < this.devices.length; ++i) {
            log.debug("load model {} to device {}: {}", (Object)this.modelId, (Object)i, (Object)this.devices[i]);
            Criteria.Builder criteriaBuilder = Criteria.builder().setTypes(Input.class, Output.class).optApplication(Application.UNDEFINED).optEngine(engine).optDevice(this.devices[i]).optModelPath(modelPath);
            Translator<Input, Output> translator = this.getTranslator(engine, modelConfig);
            TranslatorFactory translatorFactory = this.getTranslatorFactory(engine, modelConfig);
            if (translatorFactory != null) {
                criteriaBuilder.optTranslatorFactory(translatorFactory);
            } else if (translator != null) {
                criteriaBuilder.optTranslator(translator);
            }
            Map<String, Object> arguments = this.getArguments(modelConfig);
            if (arguments != null && arguments.size() > 0) {
                for (Map.Entry<String, Object> entry : arguments.entrySet()) {
                    criteriaBuilder.optArgument(entry.getKey(), entry.getValue());
                }
            }
            Criteria criteria = criteriaBuilder.build();
            ZooModel model = criteria.loadModel();
            Predictor predictor = model.newPredictor();
            predictorList.add((Predictor<Input, Output>)predictor);
            modelList.add((ZooModel<Input, Output>)model);
            this.warmUp(predictor, this.modelId, modelConfig);
        }
        if (predictorList.size() > 0) {
            this.predictors = predictorList.toArray(new Predictor[0]);
            predictorList.clear();
        }
        if (modelList.size() > 0) {
            this.models = modelList.toArray(new ZooModel[0]);
            modelList.clear();
        }
        log.info("Model {} is successfully deployed on {} devices", (Object)this.modelId, (Object)this.devices.length);
    }

    protected void loadModel(File modelZipFile, String modelId, String modelName, String version, MLModelConfig modelConfig, String engine) {
        try {
            if (!"PyTorch".equals(engine) && !"OnnxRuntime".equals(engine)) {
                throw new IllegalArgumentException("unsupported engine");
            }
            ArrayList predictorList = new ArrayList();
            ArrayList modelList = new ArrayList();
            AccessController.doPrivileged(() -> {
                ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
                try {
                    System.setProperty("PYTORCH_PRECXX11", "true");
                    System.setProperty("DJL_CACHE_DIR", this.mlEngine.getMlCachePath().toAbsolutePath().toString());
                    System.setProperty("java.library.path", this.mlEngine.getMlCachePath().toAbsolutePath().toString());
                    System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
                    System.setProperty("ai.djl.pytorch.num_threads", "1");
                    Thread.currentThread().setContextClassLoader(Model.class.getClassLoader());
                    Path modelPath = this.mlEngine.getModelCachePath(modelId, modelName, version);
                    File pathFile = new File(modelPath.toUri());
                    if (pathFile.exists()) {
                        org.apache.commons.io.FileUtils.deleteDirectory((File)pathFile);
                    }
                    ZipUtils.unzip(modelZipFile, modelPath);
                    boolean findModelFile = false;
                    for (File file : pathFile.listFiles()) {
                        String name = file.getName();
                        if (!name.endsWith(".pt") && !name.endsWith(".onnx")) continue;
                        if (findModelFile) {
                            throw new IllegalArgumentException("found multiple models");
                        }
                        findModelFile = true;
                        int dotIndex = name.lastIndexOf(".");
                        String suffix = name.substring(dotIndex);
                        String targetModelFileName = modelPath.getFileName().toString();
                        if (targetModelFileName.equals(name.substring(0, dotIndex))) continue;
                        file.renameTo(new File(modelPath.resolve(targetModelFileName + suffix).toUri()));
                    }
                    this.doLoadModel(predictorList, modelList, engine, modelPath, modelConfig);
                    File[] fileArray = null;
                    return fileArray;
                }
                catch (Throwable e) {
                    String errorMessage = "Failed to deploy model " + modelId;
                    log.error(errorMessage, e);
                    this.close();
                    if (predictorList.size() > 0) {
                        this.closePredictors(predictorList.toArray(new Predictor[0]));
                        predictorList.clear();
                    }
                    if (modelList.size() > 0) {
                        this.closeModels(modelList.toArray(new ZooModel[0]));
                        modelList.clear();
                    }
                    throw new MLException(errorMessage, e);
                }
                finally {
                    FileUtils.deleteFileQuietly(this.mlEngine.getDeployModelPath(modelId));
                    Thread.currentThread().setContextClassLoader(contextClassLoader);
                }
            });
        }
        catch (PrivilegedActionException e) {
            String errorMsg = "Failed to deploy model " + modelId;
            log.error(errorMsg, (Throwable)e);
            throw new MLException(errorMsg, (Throwable)e);
        }
    }

    protected void closePredictors(Predictor[] predictors) {
        log.debug("will close {} predictor for model {}", (Object)predictors.length, (Object)this.modelId);
        for (Predictor predictor : predictors) {
            predictor.close();
        }
    }

    protected void closeModels(ZooModel[] models) {
        log.debug("will close {} zoo model for model {}", (Object)models.length, (Object)this.modelId);
        for (ZooModel model : models) {
            model.close();
        }
    }

    public ModelTensors parseModelTensorOutput(Output output, ModelResultFilter resultFilter) {
        if (output == null) {
            throw new MLException("No output generated");
        }
        byte[] bytes = output.getData().getAsBytes();
        ModelTensors tensorOutput = ModelTensors.fromBytes((byte[])bytes);
        if (resultFilter != null) {
            tensorOutput.filter(resultFilter);
        }
        return tensorOutput;
    }
}

