(* *********************************************************************** *)
(*									   *)
(* Project: CATS 							   *)
(* Author: Till Mossakowski, University of Bremen			   *)
(* Date: 1998				 			           *)
(* Purpose of this file: Subsorting stuff                      	           *)
(*			 						   *)	
(*									   *)
(* *********************************************************************** *)

(* This module computes the ~_F and ~_P overloading relations according
   to the definition in the CASL summary.

   Moreover, a function is provided that computes the set of those sorts 
   with a term containing partial operation symbols.  
*)

structure Subsorts : sig

datatype Subsort_env_or_list = its_an_env of LocalEnv.Subsort_env
			     | its_a_list of LocalEnv.Subsort_list

datatype wrong =  its_an_atom of AS.ATOM
		| its_a_term of AS.TERM
		| its_a_formula of AS.FORMULA
		| its_an_overload_F of (Subsort_env_or_list * AS.OP_TYPE * AS.OP_TYPE)
		| allright
		

exception WRONG of int * wrong;

type ARG_RES_TRIPLE = AS.SORT list * AS.SORT list * AS.SORT list * AS.SORT * AS.SORT * AS.SORT
type ARG_TRIPLE = AS.SORT list * AS.SORT list * AS.SORT list

(*
val not_disjoint = fn : ''a list * ''a list -> bool

val get_minimum = fn : Subsort_env_or_list * SORT * SORT list -> SORT list
val get_minima1 = fn
  : Subsort_env_or_list -> SORT list -> SORT list -> SORT list
val get_minima = fn : Subsort_env_or_list -> SORT list -> SORT list
val get_maximum = fn
  : Subsort_env_or_list * Symtab_id.key * SORT list -> Symtab_id.key list
val get_maxima1 = fn
  : Subsort_env_or_list
    -> Symtab_id.key list -> Symtab_id.key list -> SORT list
val get_maxima = fn : Subsort_env_or_list -> SORT list -> SORT list

val get_common_members = fn : ''a list * ''a list -> ''a list
val get_common_lower_bounds = fn
  : Subsort_env_or_list -> SORT * SORT -> SORT list
val get_common_lower_bound_tuples = fn
  : Subsort_env_or_list -> SORT list * SORT list -> SORT list list
  *)
val leq_S : Subsort_env_or_list * AS.SORT * AS.SORT -> bool
val lookup_subsorts : AS.ID * Subsort_env_or_list -> AS.SORT list
val get_common_upper_bounds : Subsort_env_or_list -> AS.SORT * AS.SORT -> AS.SORT list
val overload_F : Subsort_env_or_list
    -> AS.OP_TYPE * AS.OP_TYPE
       -> (ARG_RES_TRIPLE) list
val overload_P : Subsort_env_or_list
    -> AS.PRED_TYPE * AS.PRED_TYPE -> ARG_TRIPLE list
val compute_s_question_mark : bool -> LocalEnv.local_list -> AS.SORT list
val test_question_mark : AS.SORT list * AS.SORT list -> bool
end

= 

struct

open AS Utils LocalEnv IDOrder

datatype Subsort_env_or_list = its_an_env of Subsort_env
			     | its_a_list of Subsort_list

datatype wrong =  its_an_atom of ATOM
		| its_a_term of TERM
		| its_a_formula of FORMULA
		| its_an_overload_F of (Subsort_env_or_list * OP_TYPE * OP_TYPE)
		| allright
		

exception WRONG of int * wrong;

type ARG_RES_TRIPLE = AS.SORT list * AS.SORT list * AS.SORT list * AS.SORT * AS.SORT * AS.SORT
type ARG_TRIPLE = AS.SORT list * AS.SORT list * AS.SORT list

fun not_disjoint (l,[]) = false
|   not_disjoint (l,x::xs) = SORT_member (x,l) orelse not_disjoint (l,xs)

local
fun lookup_subsorts1 (s,subsortlist) =  Symtab_id.lookup_multi (subsortlist,s)
fun lookup_subsorts2 (s,(s1,subsorts)::rest) = 
		if SORT_eq(s,s1) then subsorts else lookup_subsorts2 (s,rest)
  | lookup_subsorts2 _ = raise (ERR "lookup_subsorts2")
  
