#include <queue>

#include "mdp.hpp"
#include "utils.hpp"

using namespace std;

// Implementation constants
constexpr uint32_t L3BlockSize = 64; // bytes
constexpr uint32_t L3BlockPerSet = 12;
constexpr uint32_t L3NumSets = 8192;
constexpr uint32_t L3Size = L3BlockSize * L3BlockPerSet * L3NumSets; // 6MB

// Make states in partition i have id in [n_i..n_{i+1}[ for every i
void MDP::reorderStates(ReorderType type) {
  TimeVar tStart = timeNow();
  LOG(INFO) << "e[i]TVI Phase 2.1: compute new ids order";
  sccStartIds = avector<uint32_t>(size(SCC) + 1);
  newIds = avector<uint32_t>(n_states, numeric_limits<uint32_t>::max());
  oldIds = avector<uint32_t>(n_states, numeric_limits<uint32_t>::max());

  // Step 1: Assign a new id to every state
  if(type == ReorderType::Extra)
    findNewIdsExtra();
  else // type == ReorderType::ExtraIntra
    findNewIdsExtraIntra();
  const long time_p31 = duration(timeNow() - tStart);

  // Step 2: Rebuild MDP representation using new ids
  tStart = timeNow();
  LOG(INFO) << "e[i]TVI Phase 2.2: Rebuild CSR";
  rebuildCSR();
  const long time_p32 = duration(timeNow() - tStart);

  LOG(INFO) << "Running time of Phase 2.1: " << time_p31 << " ms";
  LOG(INFO) << "Running time of Phase 2.2: " << time_p32 << " ms";
}

void MDP::findNewIdsExtra() {
  for(uint32_t k = 0; k < size(SCC); ++k) {
    const auto& scc = SCC[k];
    for(const uint32_t old_id : scc) {
      oldIds[currentId] = old_id;
      newIds[old_id] = currentId++;
    }
    sccStartIds[k + 1] = currentId; // Save start of next (end of current) SCC
  }
}

void MDP::findNewIdsExtraIntra() {
  findSCCsBorderStates();
  visited.assign(n_states, false);

  for(uint32_t k = 0; k < size(SCC); ++k) {
    // Reversed BFS from outward SCC border states
    queue<uint32_t> q;
    const auto& outwardBorderStates = SCCsOutwardBorderStates[k];
    for(const auto s_id : outwardBorderStates) {
      q.push(s_id);
      visited[s_id] = true;
    }

    while(!q.empty()) {
      uint32_t s_id = q.front();
      q.pop();
      for(const auto& [s_pred, _1, _2, _3] : predecessors[s_id]) {
        if(!visited[s_pred] && stateToSCC[s_pred] == k) {
          visited[s_pred] = true;
          q.push(s_pred);
        }
      }

      oldIds[currentId] = s_id;
      newIds[s_id] = currentId++;
    }

    sccStartIds[k + 1] = currentId; // Save start of next (end of current) SCC
  }
}

void MDP::newIdsSCCDFSPostorder(uint32_t s_id) {
  // Postorder DFS starting from given state
  const uint32_t scc_id = stateToSCC[s_id];
  visited[s_id] = true;

  for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
      a_id < a_end; ++a_id) {
    for(uint32_t e_id = actions[a_id], e_end = actions[a_id + 1];
        e_id < e_end; ++e_id) {
      const uint32_t neighbor_id = neighbors[e_id];
      if(!visited[neighbor_id] && stateToSCC[neighbor_id] == scc_id) {
        visited[neighbor_id] = true;
        newIdsSCCDFSPostorder(neighbor_id);
      }
    }
  }

  oldIds[currentId] = s_id;
  newIds[s_id] = currentId++;
}

void MDP::rebuildCSR() {
  start_id = newIds[start_id];
  goal_id = 0; // will always be true after reordering
  n_states_before = n_states;
  n_states = currentId;

  // We need temporary new vectors to store the MDP
  avector<float> nV(n_states);
  avector<uint32_t> nStates(n_states + 1);
  avector<float> nCosts(size(costs));
  avector<uint32_t> nActions(size(actions));
  avector<uint32_t> nNeighbors(size(neighbors));
  avector<float> nProbabilities(size(probabilities));

  uint32_t a_new_id = 0;
  uint32_t e_new_id = 0;
  for(uint32_t new_id = 0; new_id < n_states; ++new_id) {
    const uint32_t old_id = oldIds[new_id];
    uint32_t n_actions = states[old_id + 1] - states[old_id];

    for(uint32_t a_old_id = states[old_id], a_end = states[old_id + 1];
        a_old_id < a_end; ++a_old_id) {

      const uint32_t n_effects = actions[a_old_id + 1] - actions[a_old_id];
      for(uint32_t e_old_id = actions[a_old_id], e_end = actions[a_old_id + 1];
          e_old_id < e_end; ++e_old_id) {
        const uint32_t neighbor_new_id = newIds[neighbors[e_old_id]];
        if(neighbor_new_id == numeric_limits<uint32_t>::max()) [[unlikely]] {
          --n_actions; // the action leads to a dead-end so we remove it
          goto next_action;
        }

        nNeighbors[e_new_id] = neighbor_new_id;
        nProbabilities[e_new_id] = probabilities[e_old_id];
        ++e_new_id;
      }

      nActions[a_new_id + 1] = nActions[a_new_id] + n_effects;
      nCosts[a_new_id] = costs[a_old_id];
      ++a_new_id;
next_action:;
    }

    nStates[new_id + 1] = nStates[new_id] + n_actions;
    nV[new_id] = V[old_id];
  }

  // Keep new vectors and throw away old ones
  V             = move(nV);
  states        = move(nStates);
  costs         = move(nCosts);
  actions       = move(nActions);
  neighbors     = move(nNeighbors);
  probabilities = move(nProbabilities);
}

