/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.adaptiveallocations;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.telemetry.metric.DoubleWithAttributes;
import org.elasticsearch.telemetry.metric.LongWithAttributes;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScaler;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

public class AdaptiveAllocationsScalerService
implements ClusterStateListener {
    private static final int DEFAULT_TIME_INTERVAL_SECONDS = 10;
    private static final long SCALE_UP_COOLDOWN_TIME_MILLIS = TimeValue.timeValueMinutes((long)5L).getMillis();
    private static final long SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS = TimeValue.timeValueMinutes((long)15L).getSeconds();
    private static final Logger logger = LogManager.getLogger(AdaptiveAllocationsScalerService.class);
    private final int timeIntervalSeconds;
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final InferenceAuditor inferenceAuditor;
    private final MeterRegistry meterRegistry;
    private final Metrics metrics;
    private final boolean isNlpEnabled;
    private final Map<String, Map<String, Stats>> lastInferenceStatsByDeploymentAndNode;
    private Long lastInferenceStatsTimestampMillis;
    private final Map<String, AdaptiveAllocationsScaler> scalers;
    private final Map<String, Long> lastScaleUpTimesMillis;
    private volatile Scheduler.Cancellable cancellable;
    private final AtomicBoolean busy;
    private final long scaleToZeroAfterNoRequestsSeconds;
    private final long scaleUpCooldownTimeMillis;
    private final Set<String> deploymentIdsWithInFlightScaleFromZeroRequests = new ConcurrentSkipListSet<String>();
    private final Map<String, String> lastWarningMessages = new ConcurrentHashMap<String, String>();

    public AdaptiveAllocationsScalerService(ThreadPool threadPool, ClusterService clusterService, Client client, InferenceAuditor inferenceAuditor, MeterRegistry meterRegistry, boolean isNlpEnabled) {
        this(threadPool, clusterService, client, inferenceAuditor, meterRegistry, isNlpEnabled, 10, SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS, SCALE_UP_COOLDOWN_TIME_MILLIS);
    }

    AdaptiveAllocationsScalerService(ThreadPool threadPool, ClusterService clusterService, Client client, InferenceAuditor inferenceAuditor, MeterRegistry meterRegistry, boolean isNlpEnabled, int timeIntervalSeconds, long scaleToZeroAfterNoRequestsSeconds, long scaleUpCooldownTimeMillis) {
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.inferenceAuditor = inferenceAuditor;
        this.meterRegistry = meterRegistry;
        this.isNlpEnabled = isNlpEnabled;
        this.timeIntervalSeconds = timeIntervalSeconds;
        this.scaleToZeroAfterNoRequestsSeconds = scaleToZeroAfterNoRequestsSeconds;
        this.scaleUpCooldownTimeMillis = scaleUpCooldownTimeMillis;
        this.lastInferenceStatsByDeploymentAndNode = new HashMap<String, Map<String, Stats>>();
        this.lastInferenceStatsTimestampMillis = null;
        this.lastScaleUpTimesMillis = new HashMap<String, Long>();
        this.scalers = new HashMap<String, AdaptiveAllocationsScaler>();
        this.metrics = new Metrics();
        this.busy = new AtomicBoolean(false);
    }

    public synchronized void start() {
        this.updateAutoscalers(this.clusterService.state());
        this.metrics.init();
        this.clusterService.addListener((ClusterStateListener)this);
        if (!this.scalers.isEmpty()) {
            this.startScheduling();
        }
    }

    public synchronized void stop() {
        this.clusterService.removeListener((ClusterStateListener)this);
        this.stopScheduling();
        this.scalers.clear();
    }

    public void clusterChanged(ClusterChangedEvent event) {
        if (!event.metadataChanged()) {
            return;
        }
        this.updateAutoscalers(event.state());
        if (!this.scalers.isEmpty()) {
            this.startScheduling();
        } else {
            this.stopScheduling();
        }
    }

    private synchronized void updateAutoscalers(ClusterState state) {
        if (!this.isNlpEnabled) {
            return;
        }
        HashSet<String> deploymentIds = new HashSet<String>();
        TrainedModelAssignmentMetadata assignments = TrainedModelAssignmentMetadata.fromState((ClusterState)state);
        for (TrainedModelAssignment assignment : assignments.allAssignments().values()) {
            deploymentIds.add(assignment.getDeploymentId());
            if (assignment.getAdaptiveAllocationsSettings() != null && assignment.getAdaptiveAllocationsSettings().getEnabled() == Boolean.TRUE) {
                AdaptiveAllocationsScaler adaptiveAllocationsScaler = this.scalers.computeIfAbsent(assignment.getDeploymentId(), key -> new AdaptiveAllocationsScaler(assignment.getDeploymentId(), assignment.totalTargetAllocations(), this.scaleToZeroAfterNoRequestsSeconds));
                adaptiveAllocationsScaler.setMinMaxNumberOfAllocations(assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations(), assignment.getAdaptiveAllocationsSettings().getMaxNumberOfAllocations());
                continue;
            }
            this.scalers.remove(assignment.getDeploymentId());
            this.lastInferenceStatsByDeploymentAndNode.remove(assignment.getDeploymentId());
        }
        this.scalers.keySet().removeIf(key -> !deploymentIds.contains(key));
    }

    private synchronized void startScheduling() {
        block3: {
            if (this.cancellable == null) {
                logger.debug("Starting ML adaptive allocations scaler");
                try {
                    this.cancellable = this.threadPool.scheduleWithFixedDelay(this::trigger, TimeValue.timeValueSeconds((long)this.timeIntervalSeconds), (Executor)this.threadPool.generic());
                }
                catch (EsRejectedExecutionException e) {
                    if (e.isExecutorShutdown()) break block3;
                    throw e;
                }
            }
        }
    }

    private synchronized void stopScheduling() {
        if (this.cancellable != null && !this.cancellable.isCancelled()) {
            logger.debug("Stopping ML adaptive allocations scaler");
            this.cancellable.cancel();
            this.cancellable = null;
        }
    }

    private void trigger() {
        if (this.busy.getAndSet(true)) {
            logger.debug("Skipping inference adaptive allocations scaling, because it's still busy.");
            return;
        }
        ActionListener listener = ActionListener.runAfter((ActionListener)ActionListener.wrap(this::processDeploymentStats, e -> logger.warn("Error in inference adaptive allocations scaling", (Throwable)e)), () -> this.busy.set(false));
        this.getDeploymentStats((ActionListener<GetDeploymentStatsAction.Response>)listener);
    }

    private void getDeploymentStats(ActionListener<GetDeploymentStatsAction.Response> processDeploymentStats) {
        String deploymentIds = String.join((CharSequence)",", this.scalers.keySet());
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)GetDeploymentStatsAction.INSTANCE, (ActionRequest)new GetDeploymentStatsAction.Request(deploymentIds), processDeploymentStats);
    }

    private void processDeploymentStats(GetDeploymentStatsAction.Response statsResponse) {
        String deploymentId;
        long now = System.currentTimeMillis();
        Double statsTimeInterval = this.lastInferenceStatsTimestampMillis != null ? Double.valueOf((double)(now - this.lastInferenceStatsTimestampMillis) / 1000.0) : null;
        this.lastInferenceStatsTimestampMillis = now;
        HashMap<String, Stats> recentStatsByDeployment = new HashMap<String, Stats>();
        HashMap<String, Integer> numberOfAllocations = new HashMap<String, Integer>();
        HashSet<String> hasRecentObservedScaleUp = new HashSet<String>();
        for (AssignmentStats assignmentStats : statsResponse.getStats().results()) {
            deploymentId = assignmentStats.getDeploymentId();
            numberOfAllocations.put(deploymentId, assignmentStats.getNumberOfAllocations());
            Map deploymentStats = this.lastInferenceStatsByDeploymentAndNode.computeIfAbsent(deploymentId, key -> new HashMap());
            for (AssignmentStats.NodeStats nodeStats : assignmentStats.getNodeStats()) {
                String nodeId = nodeStats.getNode().getId();
                Stats lastStats = (Stats)deploymentStats.get(nodeId);
                Stats nextStats = new Stats(nodeStats.getInferenceCount().orElse(0L), nodeStats.getPendingCount() == null ? 0L : (long)nodeStats.getPendingCount().intValue(), nodeStats.getErrorCount() + nodeStats.getTimeoutCount() + nodeStats.getRejectedExecutionCount(), nodeStats.getAvgInferenceTime().orElse(0.0) / 1000.0);
                deploymentStats.put(nodeId, nextStats);
                if (lastStats != null) {
                    Stats recentStats = nextStats.sub(lastStats);
                    recentStatsByDeployment.compute(assignmentStats.getDeploymentId(), (key, value) -> value == null ? recentStats : value.add(recentStats));
                }
                if (nodeStats.getRoutingState() != null && nodeStats.getRoutingState().getState() == RoutingState.STARTING) {
                    hasRecentObservedScaleUp.add(deploymentId);
                }
                if (nodeStats.getStartTime() == null || now >= nodeStats.getStartTime().toEpochMilli() + this.scaleUpCooldownTimeMillis) continue;
                hasRecentObservedScaleUp.add(deploymentId);
            }
        }
        if (statsTimeInterval == null) {
            return;
        }
        for (Map.Entry entry : recentStatsByDeployment.entrySet()) {
            boolean hasRecentScaleUp;
            deploymentId = (String)entry.getKey();
            Stats stats = (Stats)entry.getValue();
            AdaptiveAllocationsScaler adaptiveAllocationsScaler = this.scalers.get(deploymentId);
            adaptiveAllocationsScaler.process(stats, statsTimeInterval, (Integer)numberOfAllocations.get(deploymentId));
            Integer newNumberOfAllocations = adaptiveAllocationsScaler.scale();
            if (newNumberOfAllocations == null) continue;
            Long lastScaleUpTimeMillis = this.lastScaleUpTimesMillis.get(deploymentId);
            boolean bl = hasRecentScaleUp = lastScaleUpTimeMillis != null && now < lastScaleUpTimeMillis + this.scaleUpCooldownTimeMillis;
            if (newNumberOfAllocations < (Integer)numberOfAllocations.get(deploymentId) && (hasRecentScaleUp || hasRecentObservedScaleUp.contains(deploymentId))) {
                logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", (Object)deploymentId);
                continue;
            }
            if (newNumberOfAllocations > (Integer)numberOfAllocations.get(deploymentId)) {
                this.lastScaleUpTimesMillis.put(deploymentId, now);
            }
            this.updateNumberOfAllocations(deploymentId, newNumberOfAllocations, this.updateAssigmentListener(deploymentId, newNumberOfAllocations));
        }
    }

    public boolean maybeStartAllocation(TrainedModelAssignment assignment) {
        if (assignment.getAdaptiveAllocationsSettings() != null && assignment.getAdaptiveAllocationsSettings().getEnabled() == Boolean.TRUE && (assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations() == null || assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations() == 0)) {
            AdaptiveAllocationsScaler scaler;
            if (!this.deploymentIdsWithInFlightScaleFromZeroRequests.contains(assignment.getDeploymentId())) {
                this.lastScaleUpTimesMillis.put(assignment.getDeploymentId(), System.currentTimeMillis());
                ActionListener<CreateTrainedModelAssignmentAction.Response> updateListener = this.updateAssigmentListener(assignment.getDeploymentId(), 1);
                ActionListener cleanUpListener = ActionListener.runAfter(updateListener, () -> this.deploymentIdsWithInFlightScaleFromZeroRequests.remove(assignment.getDeploymentId()));
                this.deploymentIdsWithInFlightScaleFromZeroRequests.add(assignment.getDeploymentId());
                this.updateNumberOfAllocations(assignment.getDeploymentId(), 1, (ActionListener<CreateTrainedModelAssignmentAction.Response>)cleanUpListener);
            }
            if ((scaler = this.scalers.get(assignment.getDeploymentId())) != null) {
                scaler.resetTimeWithoutRequests();
            }
            return true;
        }
        return false;
    }

    private void updateNumberOfAllocations(String deploymentId, int numberOfAllocations, ActionListener<CreateTrainedModelAssignmentAction.Response> listener) {
        UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
        updateRequest.setNumberOfAllocations(Integer.valueOf(numberOfAllocations));
        updateRequest.setIsInternal(true);
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)UpdateTrainedModelDeploymentAction.INSTANCE, (ActionRequest)updateRequest, listener);
    }

    private ActionListener<CreateTrainedModelAssignmentAction.Response> updateAssigmentListener(String deploymentId, int numberOfAllocations) {
        return ActionListener.wrap(updateResponse -> {
            this.lastWarningMessages.remove(deploymentId);
            logger.info("adaptive allocations scaler: scaled [{}] to [{}] allocations.", (Object)deploymentId, (Object)numberOfAllocations);
            this.threadPool.executor("ml_utility").execute(() -> this.inferenceAuditor.info(deploymentId, Strings.format((String)"adaptive allocations scaler: scaled [%s] to [%s] allocations.", (Object[])new Object[]{deploymentId, numberOfAllocations})));
        }, e -> {
            Level level = e.getMessage().equals(this.lastWarningMessages.get(deploymentId)) ? Level.DEBUG : Level.WARN;
            this.lastWarningMessages.put(deploymentId, e.getMessage());
            logger.atLevel(level).withThrowable((Throwable)e).log("adaptive allocations scaler: scaling [{}] to [{}] allocations failed.", (Object)deploymentId, (Object)numberOfAllocations);
            if (level == Level.WARN) {
                this.threadPool.executor("ml_utility").execute(() -> this.inferenceAuditor.warning(deploymentId, Strings.format((String)"adaptive allocations scaler: scaling [%s] to [%s] allocations failed.", (Object[])new Object[]{deploymentId, numberOfAllocations})));
            }
        });
    }

    private class Metrics {
        private final List<AutoCloseable> metrics = new ArrayList<AutoCloseable>();

        Metrics() {
        }

        void init() {
            if (!this.metrics.isEmpty()) {
                return;
            }
            this.metrics.add((AutoCloseable)AdaptiveAllocationsScalerService.this.meterRegistry.registerLongsGauge("es.ml.trained_models.adaptive_allocations.actual_number_of_allocations.current", "the actual number of allocations", "", () -> this.observeLong(AdaptiveAllocationsScaler::getNumberOfAllocations)));
            this.metrics.add((AutoCloseable)AdaptiveAllocationsScalerService.this.meterRegistry.registerLongsGauge("es.ml.trained_models.adaptive_allocations.needed_number_of_allocations.current", "the number of allocations needed according to the adaptive allocations scaler", "", () -> this.observeLong(AdaptiveAllocationsScaler::getNeededNumberOfAllocations)));
            this.metrics.add((AutoCloseable)AdaptiveAllocationsScalerService.this.meterRegistry.registerDoublesGauge("es.ml.trained_models.adaptive_allocations.measured_request_rate.current", "the request rate reported by the stats API", "1/s", () -> this.observeDouble(AdaptiveAllocationsScaler::getLastMeasuredRequestRate)));
            this.metrics.add((AutoCloseable)AdaptiveAllocationsScalerService.this.meterRegistry.registerDoublesGauge("es.ml.trained_models.adaptive_allocations.estimated_request_rate.current", "the request rate estimated by the adaptive allocations scaler", "1/s", () -> this.observeDouble(AdaptiveAllocationsScaler::getRequestRateEstimate)));
            this.metrics.add((AutoCloseable)AdaptiveAllocationsScalerService.this.meterRegistry.registerDoublesGauge("es.ml.trained_models.adaptive_allocations.measured_inference_time.current", "the inference time reported by the stats API", "s", () -> this.observeDouble(AdaptiveAllocationsScaler::getLastMeasuredInferenceTime)));
            this.metrics.add((AutoCloseable)AdaptiveAllocationsScalerService.this.meterRegistry.registerDoublesGauge("es.ml.trained_models.adaptive_allocations.estimated_inference_time.current", "the inference time estimated by the adaptive allocations scaler", "s", () -> this.observeDouble(AdaptiveAllocationsScaler::getInferenceTimeEstimate)));
            this.metrics.add((AutoCloseable)AdaptiveAllocationsScalerService.this.meterRegistry.registerLongsGauge("es.ml.trained_models.adaptive_allocations.queue_size.current", "the queue size reported by the stats API", "s", () -> this.observeLong(AdaptiveAllocationsScaler::getLastMeasuredQueueSize)));
        }

        Collection<LongWithAttributes> observeLong(Function<AdaptiveAllocationsScaler, Long> getValue) {
            ArrayList<LongWithAttributes> observations = new ArrayList<LongWithAttributes>();
            for (AdaptiveAllocationsScaler scaler : AdaptiveAllocationsScalerService.this.scalers.values()) {
                Long value = getValue.apply(scaler);
                if (value == null) continue;
                observations.add(new LongWithAttributes(value.longValue(), Map.of("deployment_id", scaler.getDeploymentId())));
            }
            return observations;
        }

        Collection<DoubleWithAttributes> observeDouble(Function<AdaptiveAllocationsScaler, Double> getValue) {
            ArrayList<DoubleWithAttributes> observations = new ArrayList<DoubleWithAttributes>();
            for (AdaptiveAllocationsScaler scaler : AdaptiveAllocationsScalerService.this.scalers.values()) {
                Double value = getValue.apply(scaler);
                if (value == null) continue;
                observations.add(new DoubleWithAttributes(value.doubleValue(), Map.of("deployment_id", scaler.getDeploymentId())));
            }
            return observations;
        }
    }

    record Stats(long successCount, long pendingCount, long failedCount, double inferenceTime) {
        long requestCount() {
            return this.successCount + this.pendingCount + this.failedCount;
        }

        double totalInferenceTime() {
            return (double)this.successCount * this.inferenceTime;
        }

        Stats add(Stats value) {
            long newSuccessCount = this.successCount + value.successCount;
            long newPendingCount = this.pendingCount + value.pendingCount;
            long newFailedCount = this.failedCount + value.failedCount;
            double newInferenceTime = newSuccessCount > 0L ? (this.totalInferenceTime() + value.totalInferenceTime()) / (double)newSuccessCount : Double.NaN;
            return new Stats(newSuccessCount, newPendingCount, newFailedCount, newInferenceTime);
        }

        Stats sub(Stats value) {
            long newSuccessCount = Math.max(0L, this.successCount - value.successCount);
            long newPendingCount = Math.max(0L, this.pendingCount - value.pendingCount);
            long newFailedCount = Math.max(0L, this.failedCount - value.failedCount);
            double newInferenceTime = newSuccessCount > 0L ? (this.totalInferenceTime() - value.totalInferenceTime()) / (double)newSuccessCount : Double.NaN;
            return new Stats(newSuccessCount, newPendingCount, newFailedCount, newInferenceTime);
        }
    }
}

