type var = Var.var
type symbol = Symbol.symbol
type term = Term.term
type clause = Term.clause
type counter = Counter.counter
type sort = {
mutable sort: int;
mutable repr: sort option;
mutable size: int;
}
type composed_sort = sort array
module SymbolTable = Symbol.SymbolTable
type symbol_sorts = composed_sort SymbolTable.t
module VarTable = Var.VarTable
type var_sorts = composed_sort VarTable.t
type sorts = {
symbol_sorts: symbol_sorts;
number_of_sorts: int;
mutable constants: symbol list;
functions: symbol list;
predicates: symbol list;
mutable constants_partition: symbol list list;
mutable max_constant_partition_size: int;
}
let rec get_repr (sort: sort) : sort =
match sort.repr with
| None ->
sort
| Some repr ->
begin
let repr' =
get_repr repr
in
if repr' != repr then
sort.repr <- Some repr';
repr'
end
let unify (a: sort) (b: sort) : unit =
let a = get_repr a
and b = get_repr b
in
if a == b then
()
else if a.size >= b.size then begin
a.repr <- Some b;
b.size <- a.size + b.size;
end
else begin
b.repr <- Some a;
a.size <- a.size + b.size;
end
let register_symbol (sorts: symbol_sorts) (sort_counter: counter) (symbol: symbol) : composed_sort =
try
SymbolTable.find sorts symbol
with
| Not_found ->
let size =
if Symbol.is_predicate symbol then
Symbol.arity symbol
else
Symbol.arity symbol + 1
in
let composed_sort =
Array.make size { sort = 0; repr = None; size = 1 }
in
for i = 0 to Array.length composed_sort - 1 do
composed_sort.(i) <- { sort = Counter.next sort_counter; repr = None; size = 1 };
done;
SymbolTable.add sorts symbol composed_sort;
composed_sort
let register_var (var_sorts: var_sorts) (sort_counter: counter) (var: var) : composed_sort =
try
VarTable.find var_sorts var
with
| Not_found ->
let composed_sort =
[| { sort = Counter.next sort_counter; repr = None; size = 1 } |]
in
VarTable.add var_sorts var composed_sort;
composed_sort
let register_term (sorts: symbol_sorts) (var_sorts: var_sorts) (sort_counter: counter) (term: term) : composed_sort =
match term with
| Term.Var var ->
register_var var_sorts sort_counter var
| Term.Const symbol
| Term.Func { Term.symbol = symbol } ->
register_symbol sorts sort_counter symbol
let result_sort (composed_sort: composed_sort) : sort =
if Array.length composed_sort == 0 then begin
failwith ("Sort_inference: result_sort: ");
end;
composed_sort.(Array.length composed_sort - 1)
let sort_symbols (s1: symbol) (s2: symbol) =
let cmp =
compare (Symbol.arity s1) (Symbol.arity s2)
in
if cmp <> 0 then
cmp
else begin
if Symbol.is_skolem s1 && not (Symbol.is_skolem s2) then
1
else if not (Symbol.is_skolem s1) && Symbol.is_skolem s2 then
-1
else
compare (Symbol.name s1) (Symbol.name s2)
end
let sort_to_string (sort: sort) : string =
"'" ^ (string_of_int ((get_repr sort).sort))
let composed_sort_to_string (composed_sort) : string =
String.concat " -> " (Array.to_list (Array.map sort_to_string composed_sort))
let print_symbols (sorts: sorts) (symbols: symbol list) (title: string) : unit =
print_endline (title ^ ": ");
List.iter
(fun symbol ->
try
let composed_sort =
SymbolTable.find sorts.symbol_sorts symbol
in
print_endline (Symbol.name symbol ^ " : " ^ composed_sort_to_string composed_sort)
with
| Not_found ->
failwith "Sort_inference.print_symbols"
)
symbols;
print_newline ()
let print_sorts (sorts: sorts) : unit =
print_endline ("Sorts: " ^ string_of_int sorts.number_of_sorts);
print_newline ();
print_symbols sorts sorts.constants "Constants";
print_symbols sorts sorts.functions "Functions";
print_symbols sorts sorts.predicates "Predicates"
let print = print_sorts
let partition_constants symbol_sorts constants number_of_sorts : symbol list list * int =
let constants_partition =
Array.make (number_of_sorts + 1) []
in
List.iter
(fun symbol ->
try
let composed_sort =
SymbolTable.find symbol_sorts symbol
in
let sort =
get_repr (result_sort composed_sort)
in
if sort.sort < 0 || sort.sort >= Array.length constants_partition then
failwith ("Sort_inference.partition_constants: sort out of bound: " ^ string_of_int sort.sort);
constants_partition.(sort.sort) <- symbol :: constants_partition.(sort.sort)
with
| Not_found ->
failwith "Sort_inference.partition_constants: 2"
)
constants;
for i = 0 to Array.length constants_partition - 1 do
constants_partition.(i) <- List.sort sort_symbols constants_partition.(i);
done;
Array.sort
(fun x y ->
compare (List.length x) (List.length y)
)
constants_partition;
let max_size =
List.length constants_partition.(Array.length constants_partition - 1)
in
let partition =
Array.fold_right
(fun symbols acc ->
match symbols with
| [] -> acc
| _ -> symbols :: acc
)
constants_partition
[]
in
partition, max_size
let rec infer_term (sorts: symbol_sorts) (var_sorts: var_sorts) (sort_counter: counter)
(sort: sort) (term: term) : unit =
match term with
| Term.Var var ->
unify sort (result_sort (register_var var_sorts sort_counter var))
| Term.Const symbol ->
unify sort (result_sort (register_symbol sorts sort_counter symbol))
| Term.Func func ->
let composed_sort =
register_symbol sorts sort_counter func.Term.symbol
in
unify sort (result_sort composed_sort);
Array.iteri
(fun i term ->
infer_term sorts var_sorts sort_counter composed_sort.(i) term
)
func.Term.subterms
let infer_clause (sorts: symbol_sorts) (sort_counter: counter) (clause: clause) : unit =
let var_sorts =
VarTable.create 64
in
List.iter
(fun literal ->
match literal.Term.atom with
| Term.Var _ ->
failwith ("Sort_inference: infer_clause on variable: " ^ Term.literal_to_string literal);
| Term.Const symbol ->
ignore (register_symbol sorts sort_counter symbol : composed_sort);
| Term.Func func ->
if Symbol.equal Symbol.equality func.Term.symbol then begin
let left_sort = register_term sorts var_sorts sort_counter func.Term.subterms.(0)
and right_sort = register_term sorts var_sorts sort_counter func.Term.subterms.(1)
in
unify (result_sort left_sort) (result_sort right_sort);
infer_term sorts var_sorts sort_counter (result_sort left_sort) func.Term.subterms.(0);
infer_term sorts var_sorts sort_counter (result_sort right_sort) func.Term.subterms.(1);
end
else begin
let composed_sort =
register_symbol sorts sort_counter func.Term.symbol
in
Array.iteri
(fun i term ->
infer_term sorts var_sorts sort_counter composed_sort.(i) term
)
func.Term.subterms
end
)
clause
let infer ~(print:bool) (clauses: clause list) : sorts =
let symbol_sorts =
SymbolTable.create 256
in
let sort_counter =
Counter.create_with 0
in
List.iter
(fun clause -> infer_clause symbol_sorts sort_counter clause)
clauses;
let constants, functions, predicates, used_sorts =
SymbolTable.fold
(fun symbol composed_sort (constants, functions, predicates, used_sorts) ->
let used_sorts =
Array.fold_left
(fun acc sort ->
Tools.list_add (==) acc (get_repr sort)
)
used_sorts
composed_sort
in
if Symbol.is_predicate symbol then
(constants, functions, symbol :: predicates, used_sorts)
else if Array.length composed_sort > 1 then
(constants, symbol :: functions, predicates, used_sorts)
else
(symbol :: constants, functions, predicates, used_sorts)
)
symbol_sorts
([], [], [], [])
in
let rec normalize counter used_sorts =
match used_sorts with
| [] ->
()
| head :: tail ->
head.sort <- counter;
normalize (counter + 1) tail
in
normalize 1 used_sorts;
let constants = List.sort sort_symbols constants
and functions = List.sort sort_symbols functions
and predicates = List.sort sort_symbols predicates
and number_of_sorts = List.length used_sorts
in
let constants_partition, max_size =
partition_constants symbol_sorts constants number_of_sorts
in
let sorts = {
symbol_sorts = symbol_sorts;
number_of_sorts = number_of_sorts;
constants = constants;
functions = functions;
predicates = predicates;
constants_partition = constants_partition;
max_constant_partition_size = max_size;
}
in
if print then begin
print_sorts sorts
end;
sorts
let add_constant (sorts: sorts) (constant: symbol) (existing: symbol) : unit =
if SymbolTable.mem sorts.symbol_sorts constant then
failwith ("Sort_inference.add_constant: symbol already known: " ^ Symbol.to_string constant);
let sort =
try
result_sort (SymbolTable.find sorts.symbol_sorts existing)
with
| Not_found ->
failwith ("Sort_inference.add_constant: symbol is unkown: " ^ Symbol.to_string existing);
in
let composed_sort =
Array.make 1 { sort = 0; repr = Some sort; size = 1 }
in
SymbolTable.add sorts.symbol_sorts constant composed_sort;
let constants =
List.sort sort_symbols (constant :: sorts.constants)
in
let constants_partition, max_size =
partition_constants sorts.symbol_sorts constants sorts.number_of_sorts
in
sorts.constants <- constants;
sorts.constants_partition <- constants_partition;
sorts.max_constant_partition_size <- max_size
let constants_partition (sorts: sorts) =
sorts.constants_partition
let max_constant_partition_size (sorts: sorts) : int =
sorts.max_constant_partition_size