/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.utils.ScriptUtils;
import org.opensearch.script.ScriptService;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.regions.Region;

public class ConnectorUtils {
    @Generated
    private static final Logger log = LogManager.getLogger(ConnectorUtils.class);
    private static final Aws4Signer signer = Aws4Signer.create();

    public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
        RemoteInferenceInputDataSet inputData;
        if (mlInput == null) {
            throw new IllegalArgumentException("Input is null");
        }
        if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
            inputData = ConnectorUtils.processTextDocsInput((TextDocsInputDataSet)mlInput.getInputDataset(), connector, parameters, scriptService);
        } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            inputData = (RemoteInferenceInputDataSet)mlInput.getInputDataset();
        } else {
            throw new IllegalArgumentException("Wrong input type");
        }
        if (inputData.getParameters() != null) {
            HashMap newParameters = new HashMap();
            inputData.getParameters().forEach((key, value) -> {
                if (value == null) {
                    newParameters.put(key, null);
                } else if (org.opensearch.ml.common.utils.StringUtils.isJson((String)value)) {
                    newParameters.put(key, value);
                } else {
                    newParameters.put(key, StringEscapeUtils.escapeJson((String)value));
                }
            });
            inputData.setParameters(newParameters);
        }
        return inputData;
    }

    private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
        Optional<String> processedInput;
        Optional predictAction = connector.findPredictAction();
        if (predictAction.isEmpty()) {
            throw new IllegalArgumentException("no predict action found");
        }
        String preProcessFunction = ((ConnectorAction)predictAction.get()).getPreProcessFunction();
        String string = preProcessFunction = preProcessFunction == null ? "connector.pre_process.default.embedding" : preProcessFunction;
        if (MLPreProcessFunction.contains((String)preProcessFunction)) {
            Map buildInFunctionResult = (Map)MLPreProcessFunction.get((String)preProcessFunction).apply(inputDataSet.getDocs());
            return RemoteInferenceInputDataSet.builder().parameters(ConnectorUtils.convertScriptStringToJsonString(buildInFunctionResult)).build();
        }
        ArrayList<String> docs = new ArrayList<String>();
        for (String doc : inputDataSet.getDocs()) {
            if (doc != null) {
                String gsonString = org.opensearch.ml.common.utils.StringUtils.gson.toJson((Object)doc);
                docs.add(gsonString.substring(1, gsonString.length() - 1));
                continue;
            }
            docs.add(null);
        }
        if (preProcessFunction.contains("${parameters.")) {
            StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
            preProcessFunction = substitutor.replace(preProcessFunction);
        }
        if ((processedInput = ScriptUtils.executePreprocessFunction(scriptService, preProcessFunction, docs)).isEmpty()) {
            throw new IllegalArgumentException("Wrong input");
        }
        Map map = (Map)org.opensearch.ml.common.utils.StringUtils.gson.fromJson(processedInput.get(), Map.class);
        return RemoteInferenceInputDataSet.builder().parameters(ConnectorUtils.convertScriptStringToJsonString(map)).build();
    }

    private static Map<String, String> convertScriptStringToJsonString(Map<String, Object> processedInput) {
        HashMap<String, String> parameterStringMap = new HashMap<String, String>();
        try {
            AccessController.doPrivileged(() -> {
                Map parametersMap = (Map)processedInput.get("parameters");
                for (String key : parametersMap.keySet()) {
                    if (parametersMap.get(key) instanceof String) {
                        parameterStringMap.put(key, (String)parametersMap.get(key));
                        continue;
                    }
                    parameterStringMap.put(key, org.opensearch.ml.common.utils.StringUtils.gson.toJson(parametersMap.get(key)));
                }
                return null;
            });
        }
        catch (PrivilegedActionException e) {
            log.error("Error processing parameters", (Throwable)e);
            throw new RuntimeException(e);
        }
        return parameterStringMap;
    }

    public static ModelTensors processOutput(String modelResponse, Connector connector, ScriptService scriptService, Map<String, String> parameters) throws IOException {
        boolean scriptReturnModelTensor;
        if (modelResponse == null) {
            throw new IllegalArgumentException("model response is null");
        }
        ArrayList modelTensors = new ArrayList();
        Optional predictAction = connector.findPredictAction();
        if (predictAction.isEmpty()) {
            throw new IllegalArgumentException("no predict action found");
        }
        ConnectorAction connectorAction = (ConnectorAction)predictAction.get();
        String postProcessFunction = connectorAction.getPostProcessFunction();
        if (postProcessFunction != null && postProcessFunction.contains("${parameters")) {
            StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
            postProcessFunction = substitutor.replace(postProcessFunction);
        }
        String responseFilter = parameters.get("response_filter");
        if (MLPostProcessFunction.contains((String)postProcessFunction)) {
            if (StringUtils.isBlank((CharSequence)responseFilter)) {
                responseFilter = MLPostProcessFunction.getResponseFilter((String)postProcessFunction);
            }
            List vectors = (List)JsonPath.read((String)modelResponse, (String)responseFilter, (Predicate[])new Predicate[0]);
            List<ModelTensor> processedResponse = ScriptUtils.executeBuildInPostProcessFunction(vectors, MLPostProcessFunction.get((String)postProcessFunction));
            return ModelTensors.builder().mlModelTensors(processedResponse).build();
        }
        Optional<String> processedResponse = ScriptUtils.executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
        String response = processedResponse.orElse(modelResponse);
        boolean bl = scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson((String)response);
        if (responseFilter == null) {
            connector.parseResponse((Object)response, modelTensors, scriptReturnModelTensor);
        } else {
            Object filteredResponse = JsonPath.parse((String)response).read(parameters.get("response_filter"), new Predicate[0]);
            connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor);
        }
        return ModelTensors.builder().mlModelTensors(modelTensors).build();
    }

    public static SdkHttpFullRequest signRequest(SdkHttpFullRequest request, String accessKey, String secretKey, String sessionToken, String signingName, String region) {
        AwsBasicCredentials credentials = sessionToken == null ? AwsBasicCredentials.create((String)accessKey, (String)secretKey) : AwsSessionCredentials.create((String)accessKey, (String)secretKey, (String)sessionToken);
        Aws4SignerParams params = Aws4SignerParams.builder().awsCredentials((AwsCredentials)credentials).signingName(signingName).signingRegion(Region.of((String)region)).build();
        return signer.sign(request, params);
    }
}

