package generators;

import terms.*;
import parsers.*;
import util.*;

/** A "grammar" that enumerates all terms of a given, sorted signature. */
public class treeEnumerator extends treeGrammar {

//====================================================================
// First the 'interactive' part that deals with commands ...
//====================================================================
  protected final static String advance = "advance";
  protected static String enumerate = "enumeration";
  protected static String generate = "random generation";
  protected static String refine = "refine";
  protected static String back = "back";
  protected static String reset = "reset";
  private static String[][] eCommands = {{ advance, reset }, {generate}};
  private static String[][] gCommands = {{ refine, back, reset }, {enumerate}};
  public boolean isEnumerate = true;
  
  public list commands() {
    list result = new list();
    if (isEnumerate) for (int i = 0; i < eCommands.length; i++) result.append(eCommands[i]);
    else  for (int i = 0; i < gCommands.length; i++) result.append(gCommands[i]);
    return result;
  }
  
  public void execute(String command) {
    if (advance.equals(command) && isEnumerate) advance(0);
    else if (reset.equals(command)) reset();
    else if (refine.equals(command) && !isEnumerate) refine();
    else if (back.equals(command) && !isEnumerate) back();
    else if (enumerate.equals(command) && !isEnumerate) {
      isEnumerate = true;
      if (currTerm != null) {
        noMoreTerms = false;
        if (randomTermination(currTerm)) currDepth = depth(currTerm);
        else {
          currDepth = minSortDepth(currTerm);
          if (currDepth==0) advance(0); else advance(1);
        }
      }
      else noMoreTerms = true;
    }
    else if (generate.equals(command) && isEnumerate) {
      isEnumerate = false;
      if (noMoreTerms) reset();
      else {
        removeLeaves(currTerm, currDepth);
        currDepth = 0;
      }
    }
    computeResultingTerm();
  }
  
  public boolean requestsExit(String com) {
    return com == reset;
  }
  
  private term currTerm;
  private term resultingTerm;
  private sortManager sm = new sortManager();
  private int currDepth = 0; // The current depth we are at in the breadth-first search tree.
  private boolean noMoreTerms = false;
  private int mainSort;
  
/** The trees of this sort are those to be enumerated. */ 
  public void setMainSort(String sort) { mainSort = sm.findSort(sort); }
  
  public void initMinDepths() {
    sm.initMinDepths();
  }
  
  public void addSymbol(String name, String resultSort, String[] argSort, double weight) {
    sm.addSymbol(name, resultSort, argSort, weight);
  }
  
  public term currentTerm() {
    return resultingTerm;
  }
  
  private void computeResultingTerm() {
    if (currTerm == null) resultingTerm = null;
    else if (isEnumerate) {
       resultingTerm = noMoreTerms ? null : computeResultingTerm(currTerm, currDepth);
     }
    else resultingTerm = (term)currTerm.clone();
  }

// Copy the argument term up to a given depth (or, if depth < 1, up to the leaves).
  private term computeResultingTerm(term t, int depth) {
    term result = new term(t.topSymbol());
    symbol top = t.topSymbol();
    int rank = top.rank();
    while (rank > 0) {
      rank--;
      if (depth == 1) {
        result.defineSubterm(rank, new term(sm.sortSymbol(sm.sortIndex(top, rank))));
      }
      else {
        result.defineSubterm(rank, computeResultingTerm(t.subterm(rank), depth - 1));
      }
    }
    return result;
  }

  private void reset() {
    noMoreTerms = false;
    currDepth = 0;
    if (isEnumerate) advance(0);
    else currTerm = new term(sm.sortSymbol(mainSort));
  }

