/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under both the BSD-style license (found in the
 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
 * in the COPYING file in the root directory of this source tree).
 */
#include "platform.h"   /* Large Files support, SET_BINARY_MODE */
#include "Pzstd.h"
#include "SkippableFrame.h"
#include "utils/FileSystem.h"
#include "utils/Portability.h"
#include "utils/Range.h"
#include "utils/ScopeGuard.h"
#include "utils/ThreadPool.h"
#include "utils/WorkQueue.h"

#include <algorithm>
#include <chrono>
#include <cinttypes>
#include <cstddef>
#include <cstdio>
#include <memory>
#include <string>


namespace pzstd {

namespace {
#ifdef _WIN32
const std::string nullOutput = "nul";
#else
const std::string nullOutput = "/dev/null";
#endif
}

using std::size_t;

static std::uintmax_t fileSizeOrZero(const std::string &file) {
  if (file == "-") {
    return 0;
  }
  std::error_code ec;
  auto size = file_size(file, ec);
  if (ec) {
    size = 0;
  }
  return size;
}

static std::uint64_t handleOneInput(const Options &options,
                             const std::string &inputFile,
                             FILE* inputFd,
                             const std::string &outputFile,
                             FILE* outputFd,
                             SharedState& state) {
  auto inputSize = fileSizeOrZero(inputFile);
  // WorkQueue outlives ThreadPool so in the case of error we are certain
  // we don't accidentally try to call push() on it after it is destroyed
  WorkQueue<std::shared_ptr<BufferWorkQueue>> outs{options.numThreads + 1};
  std::uint64_t bytesRead;
  std::uint64_t bytesWritten;
  {
    // Initialize the (de)compression thread pool with numThreads
    ThreadPool executor(options.numThreads);
    // Run the reader thread on an extra thread
    ThreadPool readExecutor(1);
    if (!options.decompress) {
      // Add a job that reads the input and starts all the compression jobs
      readExecutor.add(
          [&state, &outs, &executor, inputFd, inputSize, &options, &bytesRead] {
            bytesRead = asyncCompressChunks(
                state,
                outs,
                executor,
                inputFd,
                inputSize,
                options.numThreads,
                options.determineParameters());
          });
      // Start writing
      bytesWritten = writeFile(state, outs, outputFd, options.decompress);
    } else {
      // Add a job that reads the input and starts all the decompression jobs
      readExecutor.add([&state, &outs, &executor, inputFd, &bytesRead] {
        bytesRead = asyncDecompressFrames(state, outs, executor, inputFd);
      });
      // Start writing
      bytesWritten = writeFile(state, outs, outputFd, options.decompress);
    }
  }
  if (!state.errorHolder.hasError()) {
    std::string inputFileName = inputFile == "-" ? "stdin" : inputFile;
    std::string outputFileName = outputFile == "-" ? "stdout" : outputFile;
    if (!options.decompress) {
      double ratio = static_cast<double>(bytesWritten) /
                     static_cast<double>(bytesRead + !bytesRead);
      state.log(kLogInfo, "%-20s :%6.2f%%   (%6" PRIu64 " => %6" PRIu64
                   " bytes, %s)\n",
                   inputFileName.c_str(), ratio * 100, bytesRead, bytesWritten,
                   outputFileName.c_str());
    } else {
      state.log(kLogInfo, "%-20s: %" PRIu64 " bytes \n",
                   inputFileName.c_str(),bytesWritten);
    }
  }
  return bytesWritten;
}

static FILE *openInputFile(const std::string &inputFile,
                           ErrorHolder &errorHolder) {
  if (inputFile == "-") {
    SET_BINARY_MODE(stdin);
    return stdin;
  }
  // Check if input file is a directory
  {
    std::error_code ec;
    if (is_directory(inputFile, ec)) {
      errorHolder.setError("Output file is a directory -- ignored");
      return nullptr;
    }
  }
  auto inputFd = std::fopen(inputFile.c_str(), "rb");
  if (!errorHolder.check(inputFd != nullptr, "Failed to open input file")) {
    return nullptr;
  }
  return inputFd;
}

static FILE *openOutputFile(const Options &options,
                            const std::string &outputFile,
                            SharedState& state) {
  if (outputFile == "-") {
    SET_BINARY_MODE(stdout);
    return stdout;
  }
  // Check if the output file exists and then open it
  if (!options.overwrite && outputFile != nullOutput) {
    auto outputFd = std::fopen(outputFile.c_str(), "rb");
    if (outputFd != nullptr) {
      std::fclose(outputFd);
      if (!state.log.logsAt(kLogInfo)) {
        state.errorHolder.setError("Output file exists");
        return nullptr;
      }
      state.log(
          kLogInfo,
          "pzstd: %s already exists; do you wish to overwrite (y/n) ? ",
          outputFile.c_str());
      int c = getchar();
      if (c != 'y' && c != 'Y') {
        state.errorHolder.setError("Not overwritten");
        return nullptr;
      }
    }
  }
  auto outputFd = std::fopen(outputFile.c_str(), "wb");
  if (!state.errorHolder.check(
          outputFd != nullptr, "Failed to open output file")) {
    return nullptr;
  }
  return outputFd;
}

int pzstdMain(const Options &options) {
  int returnCode = 0;
  SharedState state(options);
  for (const auto& input : options.inputFiles) {
    // Setup the shared state
    auto printErrorGuard = makeScopeGuard([&] {
      if (state.errorHolder.hasError()) {
        returnCode = 1;
        state.log(kLogError, "pzstd: %s: %s.\n", input.c_str(),
                  state.errorHolder.getError().c_str());
      }
    });
    // Open the input file
    auto inputFd = openInputFile(input, state.errorHolder);
    if (inputFd == nullptr) {
      continue;
    }
    auto closeInputGuard = makeScopeGuard([&] { std::fclose(inputFd); });
    // Open the output file
    auto outputFile = options.getOutputFile(input);
    if (!state.errorHolder.check(outputFile != "",
                           "Input file does not have extension .zst")) {
      continue;
    }
    auto outputFd = openOutputFile(options, outputFile, state);
    if (outputFd == nullptr) {
      continue;
    }
    auto closeOutputGuard = makeScopeGuard([&] { std::fclose(outputFd); });
    // (de)compress the file
    handleOneInput(options, input, inputFd, outputFile, outputFd, state);
    if (state.errorHolder.hasError()) {
      continue;
    }
    // Delete the input file if necessary
    if (!options.keepSource) {
      // Be sure that we are done and have written everything before we delete
      if (!state.errorHolder.check(std::fclose(inputFd) == 0,
                             "Failed to close input file")) {
        continue;
      }
      closeInputGuard.dismiss();
      if (!state.errorHolder.check(std::fclose(outputFd) == 0,
                             "Failed to close output file")) {
        continue;
      }
      closeOutputGuard.dismiss();
      if (std::remove(input.c_str()) != 0) {
        state.errorHolder.setError("Failed to remove input file");
        continue;
      }
    }
  }
  // Returns 1 if any of the files failed to (de)compress.
  return returnCode;
}

/// Construct a `ZSTD_inBuffer` that points to the data in `buffer`.
static ZSTD_inBuffer makeZstdInBuffer(const Buffer& buffer) {
  return ZSTD_inBuffer{buffer.data(), buffer.size(), 0};
}

/**
 * Advance `buffer` and `inBuffer` by the amount of data read, as indicated by
 * `inBuffer.pos`.
 */
void advance(Buffer& buffer, ZSTD_inBuffer& inBuffer) {
  auto pos = inBuffer.pos;
  inBuffer.src = static_cast<const unsigned char*>(inBuffer.src) + pos;
  inBuffer.size -= pos;
  inBuffer.pos = 0;
  return buffer.advance(pos);
}

/// Construct a `ZSTD_outBuffer` that points to the data in `buffer`.
static ZSTD_outBuffer makeZstdOutBuffer(Buffer& buffer) {
  return ZSTD_outBuffer{buffer.data(), buffer.size(), 0};
}

/**
 * Split `buffer` and advance `outBuffer` by the amount of data written, as
 * indicated by `outBuffer.pos`.
 */
Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) {
  auto pos = outBuffer.pos;
  outBuffer.dst = static_cast<unsigned char*>(outBuffer.dst) + pos;
  outBuffer.size -= pos;
  outBuffer.pos = 0;
  return buffer.splitAt(pos);
}

