open Batteries
open Term

type 'v form =
    Atom of (int * 'v term list)
  | Neg of 'v form
  | Conj of 'v form * 'v form
  | Disj of 'v form * 'v form
  | Impl of 'v form * 'v form
  | Eqiv of 'v form * 'v form
  | Forall of 'v * 'v form
  | Exists of 'v * 'v form

type iform = int form
type sform = string form

let forall x t = Forall (x, t)
let exists x t = Exists (x, t)
let conj l r = Conj (l, r)
let disj l r = Disj (l, r)
let neg t = Neg t

let form_of_lit ((p, l) as a) =
  if p > 0 then Atom a
  else Neg (Atom (-p, l))

let lit_of_form = function
    Neg (Atom (n, l)) -> -n, l
  | Atom (n, l) -> n, l
  | _ -> failwith "lit_of_form"

let rec map_form_vars f = function
    Atom (p, t) -> Atom (p, List.map (map_term_vars (fun v -> V (f v))) t)
  | Neg t -> Neg (map_form_vars f t)
  | Conj (l, r) -> Conj (map_form_vars f l, map_form_vars f r)
  | Disj (l, r) -> Disj (map_form_vars f l, map_form_vars f r)
  | Eqiv (l, r) -> Eqiv (map_form_vars f l, map_form_vars f r)
  | Impl (l, r) -> Impl (map_form_vars f l, map_form_vars f r)
  | Forall (x, t) -> Forall (f x, map_form_vars f t)
  | Exists (x, t) -> Exists (f x, map_form_vars f t)

let rec unfold_equiv polar = function
    Forall (x, t) -> Forall (x, unfold_equiv polar t)
  | Exists (x, t) -> Exists (x, unfold_equiv polar t)
  | Conj (l, r) -> Conj (unfold_equiv polar l, unfold_equiv polar r)
  | Disj (l, r) -> Disj (unfold_equiv polar l, unfold_equiv polar r)
  | Impl (l, r) -> unfold_equiv polar (Disj (Neg l, r))
  | Neg t -> Neg (unfold_equiv (not polar) t)
  | Eqiv (l, r) ->
      let lp = unfold_equiv polar l and rp = unfold_equiv polar r in
      let ln = unfold_equiv (not polar) l and rn = unfold_equiv (not polar) r in
      if polar then Conj (Disj (Neg ln, rp), Disj (Neg rn, lp))
      else Disj (Conj (lp, rp), Conj (Neg rn, Neg ln))
  | x -> x

(* unbound variables *)
let rec form_vars sf = function
    Forall (i, t) -> if IS.mem i sf then form_vars sf t else IS.remove i (form_vars sf t)
  | Exists (i, t) -> if IS.mem i sf then form_vars sf t else IS.remove i (form_vars sf t)
  | Conj (l, r) -> form_vars (form_vars sf r) l
  | Disj (l, r) -> form_vars (form_vars sf r) l
  | Eqiv (l, r) -> form_vars (form_vars sf r) l
  | Impl (l, r) -> form_vars (form_vars sf r) l
  | Neg t -> form_vars sf t
  | Atom (_, t) -> List.fold_left term_vars sf t

let rec strip_disj sf = function
    Disj (l, r) -> strip_disj (strip_disj sf r) l
  | x -> x :: sf
let rec strip_conj sf = function
    Conj (l, r) -> strip_conj (strip_conj sf r) l
  | x -> x :: sf

let rec rev_strip_disj sf = function
    Disj (l, r) -> rev_strip_disj (rev_strip_disj sf l) r
  | x -> x :: sf
let rec rev_strip_conj sf = function
    Conj (l, r) -> rev_strip_conj (rev_strip_conj sf l) r
  | x -> x :: sf

let disj_of_forms = function
    h :: t -> List.fold_left (fun acc x -> Disj (acc, x)) h t
  | [] -> failwith "disj_of_forms"

let conj_of_forms = function
    h :: t -> List.fold_left (fun acc x -> Conj (acc, x)) h t
  | [] -> failwith "conj_of_forms"




let rec miniscope = function
    Forall (x, Conj (l, r)) ->
      Conj (miniscope (Forall (x, l)), miniscope (Forall (x, r)))
  | Exists (x, Disj (l, r)) ->
      Disj (miniscope (Exists (x, l)), miniscope (Exists (x, r)))
  | Forall (x, t) -> if IS.mem x (form_vars IS.empty t) then Forall (x, miniscope t) else miniscope t
  | Exists (x, t) -> if IS.mem x (form_vars IS.empty t) then Exists (x, miniscope t) else miniscope t
  | Neg p -> Neg (miniscope p)
  | Conj (l, r) -> Conj (miniscope l, miniscope r)
  | Disj (l, r) -> Disj (miniscope l, miniscope r)
  | x -> x
let rec fix f x = let y = f x in if x = y then x else fix f y
let miniscope = fix miniscope


let rename_form min_conj =
  let rec rename_form ((map, mv) as sf) = function
    Atom (i, l) ->
      let sf, (i, l) = rename_lit sf (i, l) in sf, Atom (i, l)
  | Neg t ->
      let sf, t = rename_form sf t in sf, Neg t
  | Disj (l, r) ->
      let sf, l = rename_form sf l in
      let sf, r = rename_form sf r in sf, Disj (l, r)
  | Conj (l, r) when min_conj ->
      let _, l = rename_form sf l in
      let _, r = rename_form sf r in sf, Conj (l, r)
  | Conj (l, r) when not min_conj ->
      let sf, l = rename_form sf l in
      let sf, r = rename_form sf r in sf, Conj (l, r)
  | Forall (i, t) ->
      let sf, t = rename_form (IM.add i mv map, mv + 1) t in
      sf, Forall (mv, t)
  | Exists (i, t) ->
      let sf, t = rename_form (IM.add i mv map, mv + 1) t in
      sf, Exists (mv, t)
  | _ -> failwith "rename_form"
  in rename_form (IM.empty, 0) %> snd


let rec nnf = function
    Neg (Neg t) -> nnf t
  | Neg (Forall (i, t)) -> Exists (i, nnf (Neg t))
  | Neg (Exists (i, t)) -> Forall (i, nnf (Neg t))
  | Neg (Conj (l, r)) -> Disj (nnf (Neg l), nnf (Neg r))
  | Neg (Disj (l, r)) -> Conj (nnf (Neg l), nnf (Neg r))
  | Forall (i, t) -> Forall (i, nnf t)
  | Exists (i, t) -> Exists (i, nnf t)
  | Conj (l, r) -> Conj (nnf l, nnf r)
  | Disj (l, r) -> Disj (nnf l, nnf r)
  | t -> t

(* Expect skolemized nnf *)
let rec noforall = function
    Forall (i, t) -> noforall t
  | Conj (l, r) -> Conj (noforall l, noforall r)
  | Disj (l, r) -> Disj (noforall l, noforall r)
  | x -> x

type skolem_info =
{ formula : iform
; variables : IS.t
; relative : string
; absolute : string
}

let skolem_form const_num skolem_name =
  let vars_of_set = IS.to_list %> List.map (fun v -> V v) in
  let sk_of_info info =
    A (const_num (skolem_name info.formula info.relative), vars_of_set info.variables) in
  let skolem_tm ctxt =
    map_term_vars (fun i -> try snd (List.assoc i ctxt) with Not_found -> V i) in

  let rec skolem ctxt pos = function
      Forall (i, t) -> Forall (i, skolem ctxt (pos ^ "0") t)
    | Conj (l, r) -> Conj (skolem ctxt (pos ^ "0") l, skolem ctxt (pos ^ "1") r)
    | Disj (l, r) -> Disj (skolem ctxt (pos ^ "0") l, skolem ctxt (pos ^ "1") r)
    | Exists (i, t) as ex -> begin
        let vs = IS.remove i (form_vars IS.empty t) in
        let info = match List.filter (fun (v, _) -> IS.mem v vs) ctxt with
          [] -> {formula = ex; variables = vs; relative = "e"; absolute = pos}
        | l ->
          let univ = List.fold_right (fun (v, _) -> IS.remove v) ctxt vs in
          let deps = List.fold_right (fun (_, (ski, _)) -> IS.union ski.variables) l univ in
          let (_, (upper, _)) = List.last l in
          let rel_pos = String.lchop ~n:(String.length upper.absolute) pos in
          {formula = upper.formula; variables = deps; relative = rel_pos; absolute = pos} in
        skolem ((i, (info, sk_of_info info)) :: ctxt) (pos ^ "0") t
        end
    | Atom (p, tm) -> Atom (p, List.map (skolem_tm ctxt) tm)
    | Neg (Atom _ as a) -> Neg (skolem ctxt (pos ^ "0") a)
    | _ -> invalid_arg "skolem" in
  skolem [] ""


let rec collect_fp ((f, p) as sf) = function
    Neg x -> collect_fp sf x
  | Conj (l, r) -> collect_fp (collect_fp sf l) r
  | Disj (l, r) -> collect_fp (collect_fp sf l) r
  | Eqiv (l, r) -> collect_fp (collect_fp sf l) r
  | Impl (l, r) -> collect_fp (collect_fp sf l) r
  | Forall (i, t) -> collect_fp sf t
  | Exists (i, t) -> collect_fp sf t
  | Atom (i, t) ->
      List.fold_left term_funs f t,
      if t <> [] then IM.add (abs i) (List.length t) p else p
