package generators;

import java.io.*;
import java.util.*;
import parsers.*;
import terms.*;
import util.*;

/** An implementation of ET0L tree grammars.
  */
public class ET0LTreeGrammar extends treeGrammar {

  regularTreeGrammar gram;
  tdTransducer trans;

//====================================================================
// First the 'interactive' part that deals with commands ...
//====================================================================

  private static String enumerate = "table enumeration";
  private static String advance = "advance";
  private static String single = "derive stepwise";
  private static String complete = "results only";
  private static String step = "derivation step";
  private static String random = "random tables";
  private static String refine = "refine";
  private static String back = "back";
  private static String reset = "reset";
  private static String newSeed = "new derivation";
  private static String[][] eCommands = {{ advance, reset }, {random}};
  private static String[][] rCommands = {{ refine, back, reset }, {enumerate}};
  private static String[][] sCommands = {{ step, back }, { complete }};
  private static String[][] cCommands = {{ single }};
  private static String[] tCommands = { newSeed };

  public list commands() {
    list result = super.commands();
    if (gram.isEnumerate) {
      for (int i = 0; i < eCommands.length; i++) result.append(eCommands[i]);
      if (gram.isStepwise) for (int i = sCommands.length; i > 0;) result.tail().prepend(sCommands[--i]);
      else for (int i = cCommands.length; i > 0;) result.tail().prepend(cCommands[--i]);
    }
    else  for (int i = 0; i < rCommands.length; i++) result.append(rCommands[i]);
    if (!trans.isDeterministic()) result.append(tCommands);
    return result;
  }
  
  public void execute(String com) {
    recompute = true;
    if (enumerate.equals(com)) gram.execute("enumeration");
    else if (advance.equals(com)) gram.execute("advance");
    else if (single.equals(com)) { gram.execute("derive stepwise"); gram.execute("single step"); }
    else if (complete.equals(com)) gram.execute("results only");
    else if (step.equals(com)) gram.execute("single step");
    else if (random.equals(com)) {
      gram.execute("random generation");
      gram.execute("back");
      checkAdvance();
    }
    else if (refine.equals(com)) gram.execute("refine");
    else if (back.equals(com)) {
      gram.execute("back");
      checkAdvance();
    }
    else if (reset.equals(com)) {
      gram.execute("reset");
      if (!gram.isEnumerate || gram.isStepwise) checkAdvance();
    }
    else if (newSeed.equals(com)) trans.execute("new random seed");
    else super.execute(com);
  }
  
  private void checkAdvance() {
    term t = gram.currentTerm();
    if (t != null && !t.topSymbol().equals(init)) {
      if (!gram.isEnumerate) gram.execute("refine");
      else gram.execute("single step");
    }
  }
  
  public boolean requestsExit(String com) {
    return com.equals(advance) || com.equals(refine) || com.equals(reset);
  }

//====================================================================
// Now the actual implementation of ET0L tree grammars ...
//====================================================================

  private static String grammarExtension = ".1";
  private static String transducerExtension = ".2";
  
  private finiteSignature outputSig;
  private finiteSignature tableSig = new finiteSignature();
  private fixedRankSignature nontermSig;
  private fixedRankSignature origNontermSig;
  private fixedRankSignature grammarNonterminals = new fixedRankSignature(0);
  private term axiom;
  private Vector tableNames;
  private Vector tables;
  private Vector regularRules = new Vector();
  private Vector chainRules = new Vector();
  private rexp regulation = null;
  private BitSet terminatingTable;
  private BitSet nonterminatingTable;
  private symbol bottom = new symbol("bot", 0);
  private symbol init = new symbol("init", 1);
  private symbol startState = new symbol("q0",1);

  private term current = null;
  private boolean recompute = true;
  