/**
 * Stream chunks of input from `in`, compress it, and stream it out to `out`.
 *
 * @param state        The shared state
 * @param in           Queue that we `pop()` input buffers from
 * @param out          Queue that we `push()` compressed output buffers to
 * @param maxInputSize An upper bound on the size of the input
 */
static void compress(
    SharedState& state,
    std::shared_ptr<BufferWorkQueue> in,
    std::shared_ptr<BufferWorkQueue> out,
    size_t maxInputSize) {
  auto& errorHolder = state.errorHolder;
  auto guard = makeScopeGuard([&] { out->finish(); });
  // Initialize the CCtx
  auto ctx = state.cStreamPool->get();
  if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) {
    return;
  }
  {
    auto err = ZSTD_CCtx_reset(ctx.get(), ZSTD_reset_session_only);
    if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
      return;
    }
  }

  // Allocate space for the result
  auto outBuffer = Buffer(ZSTD_compressBound(maxInputSize));
  auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
  {
    Buffer inBuffer;
    // Read a buffer in from the input queue
    while (in->pop(inBuffer) && !errorHolder.hasError()) {
      auto zstdInBuffer = makeZstdInBuffer(inBuffer);
      // Compress the whole buffer and send it to the output queue
      while (!inBuffer.empty() && !errorHolder.hasError()) {
        if (!errorHolder.check(
                !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
          return;
        }
        // Compress
        auto err =
            ZSTD_compressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
        if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
          return;
        }
        // Split the compressed data off outBuffer and pass to the output queue
        out->push(split(outBuffer, zstdOutBuffer));
        // Forget about the data we already compressed
        advance(inBuffer, zstdInBuffer);
      }
    }
  }
  // Write the epilog
  size_t bytesLeft;
  do {
    if (!errorHolder.check(
            !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
      return;
    }
    bytesLeft = ZSTD_endStream(ctx.get(), &zstdOutBuffer);
    if (!errorHolder.check(
            !ZSTD_isError(bytesLeft), ZSTD_getErrorName(bytesLeft))) {
      return;
    }
    out->push(split(outBuffer, zstdOutBuffer));
  } while (bytesLeft != 0 && !errorHolder.hasError());
}

