let infer ~(print:bool) (clauses: clause list) : sorts =
let symbol_sorts =
SymbolTable.create 256
in
let sort_counter =
Counter.create_with 0
in
List.iter
(fun clause -> infer_clause symbol_sorts sort_counter clause)
clauses;
let constants, functions, predicates, used_sorts =
SymbolTable.fold
(fun symbol composed_sort (constants, functions, predicates, used_sorts) ->
let used_sorts =
Array.fold_left
(fun acc sort ->
Tools.list_add (==) acc (get_repr sort)
)
used_sorts
composed_sort
in
if Symbol.is_predicate symbol then
(constants, functions, symbol :: predicates, used_sorts)
else if Array.length composed_sort > 1 then
(constants, symbol :: functions, predicates, used_sorts)
else
(symbol :: constants, functions, predicates, used_sorts)
)
symbol_sorts
([], [], [], [])
in
let rec normalize counter used_sorts =
match used_sorts with
| [] ->
()
| head :: tail ->
head.sort <- counter;
normalize (counter + 1) tail
in
normalize 1 used_sorts;
let constants = List.sort sort_symbols constants
and functions = List.sort sort_symbols functions
and predicates = List.sort sort_symbols predicates
and number_of_sorts = List.length used_sorts
in
let constants_partition, max_size =
partition_constants symbol_sorts constants number_of_sorts
in
let sorts = {
symbol_sorts = symbol_sorts;
number_of_sorts = number_of_sorts;
constants = constants;
functions = functions;
predicates = predicates;
constants_partition = constants_partition;
max_constant_partition_size = max_size;
}
in
if print then begin
print_sorts sorts
end;
sorts