/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.prediction;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class TransportPredictionTaskAction
extends HandledTransportAction<ActionRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportPredictionTaskAction.class);
    private final MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
    private final TransportService transportService;
    private final MLModelCacheHelper modelCacheHelper;
    private final Client client;
    private final ClusterService clusterService;
    private final NamedXContentRegistry xContentRegistry;
    private final MLModelManager mlModelManager;
    private final ModelAccessControlHelper modelAccessControlHelper;

    @Inject
    public TransportPredictionTaskAction(TransportService transportService, ActionFilters actionFilters, MLPredictTaskRunner mlPredictTaskRunner, MLModelCacheHelper modelCacheHelper, ClusterService clusterService, Client client, NamedXContentRegistry xContentRegistry, MLModelManager mlModelManager, ModelAccessControlHelper modelAccessControlHelper) {
        super("cluster:admin/opensearch/ml/predict", transportService, actionFilters, MLPredictionTaskRequest::new);
        this.mlPredictTaskRunner = mlPredictTaskRunner;
        this.transportService = transportService;
        this.modelCacheHelper = modelCacheHelper;
        this.clusterService = clusterService;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlModelManager = mlModelManager;
        this.modelAccessControlHelper = modelAccessControlHelper;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
        MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest((ActionRequest)request);
        String modelId = mlPredictionTaskRequest.getModelId();
        User user = mlPredictionTaskRequest.getUser();
        if (user == null) {
            user = RestActionUtils.getUserContext(this.client);
            mlPredictionTaskRequest.setUser(user);
        }
        User userInfo = user;
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.mlModelManager.getModel(modelId, (ActionListener<MLModel>)ActionListener.wrap(mlModel -> {
                FunctionName functionName = mlModel.getAlgorithm();
                this.modelAccessControlHelper.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), this.client, (ActionListener<Boolean>)ActionListener.wrap(access -> {
                    if (!access.booleanValue()) {
                        listener.onFailure((Exception)new MLValidationException("User Doesn't have privilege to perform this operation on this model"));
                    } else {
                        String requestId = mlPredictionTaskRequest.getRequestID();
                        log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
                        long startTime = System.nanoTime();
                        this.mlPredictTaskRunner.run(functionName, mlPredictionTaskRequest, this.transportService, (ActionListener<MLTaskResponse>)ActionListener.runAfter((ActionListener)listener, () -> {
                            long endTime = System.nanoTime();
                            double durationInMs = (double)(endTime - startTime) / 1000000.0;
                            this.modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
                            log.debug("completed predict request " + requestId + " for model " + modelId);
                        }));
                    }
                }, e -> {
                    log.error("Failed to Validate Access for ModelId " + modelId, (Throwable)e);
                    listener.onFailure(e);
                }));
            }, e -> {
                log.error("Failed to find model " + modelId, (Throwable)e);
                listener.onFailure(e);
            }));
        }
    }
}

