#include "mdp.hpp"
#include "utils.hpp"
#include <algorithm>
#include <iomanip>
#include <stack>
#include <queue>
#include <set>
#include <unordered_set>
#include <map>

using namespace std;

// This affect tightness of upper bound in FTVI
constexpr float UpperBoundInitValue = 1'000'000.0f; // Dai/Jain: 9999.9f

void MDP::solve(const string& solverName) {
  LOG(INFO) << "Using epsilon = " << VI_Epsilon;
  LOG(INFO) << "Solving MDP using " << solverName;

  // Find V*
  if(solverName == "VI")
    VI();
  else if(solverName == "LAOstar")
    LAOstar(start_id);
  else if(solverName == "ILAOstar")
    ILAOstar(start_id);
  else if(solverName == "RTDP")
    RTDP(start_id);
  else if(solverName == "LRTDP")
    LRTDP(start_id);
  else if(solverName == "BRTDP")
    BRTDP(start_id);
  else if(solverName == "TVI")
    TVI();
  else if(solverName == "FTVI")
    FTVI(start_id);
  else if(solverName == "eTVI")
    eTVI();
  else if(solverName == "eiTVI")
    eiTVI();
  else if(solverName == "pcTVI")
    pcTVI();
  else
    LOG(FATAL) << "\"" << solverName << "\" is an unsupported solver";

  printVerboseInfo();
}

void MDP::findAndEvaluateGreedyPolicy() {
  for(uint32_t s = 0; s < n_states; ++s)
    Pi[s] = greedyBestAction(s);

  LOG(INFO) << "Number of eliminated actions: " << getNumEliminatedActions();
  LOG(INFO) << "Policy simulated value: " << evaluatePolicy(start_id);
  LOG(INFO) << "Policy computed value: " << V[start_id];
}

void MDP::computeHeuristic(const string& heuristicName) {
  if(heuristicName == "NONE") {
    // V is initialized to 0 in operator>> (dead-ends initialized to infty)
    LOG(INFO) << "Using no heuristic function (h = 0)";
  } else if(heuristicName == "H_MIN") {
    LOG(INFO) << "Computing the h_min heuristic";
    lowerBoundHMin();
  } else
    LOG(FATAL) << "Unsupported heuristic";

  CHECK(V[start_id] != numeric_limits<float>::infinity())
    << "start has no finite-cost policy to goal";
}

void MDP::markReachableStates(uint32_t s_0) {
  stack<uint32_t> sStack;
  vector<bool> reachable(n_states, false);
  reachable[s_0] = true;
  sStack.push(s_0);

  while(!sStack.empty()) {
    const uint32_t s_id = sStack.top();
    sStack.pop();

    // Add neighbors of s_id to the stack if not already visited
    for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
        a_id < a_end; ++a_id) {
      for(uint32_t effect_id = actions[a_id], effect_end = actions[a_id + 1];
          effect_id < effect_end; ++effect_id) {
        const uint32_t neighbor_id = neighbors[effect_id];
        if(reachable[neighbor_id])
          continue;
        reachable[neighbor_id] = true;
        sStack.push(neighbor_id);
      }
    }
  }

  // set V[s] = infinity for every state 's' not reachable from s_0
  for(uint32_t s = 0; s < n_states; ++s)
    if(!reachable[s])
      V[s] = numeric_limits<float>::infinity();

  LOG(INFO) << "Number of reachable states: "
            << count(begin(reachable), end(reachable), true);
  CHECK(reachable[goal_id]) << "goal is not reachable from start";
}

void MDP::printVerboseInfo() const {
  printStatesActionsStats();
  printPartitionsInfo();

  LOG(INFO) << "Number of bellman backups: " << n_backups;
  if(FLAGS_benchmark)
    cout << n_backups << "\t";
}

void MDP::printPolicy(ostream& os) const {
  constexpr uint32_t Padding = 15;
  bool reordered = !sccStartIds.empty();
  os << "s_id : " << left << setw(Padding) << "action_id" << "V(s_id)\n";
  for(uint32_t i = 0; i < (reordered ? n_states_before : n_states); ++i) {
    // compute the id in case states were reordered
    uint32_t id = reordered ? newIds[i] : i;
    if(id == numeric_limits<uint32_t>::max())
      continue; // the state was pruned during reordering

    os << right << setw(4) << i << left << " : ";
    if(V[id] == numeric_limits<float>::infinity())
      os << setw(Padding) << "Dead-end" << V[id] << '\n';
    else if(!isTerminal(id))
      os << setw(Padding) << Pi[id] << V[id] << '\n';
    else // id == goal_id
      os << setw(Padding) << "Goal" << V[id] << '\n';
  }
  os << flush;
}

pair<float, float> MDP::getExpectedAndSimulatedCost() const {
  return {V[start_id], evaluatePolicy(start_id)};
}

void MDP::VI() {
  avector<uint32_t> states_id;
  states_id.reserve(n_states);
  for(uint32_t s_id = 0; s_id < n_states; ++s_id)
    if(V[s_id] != numeric_limits<float>::infinity() && s_id != goal_id)
      states_id.push_back(s_id);

  partialVI(states_id);
}