/**
 * Calculates how large each independently compressed frame should be.
 *
 * @param size       The size of the source if known, 0 otherwise
 * @param numThreads The number of threads available to run compression jobs on
 * @param params     The zstd parameters to be used for compression
 */
static size_t calculateStep(
    std::uintmax_t size,
    size_t numThreads,
    const ZSTD_parameters &params) {
  (void)size;
  (void)numThreads;
  // Not validated to work correctly for window logs > 23.
  // It will definitely fail if windowLog + 2 is >= 4GB because
  // the skippable frame can only store sizes up to 4GB.
  assert(params.cParams.windowLog <= 23);
  return size_t{1} << (params.cParams.windowLog + 2);
}

namespace {
enum class FileStatus { Continue, Done, Error };
/// Determines the status of the file descriptor `fd`.
FileStatus fileStatus(FILE* fd) {
  if (std::feof(fd)) {
    return FileStatus::Done;
  } else if (std::ferror(fd)) {
    return FileStatus::Error;
  }
  return FileStatus::Continue;
}
} // anonymous namespace

/**
 * Reads `size` data in chunks of `chunkSize` and puts it into `queue`.
 * Will read less if an error or EOF occurs.
 * Returns the status of the file after all of the reads have occurred.
 */
static FileStatus
readData(BufferWorkQueue& queue, size_t chunkSize, size_t size, FILE* fd,
         std::uint64_t *totalBytesRead) {
  Buffer buffer(size);
  while (!buffer.empty()) {
    auto bytesRead =
        std::fread(buffer.data(), 1, std::min(chunkSize, buffer.size()), fd);
    *totalBytesRead += bytesRead;
    queue.push(buffer.splitAt(bytesRead));
    auto status = fileStatus(fd);
    if (status != FileStatus::Continue) {
      return status;
    }
  }
  return FileStatus::Continue;
}

