open Batteries
open BatFixes
open Option.Infix

open Fof
open Term


module Clausal =
struct

(* Expects nnf with no quantifiers *)
let rec cnf_form = function
  | Disj (l, r) -> (match (cnf_form l, cnf_form r) with
    | Conj (ll, lr), r -> Conj (cnf_form (Disj (ll, r)), cnf_form (Disj (lr, r)))
    | l, Conj (rl, rr) -> Conj (cnf_form (Disj (l, rl)), cnf_form (Disj (l, rr)))
    | x, y -> Disj (x, y)
  )
  | Conj (l, r) -> Conj (cnf_form l, cnf_form r)
  | x -> x

(* TODO: Look at: *)
(* P. Jackson, D. Sheridan. Clause Form Conversions for Boolean Circuits *)
(* P. Manolios, D. Vroon. Efficient Circuit to CNF Conversion *)
let dcnf_form dcnf_name =
  let rec dcnf sf = function
    | Conj (l, r) -> dcnf (dcnf sf l) r
    | (Disj _ as d) ->
        let l = rev_strip_disj [] d in
        let (sf, ret) = List.fold_left dcnf_disj (sf, []) l in ret :: sf
    | Atom l -> [l] :: sf
    | Neg (Atom (i, p)) -> [-i, p] :: sf
    | _ -> failwith "dcnf"
  and dcnf_disj (sfc, sfd) = function
    | (Conj _ as c) ->
        let l = rev_strip_conj [] c in
        let (sfc, l) = List.fold_left ccnf_disj (sfc, []) l in
        let l = List.sort compare (List.map (List.sort compare) l) in
        let n = dcnf_name l in
        let fvm = List.fold_left (List.fold_left lit_vars) IS.empty l in
        let fvs = IS.fold (fun k sf -> V k :: sf) fvm [] in
        let pos = n, fvs and neg = -n, fvs in
        (List.fold_left (fun sf d -> (neg :: d) :: sf) sfc l, pos :: sfd)
    | Atom l -> (sfc, l :: sfd)
    | Neg (Atom (i, p)) -> (sfc, (-i, p) :: sfd)
    | _ -> failwith "dcnf_disj"
  and ccnf_disj (sfc, sfd) = function
    | (Disj _ as d) ->
        let l = rev_strip_disj [] d in
        let (sfc, d) = List.fold_left dcnf_disj (sfc, []) l in (sfc, d :: sfd)
    | Atom l -> (sfc, [l] :: sfd)
    | Neg (Atom (i, p)) -> (sfc, [-i, p] :: sfd)
    | _ -> failwith "ccnf_disj"
  in dcnf

let dcnf_of_form dcnf_name =
  let rec cls_of_form t =
    let l = strip_conj [] t |> List.map cl_of_form in
    List.map fst l @ List.concat (List.map snd l)
  and cl_of_form t =
    let l = strip_disj [] t |> List.map lit_of_form in
    List.map fst l, List.concat (List.map snd l)
  and lit_of_form = function
      Atom t -> t, []
    | Neg (Atom t) -> negate_lit t, []
    | Conj _ as t ->
        let n = dcnf_name t
        and fvm = form_vars IS.empty t in
        let fvs = IS.fold (fun k sf -> V k :: sf) fvm [] in
        let pos = n, fvs in
        let neg = Neg (Atom pos) in
        let cls = strip_conj [] t
          |> List.map (fun ci -> cls_of_form (Disj (neg, ci))) |> List.concat in
        pos, cls
    | _ -> failwith "dncf_form"
  in cls_of_form

let dcnf_of_forms dcnf_name = List.fold_left (dcnf_form dcnf_name) []

let cnf_of_form t = t |>
  cnf_form |> strip_conj [] |> List.map (strip_disj [] %> List.map lit_of_form)

let cnf_of_forms ts =
  ts |> List.rev_map (cnf_form %> strip_conj []) |> List.concat |>
  List.map (strip_disj [] %> List.map lit_of_form)

let rename_clause l = snd (fold_map rename_lit (IM.empty, 0) l)


(* true if the clause does not contain instances of Px and ~Px *)
let nontriv_clause cl =
  List.for_all (fun (p1,a1) -> List.for_all (fun (p2,a2) -> p1 <> -p2 || a1 <> a2) cl) cl

let clause_positive cl = List.exists (fun (p, _) -> p > 0) cl

let clause_max_var cl = 1 + List.fold_left lit_max_var (-1) cl

end


module Nonclausal =
struct

type 't litmat = Lit of lit | Mat of 't imatrix
and 't imatrix = 't * 't matrix
and 't matrix = 't iclause list
and 't iclause = ('t * int list) * 't clause
and 't clause = 't  litmat list

let map_litmat fl fm = function
    Lit lit -> fl lit
  | Mat mat -> fm mat

let lit x = Lit x
let mat x = Mat x

