(*  Title:      Pure/Syntax/ast.ML
    ID:         $Id: ast.ML,v 1.1 2004/02/13 11:29:19 7till Exp $
    Author:     Markus Wenzel, TU Muenchen
    Modified:   Burkhart Wolff, ALU Freiburg

Abstract syntax trees, translation rules, matching and normalization of asts.
*)

signature AST0 =
  sig
  type lex_class = string
  datatype ast =
    Constant of lex_class * string |    
    Variable of lex_class * string | 
    Appl of ast list   |
    Anno of int * ast list
  exception AST of string * ast list
  val gen_anno         : (int * int) -> string -> ast list -> ast
  val strip_anno       : ast -> ast
  val strip_annos      : ast -> ast
  val strip_anno_spine : ast -> (ast * ast list)
  val map_anno         : (ast -> ast) -> ast -> ast
  end;

signature AST1 =
  sig
  include AST0
  val mk_appl: ast -> ast list -> ast
  val str_of_ast: ast -> string
  val pretty_ast: ast -> Pretty.T
  val pretty_rule: ast * ast -> Pretty.T
  val pprint_ast: ast -> Pretty.pprint_args -> unit
  val trace_norm_ast: bool ref
  val stat_norm_ast: bool ref
  end;

signature AST =
  sig
  include AST1
  val head_of_rule: ast * ast -> string
  val rule_error: ast * ast -> string Library.option
  val fold_ast: string -> ast list -> ast
  val fold_ast_p: string -> ast list * ast -> ast
  val unfold_ast: string -> ast -> ast list
  val unfold_ast_p: string -> ast -> ast list * ast
  val normalize: bool -> bool -> (string -> (ast * ast) list) Library.option -> ast -> ast
  val normalize_ast: (string -> (ast * ast) list) Library.option -> ast -> ast
  end;

structure Ast : AST =
struct

open Library Term

(** abstract syntax trees **)

(*asts come in two flavours:
   - ordinary asts representing terms and typs: Variables are (often) treated
     like Constants;
   - patterns used as lhs and rhs in rules: Variables are placeholders for
     proper asts*)

type lex_class = string

datatype ast =
  Constant of lex_class * string |    (*"not", "_abs", "fun"*)
  Variable of lex_class * string |    (*x, ?x, 'a, ?'a*)
  Appl of ast list               |    (*(f x y z),("fun" 'a 'b),("_abs" x t)*)
  Anno of int * ast list              (*(_selection 2 [path,term])*)


(*the list of subasts of an Appl node has to contain at least 2 elements, i.e.
  there are no empty asts or nullary applications; use mk_appl for convenience*)

fun mk_appl f [] = f
  | mk_appl f args = Appl (f :: args);

(*basic convenience with annotiations... *)

fun gen_anno (k,no) a args =
    if length(args) <= no
    then Anno(k,(Constant ("_anno",a))::args)
    else Appl(Anno(k,(Constant ("_anno",a))::take(k,args)) :: drop (k,args))

(* strip outermost annotation *)
fun strip_anno (Anno (k, asts)) = (List.nth(asts,k))
   |strip_anno (ast) = ast

fun map_anno f (Anno (k, asts)) = 
    let val a::ra = drop(k,asts)
    in  Anno(k,take(k,asts) @((f a) :: ra)) end;

(* strip all annotations in ast *)
fun strip_annos (a as Anno _)    = strip_anno (a)
  | strip_annos (Appl((a as Anno _)::R)) =
        (case strip_anno (a) of
             Appl S => strip_annos(Appl(S@(map strip_annos R)))
            |t      => Appl(map strip_annos (t::R)))
             (* otherwise wrong structure ! *)
  | strip_annos (Appl asts)      = Appl(map strip_annos asts)
  | strip_annos (t)              = t

(* strip all annotations in the left-spine of ast *)
fun strip_anno_spine (a as Anno _) = strip_anno_spine(strip_anno (a))
  | strip_anno_spine (Appl((a as Anno _)::R)) =
        (case strip_anno (a) of
             Appl S => strip_anno_spine(Appl(S@R))
            |t      => strip_anno_spine(Appl(t::R)))
  | strip_anno_spine (Appl (a::R)) = (a,R)
  | strip_anno_spine a = (a,[])




(*exception for system errors involving asts*)

exception AST of string * ast list;



(** print asts in a LISP-like style **)

(* str_of_ast *)