void MDP::partialVI(const avector<uint32_t>& subset, float epsilon) {
  const uint32_t n_elements = static_cast<uint32_t>(size(subset));
  VLOG(1) << "partialVI started with " << n_elements << " states";
  if(subset[0] == goal_id) [[unlikely]] {
    VLOG(1) << "goal state; skipping";
    return;
  }

  uint32_t n_sweeps = 0;
  do {
    ++n_sweeps;
    bellmanError = 0.0f;

    for(uint32_t s_id : subset) {
      const float oldV = V[s_id];
      V[s_id] = greedyBestValue(s_id);
      const float stateResidual = abs(V[s_id] - oldV);
      bellmanError = max(bellmanError, stateResidual);
    }

    VLOG_EVERY_N(1, 1) << "Current partialVI residual: " << bellmanError;
  } while(bellmanError > epsilon);

  VLOG(1) << "Number of sweeps in the partialVI: " << n_sweeps;
  n_backups += (n_sweeps * n_elements);
}

void MDP::LAOstar(uint32_t s_0) {
  unordered_set<uint32_t> F; // Fringe states
  avector<bool> FContained(n_states, false);
  avector<bool> IContained(n_states, false); // Interior states
  IContained[goal_id] = true;

  avector<unordered_set<uint32_t>> predecessorsInPolicy(n_states);

  F.emplace(s_0);
  FContained[s_0] = true;

  while(!F.empty()) {
    uint32_t s_id = n_states;

    for(const uint32_t value: F) {
      if(value == s_0 || !predecessorsInPolicy[value].empty()) {
        s_id = value;
        F.erase(value);
        FContained[s_id] = false;
        IContained[s_id] = true;
        break;
      }
    }

    if(s_id == n_states)
      break;

    // Case when action = Pi[s_id] => adding s to the predecessors to compute
    const uint32_t s_action = Pi[s_id];
    for(uint32_t e_id = actions[s_action], e_end = actions[s_action + 1];
        e_id < e_end; ++e_id) {
      const uint32_t neighbor_id = neighbors[e_id];
      if(!FContained[neighbor_id] && !IContained[neighbor_id]) {
        F.emplace(neighbor_id);
        FContained[neighbor_id] = true;
      }

      predecessorsInPolicy[neighbor_id].emplace(s_id);
    }

    // Case when action != Pi[s_id]
    for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
        a_id < a_end; ++a_id) {
      if(a_id != Pi[s_id]) [[likely]] {
        for(uint32_t e_id = actions[a_id], e_end = actions[a_id + 1];
            e_id < e_end; ++e_id) {
          const uint32_t neighbor_id = neighbors[e_id];
          if(!FContained[neighbor_id] && !IContained[neighbor_id]) {
            F.emplace(neighbor_id);
            FContained[neighbor_id] = true;
          }
        }
      }
    }

    unordered_set<uint32_t> ZSet;
    ZSet.insert(s_id);

    queue<uint32_t> awaitingState;
    awaitingState.push(s_id);

    while(!awaitingState.empty()) {
      const uint32_t currentState = awaitingState.front();
      awaitingState.pop();

      for(const auto& pred_id : predecessorsInPolicy[currentState]) {
        if(IContained[pred_id] || FContained[pred_id]) {
          if(ZSet.count(pred_id) == 0) {
            ZSet.emplace(pred_id);
            awaitingState.emplace(pred_id);
          }
        }
      }
    }

    avector<uint32_t> ZVector(begin(ZSet), end(ZSet));
    partialVI(ZVector);

    for(const uint32_t currentState : ZVector) {
      const uint32_t a_id_old = Pi[currentState];
      Pi[currentState] = greedyBestAction(currentState);
      const uint32_t a_id_new = Pi[currentState];

      if(a_id_new != a_id_old) {
        for(uint32_t k = actions[a_id_old], k_end = actions[a_id_old + 1]; k < k_end; ++k)
          predecessorsInPolicy[neighbors[k]].erase(currentState);
        for(uint32_t k = actions[a_id_new], k_end = actions[a_id_new + 1]; k < k_end; ++k)
          predecessorsInPolicy[neighbors[k]].insert(currentState);
      }
    }
  }
}


void MDP::ILAOstar(uint32_t s_0) {
  // ILAO* structures required in recursion
  avector<uint32_t> setZ{}; // vector of the states in postorder
  avector<uint32_t> setS{};
  avector<bool> isInF(n_states, false); // Fringe states
  avector<bool> isInI(n_states, false); // Interior states

  function<void(uint32_t)> ILAOstarSCCDFSPostorder;
  ILAOstarSCCDFSPostorder = [&](uint32_t s_id) -> void {
    const uint32_t a_id = Pi[s_id];

    for(uint32_t e_id = actions[a_id], e_end = actions[a_id + 1];
        e_id < e_end; ++e_id) {
      const uint32_t neighbor_id = neighbors[e_id];
      if(!visited[neighbor_id] &&
          (isInI[neighbor_id] || isInF[neighbor_id])) {
        visited[neighbor_id] = true;
        if(isInF[neighbor_id])
          setS.push_back(neighbor_id);
        ILAOstarSCCDFSPostorder(neighbor_id);
      }
    }

    setZ.push_back(s_id);
  };

  isInI[goal_id] = true;
  isInF[s_0] = true;

  // Adding neighbors of s_0 to F to allow the first execution to run
  for(uint32_t a_id = states[s_0], a_end = states[s_0 + 1]; a_id < a_end; ++a_id)
      for(uint32_t e_id = actions[a_id], e_end = actions[a_id + 1]; e_id < e_end; ++e_id)
          isInF[neighbors[e_id]] = true;

  // Storage of the fringe state to use in the next iteration
  do {
    setS = avector<uint32_t>();
    setZ = avector<uint32_t>();

    fill(begin(visited), end(visited), false);
    visited[goal_id] = true;

    ILAOstarSCCDFSPostorder(s_0);

    const uint32_t sizeS = static_cast<uint32_t>(size(setS));
    for(uint32_t sCurrent = 0; sCurrent < sizeS; ++sCurrent) {
      const uint32_t expandedId = setS[sCurrent];

      isInF[expandedId] = false;
      isInI[expandedId] = true;

      for(uint32_t a_id = states[expandedId], a_end = states[expandedId + 1];
          a_id < a_end; ++a_id) {
        for(uint32_t e_id = actions[a_id], e_end = actions[a_id + 1];
            e_id < e_end; ++e_id) {
          const uint32_t neighbor_id = neighbors[e_id];
          if(!isInF[neighbor_id] && !isInI[neighbor_id])
            isInF[neighbor_id] = true;
        }
      }
    }

    partialVI(setZ, numeric_limits<float>::max());
    for(const uint32_t s_id : setZ)
      Pi[s_id] = greedyBestAction(s_id);
  } while(!setS.empty() || bellmanError > VI_Epsilon);
}