  public synchronized term currentTerm() {
    if (recompute) {
      try { current = trans.apply(gram.currentTerm()); }
      catch (ExitException e) { return current = null; }
      recompute = false;
    }
    return current;
  }
  
/** Initialize this <code>ET0LTreeGrammar</code> by reading its definition from a stream.
  * The syntax is defined by the class <code>ET0LTreeGrammarParser</code>.
  * Before the grammar can be used, <code>translate</code> must be called with a
  * filename (which is used to store the regular tree grammar and the top-down tree
  * transducer into which the grammar is translated).
  * @see ET0LTreeGrammarParser
  * @exception ParseException if an error occurs
  */
  public void parse(ASCII_CharStream stream) throws ParseException {
    ET0LTreeGrammarParser parser = new ET0LTreeGrammarParser(stream);
    parser.ET0LTreeGrammar();
    outputSig = parser.sig;
    nontermSig = new fixedRankSignature(1);
    origNontermSig = parser.nonterminals;
    Enumeration enm = origNontermSig.elements();
    while (enm.hasMoreElements()) {
      nontermSig.addSymbol(((symbol)enm.nextElement()).toString());
    }
    this.axiom = parser.axiom;
    tableNames = parser.tableNames;
    tables = parser.tables;
    regulation = parser.regulation;
    addImplicitRules();
    determineTermination();
    computeTableSig();
  }
  
  private void addImplicitRules() {
    for (int t=0; t < tables.size(); t++) {
      Vector collect = new Vector();
      Vector table = (Vector)tables.elementAt(t);
      Enumeration enm = nontermSig.elements();
      search:
      while (enm.hasMoreElements()) {
        symbol s = (symbol)enm.nextElement();
        for (int i=0; i < table.size(); i++) {
          if (s.equals(lhsSymbol(table,i))) continue search;
        }
        collect.addElement(s);
      }
      String tableName = (String)tableNames.elementAt(t);
      for (int i=0; i < collect.size(); i++) {
        Object[] rule = new Object[3];
        table.addElement(rule);
        symbol sym = new variable(1);
        term aux = new term(sym);
        sym = new symbol(tableName, 1);
        term aux2 = new term(sym);
        aux2.defineSubterm(0, aux);
        aux = aux2;
        aux2 = new term((symbol)collect.elementAt(i));
        aux2.defineSubterm(0, aux);
        rule[0] = aux2; // left-hand side
        sym = new variable(1);
        aux = new term(sym);
        aux2 = new term((symbol)collect.elementAt(i));
        aux2.defineSubterm(0, aux);
        rule[1] = aux2; // right-hand side
        rule[2] = new Double(1); // weight
      }
    }
  }
  
  private void determineTermination() {
    terminatingTable = new BitSet(tables.size());
    nonterminatingTable = new BitSet(tables.size());
    for (int t=0; t < tables.size(); t++) {
      boolean termFailed = false;
      boolean nontermFound = false;
      Vector table = (Vector)tables.elementAt(t);
      Enumeration enm = nontermSig.elements();
      while (enm.hasMoreElements()) {
        symbol s = (symbol)enm.nextElement();
        boolean termFound = false;
        for (int i=0; i < table.size(); i++) {
          if (s.equals(lhsSymbol(table,i))) {
            if (isTerminal(rhsTerm(table,i))) termFound = true;
            if (isNonterminal(rhsTerm(table,i))) nontermFound = true;
          }
        }
        if (!termFound) termFailed = true;
      }
      if (!termFailed) terminatingTable.set(t);
      if (nontermFound) nonterminatingTable.set(t);
    }
  }
  
  private symbol lhsSymbol(Vector table, int ruleNo) {
    Object[] rule = (Object[])table.elementAt(ruleNo);
    term t = (term)rule[0];
    return t.topSymbol();
  }
  
  private term lhsTerm(Vector table, int ruleNo) {
    Object[] rule = (Object[])table.elementAt(ruleNo);
    return (term)rule[0];
  }
  
  private term rhsTerm(Vector table, int ruleNo) {
    Object[] rule = (Object[])table.elementAt(ruleNo);
    return (term)rule[1];
  }
  
  private double weight(Vector table, int ruleNo) {
    Object[] rule = (Object[])table.elementAt(ruleNo);
    return ((Double)rule[2]).doubleValue();
  }
  
  private boolean isTerminal(term t) {
    symbol top = t.topSymbol();
    if (nontermSig.contains(top)) return outputSig.contains(new symbol(top.toString(),0));
    for (int i=0; i < top.rank(); i++) {
      if (!isTerminal(t.subterm(i))) return false;
    }
    return true;
  }
  