std::uint64_t asyncCompressChunks(
    SharedState& state,
    WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks,
    ThreadPool& executor,
    FILE* fd,
    std::uintmax_t size,
    size_t numThreads,
    ZSTD_parameters params) {
  auto chunksGuard = makeScopeGuard([&] { chunks.finish(); });
  std::uint64_t bytesRead = 0;

  // Break the input up into chunks of size `step` and compress each chunk
  // independently.
  size_t step = calculateStep(size, numThreads, params);
  state.log(kLogDebug, "Chosen frame size: %zu\n", step);
  auto status = FileStatus::Continue;
  while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
    // Make a new input queue that we will put the chunk's input data into.
    auto in = std::make_shared<BufferWorkQueue>();
    auto inGuard = makeScopeGuard([&] { in->finish(); });
    // Make a new output queue that compress will put the compressed data into.
    auto out = std::make_shared<BufferWorkQueue>();
    // Start compression in the thread pool
    executor.add([&state, in, out, step] {
      return compress(
          state, std::move(in), std::move(out), step);
    });
    // Pass the output queue to the writer thread.
    chunks.push(std::move(out));
    state.log(kLogVerbose, "%s\n", "Starting a new frame");
    // Fill the input queue for the compression job we just started
    status = readData(*in, ZSTD_CStreamInSize(), step, fd, &bytesRead);
  }
  state.errorHolder.check(status != FileStatus::Error, "Error reading input");
  return bytesRead;
}

/**
 * Decompress a frame, whose data is streamed into `in`, and stream the output
 * to `out`.
 *
 * @param state        The shared state
 * @param in           Queue that we `pop()` input buffers from. It contains
 *                      exactly one compressed frame.
 * @param out          Queue that we `push()` decompressed output buffers to
 */
static void decompress(
    SharedState& state,
    std::shared_ptr<BufferWorkQueue> in,
    std::shared_ptr<BufferWorkQueue> out) {
  auto& errorHolder = state.errorHolder;
  auto guard = makeScopeGuard([&] { out->finish(); });
  // Initialize the DCtx
  auto ctx = state.dStreamPool->get();
  if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) {
    return;
  }
  {
    auto err = ZSTD_DCtx_reset(ctx.get(), ZSTD_reset_session_only);
    if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
      return;
    }
  }

  const size_t outSize = ZSTD_DStreamOutSize();
  Buffer inBuffer;
  size_t returnCode = 0;
  // Read a buffer in from the input queue
  while (in->pop(inBuffer) && !errorHolder.hasError()) {
    auto zstdInBuffer = makeZstdInBuffer(inBuffer);
    // Decompress the whole buffer and send it to the output queue
    while (!inBuffer.empty() && !errorHolder.hasError()) {
      // Allocate a buffer with at least outSize bytes.
      Buffer outBuffer(outSize);
      auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
      // Decompress
      returnCode =
          ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
      if (!errorHolder.check(
              !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
        return;
      }
      // Pass the buffer with the decompressed data to the output queue
      out->push(split(outBuffer, zstdOutBuffer));
      // Advance past the input we already read
      advance(inBuffer, zstdInBuffer);
      if (returnCode == 0) {
        // The frame is over, prepare to (maybe) start a new frame
        ZSTD_initDStream(ctx.get());
      }
    }
  }
  if (!errorHolder.check(returnCode <= 1, "Incomplete block")) {
    return;
  }
  // We've given ZSTD_decompressStream all of our data, but there may still
  // be data to read.
  while (returnCode == 1) {
    // Allocate a buffer with at least outSize bytes.
    Buffer outBuffer(outSize);
    auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
    // Pass in no input.
    ZSTD_inBuffer zstdInBuffer{nullptr, 0, 0};
    // Decompress
    returnCode =
        ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
    if (!errorHolder.check(
            !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
      return;
    }
    // Pass the buffer with the decompressed data to the output queue
    out->push(split(outBuffer, zstdOutBuffer));
  }
}

std::uint64_t asyncDecompressFrames(
    SharedState& state,
    WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames,
    ThreadPool& executor,
    FILE* fd) {
  auto framesGuard = makeScopeGuard([&] { frames.finish(); });
  std::uint64_t totalBytesRead = 0;

  // Split the source up into its component frames.
  // If we find our recognized skippable frame we know the next frames size
  // which means that we can decompress each standard frame in independently.
  // Otherwise, we will decompress using only one decompression task.
  const size_t chunkSize = ZSTD_DStreamInSize();
  auto status = FileStatus::Continue;
  while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
    // Make a new input queue that we will put the frames's bytes into.
    auto in = std::make_shared<BufferWorkQueue>();
    auto inGuard = makeScopeGuard([&] { in->finish(); });
    // Make a output queue that decompress will put the decompressed data into
    auto out = std::make_shared<BufferWorkQueue>();

    size_t frameSize;
    {
      // Calculate the size of the next frame.
      // frameSize is 0 if the frame info can't be decoded.
      Buffer buffer(SkippableFrame::kSize);
      auto bytesRead = std::fread(buffer.data(), 1, buffer.size(), fd);
      totalBytesRead += bytesRead;
      status = fileStatus(fd);
      if (bytesRead == 0 && status != FileStatus::Continue) {
        break;
      }
      buffer.subtract(buffer.size() - bytesRead);
      frameSize = SkippableFrame::tryRead(buffer.range());
      in->push(std::move(buffer));
    }
    if (frameSize == 0) {
      // We hit a non SkippableFrame, so this will be the last job.
      // Make sure that we don't use too much memory
      in->setMaxSize(64);
      out->setMaxSize(64);
    }
    // Start decompression in the thread pool
    executor.add([&state, in, out] {
      return decompress(state, std::move(in), std::move(out));
    });
    // Pass the output queue to the writer thread
    frames.push(std::move(out));
    if (frameSize == 0) {
      // We hit a non SkippableFrame ==> not compressed by pzstd or corrupted
      // Pass the rest of the source to this decompression task
      state.log(kLogVerbose, "%s\n",
          "Input not in pzstd format, falling back to serial decompression");
      while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
        status = readData(*in, chunkSize, chunkSize, fd, &totalBytesRead);
      }
      break;
    }
    state.log(kLogVerbose, "Decompressing a frame of size %zu", frameSize);
    // Fill the input queue for the decompression job we just started
    status = readData(*in, chunkSize, frameSize, fd, &totalBytesRead);
  }
  state.errorHolder.check(status != FileStatus::Error, "Error reading input");
  return totalBytesRead;
}

