/**
 * This program generates chained MDPs, which are MDPs containing many
 * independant chains of SCCs (i.e., chains of SCCs that can be solved in parallel)
 *
 * A chained MDP is parameterized by five parameters:
 *   - The number 'n_c' of parallel chains of SCCs in the MDP.
 *   - The number 'n_scc' of SCCs in each chain.
 *   - The number 'n_sps' of states per SCC.
 *   - The number 'n_a' of actions per state.
 *   - The number 'n_e' of probabilistic effects per action.
 *
 * Successors of a state can only be in the same SCC or in the SCC "to the right" in the same chain.
 */

#include <iostream>
#include <iomanip>
#include <random>
#include <functional>
#include <set>
#include <vector>

using namespace std;

random_device rd;
mt19937 gen(rd());

struct MDPSpecs {
  size_t n_c; // number of chains
  size_t n_scc; // number of SCCs
  size_t n_sps; // number of states per SCC
  size_t n_a; // number of actions per state
  size_t n_e; // number of effects per action
};

/**
 * Generate k values that sum to 1
 */
vector<float> generate_probabilities(size_t k) {
  vector<float> numbers(k);

  // Generate k random floats
  auto distProb = bind(uniform_real_distribution<float>(0, 1), ref(gen));
  for(size_t i = 0; i < k; ++i)
    numbers[i] = distProb();

  // Normalize so the sum is 1
  const float sum = accumulate(begin(numbers), end(numbers), 0.0f);
  transform(begin(numbers), end(numbers), begin(numbers),
            [sum](const auto n) { return n / sum; });

  return numbers;
}

/**
 * Generate a chained MDP and dump a textual representation to stdout
 */
void generate_chained_mdp(const MDPSpecs& specs) {
  const size_t n_states_per_chain = specs.n_scc * specs.n_sps;
  const size_t n_states = specs.n_c * n_states_per_chain + 2;
  const size_t goal_id = n_states - 1;
  const float action_cost = 1.0f;

  cout << n_states << '\n';
  cout.precision(3);

  // Find the ids of the states in each Chains
  vector<size_t> startIdChains(specs.n_c + 1, 1);
  for(size_t c = 0; c < specs.n_c; ++c)
    startIdChains[c+1] = startIdChains[c] + n_states_per_chain;

  // First state
  cout << 0 << " " << specs.n_c << "\n";
  for(size_t c = 0; c < specs.n_c; ++c)
    cout << fixed << action_cost << " " << 1 << " " << startIdChains[c] << " " << 1.00 << "\n";

  // Generate chains
  vector<size_t> possibleNeighbors;
  for(size_t c = 0; c < specs.n_c; ++c) {
    const size_t startIdChain = startIdChains[c];

    // Find the ids of states in SCCs of the chain
    vector<size_t> startIdSCCs(specs.n_scc + 1, startIdChain);
    for(size_t s = 0; s < specs.n_scc; ++s)
      startIdSCCs[s+1] = startIdSCCs[s] + specs.n_sps;

    // Generate SCCs
    for(size_t s = 0; s < specs.n_scc; ++s) {
      const size_t startIdSCC = startIdSCCs[s];

      // Generate states
      for(size_t i = 0; i < specs.n_sps; ++i) {
        const size_t id = i + startIdSCC;
        cout << id << " " << specs.n_a << "\n";

        // Generate actions
        for(size_t a = 0; a < specs.n_a; ++a) {
          cout << fixed << action_cost << " " << specs.n_e;

          // Find possible neighbors
          possibleNeighbors.clear();
          for(size_t j = startIdSCC; j < startIdSCCs[s+1]; ++j)
            if(id != j) [[unlikely]] // prevent self-loop
              possibleNeighbors.push_back(j);
          if(s != specs.n_scc - 1) [[likely]] // if not last SCC, can go to next SCC
            for(size_t j = startIdSCCs[s+1]; j < startIdSCCs[s+2]; ++j)
              possibleNeighbors.push_back(j);
          else // last SCC of a chain can go to the goal
            possibleNeighbors.push_back(goal_id);

          auto distSucc =
            bind(uniform_int_distribution<size_t>(0, size(possibleNeighbors) - 1),
                 ref(gen));

          // Generate effects
          const vector<float> p = generate_probabilities(specs.n_e);
          for(size_t e = 0; e < specs.n_e; ++e) {
            const size_t neighbor = possibleNeighbors[distSucc()];
            cout << " " << neighbor << " " << fixed << p[e];
          } // effects

          cout << "\n";
        } // actions
      } // states
    } // SCCs
  } // chains

  // the last state is a goal state (no actions)
  cout << goal_id << " " << 0 << endl;
}

int main(int argc, char* argv[]) {
  if(argc != 6) {
    cerr << "Usage: " << argv[0]
         << " n_chains n_scc n_states/scc n_actions/state n_effects/action"
         << endl;
    return EXIT_FAILURE;
  }

  const MDPSpecs specs = {
    .n_c   = strtoul(argv[1], nullptr, 10),
    .n_scc = strtoul(argv[2], nullptr, 10),
    .n_sps = strtoul(argv[3], nullptr, 10),
    .n_a   = strtoul(argv[4], nullptr, 10),
    .n_e   = strtoul(argv[5], nullptr, 10)
  };

  generate_chained_mdp(specs);
}