  private boolean isNonterminal(term t) {
    symbol top = t.topSymbol();
    if (nontermSig.contains(top)) return true;
    for (int i=0; i < top.rank(); i++) {
      if (isNonterminal(t.subterm(i))) return true;
    }
    return false;
  }
  
  private void computeTableSig() {
    tableSig.addSymbol(init);
    for (int i=0; i < tables.size(); i++) {
      String tableName = (String)tableNames.elementAt(i);
       tableSig.addSymbol(new symbol(tableName,1));
    }
    tableSig.addSymbol(bottom);
  }
  
  public void translate(String fileName) throws ParseException {
    File file1 = new File(fileName + grammarExtension);
    File file2 = new File(fileName + transducerExtension);
    boolean generated = false;
    try {
      if (regulation == null) generateDefaultGrammar(openOutFile(file1));
      else generateGrammar(openOutFile(file1), regulation);
      generateTransducer(openOutFile(file2));
      generated = true;
    }
    catch (IOException e) {
      if (!(file1.exists() && file2.exists()))
      {
        throw new ParseException("Could not write to output file:\n" + e.getMessage());
      }
    }
    catch (java.security.AccessControlException e) {
      if (!(file1.exists() && file2.exists()))
      {
        throw new ParseException("Could not write to output file:\n" + e.getMessage());
      }
    }
    try {
      objectParser parser = new objectParser(new ASCII_CharStream(new FileInputStream(file1),1,1));
      gram = (regularTreeGrammar)parser.parse();
      parser = new objectParser(new ASCII_CharStream(new FileInputStream(file2),1,1));
      trans = (tdTransducer)parser.parse();
    }
    catch (ParseException e) {
      if (generated) {
        throw new ParseException("Could not parse generated file:\n" +
          e.getMessage() +
          "\n[This should not happen. There must be one of those " +
          "small insects in the implementation 8-( ]");
      }
      else {
        throw new ParseException("Could not parse previously existing file:\n" +
          e.getMessage() +
          "\nPlease make sure that the file is writable to enable " +
          "generation of new file");
      }
    }
    catch (FileNotFoundException e) {
      throw new ParseException("Could not open file:\n" + e.getMessage());
    }
  }
  
  private FileWriter openOutFile(File file) throws IOException {
    if (file.exists()) file.delete();
    FileWriter result = new FileWriter(file);
    return result;
  }
  
  private symbol grammarSymbol(int n) {
    return new symbol("S" + n, 0);
  }
  
  private void generateDefaultGrammar(FileWriter out) throws IOException {
    grammarNonterminals.addSymbol("S");
    grammarNonterminals.addSymbol("S0");
    out.write("generators.stubbornRegularTreeGrammar:\n");
    out.write("  (\n");
    out.write("    " + grammarNonterminals + ",\n");
    out.write("    " + tableSig + ",\n");
    out.write("    {\n");
    tableEnumerationRules(out);
    out.write("    },\n");
    out.write("    S\n");
    out.write("  )");
    out.close();
  }
  
  private void tableEnumerationRules(FileWriter out) throws IOException {
    if (isTerminal(axiom)) out.write("      S -> " + init + "[" + bottom + "] weight 0,\n");
    out.write("      S -> " + init + "[" + grammarSymbol(0) + "],\n");
    for (int t=0; t < tables.size(); t++) {
      if (terminatingTable.get(t)) {
        out.write("      " + grammarSymbol(0) + " -> " +
                  tableNames.elementAt(t) + "[" + bottom + "] weight 0");
        if (nonterminatingTable.get(t)) out.write(",\n");
      }
      if (nonterminatingTable.get(t))
        out.write("      " + grammarSymbol(0) + " -> " +
                  tableNames.elementAt(t) + "[" + grammarSymbol(0) + "]");
      if (t+1 < tables.size()) out.write(",\n");
      else out.write("\n");
    }
  }
  
