(* *********************************************************************** *)
(*									   *)
(* Project: CATS  							   *)
(* Author: Bartek Klin, Warsaw/Arhus			                   *)
(* Date: 2000				 			           *)
(* Purpose of this file: Symbols maps and signature morphisms		   *)
(*			 						   *)	
(*									   *)
(* *********************************************************************** *) 	

(* This file provides a bunch of functions acting on symbol maps and
   signature morphisms.
   The background can be found in the semantics (CoFI study note S-9),
   chapter 5, and in study note S-10. Page numbers below refer to S-9,
   version July 1999 (this will have to be adpated to the latest version
   once it appers).
*)
   
structure Morphisms =
struct

exception INTERNAL_ERROR;
exception NO_MORPHISM_EXTENSION of string;

local open AS Utils Finset LocalEnv Symbols ArchTypes in
infix mem;

(* sub_sig -> symbol list *)
(* val sig_symbols = Stat_symmaps.sig_symbols; *)

(* CaslEnv.local_env list -> CaslEnv.local_env *)
fun merge_lenvlist([]) = LocalEnv.empty_local_env
  | merge_lenvlist([lenv]) = lenv
  | merge_lenvlist(lenv1::(lenv2::lenvs)) = 
      merge_lenvlist(merge_lenvs(lenv1,lenv2)::lenvs)
      

fun sort_via_morphism (sm,_,_) s = 
  case Symtab_id.lookup(sm,s) of
    None => s
  | Some(s1) => s1
;

fun fun_via_morphism (_,fm,_) (on,(oss,os),total) = 
  case Symtab_id.lookup(fm,on) of
    None => on
  | Some(vall) => foldl 
      (fn (onri,(ot1,on1,_)) => 
        if (if total then total_op_type(sorts(oss),os)
            else partial_op_type(sorts(oss),os))=ot1 then on1 else onri)
      (on,vall)
;

(* Argument is always partial *)
fun totality_via_morphism (_,fm,_) (on,(oss,os)) = 
  case Symtab_id.lookup(fm,on) of
    None => false
  | Some(vall) => exists'
      (fn (ot1,_,tot) => partial_op_type(sorts(oss),os)=ot1 andalso tot)
      vall
;

fun pred_via_morphism (_,_,pm) (pn,pt) =
  case Symtab_id.lookup(pm,pn) of
    None => pn
  | Some(vall) => foldl (fn (pnri,(pt1,pn1)) => 
      if pred_type(sorts(pt))=pt1 then pn1 else pnri)
                        (pn,vall)
;

fun sorts_via_morphism mor sl = map (sort_via_morphism mor) sl;

fun fprof_via_morphism mor (sl,s) = 
  (sorts_via_morphism mor sl,sort_via_morphism mor s);

fun symbol_via_morphism(mor,sym) =
  case sym of
    SORT_SYMBOL(s) => SORT_SYMBOL(sort_via_morphism mor s)
  | TOTAL_FUN_SYMBOL(on,fprof) =>
      TOTAL_FUN_SYMBOL(fun_via_morphism mor (on,fprof,true),
                       fprof_via_morphism mor fprof)
  | PARTIAL_FUN_SYMBOL(on,fprof) =>
      let val totality = totality_via_morphism mor (on,fprof) in
        if totality then
          TOTAL_FUN_SYMBOL(fun_via_morphism mor (on,fprof,false),
                           fprof_via_morphism mor fprof)
        else
          PARTIAL_FUN_SYMBOL(fun_via_morphism mor (on,fprof,false),
                             fprof_via_morphism mor fprof)
      end
  | PRED_SYMBOL(pn,pprof) =>
      PRED_SYMBOL(pred_via_morphism mor (pn,pprof),
                  sorts_via_morphism mor pprof)
;

(* symbol list * CaslEnv.morphism -> symbol list *)
fun symbols_via_morphism(syml,mor) = 
  map (fn sym => symbol_via_morphism(mor,sym)) syml
;


(* sub_sig * CaslEnv_morphism -> symbol_map *)
fun morphism_symmap(srcsig,mor) = 
  map (fn sym => (sym,symbol_via_morphism(mor,sym)))
      (sig_symbols(srcsig))
;

