(*
This file is part of the first order theorem prover Darwin
Copyright (C) 2006
              The University of Iowa

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
*)




(*** types ***)


type var = Var.var
type symbol = Symbol.symbol
type term = Term.term
type clause = Term.clause
type counter = Counter.counter


(* used during sort inference, which is based on union-find. *)
type sort = {
  (* a unique id per sort *)
  mutable sort: int;

  (* the current representative of this sorts' equivalence class.
     None, if this is the representative *)

  mutable repr: sort option;

  (* if this is the class representative,
     then this is the number of class elements. *)

  mutable size: int;
}

(* the sort of a function/predicate symbol:

   for an n-ary function symbol f:
   f: S0 -> S1 -> ... Sn -> Sresult

   for an n-ary predicate symbol p:
   p: S0 -> S1 -> ... Sn -> BOOL
*)

type composed_sort = sort array

module SymbolTable = Symbol.SymbolTable
type symbol_sorts = composed_sort SymbolTable.t
module VarTable = Var.VarTable
type var_sorts = composed_sort VarTable.t

type sorts = {
  (* mapping from a sybmol to its sort *)
  symbol_sorts: symbol_sorts;

  (* total number of sorts *)
  number_of_sorts: int;

  (* constant symbols of the signature *)
  mutable constants: symbol list;
  (* function symbols of the signature *)
  functions: symbol list;
  (* predicate symbols of the signature *)
  predicates: symbol list;

  (* constant symbols partitioned by sorts *)
  mutable constants_partition: symbol list list;
  (* size of biggest partition in constants_partition *)
  mutable max_constant_partition_size: int;
}




(*** union-find ***)


(* get equivalence class of a sort *)
let rec get_repr (sort: sort) : sort =
  match sort.repr with
    | None ->
        sort

    | Some repr ->
        begin
          let repr' =
            get_repr repr
          in
            (* path compression *)
            if repr' != repr then
              sort.repr <- Some repr';
                
            repr'
        end


(* unify equivalence classes *)
let unify (a: sort) (b: sort) : unit =
  let a = get_repr a
  and b = get_repr b
  in
    if a == b then
      ()

    (* add bigger equivalence class to smaller *)
    else if a.size >= b.size then begin
      a.repr <- Some b;
      b.size <- a.size + b.size;
    end

    else begin
      b.repr <- Some a;
      a.size <- a.size + b.size;
    end






(*** register sorts ***)



