// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/utils/dump_tensor.h"
#include <iomanip>
#include <mutex>
#include <thread>
#include <iostream>
#include "core/framework/print_tensor_utils.h"
#include "contrib_ops/cpu/utils/debug_macros.h"
#include "core/platform/env_var_utils.h"

namespace onnxruntime {
namespace contrib {

#if DUMP_CPU_TENSOR_LEVEL > 0

// Environment variable to enable/disable dumping
constexpr const char* kEnableCpuTensorDumper = "ORT_ENABLE_CPU_DUMP";

// Environment variable to enable/disable dumping thread id
constexpr const char* kDumpThreadId = "ORT_DUMP_THREAD_ID";

// To avoid dumping at the same time from multiple threads
static std::mutex s_mutex;

static bool s_output_thread_id = false;

template <typename T>
void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1) {
  std::unique_lock<std::mutex> lock(s_mutex);

  if (s_output_thread_id)
    std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl;

  if (nullptr != name) {
    std::cout << std::string(name) << std::endl;
  }

  if (onnxruntime::utils::kDefaultSnippetThreshold < static_cast<int64_t>(dim0 * dim1)) {
    onnxruntime::utils::PrintCpuTensorSnippet<T>(tensor, dim0, dim1, onnxruntime::utils::kDefaultSnippetEdgeItems);
  } else {
    onnxruntime::utils::PrintCpuTensorFull<T>(tensor, dim0, dim1);
  }
}

template <typename T>
void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) {
  std::unique_lock<std::mutex> lock(s_mutex);

  if (s_output_thread_id)
    std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl;

  if (nullptr != name) {
    std::cout << std::string(name) << std::endl;
  }

  if (onnxruntime::utils::kDefaultSnippetThreshold < static_cast<int64_t>(dim0 * dim1 * dim2)) {
    onnxruntime::utils::PrintCpuTensorSnippet<T>(tensor, dim0, dim1, dim2, onnxruntime::utils::kDefaultSnippetEdgeItems);
  } else {
    onnxruntime::utils::PrintCpuTensorFull<T>(tensor, dim0, dim1, dim2);
  }
}

void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) {
  MLDataType dataType = tensor.DataType();
  if (dataType == DataTypeImpl::GetType<float>()) {
    DumpCpuTensor<float>(name, tensor.Data<float>(), dim0, dim1, dim2);
  } else if (dataType == DataTypeImpl::GetType<MLFloat16>()) {
    DumpCpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1, dim2);
  } else if (dataType == DataTypeImpl::GetType<int32_t>()) {
    DumpCpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1, dim2);
  } else if (dataType == DataTypeImpl::GetType<int64_t>()) {
    DumpCpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1, dim2);
  } else {
    assert(0);
  }
}

void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) {
  MLDataType dataType = tensor.DataType();
  if (dataType == DataTypeImpl::GetType<float>()) {
    DumpCpuTensor<float>(name, tensor.Data<float>(), dim0, dim1);
  } else if (dataType == DataTypeImpl::GetType<MLFloat16>()) {
    DumpCpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1);
  } else if (dataType == DataTypeImpl::GetType<int32_t>()) {
    DumpCpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1);
  } else if (dataType == DataTypeImpl::GetType<int64_t>()) {
    DumpCpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1);
  } else {
    assert(0);
  }
}

void DumpCpuTensor(const char* name, const Tensor& tensor) {
  const auto& shape = tensor.Shape();

  if (nullptr != name) {
    std::cout << std::string(name) << std::endl;
  }
  std::cout << "Shape:" << shape << std::endl;

  size_t num_dims = shape.NumDimensions();
  if (num_dims >= 3) {
    int dim0 = static_cast<int>(shape.SizeToDimension(num_dims - 2));
    int dim1 = static_cast<int>(shape[num_dims - 2]);
    int dim2 = static_cast<int>(shape[num_dims - 1]);
    DumpCpuTensor(nullptr, tensor, dim0, dim1, dim2);
    return;
  }

  auto num_items = shape.Size();
  size_t num_rows = 1;
  if (num_dims > 1) {
    num_rows = static_cast<size_t>(shape[0]);
  }
  size_t row_size = num_items / num_rows;
  DumpCpuTensor(nullptr, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
}

CpuTensorConsoleDumper::CpuTensorConsoleDumper() {
  is_enabled_ = ParseEnvironmentVariableWithDefault<int>(kEnableCpuTensorDumper, 1) != 0;
  s_output_thread_id = ParseEnvironmentVariableWithDefault<int>(kDumpThreadId, 0) != 0;
}

void CpuTensorConsoleDumper::Print(const std::string& value) const {
  if (!is_enabled_)
    return;

  std::unique_lock<std::mutex> lock(s_mutex);
  if (s_output_thread_id)
    std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl;
  std::cout << value << std::endl;
}

void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<float>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<MLFloat16>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<size_t>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<int64_t>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<int32_t>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<float>(name, tensor, dim0, dim1, dim2);
}

void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<MLFloat16>(name, tensor, dim0, dim1, dim2);
}

void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<int64_t>(name, tensor, dim0, dim1, dim2);
}

void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<int32_t>(name, tensor, dim0, dim1, dim2);
}

void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<float>(name, tensor, dim0 * dim1, dim2, dim3);
}

void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<MLFloat16>(name, tensor, dim0 * dim1, dim2, dim3);
}

void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<int64_t>(name, tensor, dim0 * dim1, dim2, dim3);
}

void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor<int32_t>(name, tensor, dim0 * dim1, dim2, dim3);
}

void CpuTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const {
  if (!is_enabled_)
    return;
  DumpCpuTensor(name, tensor);
}

void CpuTensorConsoleDumper::Print(const char* name, const OrtValue& value) const {
  const Tensor& tensor = value.Get<Tensor>();
  Print(name, tensor);
}

void CpuTensorConsoleDumper::Print(const char* name, int index, bool end_line) const {
  if (!is_enabled_)
    return;

  std::unique_lock<std::mutex> lock(s_mutex);
  std::cout << std::string(name) << "[" << index << "]";

  if (end_line) {
    std::cout << std::endl;
  }
}

void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, bool end_line) const {
  if (!is_enabled_)
    return;

  std::unique_lock<std::mutex> lock(s_mutex);
  std::cout << std::string(name) << "=" << value;

  if (end_line) {
    std::cout << std::endl;
  }
}

void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span<const int64_t>& dims) const {
  PrintTensorByDims<CpuTensorConsoleDumper, int32_t>(this, name, tensor, dims);
}

void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span<const int64_t>& dims) const {
  PrintTensorByDims<CpuTensorConsoleDumper, int64_t>(this, name, tensor, dims);
}

void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span<const int64_t>& dims) const {
  PrintTensorByDims<CpuTensorConsoleDumper, float>(this, name, tensor, dims);
}

void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span<const int64_t>& dims) const {
  PrintTensorByDims<CpuTensorConsoleDumper, MLFloat16>(this, name, tensor, dims);
}

#else

CpuTensorConsoleDumper::CpuTensorConsoleDumper() {
}

void CpuTensorConsoleDumper::Print(const std::string&) const {
}

void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const float*, int, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const Tensor&) const {
}

void CpuTensorConsoleDumper::Print(const char*, const OrtValue&) const {
}

void CpuTensorConsoleDumper::Print(const char*, int, bool) const {
}

void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span<const int64_t>&) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span<const int64_t>&) const {
}

void CpuTensorConsoleDumper::Print(const char*, const float*, gsl::span<const int64_t>&) const {
}

void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span<const int64_t>&) const {
}
#endif

}  // namespace contrib
}  // namespace onnxruntime