void MDP::RTDP(uint32_t s_0, uint32_t num_trials) {
  for(uint32_t i = 0; i < num_trials; ++i)
    RTDPTrial(s_0);
}

void MDP::RTDPTrial(uint32_t s) {
  while(!isTerminal(s)) {
    const uint32_t a_id = greedyBestAction(s);
    V[s] = QValue(a_id);
    s = sampleNextState(a_id);
  }
}

void MDP::LRTDP(uint32_t s_0) {
  vector<bool> solved(n_states, false);
  solved[goal_id] = true;

  while(!solved[s_0])
    LRTDPTrial(s_0, solved);
  LOG(INFO) << "Number of solved states by LRTDP: "
            << count(begin(solved), end(solved), true);
}

void MDP::LRTDPTrial(uint32_t s_0, vector<bool>& solved) {
  stack<uint32_t> visited_stack;
  uint32_t s = s_0;
  while(!solved[s]) {
    const uint32_t a_best = greedyBestAction(s);
    if(a_best >= size(actions)) [[unlikely]] { // no valid actions
      V[s] = numeric_limits<float>::infinity();
      solved[s] = true;
      break;
    }

    visited_stack.push(s);
    V[s] = QValue(a_best);
    s = sampleNextState(a_best);
  }

  while(!visited_stack.empty()) {
    s = visited_stack.top();
    visited_stack.pop();
    if(!LRTDPCheckSolved(s, solved))
      break;
  }
}

bool MDP::LRTDPCheckSolved(uint32_t s_0, vector<bool>& solved) {
  bool rv = true;
  vector<bool> in_open_or_closed(n_states, false);
  stack<uint32_t> open;
  stack<uint32_t> closed;

  if(!solved[s_0]) {
    open.push(s_0);
    in_open_or_closed[s_0] = true;
  }

  while(!open.empty()) {
    const uint32_t s = open.top();
    open.pop();
    closed.push(s);

    const auto [V_prime_s, a_best] = greedyBestValueAction(s);
    const float stateResidual = V_prime_s - V[s];
    if(stateResidual > VI_Epsilon) {
      rv = false;
      continue;
    }

    for(uint32_t effect_id = actions[a_best], effect_end = actions[a_best + 1];
        effect_id < effect_end; ++effect_id) {
      const uint32_t s_prime = neighbors[effect_id];
      if(!solved[s_prime] && !in_open_or_closed[s_prime]) {
        open.push(s_prime);
        in_open_or_closed[s_prime] = true;
      }
    }
  }

  while(!closed.empty()) {
    const uint32_t s = closed.top();
    closed.pop();
    if(rv)
      solved[s] = true;
    else {
      V[s] = greedyBestValue(s);
      ++n_backups;
    }
  }

  return rv;
}

void MDP::BRTDP(uint32_t s_0, float alpha, float tau) {
  LOG(INFO) << "BRTDP: using alpha = " << alpha << " and tau = " << tau;

  initUpperBound(); // upperBoundName must be BRTDP

  LOG(INFO) << "BRTDP: Doing trials until V and V_u are alpha-close";
  TimeVar tStart = timeNow();
  do {
    BRTDPTrial(s_0, tau);
  } while(V_u[s_0] - V[s_0] > alpha);
  const long time_brtdp = duration(timeNow() - tStart);
  LOG(INFO) << "Running time of BRTDP Trials: " << time_brtdp << " ms";
}

