/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.writer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.spark.Aggregator;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.BufferManagerOptions;
import org.apache.spark.shuffle.writer.WriteBufferManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.com.google.common.collect.Lists;
import org.apache.uniffle.com.google.common.collect.Maps;
import org.apache.uniffle.com.google.common.collect.Sets;
import org.apache.uniffle.com.google.common.util.concurrent.Uninterruptibles;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.storage.util.StorageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Function1;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;

public class RssShuffleWriter<K, V, C>
extends ShuffleWriter<K, V> {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleWriter.class);
    private static final String DUMMY_HOST = "dummy_host";
    private static final int DUMMY_PORT = 99999;
    private final String appId;
    private final int shuffleId;
    private WriteBufferManager bufferManager;
    private final String taskId;
    private final int numMaps;
    private final ShuffleDependency<K, V, C> shuffleDependency;
    private final Partitioner partitioner;
    private final RssShuffleManager shuffleManager;
    private final boolean shouldPartition;
    private final long sendCheckTimeout;
    private final long sendCheckInterval;
    private final int bitmapSplitNum;
    private final Map<Integer, Set<Long>> partitionToBlockIds;
    private final ShuffleWriteClient shuffleWriteClient;
    private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
    private final Set<ShuffleServerInfo> shuffleServersForData;
    private final long[] partitionLengths;
    private final Function<String, Boolean> taskFailureCallback;
    private final Set<Long> blockIds = Sets.newConcurrentHashSet();
    protected final long taskAttemptId;
    protected final ShuffleWriteMetrics shuffleWriteMetrics;
    protected final boolean isMemoryShuffleEnabled;
    private final BlockingQueue<Object> finishEventQueue = new LinkedBlockingQueue<Object>();

    @VisibleForTesting
    public RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, WriteBufferManager bufferManager, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssHandle) {
        this(appId, shuffleId, taskId, taskAttemptId, shuffleWriteMetrics, shuffleManager, sparkConf, shuffleWriteClient, rssHandle, (String tid) -> true);
        this.bufferManager = bufferManager;
    }

    private RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback) {
        LOG.warn("RssShuffle start write taskAttemptId data" + taskAttemptId);
        this.shuffleManager = shuffleManager;
        this.appId = appId;
        this.shuffleId = shuffleId;
        this.taskId = taskId;
        this.taskAttemptId = taskAttemptId;
        this.numMaps = rssHandle.getNumMaps();
        this.shuffleWriteMetrics = shuffleWriteMetrics;
        this.shuffleDependency = rssHandle.getDependency();
        this.partitioner = this.shuffleDependency.partitioner();
        this.shouldPartition = this.partitioner.numPartitions() > 1;
        this.sendCheckTimeout = (Long)sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
        this.sendCheckInterval = (Long)sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
        this.bitmapSplitNum = (Integer)sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
        this.partitionToBlockIds = Maps.newHashMap();
        this.shuffleWriteClient = shuffleWriteClient;
        this.shuffleServersForData = rssHandle.getShuffleServersForData();
        this.partitionLengths = new long[this.partitioner.numPartitions()];
        Arrays.fill(this.partitionLengths, 0L);
        this.partitionToServers = rssHandle.getPartitionToServers();
        this.isMemoryShuffleEnabled = this.isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
        this.taskFailureCallback = taskFailureCallback;
    }

    public RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, TaskContext context) {
        this(appId, shuffleId, taskId, taskAttemptId, shuffleWriteMetrics, shuffleManager, sparkConf, shuffleWriteClient, rssHandle, taskFailureCallback);
        WriteBufferManager bufferManager;
        BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
        this.bufferManager = bufferManager = new WriteBufferManager(shuffleId, taskId, taskAttemptId, bufferOptions, rssHandle.getDependency().serializer(), rssHandle.getPartitionToServers(), context.taskMemoryManager(), shuffleWriteMetrics, RssSparkConfig.toRssConf(sparkConf), this::processShuffleBlockInfos);
    }

    private boolean isMemoryShuffleEnabled(String storageType) {
        return StorageType.withMemory(StorageType.valueOf(storageType));
    }

    public void write(Iterator<Product2<K, V>> records) throws IOException {
        try {
            this.writeImpl(records);
        }
        catch (Exception e) {
            this.taskFailureCallback.apply(this.taskId);
            throw e;
        }
    }

    protected void writeImpl(Iterator<Product2<K, V>> records) throws IOException {
        List<ShuffleBlockInfo> shuffleBlockInfos;
        boolean isCombine = this.shuffleDependency.mapSideCombine();
        Function1 createCombiner = null;
        if (isCombine) {
            createCombiner = ((Aggregator)this.shuffleDependency.aggregator().get()).createCombiner();
        }
        while (records.hasNext()) {
            this.checkIfBlocksFailed();
            Product2 record = (Product2)records.next();
            Object key = record._1();
            int partition = this.getPartition(key);
            if (isCombine) {
                Object c = createCombiner.apply(record._2());
                shuffleBlockInfos = this.bufferManager.addRecord(partition, record._1(), c);
            } else {
                shuffleBlockInfos = this.bufferManager.addRecord(partition, record._1(), record._2());
            }
            if (shuffleBlockInfos == null || shuffleBlockInfos.isEmpty()) continue;
            this.processShuffleBlockInfos(shuffleBlockInfos);
        }
        long start = System.currentTimeMillis();
        shuffleBlockInfos = this.bufferManager.clear();
        if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
            this.processShuffleBlockInfos(shuffleBlockInfos);
        }
        long checkStartTs = System.currentTimeMillis();
        this.internalCheckBlockSendResult();
        long commitStartTs = System.currentTimeMillis();
        long checkDuration = commitStartTs - checkStartTs;
        if (!this.isMemoryShuffleEnabled) {
            this.sendCommit();
        }
        long writeDurationMs = this.bufferManager.getWriteTime() + (System.currentTimeMillis() - start);
        this.shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(writeDurationMs));
        LOG.info("Finish write shuffle for appId[" + this.appId + "], shuffleId[" + this.shuffleId + "], taskId[" + this.taskId + "] with write " + writeDurationMs + " ms, include checkSendResult[" + checkDuration + "], commit[" + (System.currentTimeMillis() - commitStartTs) + "], " + this.bufferManager.getManagerCostInfo());
    }

    public long[] getPartitionLengths() {
        return new long[0];
    }

    @VisibleForTesting
    protected List<CompletableFuture<Long>> processShuffleBlockInfos(List<ShuffleBlockInfo> shuffleBlockInfoList) {
        if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) {
            shuffleBlockInfoList.forEach(sbi -> {
                long blockId = sbi.getBlockId();
                this.blockIds.add(blockId);
                int partitionId = sbi.getPartitionId();
                this.partitionToBlockIds.computeIfAbsent(partitionId, k -> Sets.newHashSet()).add(blockId);
                int n = partitionId;
                this.partitionLengths[n] = this.partitionLengths[n] + (long)sbi.getLength();
            });
            return this.postBlockEvent(shuffleBlockInfoList);
        }
        return Collections.emptyList();
    }

    protected List<CompletableFuture<Long>> postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
        ArrayList<CompletableFuture<Long>> futures = new ArrayList<CompletableFuture<Long>>();
        for (AddBlockEvent event : this.bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
            event.addCallback(() -> {
                boolean ret = this.finishEventQueue.add(new Object());
                if (!ret) {
                    LOG.error("Add event " + event + " to finishEventQueue fail");
                }
            });
            futures.add(this.shuffleManager.sendData(event));
        }
        return futures;
    }

    protected void internalCheckBlockSendResult() {
        this.checkBlockSendResult(this.blockIds);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @VisibleForTesting
    protected void checkBlockSendResult(Set<Long> blockIds) {
        boolean interrupted = false;
        try {
            long remainingMs = this.sendCheckTimeout;
            long end = System.currentTimeMillis() + remainingMs;
            while (true) {
                try {
                    Object event;
                    do {
                        this.finishEventQueue.clear();
                        this.checkIfBlocksFailed();
                        Set<Long> successBlockIds = this.shuffleManager.getSuccessBlockIds(this.taskId);
                        blockIds.removeAll(successBlockIds);
                    } while (!blockIds.isEmpty() && (!this.finishEventQueue.isEmpty() || (event = this.finishEventQueue.poll(remainingMs = Math.max(end - System.currentTimeMillis(), 0L), TimeUnit.MILLISECONDS)) != null));
                }
                catch (InterruptedException e) {
                    interrupted = true;
                    continue;
                }
                break;
            }
            if (!blockIds.isEmpty()) {
                String errorMsg = "Timeout: Task[" + this.taskId + "] failed because " + blockIds.size() + " blocks can't be sent to shuffle server in " + this.sendCheckTimeout + " ms.";
                LOG.error(errorMsg);
                throw new RssException(errorMsg);
            }
        }
        finally {
            if (interrupted) {
                Thread.currentThread().interrupt();
            }
        }
    }

    private void checkIfBlocksFailed() {
        Set<Long> failedBlockIds = this.shuffleManager.getFailedBlockIds(this.taskId);
        if (!failedBlockIds.isEmpty()) {
            String errorMsg = "Send failed: Task[" + this.taskId + "] failed because " + failedBlockIds.size() + " blocks can't be sent to shuffle server.";
            LOG.error(errorMsg);
            throw new RssException(errorMsg);
        }
    }

    @VisibleForTesting
    protected void sendCommit() {
        ExecutorService executor = Executors.newSingleThreadExecutor();
        Future<Boolean> future = executor.submit(() -> this.shuffleWriteClient.sendCommit(this.shuffleServersForData, this.appId, this.shuffleId, this.numMaps));
        int maxWait = 5000;
        int currentWait = 200;
        long start = System.currentTimeMillis();
        while (!future.isDone()) {
            LOG.info("Wait commit to shuffle server for task[" + this.taskAttemptId + "] cost " + (System.currentTimeMillis() - start) + " ms");
            Uninterruptibles.sleepUninterruptibly(currentWait, TimeUnit.MILLISECONDS);
            currentWait = Math.min(currentWait * 2, maxWait);
        }
        try {
            if (!future.get().booleanValue()) {
                throw new RssException("Failed to commit task to shuffle server");
            }
        }
        catch (InterruptedException ie) {
            LOG.warn("Ignore the InterruptedException which should be caused by internal killed");
        }
        catch (Exception e) {
            throw new RssException("Exception happened when get commit status", e);
        }
        finally {
            executor.shutdown();
        }
    }

    @VisibleForTesting
    protected <T> int getPartition(T key) {
        int result = 0;
        if (this.shouldPartition) {
            result = this.partitioner.getPartition(key);
        }
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Option<MapStatus> stop(boolean success) {
        try {
            if (success) {
                HashMap<Integer, List<Long>> ptb = Maps.newHashMap();
                for (Map.Entry<Integer, Set<Long>> entry : this.partitionToBlockIds.entrySet()) {
                    ptb.put(entry.getKey(), Lists.newArrayList((Iterable)entry.getValue()));
                }
                long start = System.currentTimeMillis();
                this.shuffleWriteClient.reportShuffleResult(this.partitionToServers, this.appId, this.shuffleId, this.taskAttemptId, ptb, this.bitmapSplitNum);
                LOG.info("Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms", new Object[]{this.taskAttemptId, this.bitmapSplitNum, System.currentTimeMillis() - start});
                BlockManagerId blockManagerId = BlockManagerId.apply((String)(this.appId + "_" + this.taskId), (String)DUMMY_HOST, (int)99999, (Option)Option.apply((Object)Long.toString(this.taskAttemptId)));
                MapStatus mapStatus = MapStatus.apply((BlockManagerId)blockManagerId, (long[])this.partitionLengths, (long)this.taskAttemptId);
                Option option = Option.apply((Object)mapStatus);
                return option;
            }
            Option option = Option.empty();
            return option;
        }
        finally {
            if (this.bufferManager != null) {
                this.bufferManager.freeAllMemory();
            }
            if (this.shuffleManager != null) {
                this.shuffleManager.clearTaskMeta(this.taskId);
            }
        }
    }

    @VisibleForTesting
    Map<Integer, Set<Long>> getPartitionToBlockIds() {
        return this.partitionToBlockIds;
    }

    @VisibleForTesting
    public WriteBufferManager getBufferManager() {
        return this.bufferManager;
    }
}