fun morphism_from_symmap(smap) = 
  foldl (fn ((smapi,fmapi,pmapi),(sym1,sym2)) => case sym1 of
      SORT_SYMBOL(s1) => (case sym2 of
          SORT_SYMBOL(s2) => 
            (Symtab_id.update_new((s1,s2),smapi),fmapi,pmapi)
        | _ => raise INTERNAL_ERROR)
    | TOTAL_FUN_SYMBOL(on1,(ss,s)) => (case sym2 of
          TOTAL_FUN_SYMBOL(on2,_) => (smapi,Symtab_id.update(
               (on1,(total_op_type(sorts(ss),s),on2,true)::
                  (case Symtab_id.lookup(fmapi,on1) of
                     None => [] | Some(x) => x)),fmapi),pmapi)
        | _ => raise INTERNAL_ERROR)
    | PARTIAL_FUN_SYMBOL(on1,(ss,s)) => (case sym2 of
          PARTIAL_FUN_SYMBOL(on2,_) => (smapi,Symtab_id.update(
               (on1,(partial_op_type(sorts(ss),s),on2,false)::
                  (case Symtab_id.lookup(fmapi,on1) of
                     None => [] | Some(x) => x)),fmapi),pmapi)
        | TOTAL_FUN_SYMBOL(on2,_) => (smapi,Symtab_id.update(
               (on1,(partial_op_type(sorts(ss),s),on2,true)::
                  (case Symtab_id.lookup(fmapi,on1) of
                     None => [] | Some(x) => x)),fmapi),pmapi)
        | _ => raise INTERNAL_ERROR)
    | PRED_SYMBOL(pn1,ss) => (case sym2 of
          PRED_SYMBOL(pn2,_) => (smapi,fmapi,Symtab_id.update(
               (pn1,(pred_type(sorts(ss)),pn2)::
                  (case Symtab_id.lookup(pmapi,pn1) of
                     None => [] | Some(x) => x)),pmapi))
        | _ => raise INTERNAL_ERROR))
   ((Symtab_id.empty,Symtab_id.empty,Symtab_id.empty),smap)
;

(* local_env * morphism * morphism -> morphism *)
fun compose_morphisms(srcsig1,mor1,mor2) = 
  let val smap = map (fn sym1 => 
      (sym1,symbol_via_morphism(mor2,symbol_via_morphism(mor1,sym1))))
    (sig_symbols(srcsig1))
  in morphism_from_symmap(smap)
  end
;

fun id_morphism sig1 =
  morphism_from_symmap ( map (fn x => (x,x)) (sig_symbols sig1) )

fun id_rsymmap sig1 =
  map (fn x => (SYMBOL_RAW_SYMBOL x,SYMBOL_RAW_SYMBOL x)) (sig_symbols sig1)
  

(* Function Ext and auxiliary functions *)

(* ID1 * ID list -> ID list *)
fun id_components(simple_id(_),_) = []
  | id_components(compound_id(_,id1l),idl) = 
      let val idll = 
        map (fn id1 => 
          case find_first (fn id11 => id11=id1) idl of
            None => id_components(id1,idl)
          | Some(id) => [id])
          id1l
      in foldl (fn (x,y) => y@x) ([],idll)
      end
;

(* ID list -> (ID * ID list) list *)
fun component_relation(idl) = 
  map (fn id1 => (id1,id_components(id1,idl))) idl
;

(* (ID * ID) list * ID1 -> ID1 *)
fun id1_via_idmap(idmap,id) = 
  case find_first (fn (id1,_) => id1=id) idmap of
    Some(_,id2) => id2
  | None => (case id of
      simple_id(_) => id
    | compound_id(tkn,idl) => 
      let val idl1 = map (fn idi => id1_via_idmap(idmap,idi)) idl in
        compound_id(tkn,idl1)
      end)
;

fun comp_id1_via_idmap(idmap,id) = 
  (case id of
    simple_id(_) => id1_via_idmap(idmap,id)
  | compound_id(tkn,idl) => 
    let val idl1 = map (fn idi => id1_via_idmap(idmap,idi)) idl in
      compound_id(tkn,idl1)
    end)
;

(* (ID * ID) list * ID -> ID *)
fun comp_id_via_idmap(idmap,(id)) = 
  (comp_id1_via_idmap(idmap,id))
;


fun remove_one(_,[]) = []
  | remove_one(x,y::z) = if x=y then z else (y::(remove(x,z)))
;

(* DEBUGGING *)
fun print_comprel [] = "\n"
  | print_comprel ((x,y)::t) = (BasicPrint.print_ID x)^"-|"^
                               (BasicPrint.print_IDs y)^"; "^
                               (print_comprel(t))
;