void MDP::findSCCsBorderStates() {
  SCCsOutwardBorderStates = avector<unordered_set<uint32_t>>(size(SCC));

  for(uint32_t s_id = 0; s_id < n_states; ++s_id) {
    const uint32_t current_scc = stateToSCC[s_id];
    for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
        a_id < a_end; ++a_id) {
      for(uint32_t effect_id = actions[a_id], effect_end = actions[a_id + 1];
          effect_id < effect_end; ++effect_id) {
        const uint32_t s_prime_id = neighbors[effect_id];
        const uint32_t neighbor_scc = stateToSCC[s_prime_id];
        if(neighbor_scc == current_scc)
          continue;
        SCCsOutwardBorderStates[current_scc].insert(s_id);
      }
    }
  }

  // Add an arbitrary state to SCCsOutwardBorderStates if empty
  for(uint32_t k = 0; k < size(SCC); ++k)
    if(SCCsOutwardBorderStates[k].empty()) [[unlikely]]
      SCCsOutwardBorderStates[k].insert(SCC[k][0]);
}

void MDP::reorderedTVI(ReorderType type) {
  LOG(INFO) << "e[i]TVI Phase 1: compute the SCCs of the MDP";
  TimeVar tStart = timeNow();
  tarjan();
  const long time_p1 = duration(timeNow() - tStart);

  LOG(INFO) << "e[i]TVI Phase 2: reorder every states";
  tStart = timeNow();
  reorderStates(type);
  const long time_p2 = duration(timeNow() - tStart);

  tStart = timeNow();
  LOG(INFO) << "e[i]TVI Phase 3: solve each SCC in reverse topological order";
  const uint32_t numSCCs = static_cast<uint32_t>(size(sccStartIds)) - 1;

  // start = 1 because i = 0 correspond to the goal SCC
  for(uint32_t i = 1, i_stop = numSCCs; i < i_stop; ++i)
    eTVIPartial(sccStartIds[i], sccStartIds[i + 1]);
  const long time_p3 = duration(timeNow() - tStart);

  LOG(INFO) << "Running time of e[i]TVI (Tarjan): " << time_p1 << " ms";
  LOG(INFO) << "Running time of e[i]TVI (Reorder): " << time_p2 << " ms";
  LOG(INFO) << "Running time of e[i]TVI (solve SCCs): " << time_p3 << " ms";
  if(FLAGS_benchmark)
    cout << time_p1 << "\t" << time_p2 << "\t" << time_p3 << "\t";
}

void MDP::eTVI() {
  reorderedTVI(ReorderType::Extra);
}

void MDP::eiTVI() {
  reorderedTVI(ReorderType::IntraExtra);
}

void MDP::eTVIPartial(uint32_t init_id, uint32_t end_id) {
  const uint32_t n_elements = end_id - init_id;
  VLOG(1) << "eTVIPartial started with " << n_elements << " states";

  // Small optimization when the subset has only 1 state
  if(n_elements == 1) { // the partition is only { init_id }
    V[init_id] = greedyBestValue(init_id);
    return;
  }

  uint32_t n_sweeps = 0;
  do {
    ++n_sweeps;
    bellmanError = 0.0f;

    for(uint32_t s_id = init_id; s_id < end_id; ++s_id) {
      const float oldV = V[s_id];
      V[s_id] = greedyBestValue(s_id);
      const float stateResidual = abs(V[s_id] - oldV);
      bellmanError = max(bellmanError, stateResidual);
    }

    VLOG_EVERY_N(1, 100) << "Current eTVIPartial residual: " << bellmanError;
  } while(bellmanError > VI_Epsilon);

  VLOG(1) << "Number of sweeps in the eTVIPartial: " << n_sweeps;
  n_backups += (n_sweeps * n_elements);
}