fun str_of_ast (Constant (l,a)) = "+"^l^"+"^(quote a)
  | str_of_ast (Variable (l,x)) = "+"^l^"+"^x
  | str_of_ast (Appl asts) = "("^(space_implode" "(map str_of_ast asts))^")"
  | str_of_ast (Anno (k,asts)) = 
	let val pre  = take(k,asts);
	    val post = drop(k,asts)
        in  "<ANNO " ^ 
	     (space_implode " " (map str_of_ast pre)) ^ ">" ^
	     (str_of_ast (hd post)) ^ "</ANNO " ^
  	     (space_implode " " (map str_of_ast (tl post))) ^ ">"
        end;


(* pretty_ast *)

fun pretty_ast (Constant (l,a)) = (Pretty.str (quote ("+"^l^"+"^a)))
  | pretty_ast (Variable (l,x)) = (Pretty.str ("+"^l^"+"^x))
  | pretty_ast (Appl asts) =
      Pretty.enclose "(" ")" (Pretty.breaks (map pretty_ast asts))
  | pretty_ast (Anno (k,asts)) =
      Pretty.enclose ("<"^(Int.toString k)) ">" 
	(Pretty.breaks (map pretty_ast asts));


(* pprint_ast *)

val pprint_ast = Pretty.pprint o pretty_ast;


(* pretty_rule *)

fun pretty_rule (lhs, rhs) =
  Pretty.block [pretty_ast lhs, Pretty.str "  ->", Pretty.brk 2, pretty_ast rhs];


(* head_of_ast, head_of_rule *)

fun head_of_ast (Constant (_,a)) = Some a
  | head_of_ast (Appl (Constant (_,a) :: _)) = Some a
  | head_of_ast (Anno (k,Constant (_,a) :: _)) = Some a
  | head_of_ast _ = None;

fun head_of_rule (lhs, _) = the (head_of_ast lhs);



(** check translation rules **)

(*a wellformed rule (lhs, rhs): (ast * ast) obeys the following conditions:
   - the head of lhs is a constant,
   - the lhs has unique vars,
   - vars of rhs is subset of vars of lhs*)

fun rule_error (rule as (lhs, rhs)) =
  let
    fun vars_of (Constant _) = []
      | vars_of (Variable (_,x)) = [x]
      | vars_of (Appl asts) = flat (map vars_of asts) 
      | vars_of (Anno (k,asts)) = flat (map vars_of asts);

    fun unique (x :: xs) = not (x mem xs) andalso unique xs
      | unique [] = true;

    val lvars = vars_of lhs;
    val rvars = vars_of rhs;
  in
    if is_none (head_of_ast lhs) then Some "lhs has no constant head"
    else if not (unique lvars) then Some "duplicate vars in lhs"
    else if not (rvars subset lvars) then Some "rhs contains extra variables"
    else None
  end;



(** ast translation utilities **)

(* fold asts *)

fun fold_ast _ [] = raise Match
  | fold_ast _ [y] = y
  | fold_ast c (x :: xs) = Appl [Constant ("",c), x, fold_ast c xs];

fun fold_ast_p c = foldr (fn (x, xs) => Appl [Constant ("",c), x, xs]);


(* unfold asts *)