let clause_lits c = List.filter_map (map_litmat Option.some (const None)) c

let rec strip_disj = function
    Disj (l, r) -> strip_disj l @ strip_disj r
  | x -> [x]
let rec strip_conj = function
    Conj (l, r) -> strip_conj l @ strip_conj r
  | x -> [x]

let rec strip_forall acc = function
    Forall (x, t) -> strip_forall (x :: acc) t
  | t -> List.rev acc, t

let rec litmat_of_form = function
    Atom t -> Lit t
  | Neg (Atom t) -> Lit (negate_lit t)
  | (Conj _ as t) | (Forall _ as t) -> Mat ((), matrix_of_form t)
  | _ -> failwith "litmat_of_form"
and matrix_of_form t = List.map iclause_of_form (strip_conj t)
and clause_of_form t = List.map litmat_of_form (strip_disj t)
and iclause_of_form t =
  let univ, t' = strip_forall [] t in ((), univ), clause_of_form t'

let rec form_of_litmat lm = map_litmat form_of_lit form_of_imatrix lm
and form_of_imatrix (_, m) = form_of_matrix m
and form_of_iclause (_, c) = form_of_clause c
and form_of_matrix m = List.map form_of_iclause m |> conj_of_forms
and form_of_clause c = List.map form_of_litmat  c |> disj_of_forms


let rec litmat_max_var acc = map_litmat (lit_max_var acc) (imatrix_max_var acc)
and imatrix_max_var acc (_, m) = matrix_max_var acc m
and iclause_max_var acc (_, c) = clause_max_var acc c
and matrix_max_var acc m = List.fold_left iclause_max_var acc m
and clause_max_var acc c = List.fold_left  litmat_max_var acc c

let  clause_offset c = 1 +  clause_max_var (-1) c
let iclause_offset c = 1 + iclause_max_var (-1) c

let rec index_imatrix i (_, m)  = let j, m' = index_matrix (i+1) m in j, (i, m')
and index_matrix i m = fold_map index_iclause i m
and index_iclause i ((_, v), c) = let j, c' = index_clause (i+1) c in j, ((i, v), c')
and index_clause i c = fold_map index_litmat i c
and index_litmat i = function
    Lit l -> i, Lit l
  | Mat m -> let j, m' = index_imatrix i m in j, Mat m'
let index_matrix m = snd (index_matrix 0 m)

let rec map_imatrix f (t, m) = (f t, map_matrix f m)
and map_matrix f = List.map (map_iclause f)
and map_iclause f ((t, v), c) = ((f t, v), map_clause f c)
and map_clause f = List.map (map_litmat lit (mat % map_imatrix f))

let copy_matrix k = map_matrix (fun i -> i, k)
let copy_clause k = map_clause (fun i -> i, k)


let rec break_map f = function
    [] -> Some []
  | x::xs -> f x >>= fun y -> break_map f xs >>= fun ys -> Some (y::ys)

let rec litmat_positive lm = map_litmat lit_positive imatrix_positive lm
and lit_positive (p, a) = if p < 0 then Some (Lit (p, a)) else None
and imatrix_positive (i, m) =
  match matrix_positive m with [] -> None | m' -> Some (Mat (i, m'))
and matrix_positive m = List.filter_map iclause_positive m
and iclause_positive (i, c) = clause_positive c >>= (fun c' -> Some (i, c'))
and clause_positive c = break_map litmat_positive c

let rec map_litmat_vars f =
  map_litmat (lit % map_lit_vars (fun v -> V (f v))) (mat % map_imatrix_vars f)
and map_imatrix_vars f (i, m) = (i, map_matrix_vars f m)
and map_matrix_vars f m = List.map (map_iclause_vars f) m
and map_iclause_vars f ((i, v), c) = ((i, List.map f v), map_clause_vars f c)
and map_clause_vars f c = List.map (map_litmat_vars f) c

let offset_iclause off = map_iclause_vars ((+) off)
let offset_clause off = map_clause_vars ((+) off)
let offset_matrix off = map_matrix_vars ((+) off)

let by_paths (p1, _) (p2, _) = compare p1 p2
let product = List.fold_left (fun acc x -> acc * x) 1

let rec litmat_paths = function
    Lit l -> 1, Lit l
  | Mat m -> let p, m' = imatrix_paths m in p, Mat m'
and imatrix_paths (i, m) = let p, m' = matrix_paths m in p, (i, m')
and matrix_paths m =
  let pm = m |> List.map iclause_paths |> List.sort by_paths in
  List.map fst pm |> product, List.map snd pm
and iclause_paths (iv, c) = let p, c' = clause_paths c in p, (iv, c')
and clause_paths c =
  let pc = c |> List.map litmat_paths |> List.sort by_paths in
  List.map fst pc |> List.sum, List.map snd pc

let matrix_paths m = snd (matrix_paths m)

end