void MDP::BRTDPTrial(uint32_t s_0, float tau) {
  vector<pair<uint32_t, float>> b;
  stack<uint32_t> traj;

  uint32_t s = s_0;
  while(true) {
    traj.push(s);
    V_u[s] = greedyBestValue<ValueOption::Use_V_u>(s);
    uint32_t a = greedyBestAction(s); // use V
    V[s] = QValue(a); // use V

    // Compute b(y) for every possible neighbor of action a
    float B = 0.f;
    for(uint32_t effect_id = actions[a], effect_end = actions[a + 1];
        effect_id < effect_end; ++effect_id) {
      const uint32_t y = neighbors[effect_id];
      const float P_xy_a = probabilities[effect_id];
      const float b_y = P_xy_a * (V_u[y] - V[y]);
      B += b_y;
      b.emplace_back(y, b_y);
    }

    if(B < 1e-5f + (V_u[s] - V[s]) / tau)
      break; // Current trajectory reached a well-known state

    // sample next state according to b/B probability distribution
    const float proba = random_proba();
    float cumul = 0.0f;
    for(uint32_t i = 0; i < size(b); ++i) {
      cumul += b[i].second / B;
      if(proba <= cumul) {
        s = b[i].first;
        break;
      }
    }

    b.clear();
  } // end while

  // backup visited states of trial in reverse order
  while(!traj.empty()) {
    s = traj.top();
    traj.pop();
    V_u[s] = greedyBestValue<ValueOption::Use_V_u>(s);
    V[s] = greedyBestValue(s); // use V
  }
}

void MDP::TVI() {
  LOG(INFO) << "TVI Phase 1: compute the SCCs of the MDP";
  TimeVar tStart = timeNow();
  tarjan();
  const long time_p1 = duration(timeNow() - tStart);

  tStart = timeNow();
  LOG(INFO) << "TVI Phase 2: solve each SCC in reverse topological order";
  for(const auto& scc : SCC)
    partialVI(scc);
  const long time_p2 = duration(timeNow() - tStart);

  LOG(INFO) << "Running time of TVI (Tarjan): " << time_p1 << " ms";
  LOG(INFO) << "Running time of TVI (VI on SCCs): " << time_p2 << " ms";
  if(FLAGS_benchmark)
    cout << time_p1 << "\t0\t" << time_p2 << "\t";
}

void MDP::FTVI(uint32_t s_0, uint32_t iterPerBatch, float minChangeV_s0) {
  LOG(INFO) << "FTVI: using " << iterPerBatch << " iterations per batch";
  LOG(INFO) << "FTVI: using minChangeV_s0 = " << minChangeV_s0;

  LOG(INFO) << "FTVI Phase 0: compute initial upper bound V_u";
  initUpperBound();

  LOG(INFO) << "FTVI Phase 1: search (to prune suboptimal actions)";
  TimeVar tStart = timeNow();
  FTVIPhase1(s_0, iterPerBatch, minChangeV_s0);
  const long time_p1 = duration(timeNow() - tStart);

  LOG(INFO) << "FTVI Phase 2: TVI, after action elimination";
  tStart = timeNow();
  if(bellmanError > VI_Epsilon) // no need to call TVI if converged in Phase 1
    TVI(); // Tarjan in TVI automatically won't consider eliminated actions
  const long time_p2 = duration(timeNow() - tStart);

  LOG(INFO) << "Running time of FTVI (Search): " << time_p1 << " ms";
  LOG(INFO) << "Running time of FTVI (TVI): " << time_p2 << " ms";
}

void MDP::FTVIPhase1(uint32_t s_0, uint32_t iterPerBatch, float minChangeV_s0) {
  const float batch_treshold = 1.f - minChangeV_s0;
  while(true) {
    const float oldV = V[s_0];
    for(uint32_t i = 0; i < iterPerBatch; ++i) {
      bellmanError = 0;
      fill(begin(visited), end(visited), false);
      FTVISearch(s_0); // implicitly updates V = V_l and V_u
      if(bellmanError < VI_Epsilon) [[unlikely]] {
        LOG(INFO) << "FTVI converged to V_star during phase 1";
        return;
      }
    }

    const float batch_ratio = oldV / V[s_0];
    /* LOG_EVERY_N(INFO, 10) << "FTVI batch value change: " << batch_ratio; */
    if(batch_ratio > batch_treshold) // oldV / newV > (100-y)%
      break; // V converged enough to start phase 2
  }
}

void MDP::FTVISearch(uint32_t s_id) {
  visited[s_id] = true;
  if(isTerminal(s_id)) [[unlikely]]
    return;

  const uint32_t a_id = greedyBestAction(s_id);
  for(uint32_t effect_id = actions[a_id], effect_end = actions[a_id + 1];
      effect_id < effect_end; ++effect_id) {
    const uint32_t neighbor_id = neighbors[effect_id];
    if(!visited[neighbor_id])
      FTVISearch(neighbor_id);
  }

  const float searchResidual = FTVIBackup(s_id);
  bellmanError = max(bellmanError, searchResidual);
}

float MDP::FTVIBackup(uint32_t s_id) {
  // Detect useless actions of s_id
  for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
      a_id < a_end; ++a_id)
    if(QValue(a_id) > V_u[s_id])
      eliminatedActions[a_id] = true;

  const float oldV = V[s_id];
  V[s_id] = greedyBestValue<ValueOption::Use_V>(s_id); // won't consider eliminated actions
  V_u[s_id] = greedyBestValue<ValueOption::Use_V_u>(s_id); // idem
  return V[s_id] - oldV;
}

// Compute the SCCs and put them in reverse topological order in 'SCC'
void MDP::tarjan() {
  for(uint32_t s = 0; s < n_states; ++s)
    if(V[s] != numeric_limits<float>::infinity())
      tarjanDFS(s);
}