(* (ID * ID list) list * (ID * ID) list -> (ID * ID) list *)
fun extID1(comprel,idmap) =
  case find_first (fn (id,ids) => 
      forall' (fn id1 => 
        exists' (fn (id2,_) => id2=id1) idmap) ids) comprel of
    None => idmap
  | Some(id,ids) => extID1(remove((id,ids),comprel),
                           (id,comp_id_via_idmap(idmap,id))::idmap)
;

(* ID list * (ID * ID) list -> (ID * ID) list *)
fun extID(idlist,idmap) =
  remove_dups(extID1(component_relation(idlist),idmap))
;

(* symbol list * symbol_map -> raw_symbol_map *)
(* p. 93 *)
(* New definition: Substracting h' from Ext(h) *)
fun ext(srcsyms,h) =
  let val rawh = map (fn (sym1,sym2) => 
    (SYMBOL_RAW_SYMBOL(sym1),SYMBOL_RAW_SYMBOL(sym2))) h
  val h1 = map (fn (sym1,sym2) => (symbol_name(sym1),symbol_name(sym2))) h
  val h1ext = extID(map symbol_name srcsyms,h1)
  val h1ext_woh1 = filter (fn symp => not (symp mem h1)) h1ext
  val rawh1 = map (fn (id1,id2) => (IMPLICIT_RAW_SYMBOL(id1),
                                    IMPLICIT_RAW_SYMBOL(id2))) h1ext_woh1
  in
    remove_dups(rawh@rawh1)
  end
;

(* LocalEnv.local_env * LocalEnv.morphism -> (symbol*symbol) list *)
(* First argument can be either the source signature of the morphism or its
   subsignature *)
fun morphism_kernel(srcsig,mor) =
  let val srcsyms = sig_symbols(srcsig) in
  foldl (fn (sli,sym1) =>
    foldl (fn (slii,sym2) =>
      if (not(sym1=sym2)) andalso
         (symbol_via_morphism(mor,sym1)=symbol_via_morphism(mor,sym2))
	 andalso not((sym1,sym2) mem slii)
      then ((sym1,sym2)::slii) else slii)
      (sli,srcsyms))
    ([],srcsyms)
  end
;


local open BasicPrint in
fun print_symbol (SORT_SYMBOL(s)) = print_ID s
  | print_symbol (TOTAL_FUN_SYMBOL(n,(ss,s))) =
      (print_ID n)^":"^(print_SORTS1 ss)^"->"^(print_ID s)
  | print_symbol (PARTIAL_FUN_SYMBOL(n,(ss,s))) =
      (print_ID n)^":"^(print_SORTS1 ss)^"->"^(print_ID s)
  | print_symbol (PRED_SYMBOL(n,ss)) =
      (print_ID n)^":"^(print_SORTS1 ss)
;



fun print_raw_symbol (SYMBOL_RAW_SYMBOL(sym)) = print_symbol sym
  | print_raw_symbol (IMPLICIT_RAW_SYMBOL(id1)) = print_ID id1
  | print_raw_symbol _ = " innyrawsym "

end;


fun print_symbols [] = "\n"
  | print_symbols (h::t) = (print_symbol(h))^" "^(print_symbols(t))
;


fun print_core1 (s1,s2) = "("^(print_symbol s1)^" "^(print_symbol s2)^")";

fun print_core [] = "++\n"
  | print_core (c::cl) = (print_core1 c)^","^(print_core cl)
;


fun print_rawcore1 (s1,s2) = "("^(print_raw_symbol s1)^" "^
                                 (print_raw_symbol s2)^")";

fun print_rawcore [] = "++\n"
  | print_rawcore (c::cl) = (print_rawcore1 c)^","^(print_rawcore cl)
;

(* Transitive closure of overloading relations, using Till' algortithm and
   sorting by similarity of symbols. *)
(* Several functions. *)

fun ops_similar((on1,(sseq1,_)),(on2,(sseq2,_))) =
  (on1=on2) andalso
  (length(sseq1)=length(sseq2))
;

fun preds_similar((pn1,sseq1),(pn2,sseq2)) =
  (pn1=pn2) andalso
  (length(sseq1)=length(sseq2))
;

(* symbol * symbol -> bool *)
fun syms_similar(sym1,sym2) =
  case (sym1,sym2) of
    (TOTAL_FUN_SYMBOL(fsym1),TOTAL_FUN_SYMBOL(fsym2)) => 
      ops_similar(fsym1,fsym2)
  | (PARTIAL_FUN_SYMBOL(fsym1),TOTAL_FUN_SYMBOL(fsym2)) => 
      ops_similar(fsym1,fsym2) 
  | (TOTAL_FUN_SYMBOL(fsym1),PARTIAL_FUN_SYMBOL(fsym2)) => 
      ops_similar(fsym1,fsym2)
  | (PARTIAL_FUN_SYMBOL(fsym1),PARTIAL_FUN_SYMBOL(fsym2)) => 
      ops_similar(fsym1,fsym2)
  | (PRED_SYMBOL(psym1),PRED_SYMBOL(psym2)) => 
      preds_similar(psym1,psym2)
  | _ => false    
;

(* (symbol * 'a) * (symbol * 'a) list list -> (symbol 'a)  list list *)
fun add_symbol_to_similar((sym1,foo),[]) = [[(sym1,foo)]]
  | add_symbol_to_similar((sym1,foo),([]::t)) = [(sym1,foo)]::t
  | add_symbol_to_similar((sym1,foo1),(((sym2,foo2)::t2)::t1)) = 
      if syms_similar(sym1,sym2) then ((sym1,foo1)::((sym2,foo2)::t2))::t1
      else (((sym2,foo2)::t2)::(add_symbol_to_similar((sym1,foo1),t1)))
;

(* (symbol * symbol list) list -> (symbol * symbol list) list list *)
fun segregate_relation (syml) =
  foldl (fn (symll,sym) => add_symbol_to_similar(sym,symll)) ([],syml)
;

(* A bit faster transitive closure *)
(* (symbol * symbol list) list -> (symbol * symbol list) list *)
fun faster_transitive_closure(ovrrel) =
  flat (map (transitive_closure (fn (sym1:symbol,sym2:symbol) => sym1=sym2))
            (segregate_relation(ovrrel)))
;

(* Transitive closure of overloading relation *)
(* p. 82 *)
(* local_env -> (symbol * symbol list) list *)
fun trans_overloading_relation(lenv) =
  faster_transitive_closure(overloading_relation(lenv))
;

(* Sum of two relations *)
(* (symbol * symbol list) list * (symbol * symbol list) list -> 
         (symbol * symbol list) list *)
fun sum_overloading_relations(rel1,rel2) = 
  map (fn sym1 => 
    let val (_,ovrsym1) = if_none (find_first (fn (sym2,_) => sym2=sym1) rel1) 
                                  (sym1,[])
    val (_,ovrsym2) = if_none (find_first (fn (sym2,_) => sym2=sym1) rel2) 
                              (sym1,[])
    in (sym1,remove_dups(ovrsym1@ovrsym2))
    end)
    (remove_dups((map fst rel1)@(map fst rel2)))
;

(* Checking subrelations *)
(* (symbol * symbol list) list * (symbol * symbol list) list -> bool *)
fun is_subrelation(rel1,rel2) = 
  forall' (fn (sym1,ovrsym1) => 
    let val (_,ovrsym2) = if_none (find_first (fn (sym2,_) => sym2=sym1) rel2)
                                  (sym1,[])
    in Finset.is_subset(ovrsym1,ovrsym2)
    end)
    rel1
;

(* CaslEnv.local_env * CaslEnv.local_env -> string option *)
(* p. 90, proposition 32 *)
fun union_is_final(sig1,sig2) =
  let val syms1 = sig_symbols(sig1)
  val syms2 = sig_symbols(sig2)
  val rel1 = overloading_relation(sig1) 
  val rel2 = overloading_relation(sig2) 
  val sig12 = merge_lenvs(sig1,sig2)
  val syms12 = sig_symbols(sig12)
  val relUtrans = faster_transitive_closure 
                     (sum_overloading_relations(rel1,rel2)) 
  val rel12 = overloading_relation(sig12)
 (* val rel12trans = trans_overloading_relation(sig12) *)
  in
    if is_subrelation(rel12,relUtrans)
  (*  if is_subrelation(relUtrans,rel12trans) andalso 
       is_subrelation(rel12trans,relUtrans) *)
    then None
    else Some (
          "The following symbols newly get into the transitive closure of the overloading relation.\n  Advice: rename these symbols in the actual parameter.\n"
          
          ) 
  end
;

(* (CaslEnv.local_env * CaslEnv.morphism)^2 -> bool *)
(* p. 92 *)
fun compatible_2_morphisms((srcsig1,mor1),(srcsig2,mor2)) =
  let val smap1 = morphism_symmap(srcsig1,mor1)
  val smap2 = morphism_symmap(srcsig2,mor2)
  val smap = remove_dups(smap1@smap2)
  val sl = map (fn (a,b) => a) smap in
    length(remove_dups(sl))=length(sl)
   end
;
(* (CaslEnv.local_env * CaslEnv.morphisms) list -> bool *)
fun compatible_morphisms([]) = true
  | compatible_morphisms([h]) = true
  | compatible_morphisms(h::t::t2) = 
      compatible_2_morphisms(h,t) andalso compatible_morphisms(t::t2)
;

(* CaslEnv.local_env * CaslEnv.morphism * CaslEnv.local_env -> bool *)
(* p. 62 *)
fun preserves_overloading(lenv1,mor,lenv2) =
  let val ovr1 = overloading_relation(lenv1)
  val ovr2 = overloading_relation(lenv2)
  val ovr3 = map (fn (sym1,sym2s) => (symbol_via_morphism(mor,sym1),
                 map (fn sym2 => symbol_via_morphism(mor,sym2)) sym2s)) ovr1 in
    is_subrelation(ovr3,ovr2)
  end
;

(* (CaslEnv.local_env * CaslEnv.morphism)^2 -> CaslEnv.morphism *)
(* p. 92 *)
(* Assumes that the sum of morphisms actually exists *)
fun merge_2_morphisms((srcsig1,mor1),(srcsig2,mor2)) =
  let val smap1 = morphism_symmap(srcsig1,mor1)
    val smap2 = morphism_symmap(srcsig2,mor2) in
      morphism_from_symmap(remove_dups(smap1@smap2))
    end
;
(* (CaslEnv.local_env * CaslEnv.morphism) list -> CaslEnv.morphism *)
fun merge_morphisms(morl) = 
  let val (_,newmor) = 
    foldl (fn ((sigi,mori),(sigma,mor)) => 
      (merge_lenvs(sigi,sigma),merge_2_morphisms((sigi,mori),(sigma,mor))))
      ((LocalEnv.empty_local_env,
       (Symtab_id.empty,Symtab_id.empty,Symtab_id.empty)),
       morl)
  in newmor
  end
;   

(* comp_sub_sigs -> bool *)
(* p. 7,133 *)
fun compatible_sigs(ssl:comp_sub_sigs) = 
  let val (_,_,fenv,_) = merge_lenvlist(ssl)
      val fenvl = Symtab_id.dest(fenv) in
  forall' (fn (fid,otl) => 
    forall' (fn ot1 => 
      forall' (fn ot2 => 
        case ot1 of 
          total_op_type(sl1,s1) => (case ot2 of
              partial_op_type(sl2,s2) => 
                not(s1=s2 andalso sl1=sl2)
            | _ => true)
        | _ => true) otl) otl) fenvl
  end
