/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.agent.tools;

import com.google.gson.Gson;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.Generated;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.sql.plugin.transport.PPLQueryAction;
import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest;
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;
import org.opensearch.sql.ppl.domain.PPLQueryRequest;

@ToolAnnotation(value="PPLTool")
public class PPLTool
implements Tool {
    @Generated
    private static final Logger log = LogManager.getLogger(PPLTool.class);
    public static final String TYPE = "PPLTool";
    private Client client;
    private static final String DEFAULT_DESCRIPTION = "\"Use this tool when user ask question based on the data in the cluster or parse user statement about which index to use in a conversion.\nAlso use this tool when question only contains index information.\n1. If uesr question contain both question and index name, the input parameters are {'question': UserQuestion, 'index': IndexName}.\n2. If user question contain only question, the input parameter is {'question': UserQuestion}.\n3. If uesr question contain only index name, find the original human input from the conversation histroy and formulate parameter as {'question': UserQuestion, 'index': IndexName}\nThe index name should be exactly as stated in user's input.";
    private String name = "PPLTool";
    private String description = "\"Use this tool when user ask question based on the data in the cluster or parse user statement about which index to use in a conversion.\nAlso use this tool when question only contains index information.\n1. If uesr question contain both question and index name, the input parameters are {'question': UserQuestion, 'index': IndexName}.\n2. If user question contain only question, the input parameter is {'question': UserQuestion}.\n3. If uesr question contain only index name, find the original human input from the conversation histroy and formulate parameter as {'question': UserQuestion, 'index': IndexName}\nThe index name should be exactly as stated in user's input.";
    private String version;
    private String modelId;
    private String contextPrompt;
    private Boolean execute;
    private PPLModelType pplModelType;
    private String previousToolKey;
    private int head;
    private static Gson gson = StringUtils.gson;
    private static Map<String, String> DEFAULT_PROMPT_DICT;
    private static Set<String> ALLOWED_FIELDS_TYPE;

    public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType, String previousToolKey, int head, boolean execute) {
        this.client = client;
        this.modelId = modelId;
        this.pplModelType = PPLModelType.from(pplModelType);
        this.contextPrompt = contextPrompt.isEmpty() ? DEFAULT_PROMPT_DICT.getOrDefault(this.pplModelType.toString(), "") : contextPrompt;
        this.previousToolKey = previousToolKey;
        this.head = head;
        this.execute = execute;
    }

    public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
        this.extractFromChatParameters(parameters);
        String indexName = this.getIndexNameFromParameters(parameters);
        if (org.apache.commons.lang3.StringUtils.isBlank((CharSequence)indexName)) {
            throw new IllegalArgumentException("Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name");
        }
        String question = parameters.get("question");
        if (org.apache.commons.lang3.StringUtils.isBlank((CharSequence)indexName) || org.apache.commons.lang3.StringUtils.isBlank((CharSequence)question)) {
            throw new IllegalArgumentException("Parameter index and question can not be null or empty.");
        }
        if (indexName.startsWith(".")) {
            throw new IllegalArgumentException("PPLTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + indexName);
        }
        GetMappingsRequest getMappingsRequest = this.buildGetMappingRequest(indexName);
        this.client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> {
            Map mappings = getMappingsResponse.getMappings();
            if (mappings.isEmpty()) {
                throw new IllegalArgumentException("No matching mapping with index name: " + indexName);
            }
            String firstIndexName = (String)mappings.keySet().toArray()[0];
            SearchRequest searchRequest = this.buildSearchRequest(firstIndexName);
            this.client.search(searchRequest, ActionListener.wrap(searchResponse -> {
                SearchHit[] searchHits = searchResponse.getHits().getHits();
                String tableInfo = this.constructTableInfo(searchHits, mappings);
                String prompt = this.constructPrompt(tableInfo, question.strip(), indexName);
                RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Collections.singletonMap("prompt", prompt)).build();
                MLPredictionTaskRequest request = new MLPredictionTaskRequest(this.modelId, MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)inputDataSet).build());
                this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)request, ActionListener.wrap(mlTaskResponse -> {
                    ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlTaskResponse.getOutput();
                    ModelTensors modelTensors = (ModelTensors)modelTensorOutput.getMlModelOutputs().get(0);
                    ModelTensor modelTensor = (ModelTensor)modelTensors.getMlModelTensors().get(0);
                    Map dataAsMap = modelTensor.getDataAsMap();
                    String ppl = this.parseOutput((String)dataAsMap.get("response"), indexName);
                    if (!this.execute.booleanValue()) {
                        ImmutableMap ret = ImmutableMap.of((Object)"ppl", (Object)ppl);
                        listener.onResponse((Object)AccessController.doPrivileged(() -> PPLTool.lambda$run$0((Map)ret)));
                        return;
                    }
                    JSONObject jsonContent = new JSONObject((Map)ImmutableMap.of((Object)"query", (Object)ppl));
                    PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc");
                    TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest);
                    this.client.execute((ActionType)PPLQueryAction.INSTANCE, (ActionRequest)transportPPLQueryRequest, this.getPPLTransportActionListener((ActionListener<TransportPPLQueryResponse>)ActionListener.wrap(transportPPLQueryResponse -> {
                        String results = transportPPLQueryResponse.getResult();
                        ImmutableMap returnResults = ImmutableMap.of((Object)"ppl", (Object)ppl, (Object)"executionResult", (Object)results);
                        listener.onResponse((Object)AccessController.doPrivileged(() -> PPLTool.lambda$run$1((Map)returnResults)));
                    }, e -> {
                        String pplError = "execute ppl:" + ppl + ", get error: " + e.getMessage();
                        Exception exception = new Exception(pplError);
                        listener.onFailure(exception);
                    })));
                }, e -> {
                    log.error(String.format(Locale.ROOT, "fail to predict model: %s with error: %s", this.modelId, e.getMessage()), (Throwable)e);
                    listener.onFailure(e);
                }));
            }, e -> {
                log.error(String.format(Locale.ROOT, "fail to search model: %s with error: %s", this.modelId, e.getMessage()), (Throwable)e);
                listener.onFailure(e);
            }));
        }, e -> {
            log.error(String.format(Locale.ROOT, "fail to get mapping of index: %s with error: %s", indexName, e.getMessage()), (Throwable)e);
            String errorMessage = e.getMessage();
            if (errorMessage.contains("no such index")) {
                listener.onFailure((Exception)new IllegalArgumentException("Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name"));
            } else {
                listener.onFailure(e);
            }
        }));
    }

    public String getType() {
        return TYPE;
    }

    public String getName() {
        return this.name;
    }

    public boolean validate(Map<String, String> parameters) {
        return parameters != null && !parameters.isEmpty();
    }

    private SearchRequest buildSearchRequest(String indexName) {
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.size(1).query((QueryBuilder)new MatchAllQueryBuilder());
        return new SearchRequest(new String[]{indexName}, searchSourceBuilder);
    }

    private GetMappingsRequest buildGetMappingRequest(String indexName) {
        String[] indices = new String[]{indexName};
        GetMappingsRequest getMappingsRequest = new GetMappingsRequest();
        getMappingsRequest.indices(indices);
        return getMappingsRequest;
    }

    private static void validatePPLToolParameters(Map<String, Object> map) {
        String execute;
        if (org.apache.commons.lang3.StringUtils.isBlank((CharSequence)((String)map.get("model_id")))) {
            throw new IllegalArgumentException("PPL tool needs non blank model id.");
        }
        if (map.containsKey("execute") && Objects.nonNull(map.get("execute")) && !(execute = map.get("execute").toString().toLowerCase(Locale.ROOT)).equals("true") && !execute.equals("false")) {
            throw new IllegalArgumentException("PPL tool parameter execute must be false or true");
        }
        if (map.containsKey("head")) {
            String head = map.get("head").toString();
            try {
                int n = NumberUtils.createInteger((String)head);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("PPL tool parameter head must be integer.");
            }
        }
    }

    private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMetadata> mappings) throws PrivilegedActionException {
        String firstIndexName = (String)mappings.keySet().toArray()[0];
        MappingMetadata mappingMetadata = mappings.get(firstIndexName);
        Map mappingSource = (Map)mappingMetadata.getSourceAsMap().get("properties");
        if (Objects.isNull(mappingSource)) {
            throw new IllegalArgumentException("The querying index doesn't have mapping metadata, please add data to it or using another index.");
        }
        HashMap<String, String> fieldsToType = new HashMap<String, String>();
        ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", false);
        StringJoiner tableInfoJoiner = new StringJoiner("\n");
        ArrayList sortedKeys = new ArrayList(fieldsToType.keySet());
        Collections.sort(sortedKeys);
        if (searchHits.length > 0) {
            SearchHit hit = searchHits[0];
            Map sampleSource = hit.getSourceAsMap();
            HashMap<String, String> fieldsToSample = new HashMap<String, String>();
            for (String key : fieldsToType.keySet()) {
                fieldsToSample.put(key, "");
            }
            PPLTool.extractSamples(sampleSource, fieldsToSample, "");
            for (String key : sortedKeys) {
                if (!ALLOWED_FIELDS_TYPE.contains(fieldsToType.get(key))) continue;
                String line = "- " + key + ": " + (String)fieldsToType.get(key) + " (" + (String)fieldsToSample.get(key) + ")";
                tableInfoJoiner.add(line);
            }
        } else {
            for (String key : sortedKeys) {
                if (!ALLOWED_FIELDS_TYPE.contains(fieldsToType.get(key))) continue;
                String line = "- " + key + ": " + (String)fieldsToType.get(key);
                tableInfoJoiner.add(line);
            }
        }
        return tableInfoJoiner.toString();
    }

    private String constructPrompt(String tableInfo, String question, String indexName) {
        ImmutableMap indexInfo = ImmutableMap.of((Object)"mappingInfo", (Object)tableInfo, (Object)"question", (Object)question, (Object)"indexName", (Object)indexName);
        StringSubstitutor substitutor = new StringSubstitutor((Map)indexInfo, "${indexInfo.", "}");
        return substitutor.replace(this.contextPrompt);
    }

    private static void extractSamples(Map<String, Object> sampleSource, Map<String, String> fieldsToSample, String prefix) throws PrivilegedActionException {
        if (!((String)prefix).isEmpty()) {
            prefix = (String)prefix + ".";
        }
        for (Map.Entry<String, Object> entry : sampleSource.entrySet()) {
            String p = entry.getKey();
            Object v = entry.getValue();
            String fullKey = (String)prefix + p;
            if (fieldsToSample.containsKey(fullKey)) {
                fieldsToSample.put(fullKey, AccessController.doPrivileged(() -> gson.toJson(v)));
                continue;
            }
            if (!(v instanceof Map)) continue;
            PPLTool.extractSamples((Map)v, fieldsToSample, fullKey);
        }
    }

    private <T extends ActionResponse> ActionListener<T> getPPLTransportActionListener(ActionListener<TransportPPLQueryResponse> listener) {
        return ActionListener.wrap(r -> listener.onResponse((Object)TransportPPLQueryResponse.fromActionResponse((ActionResponse)r)), arg_0 -> listener.onFailure(arg_0));
    }

    private void extractFromChatParameters(Map<String, String> parameters) {
        if (parameters.containsKey("input")) {
            String input = parameters.get("input");
            try {
                Map chatParameters = (Map)gson.fromJson(input, Map.class);
                parameters.putAll(chatParameters);
            }
            catch (Exception e) {
                log.error(String.format(Locale.ROOT, "Failed to parse chat parameters, input is: %s, which is not a valid json", input), (Throwable)e);
            }
        }
    }

    private String parseOutput(String llmOutput, String indexName) {
        String[] lists;
        String lastCommand;
        Object ppl;
        Pattern pattern = Pattern.compile("<ppl>((.|[\\r\\n])+?)</ppl>");
        Matcher matcher = pattern.matcher(llmOutput);
        if (matcher.find()) {
            ppl = matcher.group(1).replaceAll("[\\r\\n]", "").replaceAll("ISNOTNULL", "isnotnull").trim();
        } else {
            int sourceIndex = llmOutput.indexOf("source=");
            int describeIndex = llmOutput.indexOf("describe ");
            if (sourceIndex != -1) {
                CharSequence[] lists2 = (llmOutput = llmOutput.substring(sourceIndex)).split("\\|");
                if (lists2.length > 0) {
                    lists2[0] = "source=" + indexName;
                }
                ppl = String.join((CharSequence)"|", lists2);
            } else if (describeIndex != -1) {
                CharSequence[] lists3 = (llmOutput = llmOutput.substring(describeIndex)).split("\\|");
                if (lists3.length > 0) {
                    lists3[0] = "describe " + indexName;
                }
                ppl = String.join((CharSequence)"|", lists3);
            } else {
                throw new IllegalArgumentException("The returned PPL: " + llmOutput + " has wrong format");
            }
        }
        if (this.pplModelType != PPLModelType.FINETUNE) {
            ppl = ((String)ppl).replace("`", "");
        }
        ppl = ((String)ppl).replaceAll("\\bSPAN\\(", "span(");
        if (this.head > 0 && !(lastCommand = (lists = llmOutput.split("\\|"))[lists.length - 1].strip()).toLowerCase(Locale.ROOT).startsWith("head")) {
            ppl = (String)ppl + " | head " + this.head;
        }
        return ppl;
    }

    private String getIndexNameFromParameters(Map<String, String> parameters) {
        String indexName = parameters.getOrDefault("index", "");
        if (!org.apache.commons.lang3.StringUtils.isBlank((CharSequence)this.previousToolKey) && org.apache.commons.lang3.StringUtils.isBlank((CharSequence)indexName)) {
            indexName = parameters.getOrDefault(this.previousToolKey + ".output", "");
        }
        return indexName.trim();
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static Map<String, String> loadDefaultPromptDict() {
        try (InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json");){
            if (searchResponseIns == null) return new HashMap<String, String>();
            String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
            Map map = (Map)gson.fromJson(defaultPromptContent, Map.class);
            return map;
        }
        catch (IOException e) {
            log.error("Failed to load default prompt dict", (Throwable)e);
        }
        return new HashMap<String, String>();
    }

    @Generated
    public void setVersion(String version) {
        this.version = version;
    }

    @Generated
    public void setModelId(String modelId) {
        this.modelId = modelId;
    }

    @Generated
    public void setContextPrompt(String contextPrompt) {
        this.contextPrompt = contextPrompt;
    }

    @Generated
    public void setExecute(Boolean execute) {
        this.execute = execute;
    }

    @Generated
    public void setPplModelType(PPLModelType pplModelType) {
        this.pplModelType = pplModelType;
    }

    @Generated
    public void setPreviousToolKey(String previousToolKey) {
        this.previousToolKey = previousToolKey;
    }

    @Generated
    public void setHead(int head) {
        this.head = head;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public String getModelId() {
        return this.modelId;
    }

    @Generated
    public String getContextPrompt() {
        return this.contextPrompt;
    }

    @Generated
    public Boolean getExecute() {
        return this.execute;
    }

    @Generated
    public PPLModelType getPplModelType() {
        return this.pplModelType;
    }

    @Generated
    public String getPreviousToolKey() {
        return this.previousToolKey;
    }

    @Generated
    public int getHead() {
        return this.head;
    }

    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Generated
    public void setName(String name) {
        this.name = name;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }

    @Generated
    public void setDescription(String description) {
        this.description = description;
    }

    @Generated
    public String getVersion() {
        return this.version;
    }

    private static /* synthetic */ String lambda$run$1(Map returnResults) throws Exception {
        return gson.toJson((Object)returnResults);
    }

    private static /* synthetic */ String lambda$run$0(Map ret) throws Exception {
        return gson.toJson((Object)ret);
    }

    static {
        ALLOWED_FIELDS_TYPE = new HashSet<String>();
        ALLOWED_FIELDS_TYPE.add("boolean");
        ALLOWED_FIELDS_TYPE.add("byte");
        ALLOWED_FIELDS_TYPE.add("short");
        ALLOWED_FIELDS_TYPE.add("integer");
        ALLOWED_FIELDS_TYPE.add("long");
        ALLOWED_FIELDS_TYPE.add("float");
        ALLOWED_FIELDS_TYPE.add("half_float");
        ALLOWED_FIELDS_TYPE.add("scaled_float");
        ALLOWED_FIELDS_TYPE.add("double");
        ALLOWED_FIELDS_TYPE.add("keyword");
        ALLOWED_FIELDS_TYPE.add("text");
        ALLOWED_FIELDS_TYPE.add("date");
        ALLOWED_FIELDS_TYPE.add("date_nanos");
        ALLOWED_FIELDS_TYPE.add("ip");
        ALLOWED_FIELDS_TYPE.add("binary");
        ALLOWED_FIELDS_TYPE.add("object");
        ALLOWED_FIELDS_TYPE.add("nested");
        ALLOWED_FIELDS_TYPE.add("geo_point");
        DEFAULT_PROMPT_DICT = PPLTool.loadDefaultPromptDict();
    }

    public static enum PPLModelType {
        CLAUDE,
        FINETUNE,
        OPENAI;


        public static PPLModelType from(String value) {
            if (value.isEmpty()) {
                return CLAUDE;
            }
            try {
                return PPLModelType.valueOf(value.toUpperCase(Locale.ROOT));
            }
            catch (Exception e) {
                log.error("Wrong PPL Model type, should be CLAUDE, FINETUNE, or OPENAI");
                return CLAUDE;
            }
        }
    }

    public static class Factory
    implements Tool.Factory<PPLTool> {
        private Client client;
        private static Factory INSTANCE;

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public static Factory getInstance() {
            if (INSTANCE != null) {
                return INSTANCE;
            }
            Class<PPLTool> clazz = PPLTool.class;
            synchronized (PPLTool.class) {
                if (INSTANCE != null) {
                    // ** MonitorExit[var0] (shouldn't be in output)
                    return INSTANCE;
                }
                INSTANCE = new Factory();
                // ** MonitorExit[var0] (shouldn't be in output)
                return INSTANCE;
            }
        }

        public void init(Client client) {
            this.client = client;
        }

        public PPLTool create(Map<String, Object> map) {
            PPLTool.validatePPLToolParameters(map);
            return new PPLTool(this.client, (String)map.get("model_id"), (String)map.getOrDefault("prompt", ""), (String)map.getOrDefault("model_type", ""), (String)map.getOrDefault("previous_tool_name", ""), NumberUtils.toInt((String)((String)map.get("head")), (int)-1), Boolean.parseBoolean((String)map.getOrDefault("execute", "true")));
        }

        public String getDefaultDescription() {
            return PPLTool.DEFAULT_DESCRIPTION;
        }

        public String getDefaultType() {
            return PPLTool.TYPE;
        }

        public String getDefaultVersion() {
            return null;
        }
    }
}