  private void advance(int missingLevels) {
    int depthIncreased = 0;
    boolean branchCompleted = false;
    list bottomLine;
    while (!noMoreTerms) {
      allowExit();
      if (currDepth == 0) { // We are at the root of the search tree. If there is no symbol
                                      // of the sort mainSort, or if the last successful search was
                                      // more than the total number of sorts above the depth to be
                                      // reached now, there are no more terms left.
        if (depthIncreased > sm.numberOfSorts()) noMoreTerms = true;
        else { // Go to the next deeper level of the search tree ...
          symbol sym = sm.firstSymbol(mainSort);
          while (sym != null) {
            if (missingLevels > 0 || sym.rank() == 0) {
              currTerm = new term(sym);
              break;
            }
            sym = sm.nextSymbol(sym);
          }
          depthIncreased++;
          if (sym == null) missingLevels++;
          else {
            currDepth = 1;
            branchCompleted = false;
          }
        }
      }
      else { // If we are not at the root of the search tree we first have to find the sequence
                // of symbols at maximum depth
        bottomLine = getBottomLine(currTerm, currDepth - 1);
        if (missingLevels == 0) { // We are at the depth to be reached. Now, try to move to the
                                               // next sibling of this node of the search tree. If there is none,
                                               // this branch is completed and we have to move to the parent
                                               // node.
          if (!tryAdvance(bottomLine, true)) {
            branchCompleted = true;
            missingLevels = 1;
            currDepth--;
            depthIncreased++;
          }
        }
        else if (branchCompleted) { // We have not yet reached the necessary depth. As above,
                                                    // try to move to the next sibling. If it exists, memorize that
                                                    // this branch has not yet been completed. Otherwise, step
                                                    // back to the parent node.
          if (tryAdvance(bottomLine, false)) branchCompleted = false;
          else { missingLevels++; currDepth--; }
        }
        else if (tryDownwards(bottomLine, missingLevels == 1)) {
                                 // If the branch has not yet been completed and we have not yet
                                 // reached the required depth, try to move to the first child of the
                                 // current  node. If this child does not exist, note that the branch
                                 // has been completed.
          missingLevels--; currDepth++;
        }
        else branchCompleted = true;
      }
      if (missingLevels == 0) return; // Return if the required depth has been reached.
    }
  }

// Determine the list of symbols at a given depth.
  private list getBottomLine(term t, int atDepth) {
    list result = new list();
    if (atDepth == 0) {
      result.append(t);
    }
    else {
      int rank = t.topSymbol().rank();
      for (int i = 0; i < rank; i++) result.concat(getBottomLine(t.subterm(i), atDepth - 1));
    }
    return result;
  }

// This is implemented by "counting": Starting at the head of the list, a symbol which is
// not the largest of its sort is searched for. All symbols on the way are replaced by the
// minimal one of their sort, the symbol found is replaced by the next larger, and the rest
// remains untouched. If the search fails (i.e., the end of the list is reached before finding
// a symbol which is not maximal), false is returned.
  private boolean tryAdvance(list bottomLine, boolean constantsOnly) {
    while (!bottomLine.isEmpty()) {
      allowExit();
      term t = (term)bottomLine.head();
      symbol symb = sm.nextSymbol(t.topSymbol());
      while (symb != null) {
        if (!constantsOnly || symb.rank() == 0) {
          t.relabel(symb);
          return true;
        }
        symb = sm.nextSymbol(symb);
      }
      symb = sm.firstSymbol(sm.sortIndex(t.topSymbol()));
      while (symb != null) {
        if (!constantsOnly || symb.rank() == 0) {
          t.relabel(symb);
          bottomLine = bottomLine.tail();
          break;
        }
        symb = sm.nextSymbol(symb);
      }
      if (symb == null) return false;
    }
    return false;
  }

  private boolean tryDownwards(list bottomLine, boolean constantsOnly) {
    boolean depthReached = false;
    while (!bottomLine.isEmpty()) {
      allowExit();
      term t = (term)bottomLine.head();
      symbol top = t.topSymbol();
      if (top.rank() > 0) {
        for (int i = 0; i < top.rank(); i++) {
          symbol symb = sm.firstSymbol(sm.sortIndex(top,i));
          while (symb != null) {
            if (!constantsOnly || symb.rank() == 0) {
              t.defineSubterm(i, new term(symb));
              break;
            }
            symb = sm.nextSymbol(symb);
          }
          if (symb == null) return false;
        }
        depthReached = true;
      }
      bottomLine = bottomLine.tail();
    }
    return depthReached;
  }
  
//====================================================================
// Now the random derivations ...
//====================================================================