void MDP::tarjanDFS(uint32_t s_id) {
  static uint32_t clock = 0;
  static stack<uint32_t> S;
  static vector<uint32_t> pre(n_states, numeric_limits<uint32_t>::max());
  static vector<uint32_t> low(n_states, numeric_limits<uint32_t>::max());
  static vector<uint32_t> root(n_states, numeric_limits<uint32_t>::max());

  if(pre[s_id] != numeric_limits<uint32_t>::max())
    return;

  low[s_id] = pre[s_id] = ++clock;
  S.push(s_id);

  for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
      a_id < a_end; ++a_id) {
    if(eliminatedActions[a_id])
      continue;
    for(uint32_t effect_id = actions[a_id], effect_end = actions[a_id + 1];
        effect_id < effect_end; ++effect_id) {
      const uint32_t s_prime_id = neighbors[effect_id];
      if(V[s_prime_id] == numeric_limits<float>::infinity())
        continue; // state is unreachable or can't reach goal

      if(pre[s_prime_id] == numeric_limits<uint32_t>::max()) { // if unmarked
        tarjanDFS(s_prime_id);
        low[s_id] = min(low[s_id], low[s_prime_id]);
      } else if(root[s_prime_id] == numeric_limits<uint32_t>::max())
        low[s_id] = min(low[s_id], pre[s_prime_id]);
    }
  }

  if(low[s_id] == pre[s_id]) {
    avector<uint32_t> component;
    uint32_t w = numeric_limits<uint32_t>::max();
    do {
      w = S.top();
      S.pop();
      root[w] = s_id;
      component.push_back(w);
      stateToSCC[w] = static_cast<uint32_t>(size(SCC));
    } while(w != s_id);
    SCC.push_back(std::move(component));
  }
}

// C(s, a) + sum_s' { T(s, a, s') * V[s'] }
template <MDP::ValueOption option>
float MDP::QValue(uint32_t a_id) const {
  float Q_a = costs[a_id]; // C(s, a)
  for(uint32_t effect_id = actions[a_id], effect_end = actions[a_id + 1];
      effect_id < effect_end; ++effect_id) {
    const uint32_t s_prime_id = neighbors[effect_id];
    const float proba = probabilities[effect_id];

    if constexpr (option == ValueOption::Use_V_u)
      Q_a += proba * V_u[s_prime_id];
    else // valueOption == Use_V
      Q_a += proba * V[s_prime_id];
  }
  return Q_a;
}

uint32_t MDP::greedyBestAction(uint32_t s_id) const {
  float min_Q = UpperBoundInitValue;
  uint32_t best_a = numeric_limits<uint32_t>::max();

  // This loop computes argmin_a Q(s_id, a)
  for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
      a_id < a_end; ++a_id) {
    if(eliminatedActions[a_id])
      continue;
    const float Q_a = QValue(a_id);
    if(Q_a < min_Q) {
      min_Q = Q_a;
      best_a = a_id;
    }
  }

  return best_a;
}

template <MDP::ValueOption option>
float MDP::greedyBestValue(uint32_t s_id) const {
  float min_Q = numeric_limits<float>::infinity();

  // This loop computes min_a Q(s_id, a)
  for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
      a_id < a_end; ++a_id) {
    if(eliminatedActions[a_id])
      continue;
    const float Q_a = QValue<option>(a_id);
    min_Q = min(min_Q, Q_a);
  }

  return min_Q;
}

pair<float, uint32_t> MDP::greedyBestValueAction(uint32_t s_id) const {
  float min_Q = UpperBoundInitValue;
  uint32_t best_a = numeric_limits<uint32_t>::max();

  // This loop computes (min/argmin)_a Q(s_id, a)
  for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
      a_id < a_end; ++a_id) {
    if(eliminatedActions[a_id])
      continue;
    const float Q_a = QValue(a_id);
    if(Q_a < min_Q) {
      min_Q = Q_a;
      best_a = a_id;
    }
  }

  return {min_Q, best_a};
}

uint32_t MDP::sampleNextState(uint32_t a_id) const {
  const float proba = random_proba();
  float cumul = 0.0f;

  const uint32_t effect_end = actions[a_id + 1];
  for(uint32_t effect_id = actions[a_id]; effect_id < effect_end; ++effect_id) {
    const uint32_t s_prime_id = neighbors[effect_id];
    cumul += probabilities[effect_id];
    if(proba <= cumul)
      return s_prime_id;
  }

  // Due to rounding errors, being here means last neighbor should be returned
  return neighbors[effect_end - 1];
}

bool MDP::isTerminal(uint32_t s_id) const {
  // check if the state has an applicable action
  return states[s_id] == states[s_id + 1] || s_id == goal_id;
}

template <MDP::BoundType bound>
void MDP::reversedDijkstra() {
  using PQ_Pair = pair<uint32_t, float>; // id state -> min/max cost to goal
  constexpr auto comp = [](const PQ_Pair& p1, const PQ_Pair& p2) {
    return p1.second > p2.second;
  };
  priority_queue<PQ_Pair, vector<PQ_Pair>, decltype(comp)> pq(comp);

  auto* V_or_V_u = &V;
  if constexpr (bound == BoundType::Lower) // V was initialized to 0 in operator>>
    fill(begin(V), end(V), numeric_limits<float>::infinity());
  else { // bound == Upper, V_u already init to UpperBoundInitValue in operator>>
    V_or_V_u = &V_u;
    fill(begin(visited), end(visited), false);
  }

  (*V_or_V_u)[goal_id] = 0.0f; // V[goal_id] = 0 = V_u[goal_id]
  pq.emplace(goal_id, 0.0f);
  while(!pq.empty()) {
    const auto [s, V_s] = pq.top();
    pq.pop();

    for(const auto& [pred, _1, arc_cost, _2] : predecessors[s]) {
      if constexpr (bound == BoundType::Upper) {
        if (visited[pred])
          continue;
        else
          visited[pred] = true;
      }

      const float V_pred = bound == BoundType::Lower
                         ? arc_cost + V_s
                         : greedyBestValue<ValueOption::Use_V_u>(pred);
      if(V_pred < (*V_or_V_u)[pred]) {
        (*V_or_V_u)[pred] = V_pred;
        pq.emplace(pred, V_pred);
      }
    }
  }
}

