
#include <flashinfer/attention/decode_mla_cute_sm80.cuh>
#include <flashinfer/attention/scheduler.cuh>

#include "mla_config.inc"
#include "tvm/ffi/container/array.h"
#include "tvm_ffi_utils.h"

using namespace flashinfer;

using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_buffer,
                                                  ffi::Tensor int_workspace_buffer,
                                                  ffi::Tensor page_locked_int_workspace_buffer,
                                                  ffi::Tensor indptr, int64_t batch_size,
                                                  int64_t num_qo_heads, int64_t page_size,
                                                  bool enable_cuda_graph) {
  size_t float_workspace_size_in_bytes =
      float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
  size_t int_workspace_size_in_bytes =
      int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer);

  DecodePlanInfo plan_info;
  cudaSetDevice(float_workspace_buffer->device.device_id);
  const cudaStream_t stream = get_stream(float_workspace_buffer->device);

  auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80<
      HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN, AttentionVariant, Params>;
  cudaError_t status =
      DecodePlan<HEAD_DIM_CKV, flashinfer::PosEncodingMode::kNone, AttentionVariant, Params>(
          static_cast<void*>(float_workspace_buffer->data), float_workspace_size_in_bytes,
          static_cast<void*>(int_workspace_buffer->data),
          static_cast<void*>(page_locked_int_workspace_buffer->data), int_workspace_size_in_bytes,
          plan_info, static_cast<IdType*>(indptr->data), batch_size, num_qo_heads, page_size,
          enable_cuda_graph, /*stream=*/stream, work_estimation_func);

  TVM_FFI_ICHECK(status == cudaSuccess)
      << "BatchDecodeWithPagedKVCachePlanMLA failed with error " << cudaGetErrorString(status);

  return Array(plan_info.ToVector());
}

void BatchDecodeWithPagedKVCacheRunMLA(
    ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer,
    Array<int64_t> plan_info_vec, ffi::Tensor q_nope, ffi::Tensor q_pe, ffi::Tensor paged_ckv_cache,
    ffi::Tensor paged_kpe_cache, ffi::Tensor paged_kv_indptr, ffi::Tensor paged_kv_indices,
    ffi::Tensor paged_kv_last_page_len, ffi::Tensor o, double sm_scale, int64_t window_left,
    double logits_soft_cap, double rope_scale, double rope_theta, Optional<ffi::Tensor> maybe_lse,
    bool enable_pdl  // fake placeholder, sm80 does not support pdl
) {
  DecodePlanInfo plan_info;
  plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));

  int64_t batch_size = q_nope->shape[0];
  int64_t num_qo_heads = q_nope->shape[1];
  int64_t page_size = paged_ckv_cache->shape[1];

  if (maybe_lse.has_value()) {
    const auto& lse = maybe_lse.value();
    TVM_FFI_ICHECK_EQ(lse->shape[0], batch_size);
    TVM_FFI_ICHECK_EQ(lse->shape[1], num_qo_heads);
  }

  TVM_FFI_ICHECK_GE(logits_soft_cap, 0.f) << "logits_soft_cap must be non-negative";

  void* float_buffer = static_cast<void*>(float_workspace_buffer->data);
  void* int_buffer = static_cast<void*>(int_workspace_buffer->data);

  paged_kv_mla_t<DTypeKV, IdType> paged_kv(
      page_size, HEAD_DIM_CKV, HEAD_DIM_KPE, batch_size,
      static_cast<DTypeKV*>(paged_ckv_cache->data), paged_ckv_cache.strides().data(),
      static_cast<DTypeKV*>(paged_kpe_cache->data), paged_kpe_cache.strides().data(),
      static_cast<IdType*>(paged_kv_indices->data), static_cast<IdType*>(paged_kv_indptr->data),
      static_cast<IdType*>(paged_kv_last_page_len->data));
  Params params(
      static_cast<DTypeQ*>(q_nope->data), static_cast<DTypeQ*>(q_pe->data),
      /*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o->data),
      /*lse=*/(maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value()->data) : nullptr),
      num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);

  DTypeO* tmp_v = nullptr;
  float* tmp_s = nullptr;
  params.request_indices =
      GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset);
  params.kv_tile_indices =
      GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset);
  params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset);
  params.kv_chunk_size_ptr =
      GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_chunk_size_ptr_offset);
  if (plan_info.split_kv) {
    tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset);
    tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset);
    if (plan_info.enable_cuda_graph) {
      params.block_valid_mask =
          GetPtrFromBaseOffset<bool>(int_buffer, plan_info.block_valid_mask_offset);
    }
  }
  params.padded_batch_size = plan_info.padded_batch_size;

  cudaSetDevice(paged_ckv_cache->device.device_id);
  const cudaStream_t stream = get_stream(paged_ckv_cache->device);
  cudaError_t status = BatchDecodeWithPagedKVCacheDispatchedMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE,
                                                                        QO_TILE_LEN, Params>(
      params, tmp_v, tmp_s, /*stream=*/stream);
  TVM_FFI_ICHECK(status == cudaSuccess)
      << "BatchDecodeWithPagedKVCache failed with error " << cudaGetErrorString(status);
}
