let infer ~(print:bool) (clauses: clause list) : sorts =
  let symbol_sorts =
    SymbolTable.create 256
  in
  let sort_counter =
    Counter.create_with 0
  in
    (* infer the sorts for each symbol *)
    List.iter
      (fun clause -> infer_clause symbol_sorts sort_counter clause)
      clauses;

    (* for each found symbol, check if it is a constant, function, or predicate.
       also check which sorts are left after unifying sorts. *)

    let constants, functions, predicates, used_sorts =
      SymbolTable.fold
        (fun symbol composed_sort (constants, functions, predicates, used_sorts) ->
           (* add sorts used by this symbol, which did not occur before *)
           let used_sorts =
             Array.fold_left
               (fun acc sort ->
                 Tools.list_add (==) acc (get_repr sort)
               )
               used_sorts
               composed_sort
           in
             (* a predicate symbol *)
             if Symbol.is_predicate symbol then
               (constants, functions, symbol :: predicates, used_sorts)
                 
             (* a function symbol *)
             else if Array.length composed_sort > 1 then
               (constants, symbol :: functions, predicates, used_sorts)

             (* a constant symbol *)
             else
               (symbol :: constants, functions, predicates, used_sorts)
        )
        symbol_sorts
        ([], [], [], [])
    in

    (* normalize used sorts to use the numbers from 1 .. number_of_used_sorts *)
    let rec normalize counter used_sorts =
      match used_sorts with
        | [] ->
            ()

        | head :: tail ->
            head.sort <- counter;
            normalize (counter + 1) tail
    in
      normalize 1 used_sorts;

    (* sort the symbols *)
    let constants = List.sort sort_symbols constants
    and functions = List.sort sort_symbols functions
    and predicates = List.sort sort_symbols predicates
    and number_of_sorts = List.length used_sorts
    in

    let constants_partition, max_size =
      partition_constants symbol_sorts constants number_of_sorts
    in

    let sorts = {
      symbol_sorts = symbol_sorts;
      number_of_sorts = number_of_sorts;

      constants = constants;
      functions = functions;
      predicates = predicates;

      constants_partition = constants_partition;
      max_constant_partition_size = max_size;
    }
    in
     
      if print then begin
        print_sorts sorts
      end;

      sorts