void MDP::lowerBoundHMin() {
  reversedDijkstra<BoundType::Lower>();
}

void MDP::initUpperBound() {
  TimeVar tStart = timeNow();
  if(upperBoundName == "NONE") {
    // V_u is initialized to UpperBoundInitValue in operator>>
    LOG(INFO) << "Using no upper bound: V_u = " << UpperBoundInitValue;
  } else if(upperBoundName == "FTVI") {
    LOG(INFO) << "Computing the FTVI upper bound";
    upperBoundFTVI();
  } else if(upperBoundName == "BRTDP") {
    LOG(INFO) << "Computing the DS-MPI upper bound used in BRTDP";
    upperBoundBRTDP();
  } else if(upperBoundName == "BRTDP_FTVI") {
    LOG(INFO) << "Computing DS-MPI followed by FTVI upper-bound";
    upperBoundBRTDP();
    upperBoundFTVI(); // this sometimes gives a tighter bound
  } else
    LOG(FATAL) << "Unsupported upper bound";

  // In case of dead-ends, V = inf and V_u = UpperBoundInitValue, which can be smaller
  for(uint32_t s = 0; s < n_states; ++s)
    V_u[s] = max(V[s], V_u[s]);

  const long time_upper = duration(timeNow() - tStart);
  LOG(INFO) << "Time to compute initial V_u: " << time_upper << " ms";
}

void MDP::upperBoundFTVI() {
  reversedDijkstra<BoundType::Upper>();
}

void MDP::upperBoundBRTDP() {
  // Algorithm 1 in BRTDP paper to compute p_g and w
  const uint32_t n_actions = static_cast<uint32_t>(size(costs));
  avector<pair<float, float>>
    pri(n_states, {numeric_limits<float>::infinity(),
                   numeric_limits<float>::infinity()});
  avector<float> p_g_hat(n_actions, 0.0f);
  avector<float> w_hat = costs;
  avector<float> p_g(n_states, 0.0f);
  avector<float> w(n_states, numeric_limits<float>::infinity());
  p_g[goal_id] = 1.0f;
  w[goal_id] = 0.0f;

  using PQ_Pair = pair<uint32_t, pair<float, float>>; // id state -> priority
  constexpr auto comp = [](const PQ_Pair& p1, const PQ_Pair& p2) {
    return p1.second > p2.second;
  };

  priority_queue<PQ_Pair, vector<PQ_Pair>, decltype(comp)> pq(comp);
  pq.emplace(goal_id, pair(0.f, 0.f));

  while(!pq.empty()) {
    uint32_t x = pq.top().first;
    pq.pop();

    // happens if x was more than once in pq (because we don't have decrease_key)
    if(visited[x]) // visited is called "fin" in paper
      continue;
    visited[x] = true;

    // x can never be an absorbing dead-end (since they are never a predecessor)
    if(x != goal_id) [[likely]] {
      const uint32_t Pi_x = Pi[x];
      p_g[x] = min(1.f, p_g_hat[Pi_x]); // min is in case of numerical imprecision
      w[x] = w_hat[Pi_x];
    }

    for(const auto& [y, a, _1, P_yx_a] : predecessors[x]) {
      if(visited[y])
        continue;

      p_g_hat[a] += P_yx_a * p_g[x];
      w_hat[a] += P_yx_a * w[x];
      const pair<float, float> curr_pri = {1.f - p_g_hat[a], w_hat[a]};

      if(curr_pri < pri[y]) {
        pri[y] = curr_pri;
        Pi[y] = a;
        pq.emplace(y, curr_pri);
      }
    }
  }

  // Theorem 3 in BRTDP paper to compute V_u using p_g and w
  float best_lambda = -numeric_limits<float>::infinity();
  for(uint32_t x = 0; x < n_states; ++x) {
    if(isTerminal(x)) [[unlikely]]
      continue;

    float min_action = numeric_limits<float>::infinity();
    for(uint32_t a = states[x], a_end = states[x + 1]; a < a_end; ++a) {
      // Compute sum_y P_xy^a p_g[y] and sum_y P_xy^a w[y]
      float sum_p_g_a = 0.f;
      float sum_w_a = 0.f;
      for(uint32_t effect_id = actions[a], effect_end = actions[a + 1];
          effect_id < effect_end; ++effect_id) {
        const uint32_t y = neighbors[effect_id];
        const float P_xy_a = probabilities[effect_id];
        sum_p_g_a += P_xy_a * p_g[y];
        sum_w_a += P_xy_a * w[y];
      }

      // Compute lambda_a
      float lambda_a;
      if(p_g[x] < sum_p_g_a) // case I
        lambda_a = (costs[a] + sum_w_a - w[x]) / (sum_p_g_a - p_g[x]);
      else if(w[x] >= costs[a] + sum_w_a && fabs(p_g[x] - sum_p_g_a) < 1e-4f) // case II
        lambda_a = 0.f;
      else // case III
        lambda_a = numeric_limits<float>::infinity();

      // Keep min_a lambda_a
      min_action = min(min_action, lambda_a);
    }

    // Not in the paper. See comment in other implementation:
    // github.com/instance01/BRTDP-DS-MPI/blob/master/BRTDP_DS_MPI/algorithm/cpp_brtdp.cpp
    if(min_action == numeric_limits<float>::infinity()) {
      LOG_FIRST_N(WARNING, 1)
        << "DS-MPI upper bound: min_action was infinity, using 0";
      min_action = 0.f;
    }

    // Compute lambda = max_x min_a lambda[a]
    best_lambda = max(best_lambda, min_action);
  }

  // Initialize upper bound V_u
  for(uint32_t x = 0; x < n_states; ++x)
    V_u[x] = min(w[x] + (1.f - p_g[x]) * best_lambda, UpperBoundInitValue);
}