  private int depth(term t) {
    if (t == null) return 0;
    symbol top = t.topSymbol();
    if (sm.isSort(top)) return 0;
    int result = 0;
    for (int i = 0; i < top.rank(); i++) result = Math.max(result, depth(t.subterm(i)));
    return result + 1;
  }
  
  private int minSortDepth(term t) {
    if (t == null) return 0;
    symbol top = t.topSymbol();
    if (sm.isSort(top)) return 0;
    int result = -2;
    for (int i = 0; i < top.rank(); i++) {
      int d = minSortDepth(t.subterm(i));
      if (d >= 0) {
        if (result >= 0) result = Math.min(result, d);
        else result = d;
      }
    }
    return result + 1;
  }
  
  private void refine() {
    if (currTerm != null && !refine(currTerm)) currTerm = null;
  }
  
  private boolean refine(term t) {
    allowExit();
    symbol top = t.topSymbol();
    if (sm.isSort(top)) {
      symbol symb = sm.firstSymbol(sm.sortIndex(top));
      double weightSum = 0;
      double pseudoWeightSum = 0;
      boolean success = false;
      while (symb != null) {
        double weight = sm.getWeight(symb);
        if (weight > 0) {
          weightSum += weight;
          if (Math.random() <= weight/weightSum) {
            t.relabel(symb);
            success = true;
          }
        }
        else if (weightSum == 0) {
          pseudoWeightSum += 1;
          if (Math.random() <= 1.0/pseudoWeightSum) {
            t.relabel(symb);
            success = true;
          }
        }
        symb = sm.nextSymbol(symb);
      }
      if (success) {
        top = t.topSymbol();
        for (int i  = 0; i < top.rank(); i++) {
          t.defineSubterm(i, new term(sm.sortSymbol(sm.sortIndex(top, i))));
        }
      }
      else return false;
    }
    else for (int i = 0; i < top.rank(); i++) if (!refine(t.subterm(i))) return false;
    return true;
  }
  
  private boolean randomTermination(term t) {
    symbol top = t.topSymbol();
    if (sm.isSort(top)) {
      if (sm.getMinDepth(top) != Integer.MAX_VALUE) terminate(t);
      else return false;
      return true;
    } else for (int i = 0; i < top.rank(); i++) if (!randomTermination(t.subterm(i))) return false;
    return true;
  }
  
  private void terminate(term t) {
    symbol top = t.topSymbol();
    symbol symb = sm.firstSymbol(sm.sortIndex(top));
    double weightSum = 0;
    while (symb != null) {
      double weight = sm.getWeight(symb);
      if ((weight > 0 || weightSum == 0) && sm.getMinDepth(symb) == sm.getMinDepth(top)) {
        weightSum += weight;
        if (weightSum == 0 || Math.random() <= weight/weightSum) t.relabel(symb);
      }
      symb = sm.nextSymbol(symb);
    }
    top = t.topSymbol();
    for (int i  = 0; i < top.rank(); i++) {
      t.defineSubterm(i, new term(sm.sortSymbol(sm.sortIndex(top, i))));
      terminate(t.subterm(i));
    }
  }
  
  private void back() {
    int depth = depth(currTerm);
    if (depth > 0) removeLeaves(currTerm, depth);
  }
  
  private void removeLeaves(term t, int depth) {
    symbol top = t.topSymbol();
    if (depth == 1) t.relabel(sm.sortSymbol(sm.sortIndex(top)));
    else for (int i = 0; i < top.rank(); i++) removeLeaves(t.subterm(i), depth - 1);
  }
      

//====================================================================
// and finally parsing ...
//====================================================================

/** Parse a definition of a treeEnumerator (i.e., of the corresponding sorted
  * signature).
  * For the syntax see <code>treeEnumeratorParser</code>.
  * @exception ParseException if an error occurs
  * @see treeEnumeratorParser
  * @see parsable
  */
  public void parse(ASCII_CharStream stream) throws ParseException {
    treeEnumeratorParser parser = new treeEnumeratorParser(stream);
    treeEnumerator enm = parser.treeEnumerator();
    sm = enm.sm;
    mainSort = enm.mainSort;
    advance(0);
    computeResultingTerm();
  }
  
}

