#include <iostream>
#include <fstream>

#include <gflags/gflags.h>

#include "mdp.hpp"
#include "utils.hpp"

using namespace std;

const string usageMsg =
  "SSP-MDP Solver\n\n"
  "This program finds an optimal value function V* and policy Pi*\n"
  "for a MDP specified in the `inputFile` parameter, and output them in\n"
  "the `outputFile` parameter.\n\n"
  "The solver to be used is specified in the `solver` parameter.\n"
  "Possible values:\n"
  "\tVI (Gauss-Seidel asynchronous variant)\n"
  "\tLAOstar\tILAOstar\n"
  "\tRTDP\tLRTDP\tBRTDP\n"
  "\tTVI\tFTVI\n"
  "\teTVI\teiTVI\n"
  "\tpcTVI\n"
  "The heuristic to be used is specified in the `heuristic` parameter.\n"
  "Possible values:\n"
  "\tNONE:  No heuristic (h = 0)\n"
  "\tH_MIN: h_min heuristic\n"
  "The upper bound to be used is specified in the `upperbound` parameter.\n"
  "Possible values:\n"
  "\tNONE:  No upper bound (V_u = infty)\n"
  "\tFTVI:  upper bound described in FTVI paper\n"
  "\tBRTDP: upper bound described in BRTDP paper (DS-MPI)\n"
  "\tBRTDP_FTVI: both previous upper bound combined\n";

// args definition
DEFINE_int32(start, 0, "id of the start state (needed for some solvers)");
DEFINE_int32(goal, numeric_limits<int>::max(),
    "id of the goal state [last state loaded if not specified]");
DEFINE_string(inputFile, "cin", "path to a file containing the MDP to solve");
DEFINE_string(outputFile, "cout", "path to a file where to output results");
DEFINE_string(solver, "VI", "name of the solver to use");
DEFINE_string(heuristic, "H_MIN", "name of the heuristic to use");
DEFINE_string(upperbound, "BRTDP", "name of the initial upper bound for FTVI/BRTDP");
DEFINE_bool(reach, true, "do a reachability analysis before starting solver");
DEFINE_bool(policy, false, "print the obtained policy after the solver has completed");
DEFINE_bool(benchmark, false, "print the info used by the benchmark script to cout");
DEFINE_bool(graphviz, false, "output graphviz format of the loaded MDP to cout");

int main(int argc, char* argv[]) {
  // Initialize logger (glog) and flags parser (gflags)
  google::InitGoogleLogging(argv[0]);
  gflags::SetUsageMessage(usageMsg);
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  // Speedup I/O (mainly the reading of MDP from cin)
  // Safe since we don't use C stdio functions
  ios_base::sync_with_stdio(false);
  cin.tie(nullptr);

  LOG(INFO) << "Creating the output and input stream";
  ifstream ifs;
  if(FLAGS_inputFile != "cin")
    ifs.open(FLAGS_inputFile);
  istream& is = ifs.is_open() ? ifs : cin;
  ofstream ofs;
  if(FLAGS_outputFile != "cout")
    ofs.open(FLAGS_outputFile);
  ostream& os = ofs.is_open() ? ofs : cout;
  os.precision(5);

  LOG(INFO) << "Reading the MDP from the input stream";
  uint32_t start_id = static_cast<uint32_t>(FLAGS_start);
  uint32_t goal_id = static_cast<uint32_t>(FLAGS_goal);
  TimeVar tStart = timeNow();
  MDP mdp(start_id, goal_id, FLAGS_upperbound);
  is >> mdp;
  const long time_load = duration(timeNow() - tStart);

  if(FLAGS_graphviz) {
    LOG(INFO) << "MDP will be printed in graphviz format and program will stop";
    mdp.dumpGraphvizFormat(os);
    return EXIT_SUCCESS;
  }

  LOG(INFO) << "Initializing V_0 with the specified heuristic";
  tStart = timeNow();
  mdp.computeHeuristic(FLAGS_heuristic);
  const long time_heuristic = duration(timeNow() - tStart);

  tStart = timeNow();
  if(FLAGS_reach) {
    LOG(INFO) << "Detecting which states are reachable from the start";
    mdp.markReachableStates(start_id);
  }
  const long time_reach = duration(timeNow() - tStart);

  LOG(INFO) << "Solving the MDP with the specified solver";
  tStart = timeNow();
  mdp.solve(FLAGS_solver);
  const long time_solve = duration(timeNow() - tStart);

  tStart = timeNow();
  if(FLAGS_policy) {
    mdp.findAndEvaluateGreedyPolicy();
    mdp.printPolicy(os);
  }
  const long time_policy = duration(timeNow() - tStart);

  LOG(INFO) << "Max memory used: " << getMaxMemoryUsed() << " MB";
  LOG(INFO) << "Time to load MDP from file: " << time_load << " ms";
  LOG(INFO) << "Time to compute the heuristic: " << time_heuristic << " ms";
  LOG(INFO) << "Time to mark reachable states (DFS): " << time_reach << " ms";
  LOG(INFO) << "Time to solve the MDP: " << time_solve << " ms";
  LOG(INFO) << "Time to find and evaluate the policy: " << time_policy << " ms";

  if(FLAGS_benchmark)
    os << time_solve << flush;
}
