/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.text_embedding;

import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.algorithms.text_embedding.HuggingfaceTextEmbeddingTranslatorFactory;
import org.opensearch.ml.engine.algorithms.text_embedding.ONNXSentenceTransformerTextEmbeddingTranslator;
import org.opensearch.ml.engine.algorithms.text_embedding.SentenceTransformerTextEmbeddingTranslator;
import org.opensearch.ml.engine.annotation.Function;

@Function(value=FunctionName.TEXT_EMBEDDING)
public class TextEmbeddingModel
extends DLModel {
    @Generated
    private static final Logger log = LogManager.getLogger(TextEmbeddingModel.class);
    public static final String SENTENCE_EMBEDDING = "sentence_embedding";

    @Override
    public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
        MLInputDataset inputDataSet = mlInput.getInputDataset();
        ArrayList<ModelTensors> tensorOutputs = new ArrayList<ModelTensors>();
        TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet)inputDataSet;
        ModelResultFilter resultFilter = textDocsInput.getResultFilter();
        for (String doc : textDocsInput.getDocs()) {
            Input input = new Input();
            input.add(doc);
            Output output = (Output)this.getPredictor().predict((Object)input);
            tensorOutputs.add(this.parseModelTensorOutput(output, resultFilter));
        }
        return new ModelTensorOutput(tensorOutputs);
    }

    @Override
    public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) {
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        TextEmbeddingModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType();
        String modelType = textEmbeddingModelConfig.getModelType();
        TextEmbeddingModelConfig.PoolingMode poolingMode = textEmbeddingModelConfig.getPoolingMode();
        boolean normalizeResult = textEmbeddingModelConfig.isNormalizeResult();
        if ("OnnxRuntime".equals(engine)) {
            return new ONNXSentenceTransformerTextEmbeddingTranslator(poolingMode, normalizeResult, modelType);
        }
        if (transformersType == TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) {
            return new SentenceTransformerTextEmbeddingTranslator();
        }
        return null;
    }

    @Override
    public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        TextEmbeddingModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType();
        String modelType = textEmbeddingModelConfig.getModelType();
        TextEmbeddingModelConfig.PoolingMode poolingMode = textEmbeddingModelConfig.getPoolingMode();
        boolean normalizeResult = textEmbeddingModelConfig.isNormalizeResult();
        if ("PyTorch".equals(engine) && transformersType != TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) {
            boolean neuron = false;
            if (transformersType.name().endsWith("_NEURON")) {
                neuron = true;
            }
            return new HuggingfaceTextEmbeddingTranslatorFactory(poolingMode, normalizeResult, modelType, neuron);
        }
        return null;
    }

    @Override
    public Map<String, Object> getArguments(MLModelConfig modelConfig) {
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        Integer modelMaxLength = textEmbeddingModelConfig.getModelMaxLength();
        HashMap<String, Object> arguments = new HashMap<String, Object>();
        if (modelMaxLength != null) {
            arguments.put("modelMaxLength", modelMaxLength);
        }
        return arguments;
    }

    @Override
    public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        Integer modelMaxLength = textEmbeddingModelConfig.getModelMaxLength();
        String warmUpSentence = "warm up sentence";
        if (modelMaxLength != null) {
            warmUpSentence = "sentence ".repeat(modelMaxLength);
        }
        Input input = new Input();
        input.add(warmUpSentence);
        predictor.predict((Object)input);
    }
}