void MDP::printPartitionsInfo() const {
  if(size(SCC) == 0)
    return;

  LOG(INFO) << "--- Partitions Info ---";
  uint32_t SCCMaxSize = 0;
  const uint32_t numSCCs = static_cast<uint32_t>(size(SCC));
  for(uint32_t scc_id = 0; scc_id < numSCCs; ++scc_id) {
    const auto& scc = SCC[scc_id];
    const uint32_t numElemsInSCC = static_cast<uint32_t>(size(scc));
    SCCMaxSize = max(SCCMaxSize, numElemsInSCC);

    stringstream ss;
    ss << "scc " << scc_id << ": ";
    for(uint32_t id : scc)
      ss << id << ' ';
    ss << '\n';
    VLOG(2) << ss.str();
  }

  LOG(INFO) << "Number of SCC: " << size(SCC);
  LOG(INFO) << "Size of largest SCC: " << SCCMaxSize;
  LOG(INFO) << "-----------------------";

  if(FLAGS_benchmark)
    cout << size(SCC) << "\t" << SCCMaxSize << "\t";
}

void MDP::printStatesActionsStats() const {
  VLOG(2) << "-------  Current Value Bounds  -------";
  for(uint32_t i = 0; i < n_states; ++i) {
    // compute the id in case states were reordered
    uint32_t id = sccStartIds.empty() ? i : newIds[i];
    VLOG(2) << "\t" << fixed << setprecision(3)
              << V[id] << " <= V[" << i << "] <= " << V_u[id];
  }
  VLOG(2) << "--------------------------------------";
}

uint32_t MDP::getNumEliminatedActions() const {
  return static_cast<uint32_t>(
      count(begin(eliminatedActions), end(eliminatedActions), true));
}

float MDP::evaluatePolicy(uint32_t s_0, uint32_t numTrials) const {
  const TimeVar tStart = timeNow();
  float avg_score = 0.0f;
  for(uint32_t i = 0; i <= numTrials; ++i) {
    float score = 0.0f;
    uint32_t s = s_0;
    while(!isTerminal(s)) {
      const uint32_t a_id = Pi[s];
      CHECK(a_id < size(actions)) << "Improper policy";
      score += costs[a_id];
      s = sampleNextState(a_id);
    }

    if(s != goal_id) [[unlikely]] {
      LOG(ERROR) << "The policy has dead-ends !";
      return numeric_limits<float>::infinity();
    }

    const float i_f = static_cast<float>(i);
    avg_score = (i_f * avg_score + score) / (i_f + 1.0f);
  }

  const long time_evaluate = duration(timeNow() - tStart);
  LOG(INFO) << "Time to evaluate policy: " << time_evaluate << " ms";
  return avg_score;
}

void MDP::dumpGraphvizFormat(std::ostream& os) const {
  os << "digraph {\n";
  os << "\trankdir=\"LR\"\n";

  // Print states
  for(uint32_t s_id = 0; s_id < n_states; ++s_id)
    os << "\t" << s_id << " [shape=ellipse];\n";

  // Print actions
  for(uint32_t s_id = 0; s_id < n_states; ++s_id) {
    for(uint32_t a_id = states[s_id], a_end = states[s_id + 1];
        a_id != a_end; ++a_id) {
      const uint32_t n_effects = actions[a_id + 1] - actions[a_id];
      const float cost = costs[a_id];

      if(n_effects == 1) { // deterministic action
        const uint32_t neighbor_id = neighbors[actions[a_id]];
        os << "\t\t" << s_id << " -> " << neighbor_id
          << " [label=" << cost << "];\n";
        continue;
      }

      // probabilistic action: we draw a special action node
      stringstream anode;
      anode << "\"" << s_id << "-" << a_id << "\"";
      os << "\t\t" << anode.str() << " [shape=point];\n";
      os << "\t\t\t" << s_id << " -> " << anode.str()
         << " [label=" << cost << " arrowhead=none];\n";


      // Print probabilistic effects
      for(uint32_t e_id = actions[a_id], e_end = actions[a_id + 1];
          e_id != e_end; ++e_id) {
        const uint32_t neighbor = neighbors[e_id];
        const float proba = probabilities[e_id];
        os << "\t\t\t" << anode.str() << " -> " << neighbor
           << " [label=" << proba << "];\n";
      }
    }
  }

  os << "}" << endl;
}