  private void generateGrammar(FileWriter out, rexp r) throws IOException {
    int max = generateRules(r, 0);
    for (int i = chainRules.size(); i <= max; i++) chainRules.addElement(new Vector());
    for (int i = regularRules.size(); i <= max; i++) regularRules.addElement(new Vector());
    addRegularRule(max, bottom + " weight 0");
    applyClosure();
    grammarNonterminals.addSymbol("S");
    for (int i = 0; i <= max; i++) grammarNonterminals.addSymbol("S" + i);
    max = removeUnreachable(max);
    out.write("generators.stubbornRegularTreeGrammar:\n");
    out.write("  (\n");
    out.write("    " + grammarNonterminals + ",\n");
    out.write("    " + tableSig + ",\n");
    out.write("    {\n");
    out.write("      S -> " + init + "[" + grammarSymbol(0) + "],\n");
    for (int i = 0; i < regularRules.size(); i++) {
      Vector v = (Vector)regularRules.elementAt(i);
      for (int j = 0; j < v.size(); j++) {
        out.write("      S" + i + " -> " + v.elementAt(j));
        if (i < max || j+1 < v.size()) out.write(",\n"); else out.write("\n");
      }
    }
    out.write("    },\n");
    out.write("    S\n");
    out.write("  )");
    out.close();
  }
  
  boolean loopFlag = false;
      
  private int generateRules(rexp r, int n) {
    switch (r.type) {
      case rexp.CONC:
        boolean oldLoopFlag = loopFlag;
        loopFlag = loopFlag || r.subexp[1].containsLoop();
        int m = generateRules(r.subexp[0], n);
        loopFlag = oldLoopFlag;
        return generateRules(r.subexp[1], m);
      case rexp.UNION:
        oldLoopFlag = loopFlag;
        int m1 = generateRules(r.subexp[0], n+1);
        boolean newLoopFlag = loopFlag;
        loopFlag = oldLoopFlag;
        int m2 = generateRules(r.subexp[1], m1 + 1);
        loopFlag = loopFlag || newLoopFlag;
        addChainRule(n, n+1);
        addChainRule(n, m1+1);
        addChainRule(m1, m2);
        return m2;
      case rexp.PLUS:
        loopFlag = true;
        m = generateRules(r.subexp[0], n);
        addChainRule(m, m + 1);
        addChainRule(m, n);
        return m+1;
      case rexp.STAR:
        loopFlag = true;
      case rexp.OPT:
        m = generateRules(r.subexp[0], n+1);
        addChainRule(n, m + 1);
        addChainRule(m, m + 1);
        addChainRule(n, n + 1);
        if (r.type == rexp.STAR) addChainRule(m, n + 1);
        return m+1;        
      default:
        if (loopFlag) addRegularRule(n, "t" + r.type + "[S" + (n+1) + "]");
        else addRegularRule(n, "t" + r.type + "[S" + (n+1) + "] weight 0");
        return n+1;
    }
  }
  
  private void addChainRule(int lhs, int rhs) {
    for (int i = chainRules.size(); i <= lhs; i++) chainRules.addElement(new Vector());
    ((Vector)chainRules.elementAt(lhs)).addElement(new Integer(rhs));
  }
  
  private void addRegularRule(int lhs, String rhs) {
    for (int i = regularRules.size(); i <= lhs; i++) regularRules.addElement(new Vector());
    ((Vector)regularRules.elementAt(lhs)).addElement(rhs);
  }
  
  private void applyClosure() {
    closure();
    Vector result = new Vector(regularRules.size());
    for (int i = 0; i < chainRules.size(); i++) {
      Vector v = new Vector();
      result.addElement(v);
      Vector cr = (Vector)chainRules.elementAt(i);
      for (int j = 0; j < cr.size(); j++) {
        int k = ((Integer)cr.elementAt(j)).intValue();
        v.addAll((Vector)regularRules.elementAt(k));
      }
    }
    regularRules = result;
  }
  
  private void closure() {
    for (int i = 0; i < chainRules.size(); i++) {
      Vector cr = (Vector)chainRules.elementAt(i);
      cr.addElement(new Integer(i));
    }
    boolean changed;
    do {
      changed = false;
      for (int i = 0; i < chainRules.size(); i++) {
        Vector cr = (Vector)chainRules.elementAt(i);
        for (int j = 0; j < cr.size(); j++) {
          int k = ((Integer)cr.elementAt(j)).intValue();
          Vector cr2 = (Vector)chainRules.elementAt(k);
          for (int l = 0; l < cr2.size(); l++) {
            if (!cr.contains(cr2.elementAt(l))) {
              cr.addElement(cr2.elementAt(l));
              changed = true;
            }
          }
        }
      }
    } while (changed);
  }
  
