/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.memory;

import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.MLMemoryType;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.memory.Memory;
import org.opensearch.ml.common.memory.Message;
import org.opensearch.ml.common.memorycontainer.MLWorkingMemory;
import org.opensearch.ml.common.memorycontainer.MemoryType;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLGetMemoryAction;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLGetMemoryRequest;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesAction;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesInput;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesRequest;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryAction;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryInput;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryRequest;
import org.opensearch.ml.common.transport.session.MLCreateSessionAction;
import org.opensearch.ml.common.transport.session.MLCreateSessionInput;
import org.opensearch.ml.common.transport.session.MLCreateSessionRequest;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.transport.client.Client;

public class AgenticConversationMemory
implements Memory<Message, CreateInteractionResponse, UpdateResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(AgenticConversationMemory.class);
    public static final String TYPE = MLMemoryType.AGENTIC_MEMORY.name();
    private static final String SESSION_ID_FIELD = "session_id";
    private static final String CREATED_TIME_FIELD = "created_time";
    private final Client client;
    private final String conversationId;
    private final String memoryContainerId;

    public AgenticConversationMemory(Client client, String memoryId, String memoryContainerId) {
        this.client = client;
        this.conversationId = memoryId;
        this.memoryContainerId = memoryContainerId;
    }

    public String getType() {
        return TYPE;
    }

    public String getId() {
        return this.conversationId;
    }

    public void save(Message message, String parentId, Integer traceNum, String action) {
        this.save(message, parentId, traceNum, action, (ActionListener<CreateInteractionResponse>)ActionListener.wrap(r -> log.info("Saved message to agentic memory, session id: {}, working memory id: {}", (Object)this.conversationId, (Object)r.getId()), e -> log.error("Failed to save message to agentic memory", (Throwable)e)));
    }

    public void save(Message message, String parentId, Integer traceNum, String action, ActionListener<CreateInteractionResponse> listener) {
        if (Strings.isNullOrEmpty((String)this.memoryContainerId)) {
            listener.onFailure((Exception)new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory. Cannot save messages without a valid memory container."));
            return;
        }
        ConversationIndexMessage msg = (ConversationIndexMessage)message;
        HashMap<String, String> namespace = new HashMap<String, String>();
        namespace.put(SESSION_ID_FIELD, this.conversationId);
        boolean isTrace = traceNum != null;
        HashMap<String, String> metadata = new HashMap<String, String>();
        HashMap<String, Object> structuredData = new HashMap<String, Object>();
        structuredData.put("input", msg.getQuestion() != null ? msg.getQuestion() : "");
        structuredData.put("response", msg.getResponse() != null ? msg.getResponse() : "");
        if (isTrace) {
            metadata.put("type", "trace");
            if (parentId != null) {
                metadata.put("parent_message_id", parentId);
                structuredData.put("parent_message_id", parentId);
            }
            metadata.put("trace_number", String.valueOf(traceNum));
            structuredData.put("trace_number", traceNum);
            if (action != null) {
                metadata.put("origin", action);
                structuredData.put("origin", action);
            }
        } else {
            metadata.put("type", "message");
            if (msg.getFinalAnswer() != null) {
                structuredData.put("final_answer", msg.getFinalAnswer());
            }
        }
        Instant now = Instant.now();
        structuredData.put("create_time", now.toString());
        structuredData.put("updated_time", now.toString());
        MLAddMemoriesInput input = MLAddMemoriesInput.builder().memoryContainerId(this.memoryContainerId).structuredDataBlob(structuredData).messageId(traceNum).namespace(namespace).metadata(metadata).infer(false).build();
        MLAddMemoriesRequest request = MLAddMemoriesRequest.builder().mlAddMemoryInput(input).build();
        this.client.execute((ActionType)MLAddMemoriesAction.INSTANCE, (ActionRequest)request, ActionListener.wrap(response -> {
            CreateInteractionResponse interactionResponse = new CreateInteractionResponse(response.getWorkingMemoryId());
            listener.onResponse((Object)interactionResponse);
        }, e -> {
            log.error("Failed to add memories to memory container", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    public void update(String messageId, Map<String, Object> updateContent, ActionListener<UpdateResponse> updateListener) {
        if (Strings.isNullOrEmpty((String)this.memoryContainerId)) {
            updateListener.onFailure((Exception)new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory"));
            return;
        }
        MLGetMemoryRequest getRequest = MLGetMemoryRequest.builder().memoryContainerId(this.memoryContainerId).memoryType(MemoryType.WORKING).memoryId(messageId).build();
        this.client.execute((ActionType)MLGetMemoryAction.INSTANCE, (ActionRequest)getRequest, ActionListener.wrap(getResponse -> {
            MLWorkingMemory workingMemory = getResponse.getWorkingMemory();
            if (workingMemory == null) {
                updateListener.onFailure((Exception)new IllegalStateException("Working memory not found for id: " + messageId));
                return;
            }
            HashMap structuredData = workingMemory.getStructuredDataBlob();
            structuredData = structuredData == null ? new HashMap() : new HashMap(structuredData);
            for (Map.Entry entry : updateContent.entrySet()) {
                structuredData.put((String)entry.getKey(), entry.getValue());
            }
            structuredData.put("updated_time", Instant.now().toString());
            HashMap finalUpdateContent = new HashMap();
            finalUpdateContent.put("structured_data_blob", structuredData);
            MLUpdateMemoryInput input = MLUpdateMemoryInput.builder().updateContent(finalUpdateContent).build();
            MLUpdateMemoryRequest updateRequest = MLUpdateMemoryRequest.builder().memoryContainerId(this.memoryContainerId).memoryType(MemoryType.WORKING).memoryId(messageId).mlUpdateMemoryInput(input).build();
            this.client.execute((ActionType)MLUpdateMemoryAction.INSTANCE, (ActionRequest)updateRequest, ActionListener.wrap(indexResponse -> {
                UpdateResponse updateResponse = new UpdateResponse(indexResponse.getShardInfo(), indexResponse.getShardId(), indexResponse.getId(), indexResponse.getSeqNo(), indexResponse.getPrimaryTerm(), indexResponse.getVersion(), indexResponse.getResult());
                updateListener.onResponse((Object)updateResponse);
            }, e -> {
                log.error("Failed to update memory in memory container", (Throwable)e);
                updateListener.onFailure(e);
            }));
        }, e -> {
            log.error("Failed to get existing memory for update", (Throwable)e);
            updateListener.onFailure(e);
        }));
    }

    public void getMessages(int size, ActionListener<List<Message>> listener) {
        if (Strings.isNullOrEmpty((String)this.memoryContainerId)) {
            listener.onFailure((Exception)new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory"));
            return;
        }
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
        boolQuery.must((QueryBuilder)QueryBuilders.termQuery((String)"namespace.session_id", (String)this.conversationId));
        boolQuery.mustNot((QueryBuilder)QueryBuilders.termQuery((String)"metadata.type", (String)"trace"));
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.query((QueryBuilder)boolQuery);
        searchSourceBuilder.size(size);
        searchSourceBuilder.sort(CREATED_TIME_FIELD, SortOrder.ASC);
        MLSearchMemoriesInput searchInput = MLSearchMemoriesInput.builder().memoryContainerId(this.memoryContainerId).memoryType(MemoryType.WORKING).searchSourceBuilder(searchSourceBuilder).build();
        MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(null).build();
        this.client.execute((ActionType)MLSearchMemoriesAction.INSTANCE, (ActionRequest)request, ActionListener.wrap(searchResponse -> {
            List<Message> interactions = this.parseSearchResponseToInteractions((SearchResponse)searchResponse);
            listener.onResponse(interactions);
        }, e -> {
            log.error("Failed to search memories in memory container", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    private List<Message> parseSearchResponseToInteractions(SearchResponse searchResponse) {
        ArrayList<Message> interactions = new ArrayList<Message>();
        for (SearchHit hit : searchResponse.getHits().getHits()) {
            Map metadata;
            String parentInteractionId;
            Map sourceMap = hit.getSourceAsMap();
            Map structuredData = (Map)sourceMap.get("structured_data_blob");
            if (structuredData == null) continue;
            String input = (String)structuredData.get("input");
            String response = (String)structuredData.get("response");
            Long createdTimeMs = (Long)sourceMap.get(CREATED_TIME_FIELD);
            Long updatedTimeMs = (Long)sourceMap.get("last_updated_time");
            String createTimeStr = (String)structuredData.get("create_time");
            String updatedTimeStr = (String)structuredData.get("updated_time");
            Instant createTime = null;
            Instant updatedTime = null;
            if (createTimeStr != null) {
                try {
                    createTime = Instant.parse(createTimeStr);
                }
                catch (Exception e) {
                    log.warn("Failed to parse create_time from structured_data", (Throwable)e);
                }
            }
            if (updatedTimeStr != null) {
                try {
                    updatedTime = Instant.parse(updatedTimeStr);
                }
                catch (Exception e) {
                    log.warn("Failed to parse updated_time from structured_data", (Throwable)e);
                }
            }
            if (createTime == null && createdTimeMs != null) {
                createTime = Instant.ofEpochMilli(createdTimeMs);
            }
            if (updatedTime == null && updatedTimeMs != null) {
                updatedTime = Instant.ofEpochMilli(updatedTimeMs);
            }
            String string = parentInteractionId = (metadata = (Map)sourceMap.get("metadata")) != null ? (String)metadata.get("parent_message_id") : null;
            if (input == null && response == null) continue;
            Interaction interaction = Interaction.builder().id(hit.getId()).conversationId(this.conversationId).createTime(createTime != null ? createTime : Instant.now()).updatedTime(updatedTime).input(input != null ? input : "").response(response != null ? response : "").origin("agentic_memory").promptTemplate(null).additionalInfo(null).parentInteractionId(parentInteractionId).traceNum(null).build();
            interactions.add((Message)interaction);
        }
        return interactions;
    }

    public void clear() {
        throw new UnsupportedOperationException("clear method is not supported in AgenticConversationMemory");
    }

    public void deleteInteractionAndTrace(String interactionId, ActionListener<Boolean> listener) {
        log.warn("deleteInteractionAndTrace is not fully implemented for AgenticConversationMemory");
        listener.onResponse((Object)false);
    }

    public void getTraces(String parentMessageId, ActionListener<List<Interaction>> listener) {
        if (Strings.isNullOrEmpty((String)this.memoryContainerId)) {
            listener.onFailure((Exception)new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory"));
            return;
        }
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
        boolQuery.must((QueryBuilder)QueryBuilders.termQuery((String)"namespace.session_id", (String)this.conversationId));
        boolQuery.must((QueryBuilder)QueryBuilders.termQuery((String)"metadata.type", (String)"trace"));
        boolQuery.must((QueryBuilder)QueryBuilders.termQuery((String)"metadata.parent_message_id", (String)parentMessageId));
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.query((QueryBuilder)boolQuery);
        searchSourceBuilder.size(1000);
        searchSourceBuilder.sort("message_id", SortOrder.ASC);
        MLSearchMemoriesInput searchInput = MLSearchMemoriesInput.builder().memoryContainerId(this.memoryContainerId).memoryType(MemoryType.WORKING).searchSourceBuilder(searchSourceBuilder).build();
        MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(null).build();
        this.client.execute((ActionType)MLSearchMemoriesAction.INSTANCE, (ActionRequest)request, ActionListener.wrap(searchResponse -> {
            List<Interaction> traces = this.parseSearchResponseToTraces((SearchResponse)searchResponse);
            listener.onResponse(traces);
        }, e -> {
            log.error("Failed to search traces in memory container", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    private List<Interaction> parseSearchResponseToTraces(SearchResponse searchResponse) {
        ArrayList<Interaction> traces = new ArrayList<Interaction>();
        for (SearchHit hit : searchResponse.getHits().getHits()) {
            Instant updatedTime;
            Map sourceMap = hit.getSourceAsMap();
            Map structuredData = (Map)sourceMap.get("structured_data_blob");
            if (structuredData == null) continue;
            String input = (String)structuredData.get("input");
            String response = (String)structuredData.get("response");
            String origin = (String)structuredData.get("origin");
            String parentMessageId = (String)structuredData.get("parent_message_id");
            Integer traceNum = null;
            Object traceNumObj = structuredData.get("trace_number");
            if (traceNumObj instanceof Integer) {
                traceNum = (Integer)traceNumObj;
            } else if (traceNumObj instanceof String) {
                try {
                    traceNum = Integer.parseInt((String)traceNumObj);
                }
                catch (NumberFormatException e) {
                    log.warn("Failed to parse trace_number", (Throwable)e);
                }
            }
            Integer messageId = (Integer)sourceMap.get("message_id");
            if (traceNum == null && messageId != null) {
                traceNum = messageId;
            }
            Long createdTimeMs = (Long)sourceMap.get(CREATED_TIME_FIELD);
            Long updatedTimeMs = (Long)sourceMap.get("last_updated_time");
            Instant createTime = createdTimeMs != null ? Instant.ofEpochMilli(createdTimeMs) : Instant.now();
            Instant instant = updatedTime = updatedTimeMs != null ? Instant.ofEpochMilli(updatedTimeMs) : null;
            if (input == null && response == null) continue;
            Interaction trace = Interaction.builder().id(hit.getId()).conversationId(this.conversationId).createTime(createTime).updatedTime(updatedTime).input(input != null ? input : "").response(response != null ? response : "").origin(origin != null ? origin : "").promptTemplate(null).additionalInfo(null).parentInteractionId(parentMessageId).traceNum(traceNum).build();
            traces.add(trace);
        }
        return traces;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public String getConversationId() {
        return this.conversationId;
    }

    @Generated
    public String getMemoryContainerId() {
        return this.memoryContainerId;
    }

    public static class Factory
    implements Memory.Factory<AgenticConversationMemory> {
        private Client client;

        public void init(Client client) {
            this.client = client;
        }

        public void create(Map<String, Object> map, ActionListener<AgenticConversationMemory> listener) {
            if (map == null || map.isEmpty()) {
                listener.onFailure((Exception)new IllegalArgumentException("Invalid input parameter for creating AgenticConversationMemory"));
                return;
            }
            String memoryId = (String)map.get("memory_id");
            String name = (String)map.get("memory_name");
            String appType = (String)map.get("app_type");
            String memoryContainerId = (String)map.get("memory_container_id");
            this.create(name, memoryId, appType, memoryContainerId, listener);
        }

        public void create(String name, String memoryId, String appType, String memoryContainerId, ActionListener<AgenticConversationMemory> listener) {
            if (Strings.isNullOrEmpty((String)memoryContainerId)) {
                listener.onFailure((Exception)new IllegalArgumentException("Memory container ID is required for AgenticConversationMemory. Please provide 'memory_container_id' in the agent configuration."));
                return;
            }
            if (Strings.isEmpty((CharSequence)memoryId)) {
                this.createSessionInMemoryContainer(name, memoryContainerId, (ActionListener<String>)ActionListener.wrap(sessionId -> {
                    this.create((String)sessionId, memoryContainerId, listener);
                    log.debug("Created session in memory container, session id: {}", sessionId);
                }, e -> {
                    log.error("Failed to create session in memory container", (Throwable)e);
                    listener.onFailure(e);
                }));
            } else {
                this.create(memoryId, memoryContainerId, listener);
            }
        }

        private void createSessionInMemoryContainer(String summary, String memoryContainerId, ActionListener<String> listener) {
            MLCreateSessionInput input = MLCreateSessionInput.builder().memoryContainerId(memoryContainerId).summary(summary).build();
            MLCreateSessionRequest request = MLCreateSessionRequest.builder().mlCreateSessionInput(input).build();
            this.client.execute((ActionType)MLCreateSessionAction.INSTANCE, (ActionRequest)request, ActionListener.wrap(response -> listener.onResponse((Object)response.getSessionId()), e -> {
                log.error("Failed to create session via TransportCreateSessionAction", (Throwable)e);
                listener.onFailure(e);
            }));
        }

        public void create(String memoryId, String memoryContainerId, ActionListener<AgenticConversationMemory> listener) {
            listener.onResponse((Object)new AgenticConversationMemory(this.client, memoryId, memoryContainerId));
        }
    }
}