in
fun lookup_subsorts (s,its_an_env e) = lookup_subsorts1 (s,e)
| lookup_subsorts (s,its_a_list l) = lookup_subsorts2 (s,l) 
end

fun leq_S (slist,s,s1) = SORT_member(s,lookup_subsorts (s1,slist))


fun get_minimum (slist:Subsort_env_or_list,s:SORT,[]:SORT list) = [s]
  | get_minimum (slist,s,(s1::rest)) =
   	if leq_S (slist,s,s1) then get_minimum (slist,s,rest)
  	else if leq_S (slist,s1,s) then s1::rest
  	     else s1::get_minimum (slist,s,rest)
  	     
fun get_minima1 (slist:Subsort_env_or_list) (acc:SORT list) [] = acc
  | get_minima1 slist acc (s::rest) =
	get_minima1 slist (get_minimum (slist,s,acc)) rest 

fun get_minima (slist:Subsort_env_or_list) (l:SORT list) = 
	get_minima1 slist [] l

fun get_maximum (slist,s,[]) = [s]
  | get_maximum (slist,s,(s1::rest)) =
   	if leq_S (slist,s1,s) then get_maximum (slist,s,rest)
  	else if leq_S (slist,s,s1) then s1::rest
  	     else s1::get_maximum (slist,s,rest)
  	     
fun get_maxima1 (slist:Subsort_env_or_list) (acc:SORT list) [] = acc
  | get_maxima1 slist acc (s::rest) =
	get_maxima1 slist (get_maximum (slist,s,acc)) rest 

fun get_maxima (slist:Subsort_env_or_list) (l:SORT list) = 
	get_maxima1 slist [] l



local

fun get_common_upper_bounds1 (e:Subsort_env) (s1:SORT,s2:SORT):SORT list =
	let
         fun is_common_upper_bound(s,subsorts) = SORT_member(s1,subsorts) andalso SORT_member(s2,subsorts)
	in
	case Symtab_id.find_first is_common_upper_bound e of
	Some (s,subsorts) => [s]
	| None => []
	end


	
fun get_common_upper_bounds2 ([]:Subsort_list) (s1:SORT,s2:SORT):SORT list = []
  | get_common_upper_bounds2 ((s,subsorts)::rest) (s1:SORT,s2:SORT):SORT list =
    if (SORT_member(s1,subsorts) andalso SORT_member(s2,subsorts)) 
    then s::get_common_upper_bounds2 rest (s1,s2)
    else get_common_upper_bounds2 rest (s1,s2)
    
(*fun get_common_upper_bounds1 (e:Subsort_env) (s1:SORT,s2:SORT):SORT list =
 let val l = Symtab_id.dest e
 in  get_common_upper_bounds2 l (s1,s2)
 end   *) 

in

fun get_common_upper_bounds (its_an_env e) (s1,s2) = 
	get_minima (its_an_env e) (get_common_upper_bounds1 e (s1,s2))
|   get_common_upper_bounds (its_a_list l) (s1,s2) = 
	get_minima (its_a_list l) (get_common_upper_bounds2 l (s1,s2))
end

fun get_common_members ([],l) = []
|   get_common_members (l,[]) = []
|   get_common_members (x::xs,ys) = 
   if SORT_member (x, ys) 
   then x::get_common_members (xs,ys)
   else get_common_members (xs,ys)



fun get_common_lower_bounds (slist:Subsort_env_or_list) (s1:SORT,s2:SORT):SORT list = 
let val subsorts1 = lookup_subsorts (s1,slist)
    val subsorts2 = lookup_subsorts (s2,slist)
in
get_maxima slist (get_common_members (subsorts1,subsorts2))
end

fun get_common_lower_bound_tuples (slist:Subsort_env_or_list) 
	(s1:SORT list,s2:SORT list):SORT list list =
permute (map (get_common_lower_bounds slist)  (zip (s1,s2)))



