#ifndef MDP_HPP
#define MDP_HPP

#include <cmath>
#include <iostream>
#include <limits>
#include <vector>
#include <set>
#include <unordered_set>

// Vector alignment
#include <boost/align/aligned_allocator.hpp>
constexpr uint32_t AlignSize = 64; // std::vector heap memory alignment
template <typename T>
using avector = std::vector<T, boost::alignment::aligned_allocator<T, AlignSize>>;

// Solvers default parameters
constexpr float VI_Epsilon           = 0.0001f; // McMahan: 0.001; Dai: 1e-6
constexpr float FTVI_MinChangeV_s0   = 0.03f;   // Dai: 0.03 ; Jain: 0.40
constexpr uint32_t FTVI_IterPerBatch = 5;       // Dai: 100  ; Jain: 10
constexpr uint32_t RTDP_Trials       = 100'000;
constexpr float BRTDP_Alpha          = 0.001f;  // McMahan: 0.1    ; Dai: 2e-6
constexpr float BRTDP_Tau            = 10.f;    // McMahan: 10-100 ; Dai: 10

constexpr uint32_t NumTrials = 50'000; // Number of trials to estimate policy value

// For pcTVI only
struct DAG {
  DAG(uint32_t numSCCs) : revNeighbors(numSCCs), numIncomingArcs(numSCCs) {}

  std::vector<std::set<uint32_t>> revNeighbors; // reverse neighbors of each SCC
  std::vector<uint32_t> numIncomingArcs;
};

class MDP {
 public:
  // Constructor
  MDP(uint32_t start = 0,
      uint32_t goal = std::numeric_limits<uint32_t>::max(),
      const std::string& upperbound = "BRTDP")
    : start_id(start), goal_id(goal), upperBoundName(upperbound) {}

  // Compute V^* using the specified solver
  void solve(const std::string& solverName = "VI");

  // Compute the specified heuristic for all states and initialize V
  void computeHeuristic(const std::string& heuristicName = "NONE");

  // Set V[s] = infty for every states not reachable from s_0
  void markReachableStates(uint32_t s_0); // DFS

  // Compute and evaluate the greedy policy Pi := Pi_V
  void findAndEvaluateGreedyPolicy();

  // Print to a stream the tuple (Pi_star(s), V_star(s)) for every state s
  // Note: action_id displayed will be different if states were reordered
  void printPolicy(std::ostream& os = std::cout) const;

  // Returns (V[start_id], simulated(Pi(start_id)))
  std::pair<float, float> getExpectedAndSimulatedCost() const;

  // I/O
  void dumpGraphvizFormat(std::ostream& os) const;
  friend std::istream& operator>> (std::istream& is, MDP& mdp);

 private:
  // Solvers main function

  void VI(); // Gauss-Seidel

  void RTDP(uint32_t s, uint32_t = RTDP_Trials);
  void LRTDP(uint32_t s_0);
  void BRTDP(uint32_t s_0, float = BRTDP_Alpha, float = BRTDP_Tau);

  void LAOstar(uint32_t s_0);
  void ILAOstar(uint32_t s_0);

  void TVI(); // Topological Value Iteration
  void FTVI(uint32_t s_0, uint32_t = FTVI_IterPerBatch, float = FTVI_MinChangeV_s0);

  void eTVI(); // efficient TVI or extra TVI
  void eiTVI(); // extra-intra TVI
  void pcTVI(); // parallel-chained TVI

  // Solvers auxiliary functions

  void partialVI(const avector<uint32_t>& subset, float epsilon = VI_Epsilon);
  void eTVIPartial(uint32_t init_id, uint32_t end_id); // [init_id, end_id[
  void RTDPTrial(uint32_t s_0);
  void LRTDPTrial(uint32_t s_0, std::vector<bool>& solved);
  bool LRTDPCheckSolved(uint32_t s_0, std::vector<bool>& solved);
  void BRTDPTrial(uint32_t s_0, float tau);
  void FTVIPhase1(uint32_t s_0, uint32_t iterPerBatch, float minChangeV_s0);
  void FTVISearch(uint32_t s_id);
  float FTVIBackup(uint32_t s_id); // returns local residual

  enum class ReorderType { Extra, IntraExtra };
  void reorderedTVI(ReorderType type);
  void reorderStates(ReorderType type); // Rebuild CSR so SCCs are continuous
  void findNewIdsExtra(); // Reorder SCCs so that they are contiguous in memory
  void findNewIdsExtraIntra(); // Like above but also reorder inside SCCs
  void newIdsSCCDFSPostorder(uint32_t s_id); // DFS limited to current SCC
  void ILAOstarSCCDFSPostorder(uint32_t s_id); // DFS limited to envelope

