/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/relax/type.h
 * \brief Relax Types.
 */
#ifndef TVM_RELAX_TYPE_H_
#define TVM_RELAX_TYPE_H_

#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
#include <tvm/ir/type.h>
#include <tvm/tir/expr.h>

#include <string>

namespace tvm {
namespace relax {

/*! \brief Indicates the number of dimensions of a tensor is unknown at compile time. */
static constexpr int kUnknownNDim = -1;

class ShapeTypeNode : public TypeNode {
 public:
  /*! \brief size of the shape. */
  int ndim;

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ShapeTypeNode>().def_ro("ndim", &ShapeTypeNode::ndim);
  }

  bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
    return equal(ndim, other->ndim);
  }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(ndim); }

  static constexpr const char* _type_key = "relax.ShapeType";
  TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode);
};

class ShapeType : public Type {
 public:
  // TODO(relax-team): remove the default value later.
  TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span());

  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode);
};

/*!
 * \brief Dynamic version of TensorType
 *
 * Use relax::TensorStructInfo for more detailed (possibly dynamic) shape constrains
 */
class TensorTypeNode : public TypeNode {
 public:
  /*!
   * \brief The number of dimensions of the tensor, use -1 to denote tensor with unknown number of
   * dimensions.
   */
  int ndim;
  /*! \brief The content data type, use void to denote the dtype is unknown. */
  DataType dtype;

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<TensorTypeNode>()
        .def_ro("ndim", &TensorTypeNode::ndim)
        .def_ro("dtype", &TensorTypeNode::dtype);
  }

  bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
    return equal(ndim, other->ndim) && equal(dtype, other->dtype);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(ndim);
    hash_reduce(dtype);
  }

  inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; }

  inline bool IsUnknownDtype() const { return dtype.is_void(); }

  static constexpr const char* _type_key = "relax.DynTensorType";
  TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, TypeNode);
};

/*!
 * \brief Managed reference to TensorTypeNode.
 * \sa TensorTypeNode.
 */
class TensorType : public Type {
 public:
  /*!
   * \brief Constructor.
   * \param ndim The number of dimensions of the tensor.
   * \param dtype The runtime dtype of the tensor's elements.
   * \param span The span.
   */
  TVM_DLL TensorType(int ndim, DataType dtype, Span span = Span());

  /*!
   * \brief Create a TensorType with unknown ndim.
   */
  TVM_DLL static TensorType CreateUnknownNDim(DataType dtype, Span span = Span());

  TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};

using TensorTypeNode = TensorTypeNode;
using TensorType = TensorType;

class ObjectTypeNode : public TypeNode {
 public:
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ObjectTypeNode>();
  }

  bool SEqualReduce(const ObjectTypeNode* other, SEqualReducer equal) const { return true; }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }

  static constexpr const char* _type_key = "relax.ObjectType";
  TVM_DECLARE_FINAL_OBJECT_INFO(ObjectTypeNode, TypeNode);
};

class ObjectType : public Type {
 public:
  TVM_DLL ObjectType(Span span = Span());

  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectType, Type, ObjectTypeNode);
};

class PackedFuncTypeNode : public TypeNode {
 public:
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<PackedFuncTypeNode>();
  }

  bool SEqualReduce(const PackedFuncTypeNode* other, SEqualReducer equal) const { return true; }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }

  static constexpr const char* _type_key = "relax.PackedFuncType";
  TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncTypeNode, TypeNode);
};

class PackedFuncType : public Type {
 public:
  TVM_DLL PackedFuncType(Span span = Span());

  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PackedFuncType, Type, PackedFuncTypeNode);
};

}  // namespace relax
}  // namespace tvm
#endif  // TVM_RELAX_TYPE_H_