  private int removeUnreachable(int max) {
    Vector rhsNonterminals = new Vector(max+1);
    for (int i = 0; i <= max; i++) {
      Vector v = (Vector)regularRules.elementAt(i);
      Vector w = new Vector();
      rhsNonterminals.add(w);
      for (int j = 0; j < v.size(); j++) {
        String rhs = (String)v.elementAt(j);
        int k = rhs.indexOf('S');
        if (k >= 0) {
          try {
            w.add(new Integer(rhs.substring(k+1, rhs.indexOf(']'))));
          } catch (NumberFormatException e) { throw new InternalError(); }
        }
      }
    }
    BitSet reachable = new BitSet(max+1);
    reachable.set(0);
    boolean changed = true;
    while (changed) {
      changed = false;
      for (int i = 0; i <= max; i++) {
        if (reachable.get(i)) {
          Vector v = (Vector)rhsNonterminals.elementAt(i);
          for (int j = 0; j < v.size(); j++) {
            int index = ((Integer)v.elementAt(j)).intValue();
            if (!reachable.get(index)) {
              reachable.set(index);
              changed = true;
            }
          }
        }
      }
    }
    for (int i = max; i > 0; i--) {
      if (!reachable.get(i)) {
        if (i == max) max--;
        grammarNonterminals.removeSymbol(new symbol("S" + i, 0));
        regularRules.setElementAt(new Vector(0), i);
      }
    }
    return max;
  }
  
  private void generateTransducer(FileWriter out) throws IOException {
    out.write("generators.tdTransducer:\n");
    out.write("  (\n");
    finiteSignature extendedTableSig = new finiteSignature();
    extendedTableSig.unionWith(tableSig);
    extendedTableSig.unionWith(grammarNonterminals);
    out.write("    " + extendedTableSig + ",\n");
    finiteSignature extendedOutputSig = new finiteSignature();
    extendedOutputSig.unionWith(origNontermSig);
    extendedOutputSig.unionWith(outputSig);
    out.write("    " + extendedOutputSig + ",\n");
    fixedRankSignature extendedNontermSig = new fixedRankSignature(1);
    extendedNontermSig.unionWith(nontermSig);
    extendedNontermSig.addSymbol(startState);
    out.write("    " + extendedNontermSig + ",\n");
    out.write("    {\n");
    tableImplementationRules(out);
    out.write("    },\n");
    out.write("    " + nameParser.unparse(startState.toString()) + "\n");
    out.write("  )");
    out.close();
  }
  
  private void tableImplementationRules(FileWriter out) throws IOException {
    Vector rules = new Vector();
    rules.addElement(nameParser.unparse(startState.toString()) +
                         "[" + init + "[x1]] -> " + termParser.unparse(axiom));
    for (int t=0; t < tables.size(); t++) {
      Vector table = (Vector)tables.elementAt(t);
      for (int i=0; i < table.size(); i++) {
        term left = lhsTerm(table, i);
        term right = rhsTerm(table, i);
        double weight = weight(table, i);
        if (weight == 1) rules.addElement(
          termParser.unparse(left) + " -> " + termParser.unparse(right));
        else rules.addElement(
          termParser.unparse(left) + " -> " + termParser.unparse(right) + " weight " + weight);
      }
    }
    Enumeration enum1 = nontermSig.elements();
    while (enum1.hasMoreElements()) {
      String s1 = enum1.nextElement().toString();
      if (outputSig.contains(new symbol(s1,0))) {
        rules.addElement(nameParser.unparse(s1) + "[" + bottom + "] -> " + nameParser.unparse(s1));
      }
      Enumeration enum2 = grammarNonterminals.elements();
      while (enum2.hasMoreElements()) {
        String s2 = enum2.nextElement().toString();
        rules.addElement(nameParser.unparse(s1) + "[" + s2 + "] -> " + nameParser.unparse(s1));
      }
    }
    for (int i=0; i < rules.size(); i++) {
      out.write("      " + (String)rules.elementAt(i));
      if (i+1 < rules.size()) out.write(", \n");
      else out.write("\n");
    }
  }
  
}