(* retrieves/creates the sort of a symbol *)
let register_symbol (sorts: symbol_sorts) (sort_counter: counter) (symbol: symbol) : composed_sort =
  try
    SymbolTable.find sorts symbol
  with
    | Not_found ->
        (* function symbols have a result sort, predicate symbols don't *)
        let size =
          if Symbol.is_predicate symbol then
            Symbol.arity symbol
          else
            Symbol.arity symbol + 1
        in
        let composed_sort =
          Array.make size { sort = 0; repr = None; size = 1 }
        in
          for i = 0 to Array.length composed_sort - 1 do
            composed_sort.(i) <- { sort = Counter.next sort_counter; repr = None; size = 1 };
          done;
          SymbolTable.add sorts symbol composed_sort;

          composed_sort

(* retrieves/creates the sort of a variable *)
let register_var (var_sorts: var_sorts) (sort_counter: counter) (var: var) : composed_sort =
  try
    VarTable.find var_sorts var
  with
    | Not_found ->
        let composed_sort =
          [| { sort = Counter.next sort_counter; repr = None; size = 1 } |]
        in
          VarTable.add var_sorts var composed_sort;

          composed_sort

(* retrieves/creates the sort of the top symbol of a term *)
let register_term (sorts: symbol_sorts) (var_sorts: var_sorts) (sort_counter: counter) (term: term) : composed_sort =
  match term with
    | Term.Var var ->
        register_var var_sorts sort_counter var

    | Term.Const symbol
    | Term.Func { Term.symbol = symbol } ->
        register_symbol sorts sort_counter symbol


(* returns the result sort of a non-constant function symbol *)
let result_sort (composed_sort: composed_sort) : sort =
  if Array.length composed_sort == 0 then begin
    failwith ("Sort_inference: result_sort: ");
  end;
  composed_sort.(Array.length composed_sort - 1)


(* sorts sorts by arity and name *)
let sort_symbols (s1: symbol) (s2: symbol) =
  let cmp =
    compare (Symbol.arity s1) (Symbol.arity s2)
  in
    if cmp <> 0 then
      cmp
    else begin
      (* put skolem symbols last *)
      if Symbol.is_skolem s1 && not (Symbol.is_skolem s2) then
        1
      else if not (Symbol.is_skolem s1) && Symbol.is_skolem s2 then
        -1
      else
        compare (Symbol.name s1) (Symbol.name s2)
    end



(*** print ***)


let sort_to_string (sort: sort) : string =
  "'" ^ (string_of_int ((get_repr sort).sort))

let composed_sort_to_string (composed_sort) : string =
  String.concat " -> " (Array.to_list (Array.map sort_to_string composed_sort))


let print_symbols (sorts: sorts) (symbols: symbol list)  (title: string) : unit =
  print_endline (title ^ ": ");
  List.iter
    (fun symbol ->
       try
         let composed_sort =
           SymbolTable.find sorts.symbol_sorts symbol
         in
           print_endline (Symbol.name symbol ^ " : " ^ composed_sort_to_string composed_sort)
       with
         | Not_found ->
             failwith "Sort_inference.print_symbols"
    )
    symbols;
  print_newline ()

let print_sorts (sorts: sorts) : unit =
  print_endline ("Sorts: " ^ string_of_int sorts.number_of_sorts);
  print_newline ();

  print_symbols sorts sorts.constants "Constants";
  print_symbols sorts sorts.functions "Functions";
  print_symbols sorts sorts.predicates "Predicates"


let print = print_sorts





(*** infer ***)



(* partition the constant symbols into their sort equivalence classes,
   and return the size of the biggest partition *)

let partition_constants symbol_sorts constants number_of_sorts : symbol list list * int =
  let constants_partition =
    Array.make (number_of_sorts + 1) []
  in
    List.iter
      (fun symbol ->
        try
          let composed_sort =
            SymbolTable.find symbol_sorts symbol
          in
             let sort =
               get_repr (result_sort composed_sort)
             in
               if sort.sort < 0 || sort.sort >= Array.length constants_partition then
                 failwith ("Sort_inference.partition_constants: sort out of bound: " ^ string_of_int sort.sort);
               
               constants_partition.(sort.sort) <- symbol :: constants_partition.(sort.sort)
        with
          | Not_found ->
              failwith "Sort_inference.partition_constants: 2"
      )
      constants;
    
      (* sort the symbols of each partition *)
      for i = 0 to Array.length constants_partition - 1 do
        constants_partition.(i) <- List.sort sort_symbols constants_partition.(i);
      done;

      (* put domains with few constants first *)
      Array.sort
        (fun x y ->
           compare (List.length x) (List.length y)
        )
        constants_partition;

      let max_size =
        List.length constants_partition.(Array.length constants_partition - 1)
      in

      (* ignore sorts with no constants, and convert array to list *)
      let partition =
        Array.fold_right
          (fun symbols acc ->
            match symbols with
              | [] -> acc
              | _ -> symbols :: acc
          )
          constants_partition
          []
      in
        partition, max_size


(* infer sorts for a term.
   - [sort] is the sort of the term position at which term occurs,
     i.e. the term's result sort must be the same as [sort] *)

let rec infer_term (sorts: symbol_sorts) (var_sorts: var_sorts) (sort_counter: counter)
    (sort: sort) (term: term) : unit =
  match term with
    | Term.Var var ->
        unify sort (result_sort (register_var var_sorts sort_counter var))

    | Term.Const symbol ->
        unify sort (result_sort (register_symbol sorts sort_counter symbol))

    | Term.Func func ->
        let composed_sort =
          register_symbol sorts sort_counter func.Term.symbol
        in
          unify sort (result_sort composed_sort);
          (* register all subterms, and infer their sorts *)
          Array.iteri
            (fun i term ->
               infer_term sorts var_sorts sort_counter composed_sort.(i) term
            )
            func.Term.subterms
        
        

(* infer sorts for a clause *)
let infer_clause (sorts: symbol_sorts) (sort_counter: counter) (clause: clause) : unit =
  (* variable sorts are local to a clause *)
  let var_sorts =
    VarTable.create 64
  in
    List.iter
      (fun literal ->
         match literal.Term.atom with
           | Term.Var _ ->
               failwith ("Sort_inference: infer_clause on variable: " ^ Term.literal_to_string literal);

           | Term.Const symbol ->
               (* predicate with arity 0 - nothing to infer *)
               ignore (register_symbol sorts sort_counter symbol : composed_sort);

           | Term.Func func ->
               (* equality predicate *)
               if Symbol.equal Symbol.equality func.Term.symbol then begin
                 (* register both subterms, unify their result sort, and infer their sorts *)
                 let left_sort = register_term sorts var_sorts sort_counter func.Term.subterms.(0)
                 and right_sort = register_term sorts var_sorts sort_counter func.Term.subterms.(1)
                 in
                   unify (result_sort left_sort) (result_sort right_sort);
                   infer_term sorts var_sorts sort_counter (result_sort left_sort) func.Term.subterms.(0);
                   infer_term sorts var_sorts sort_counter (result_sort right_sort) func.Term.subterms.(1);
               end

               (* predicate with arguments *)
               else begin
                 let composed_sort =
                   register_symbol sorts sort_counter func.Term.symbol
                 in
                   (* register all subterms, and infer their sorts *)
                   Array.iteri
                     (fun i term ->
                        infer_term sorts var_sorts sort_counter composed_sort.(i) term
                     )
                     func.Term.subterms
               end
      )
      clause


(* infer sorts for clauses *)
let infer ~(print:bool) (clauses: clause list) : sorts =
  let symbol_sorts =
    SymbolTable.create 256
  in
  let sort_counter =
    Counter.create_with 0
  in
    (* infer the sorts for each symbol *)
    List.iter
      (fun clause -> infer_clause symbol_sorts sort_counter clause)
      clauses;

    (* for each found symbol, check if it is a constant, function, or predicate.
       also check which sorts are left after unifying sorts. *)

    let constants, functions, predicates, used_sorts =
      SymbolTable.fold
        (fun symbol composed_sort (constants, functions, predicates, used_sorts) ->
           (* add sorts used by this symbol, which did not occur before *)
           let used_sorts =
             Array.fold_left
               (fun acc sort ->
                 Tools.list_add (==) acc (get_repr sort)
               )
               used_sorts
               composed_sort
           in
             (* a predicate symbol *)
             if Symbol.is_predicate symbol then
               (constants, functions, symbol :: predicates, used_sorts)
                 
             (* a function symbol *)
             else if Array.length composed_sort > 1 then
               (constants, symbol :: functions, predicates, used_sorts)

             (* a constant symbol *)
             else
               (symbol :: constants, functions, predicates, used_sorts)
        )
        symbol_sorts
        ([], [], [], [])
    in

    (* normalize used sorts to use the numbers from 1 .. number_of_used_sorts *)
    let rec normalize counter used_sorts =
      match used_sorts with
        | [] ->
            ()

        | head :: tail ->
            head.sort <- counter;
            normalize (counter + 1) tail
    in
      normalize 1 used_sorts;

    (* sort the symbols *)
    let constants = List.sort sort_symbols constants
    and functions = List.sort sort_symbols functions
    and predicates = List.sort sort_symbols predicates
    and number_of_sorts = List.length used_sorts
    in

    let constants_partition, max_size =
      partition_constants symbol_sorts constants number_of_sorts
    in

    let sorts = {
      symbol_sorts = symbol_sorts;
      number_of_sorts = number_of_sorts;

      constants = constants;
      functions = functions;
      predicates = predicates;

      constants_partition = constants_partition;
      max_constant_partition_size = max_size;
    }
    in
     
      if print then begin
        print_sorts sorts
      end;

      sorts



let add_constant (sorts: sorts) (constant: symbol) (existing: symbol) : unit =
  if SymbolTable.mem sorts.symbol_sorts constant then
    failwith ("Sort_inference.add_constant: symbol already known: " ^ Symbol.to_string constant);

  let sort =
    try
      result_sort (SymbolTable.find sorts.symbol_sorts existing)
    with
      | Not_found ->
          failwith ("Sort_inference.add_constant: symbol is unkown: " ^ Symbol.to_string existing);
  in 
  let composed_sort =
    Array.make 1 { sort = 0; repr = Some sort; size = 1 }
  in
    SymbolTable.add sorts.symbol_sorts constant composed_sort;

    let constants =
      List.sort sort_symbols (constant :: sorts.constants)
    in
    let constants_partition, max_size =
      partition_constants sorts.symbol_sorts constants sorts.number_of_sorts
    in
      sorts.constants <- constants;
      sorts.constants_partition <- constants_partition;
      sorts.max_constant_partition_size <- max_size


let constants_partition (sorts: sorts) =
  sorts.constants_partition

let max_constant_partition_size (sorts: sorts) : int =
  sorts.max_constant_partition_size