;

(* (CaslEnv.local_env)^2 * CaslEnv.morphism * CaslEnv.local_env ->
     CaslEnv.morphism * CaslEnv.local_env *)
(* p. 93-94 *)
fun morphism_extension_along_sig_extension(sigma,sig1,mor,sigA) =
  let val sig1syms = sig_symbols(sig1) 
  val r = foldl (fn (ri,(rsym1,rsym2)) => 
    if (exists' (fn sym1 => matches(sym1,rsym1)) sig1syms) 
    then (rsym1,rsym2)::ri else ri) 
    ([],ext(sig_symbols(sig1),morphism_symmap(sigma,mor)))
  val (sigAD,_,rsig1) = 
    ((Stat_symmaps.induced_from_morphism(r,sig1,[]))
     handle (Symbols.STAT_EXCEPTION s) => 
       raise NO_MORPHISM_EXTENSION ("Error in construction of pushout morphism:\n"^s))
  in
  if not (compatible_sigs([sigA,sigAD]))
  then 
     raise NO_MORPHISM_EXTENSION "Signatures are not compatible"
  else case (union_is_final(sigA,sigAD)) of
     Some s => 
        raise NO_MORPHISM_EXTENSION ("Union is not final\n"^s)
     | None => if
        not (is_subset(intersect(sig_symbols(sigA),sig_symbols(sigAD)),
               symbols_via_morphism(sig_symbols(sigma),mor)) )
  then raise NO_MORPHISM_EXTENSION 
    ("Symbols shared between actual parameter and body must be in formal parameter\n"
     ^StructuredPrint.print_symbol_set (remove_set(
           symbols_via_morphism(sig_symbols(sigma),mor),
           intersect(sig_symbols(sigA),sig_symbols(sigAD))
            ))
    )
  else if
     not (is_subset(morphism_kernel(sig1,rsig1),morphism_kernel(sigma,mor)))
  then raise NO_MORPHISM_EXTENSION 
     ("Fitting morphism leads to forbidden identifications\n"
      ^StructuredPrint.print_symbol_pair_set (remove_set(
        morphism_kernel(sigma,mor),
        morphism_kernel(sig1,rsig1)
      ))
     )
  else
    (rsig1,merge_lenvs(sigA,sigAD))
  end
;   

end
end
