#include "mdp.hpp"
#include "utils.hpp"
#include <queue>
#include <omp.h>

using namespace std;

DAG MDP::buildCondensation() const {
  DAG dag(static_cast<uint32_t>(size(SCC)));
  for(uint32_t s_id = 0; s_id < n_states; ++s_id) {
    const uint32_t current_scc = stateToSCC[s_id];
    for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
        a_id < a_end; ++a_id) {
      for(uint32_t effect_id = actions[a_id], effect_end = actions[a_id + 1];
          effect_id < effect_end; ++effect_id) {
        const uint32_t s_prime_id = neighbors[effect_id];
        const uint32_t neighbor_scc = stateToSCC[s_prime_id];
        if(neighbor_scc == current_scc)
          continue;
        const auto [_, inserted] = dag.revNeighbors[neighbor_scc].insert(current_scc);
        if(inserted)
          ++dag.numIncomingArcs[current_scc];
      }
    }
  }
  return dag;
}

void MDP::pcTVI() {
  LOG(INFO) << "pcTVI Phase 1: compute the SCCs of the MDP";
  TimeVar tStart = timeNow();
  tarjan();
  const long time_p1 = duration(timeNow() - tStart);

  LOG(INFO) << "pcTVI Phase 2: build a DAG of SCCs (tasks)";
  tStart = timeNow();
  DAG dag = buildCondensation();
  const long time_p2 = duration(timeNow() - tStart);

  tStart = timeNow();
  LOG(INFO) << "pcTVI Phase 3: solve in parallel independent chains of SCCs";
  queue<uint32_t> sccQueue;
  #pragma omp parallel shared(sccQueue)
  #pragma omp single
  {
    LOG(INFO) << "OpenMP created " << omp_get_num_threads() << " threads";
    sccQueue.push(0);
    while(!sccQueue.empty()) {
      const uint32_t scc = sccQueue.front();
      sccQueue.pop();
      for(const uint32_t neighbor : dag.revNeighbors[scc]) {
        if(--dag.numIncomingArcs[neighbor] == 0) {
          #pragma omp task
          {
            partialVI(SCC[neighbor]);

            #pragma omp critical
            sccQueue.push(neighbor);
          }
        }
      }
      if(sccQueue.empty()) {
        #pragma omp taskwait
      }
    }
  }
  const long time_p3 = duration(timeNow() - tStart);

  LOG(INFO) << "Running time of pcTVI (Tarjan): " << time_p1 << " ms";
  LOG(INFO) << "Running time of pcTVI (Build DAG): " << time_p2 << " ms";
  LOG(INFO) << "Running time of pcTVI (Solve SCCs in parallel): " << time_p3 << " ms";
}