/// Write `data` to `fd`, returns true iff success.
static bool writeData(ByteRange data, FILE* fd) {
  while (!data.empty()) {
    data.advance(std::fwrite(data.begin(), 1, data.size(), fd));
    if (std::ferror(fd)) {
      return false;
    }
  }
  return true;
}

std::uint64_t writeFile(
    SharedState& state,
    WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs,
    FILE* outputFd,
    bool decompress) {
  auto& errorHolder = state.errorHolder;
  auto lineClearGuard = makeScopeGuard([&state] {
    state.log.clear(kLogInfo);
  });
  std::uint64_t bytesWritten = 0;
  std::shared_ptr<BufferWorkQueue> out;
  // Grab the output queue for each decompression job (in order).
  while (outs.pop(out)) {
    if (errorHolder.hasError()) {
      continue;
    }
    if (!decompress) {
      // If we are compressing and want to write skippable frames we can't
      // start writing before compression is done because we need to know the
      // compressed size.
      // Wait for the compressed size to be available and write skippable frame
      assert(uint64_t(out->size()) < uint64_t(1) << 32);
      SkippableFrame frame(uint32_t(out->size()));
      if (!writeData(frame.data(), outputFd)) {
        errorHolder.setError("Failed to write output");
        return bytesWritten;
      }
      bytesWritten += frame.kSize;
    }
    // For each chunk of the frame: Pop it from the queue and write it
    Buffer buffer;
    while (out->pop(buffer) && !errorHolder.hasError()) {
      if (!writeData(buffer.range(), outputFd)) {
        errorHolder.setError("Failed to write output");
        return bytesWritten;
      }
      bytesWritten += buffer.size();
      state.log.update(kLogInfo, "Written: %u MB   ",
                static_cast<std::uint32_t>(bytesWritten >> 20));
    }
  }
  return bytesWritten;
}
}