// !!! This is a file automatically generated by hipify!!!
#pragma once

#include <c10/core/ScalarType.h>

#include <hip/hip_runtime.h>
#include <hip/library_types.h>

namespace at::cuda {

template <typename scalar_t>
hipDataType getCudaDataType() {
  static_assert(false && sizeof(scalar_t), "Cannot convert type to hipDataType.");
  return {};
}

template<> inline hipDataType getCudaDataType<at::Half>() {
  return HIP_R_16F;
}
template<> inline hipDataType getCudaDataType<float>() {
  return HIP_R_32F;
}
template<> inline hipDataType getCudaDataType<double>() {
  return HIP_R_64F;
}
template<> inline hipDataType getCudaDataType<c10::complex<c10::Half>>() {
  return HIP_C_16F;
}
template<> inline hipDataType getCudaDataType<c10::complex<float>>() {
  return HIP_C_32F;
}
template<> inline hipDataType getCudaDataType<c10::complex<double>>() {
  return HIP_C_64F;
}

template<> inline hipDataType getCudaDataType<uint8_t>() {
  return HIP_R_8U;
}
template<> inline hipDataType getCudaDataType<int8_t>() {
  return HIP_R_8I;
}
template<> inline hipDataType getCudaDataType<int>() {
  return HIP_R_32I;
}

template<> inline hipDataType getCudaDataType<int16_t>() {
  return HIP_R_16I;
}
template<> inline hipDataType getCudaDataType<int64_t>() {
  return HIP_R_64I;
}
template<> inline hipDataType getCudaDataType<at::BFloat16>() {
  return HIP_R_16BF;
}

inline hipDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) {
  switch (scalar_type) {
    case c10::ScalarType::Byte:
      return HIP_R_8U;
    case c10::ScalarType::Char:
      return HIP_R_8I;
    case c10::ScalarType::Int:
      return HIP_R_32I;
    case c10::ScalarType::Half:
      return HIP_R_16F;
    case c10::ScalarType::Float:
      return HIP_R_32F;
    case c10::ScalarType::Double:
      return HIP_R_64F;
    case c10::ScalarType::ComplexHalf:
      return HIP_C_16F;
    case c10::ScalarType::ComplexFloat:
      return HIP_C_32F;
    case c10::ScalarType::ComplexDouble:
      return HIP_C_64F;
    case c10::ScalarType::Short:
      return HIP_R_16I;
    case c10::ScalarType::Long:
      return HIP_R_64I;
    case c10::ScalarType::BFloat16:
      return HIP_R_16BF;
#if !defined(USE_ROCM) || ROCM_VERSION >= 60300
    case c10::ScalarType::Float8_e4m3fn:
      return HIP_R_8F_E4M3;
    case c10::ScalarType::Float8_e5m2:
      return HIP_R_8F_E5M2;
#endif
#if defined(USE_ROCM)
    case c10::ScalarType::Float8_e4m3fnuz:
      return HIP_R_8F_E4M3_FNUZ;
    case c10::ScalarType::Float8_e5m2fnuz:
      return HIP_R_8F_E5M2_FNUZ;
#endif
#if (defined(TORCH_HIP_VERSION) && TORCH_HIP_VERSION >= 12080)
    case c10::ScalarType::Float4_e2m1fn_x2:
      return CUDA_R_4F_E2M1;
#endif
    default:
      TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to hipDataType.")
  }
}

} // namespace at::cuda