fun overload_F (slist:Subsort_env_or_list)
(t1:OP_TYPE,t2:OP_TYPE):(ARG_RES_TRIPLE) list =
(let val s1 = get_res t1
    val s2 = get_res t2
    val w1 = get_args t1
    val w2 = get_args t2
    val cubs = get_common_upper_bounds slist (s1,s2)
    (*val _ = writeln ("Upper bounds of "^BasicPrint.print_ID s1^" and "^BasicPrint.print_ID s2^" :\n"
                 ^BasicPrint.print_IDs cubs^"\n")*)
    val clbs = map (get_common_lower_bounds slist) (zip (w1,w2))
    fun add_type (s::w) = (w,w1,w2,s,s1,s2)
      | add_type _ = raise (ERR "add_type")
in
   map add_type (permute (cubs::clbs))
end)
handle Match => (raise (WRONG (5,its_an_overload_F (slist,t1,t2))))
| ZIP_ERROR => []
                  

fun overload_P (slist:Subsort_env_or_list) 
               (t1:PRED_TYPE,t2:PRED_TYPE):
               (ARG_TRIPLE) list =
let val w1 = get_sorts (get_pred_type t1)
    val w2 = get_sorts (get_pred_type t2)
    val clbs = map (get_common_lower_bounds slist) (zip (w1,w2))
    fun add_type w = (w,w1,w2)
in
   map add_type (permute clbs)
end





(**********************************************************)
(*                                                        *)
(*     Get those sorts with a term containing             *)
(*     partial operation symbols                          *)
(*                                                        *)
(**********************************************************)

local
fun add_types ([],acc) = acc
|   add_types (partial_op_type (args,res)::rest,(totaltypes,partialresults)) =
       add_types (rest,(totaltypes,res::partialresults))
|   add_types (total_op_type (args,res)::rest,(totaltypes,partialresults))=
       add_types (rest,((get_sorts args,res)::totaltypes,partialresults))
|   add_types (pos_OP_TYPE (_,t)::rest,acc) =
       add_types (t::rest,acc)

fun extract_types ([]:Fun_list,acc) = acc
|   extract_types (((f,typelist)::rest),acc) = 
    extract_types (rest,add_types (typelist, acc))

fun mk_set base_list list_of_elems = 
	map ((fn x => fn y => SORT_member (y,x)) list_of_elems) base_list

fun mk_list nil nil = nil
  | mk_list (b::base_list) (m::memberhood_list) =
	if m then b::mk_list base_list memberhood_list
	else mk_list base_list memberhood_list
  | mk_list _ _ = raise (ERR "Subsorts: mk_list")
	
fun add_to_set (b::base_list) (x,m::memberhood_list) = 
	if SORT_eq(x,b) then true::memberhood_list
	else m::add_to_set base_list (x,memberhood_list)
  | add_to_set _ _ = raise (ERR "Subsorts: add_to_set")

fun binor (x,y) = x orelse y
fun binand (x,y) = x andalso y
fun not_disjoint_set (m1,m2) =
	foldl binor (false,map binand (zip (m1,m2)))

fun check_total_fun base_list (thesorts,(args,res)) =
if not_disjoint_set (thesorts,mk_set base_list args) 
then add_to_set base_list (res,thesorts)
else thesorts

fun upward_close1 base_list (thesorts,(s,subs)) =
	if not_disjoint_set (thesorts,mk_set base_list subs) 
	then add_to_set base_list (s,thesorts)
	else thesorts
	
fun upward_close base_list subsorts thesorts =
	foldl (upward_close1 base_list) (thesorts,subsorts)

fun iterate_it base_list subsorts (thesorts,totaltypes) = 
let
val thesorts1 = foldl (check_total_fun base_list) (thesorts,totaltypes)
val thesorts2 = upward_close base_list subsorts thesorts1
in
if (thesorts2=thesorts) then thesorts
else iterate_it base_list subsorts (thesorts2,totaltypes)
end

fun remove (x, []) = []
|   remove (x, (y::ys)) = if SORT_eq(x,y) then remove (x, ys) else y::remove (x, ys)

fun get_subsorts subl =
	flat (map remove subl)


in	
fun compute_s_question_mark (use_projections:bool)
             ((sl,subl,varl,funl,predl):local_list):SORT list=
let val (totaltypes,partialresults) = extract_types (funl,([],[]))
    val partial_subsorts = if use_projections then get_subsorts subl else nil
in    mk_list sl (iterate_it sl subl (mk_set sl (partialresults@partial_subsorts),totaltypes))
end;


fun test_question_mark (sq:SORT list,arglist:SORT list):bool =
    not_disjoint (sq,arglist) ;
 
end   

end