fun unfold_ast c (y as Appl [Constant (l,c'), x, xs]) =
      if c = c' then x :: (unfold_ast c xs) else [y]
  | unfold_ast c (y as Anno(_, [Constant (l,c'), x, xs])) =
      if c = c' then x :: (unfold_ast c xs) else [y]
  | unfold_ast _ y = [y];

fun unfold_ast_p c (y as Appl [Constant (l,c'), x, xs]) =
      if c = c' then apfst (cons x) (unfold_ast_p c xs)
      else ([], y)
  | unfold_ast_p c (y as Anno(_, [Constant (l,c'), x, xs])) =
      if c = c' then apfst (cons x) (unfold_ast_p c xs)
      else ([], y)
  | unfold_ast_p _ y = ([], y);


(** normalization of asts **)

(* tracing options *)

val trace_norm_ast = ref false;
val stat_norm_ast = ref false;


(* simple env *)

structure Env =
struct
  val empty = [];
  val add = op ::;
  fun get (alist,x) = the (assoc (alist,x));
end;


(* match *)

fun match ast pat =
  let
    exception NO_MATCH;

    val strip_anno_spine = (uncurry mk_appl) o strip_anno_spine

    fun mtch (Constant (_,a)) (Constant (_,b)) env =
          if a = b then env else raise NO_MATCH
      | mtch (Variable (_,a)) (Constant (_,b)) env =
          if a = b then env else raise NO_MATCH
      | mtch ast (Variable (_,x)) env = Env.add ((x, ast), env)
      | mtch (a as Appl((Anno _)::_)) b env = mtch (strip_annos a) b env
      | mtch a (b as Appl((Anno _)::_)) env = mtch a (strip_annos b) env
					(* should not occur, but anyway ... *)
      | mtch (Appl asts) (Appl pats) env = mtch_lst asts pats env
      | mtch (a as Anno (k, asts)) pat env = mtch (strip_anno_spine a) pat env
      | mtch ast (a as Anno (k, asts)) env = mtch ast (strip_anno_spine a) env
					(* should not occur, but anyway ... *)
      | mtch _ _ _ = raise NO_MATCH
    and mtch_lst (ast :: asts) (pat :: pats) env =
          mtch_lst asts pats (mtch ast pat env)
      | mtch_lst [] [] env = env
      | mtch_lst _ _ _ = raise NO_MATCH;

    val (head, args) =
      (case (ast, pat) of
        (Appl asts, Appl pats) =>
          let val a = length asts and p = length pats in
            if a > p then (Appl (take (p, asts)), drop (p, asts))
            else (ast, [])
          end
      | _ => (ast, []));
  in
    Some (mtch head pat Env.empty, args) handle NO_MATCH => None
  end;


(* normalize *)

(*the normalizer works yoyo-like: top-down, bottom-up, top-down, ...*)

fun normalize trace stat get_rules pre_ast =
  let
    val passes = ref 0;
    val lookups = ref 0;
    val failed_matches = ref 0;
    val changes = ref 0;
    val strip_anno_spine = (uncurry mk_appl) o strip_anno_spine

    fun subst _ (ast as Constant _) = ast
      | subst env (Variable (_,x)) = Env.get (env, x)
      | subst env (Appl asts) = Appl (map (subst env) asts);

    fun try_rules ast ((lhs, rhs) :: pats) =
          (case match ast lhs of
            Some (env, args) =>
              (inc changes; Some (mk_appl (subst env rhs) args))
          | None => (inc failed_matches; try_rules ast pats))
      | try_rules _ [] = None;

    fun try ast a = (inc lookups; try_rules ast (the get_rules a));

    fun rewrite (ast as Constant (_,a)) = try ast a 
					(* future: ast with lex_class *)
      | rewrite (ast as Variable (_,a)) = try ast a
      | rewrite (ast as Appl((Anno _)::_)) = rewrite (strip_anno_spine ast)
      | rewrite (ast as Appl (Constant (_,a) :: _)) = try ast a
      | rewrite (ast as Appl (Variable (_,a) :: _)) = try ast a
      | rewrite _ = None; (* Annotations can't be rewritten *)

    fun rewrote old_ast new_ast =
      if trace then
        writeln ("rewrote: " ^ str_of_ast old_ast ^ "  ->  " ^ str_of_ast new_ast)
      else ();

    fun norm_root ast =
      (case rewrite ast of
        Some new_ast => (rewrote ast new_ast; norm_root new_ast)
      | None => ast);

    fun norm ast =
      (case norm_root ast of
        Appl sub_asts =>
          let
            val old_changes = ! changes;
            val new_ast = Appl (map norm sub_asts);
          in
            if old_changes = ! changes then new_ast else norm_root new_ast
          end
      | Anno(k, sub_asts) =>
          let
            val old_changes = ! changes;
            val new_ast = Anno(k, (map norm sub_asts));
          in
            if old_changes = ! changes then new_ast else norm_root new_ast
          end
      | atomic_ast => atomic_ast);

    fun normal ast =
      let
        val old_changes = ! changes;
        val new_ast = norm ast;
      in
        inc passes;
        if old_changes = ! changes then new_ast else normal new_ast
      end;


    val _ = if trace then writeln ("pre: " ^ str_of_ast pre_ast) else ();

    val post_ast = if is_some get_rules then normal pre_ast else pre_ast;
  in
    if trace orelse stat then
      writeln ("post: " ^ str_of_ast post_ast ^ "\nnormalize: " ^
        string_of_int (! passes) ^ " passes, " ^
        string_of_int (! lookups) ^ " lookups, " ^
        string_of_int (! changes) ^ " changes, " ^
        string_of_int (! failed_matches) ^ " matches failed")
    else ();
    post_ast
  end;


(* normalize_ast *)

fun normalize_ast get_rules ast =
  normalize (! trace_norm_ast) (! stat_norm_ast) get_rules ast;

end;