  void rebuildCSR(); // Rebuild CSR in the order of the 'newIds' attribute

  // Partitionning functions

  void tarjan(); // Compute strongly connected components in the SCC attribute
  void tarjanDFS(uint32_t s_id); // Tarjan auxiliary function
  DAG buildCondensation() const; // graph condensation (graph of SCCs)
  void findSCCsBorderStates(); // populate SCCsOutwardBorderStates

  // General auxiliary functions

  uint32_t sampleNextState(uint32_t a_id) const;
  bool isTerminal(uint32_t s_id) const;

  // lower and upper bounds

  void lowerBoundHMin(); // h_min heuristic (compute lower bound)
  void initUpperBound(); // compute upper bound "upperBoundName"
  void upperBoundFTVI(); // compute V_u as described in FTVI paper
  void upperBoundBRTDP(); // Dijkstra Sweep for Monotone Pessimistic Init (DS-MPI)

  enum class BoundType { Lower, Upper };
  template <BoundType>
  void reversedDijkstra(); // traverse the determinized MDP from goal

  // Bellman Computations

  enum class ValueOption { Use_V, Use_V_u };
  template <ValueOption = ValueOption::Use_V>
  float QValue(uint32_t a_id) const; // Q(s,a) (s is implicit in action a)

  template <ValueOption = ValueOption::Use_V>
  float greedyBestValue(uint32_t s_id) const; // min_a Q(s,a)
  uint32_t greedyBestAction(uint32_t s_id) const; // argmin_a Q(s,a)
  std::pair<float, uint32_t> greedyBestValueAction(uint32_t s_id) const;

  // Debugging and Benchmarking

  void printVerboseInfo() const; // Print info about SCCs and state values
  void printPartitionsInfo() const; // content of each SCC, max_size
  void printStatesActionsStats() const; // V(s), V_u(s), eliminatedActions
  uint32_t getNumEliminatedActions() const;
  float evaluatePolicy(uint32_t s_0, uint32_t = NumTrials) const;

  // --- MDP Instance attributes ---

  // MDP stored in a modified "compressed sparse row" (CSR) format
  avector<uint32_t> states{}; // index=id, content=indices of costs and actions vector

  // Following vectors contain the relevant infos for state i
  // at positions between states[i] and states[i+1]-1
  // e.g., costs[states[i]] is cost of first applicable action at state i
  avector<float> costs{};
  avector<uint32_t> actions{};

  // Following vectors contain the relevant infos for action i
  // at positions between actions[i] and actions[i+1]-1
  // e.g., neighbors[actions[states[i]]] is first possible neighbor state
  // of first applicable action in state i
  avector<uint32_t> neighbors{};
  avector<float> probabilities{};

  // Variables to store temporary information needed for some solvers
  avector<uint32_t> Pi{}; // current best policy
  avector<float> V{}; // current state value lower bound estimate
  avector<float> V_u{}; // FTVI/BRTDP: current state value upper bound estimate
  avector<bool> visited{}; // is some state already visited ?
  avector<bool> eliminatedActions{}; // is some action suboptimal ?
  avector<avector<uint32_t>> SCC{}; // [*]TVI: in reverse topological order
  avector<uint32_t> stateToSCC{}; // state id -> scc id
  avector<uint32_t> sccStartIds{}; // eTVI: scc id -> start id of scc states
  avector<uint32_t> newIds{}; // eTVI: state old id -> new id (after reordering)
  avector<uint32_t> oldIds{}; // eTVI: state new id -> old id (after reordering)
  avector<std::unordered_set<uint32_t>> SCCsOutwardBorderStates{};

  // Necessary to compute the heuristics and upper bounds
  // id -> [(id_pred, action, arc_cost, arc_proba)]
  avector<avector<std::tuple<uint32_t, uint32_t, float, float>>> predecessors{};

  uint32_t n_states = 0; // doesn't count the states pruned when reordering
  uint32_t n_states_before = 0; // only used if states were reordered
  uint32_t n_backups = 0; // for benchmarking: number of bellman backups
  uint32_t currentId = 0; // used for e[i]TVI when reordering states;
  float bellmanError = 0.0f;

  // Initialized in constructor
  uint32_t start_id;
  uint32_t goal_id;
  const std::string upperBoundName;
};

#endif