istream& operator>> (istream& is, MDP& mdp) {
  is >> mdp.n_states;

  // If no goal specified in constructor, use last state as goal
  if(mdp.goal_id == numeric_limits<int>::max())
    mdp.goal_id = mdp.n_states - 1;
  CHECK(mdp.start_id < mdp.n_states) << "Specified start id is invalid";
  CHECK(mdp.goal_id < mdp.n_states) << "Specified goal id is invalid";

  // Automatically value-initialized
  mdp.states    = avector<uint32_t>(mdp.n_states + 1);
  mdp.V         = avector<float>(mdp.n_states, 0.0f);
  mdp.V_u       = avector<float>(mdp.n_states, UpperBoundInitValue);
  mdp.Pi        = avector<uint32_t>(mdp.n_states, numeric_limits<uint32_t>::max());
  mdp.visited   = avector<bool>(mdp.n_states, false);
  mdp.predecessors
    = avector<avector<tuple<uint32_t, uint32_t, float, float>>>(mdp.n_states);
  mdp.stateToSCC = avector<uint32_t>(mdp.n_states);
  mdp.actions.push_back(0);

  // Read every state
  for(uint32_t s_id = 0; s_id < mdp.n_states; ++s_id) {
    uint32_t id, n_actions;
    is >> id >> n_actions;
    CHECK(id == s_id) << "The states must be indexed from 0 to n_states - 1";

    // Read every action applicable in current state
    uint32_t n_valid_actions = n_actions;
    for(uint32_t j = 0; j < n_actions; ++j) {
      float cost; // in SSP-MDP, reward is expressed as cost
      uint32_t n_outcomes;
      is >> cost >> n_outcomes;

      // We don't store "no-op" actions
      if(n_outcomes == 0) [[unlikely]] {
        --n_valid_actions;
        continue;
      }

      // Actions of goal are ignored
      if(s_id == mdp.goal_id) [[unlikely]] {
        string action_line;
        getline(is >> ws, action_line);
        --n_valid_actions;
        continue;
      }

      // Compute end indexing of action j in state s_id
      mdp.actions.push_back(mdp.actions.back() + n_outcomes);

      // Save action cost at mdp.costs[mdp.states[s_id] + j]
      mdp.costs.push_back(cost);

      // Temporary struct to save every outcomes of current action
      struct Outcome {
        uint32_t id_neighbor;
        float proba;
      };
      Outcome* actionOutcomes
        = static_cast<Outcome*>(alloca(n_outcomes * sizeof(Outcome)));

      // Read every possible outcome of current action
      const uint32_t a_id = static_cast<uint32_t>(size(mdp.costs)) - 1;
      for(uint32_t k = 0; k < n_outcomes; ++k) {
        Outcome& kth_outcome = actionOutcomes[k];
        is >> kth_outcome.id_neighbor >> kth_outcome.proba;

        // Save reverse dependencies
        auto& preds = mdp.predecessors[kth_outcome.id_neighbor];
        auto it = find_if(begin(preds), end(preds), [s_id](const auto& p) {
            return get<0>(p) == s_id; // id_predecessor == s_id
        });
        if(it == end(preds))
          preds.push_back({s_id, a_id, cost, kth_outcome.proba});
        else if(auto& [_1, tuple_action, tuple_cost, tuple_proba] = *it;
                tuple_cost > cost ||
                (tuple_cost == cost && tuple_proba < kth_outcome.proba)) {
          // keep only best action from s to neighbor
          tuple_action = a_id;
          tuple_cost = cost;
          tuple_proba = kth_outcome.proba;
        }
      } // end outcomes

      // Sort outcomes by decreasing probabilities
      sort(actionOutcomes, actionOutcomes + n_outcomes,
          [](Outcome& o1, Outcome& o2) { return o1.proba > o2.proba; });

      // Insert in order the outcomes in the MDP
      for(uint32_t k = 0; k < n_outcomes; ++k) {
        mdp.neighbors.push_back(actionOutcomes[k].id_neighbor);
        mdp.probabilities.push_back(actionOutcomes[k].proba);
      }

      // Normalize probabilities, in case some actions don't sum to 1
      const long outcomes = static_cast<long>(n_outcomes);
      const float sum_proba = accumulate(end(mdp.probabilities) - outcomes,
                                         end(mdp.probabilities), 0.0f);
      transform(end(mdp.probabilities) - outcomes, end(mdp.probabilities),
                end(mdp.probabilities) - outcomes,
                [sum_proba](auto n) { return n / sum_proba; });
    } // end actions

    // Compute end indexing of state s_id
    mdp.states[s_id+1] = mdp.states[s_id] + n_valid_actions;

    // Initialize a policy with first applicable action
    mdp.Pi[s_id] = mdp.states[s_id];
  } // end states

  mdp.eliminatedActions = avector<bool>(size(mdp.actions));

  LOG(INFO) << "size(states)        : " << size(mdp.states);
  LOG(INFO) << "size(actions)       : " << size(mdp.actions);
  LOG(INFO) << "size(costs)         : " << size(mdp.costs);
  LOG(INFO) << "size(neighbors)     : " << size(mdp.neighbors);
  LOG(INFO) << "size(probabilities) : " << size(mdp.probabilities);
  LOG(INFO) << "sizeof(mdp)         : " << sizeof(mdp);
  return is;
}
