open Batteries
open BatFixes
open Option.Infix

open Arglean.Female
open Arglean.Monte
open Arglean.Monte.Rollout
open Arglean.Monte.Evaluation
open Arglean.Trace
open Arglean.Lean
open Common
open Database.Clausal.ClassifierDb
open Features
open Mapping
open Term
open Proof.Clausal.Step

module Subst = Substitution.Substoff (Substitution.Substlist)


let uct_sim_steps = ref 0
let uct_infer = ref 0
let ml_lit_tried = ref 0
let ml_lit_found = ref 0


type context = {
  path : lit list
; features : float FM.t
; lemmas : lit list
; extensions : int
; extension_cps : int list
}

type clause = lit list

type pstate = {
  sub : Subst.t * int
; proof : proof_step list
; tasks : (context * (clause * int)) list
; opened : int
}

let mc_pstate sub (path, fea, lem, lim) lit =
  let ctxt = {path = path; features = fea; lemmas = lem; extensions = 0; extension_cps = []}
  and goal = [lit], Database.Clausal.contr_number lit [] in
  {sub = sub; proof = []; tasks = [ctxt, goal]; opened = 1}

let proof_ext state =
  match state.proof, state.tasks with
    [Res _], (ctxt, cls) :: _ -> Some (state.sub, (ctxt.features, cls))
  | _ -> None

let show_last_proof_step t = let open Uct in
  match t.state.proof with
    [] -> "None"
  | hd :: _ -> show_proof_step_vbar hd

let rec fold_first f acc t = let open Uct in
  let acc' = f acc t in
  match t.children with
    [] -> acc'
  | c::_ -> fold_first f acc' c

let rec first_children acc t =
  fold_first (fun (n, r) t -> (n+1, r +. Uct.avg_reward t)) acc t

let rec branch_children acc0 t0 = let open Uct in
  fold_first (fun acc t -> List.fold_left (fun (n, r) c -> n + 1, r +. avg_reward c) acc t.children) acc0 t0

let rec all_children (n, r) t = let open Uct in
  let acc = (n+1, r +. avg_reward t) in
  List.fold_left all_children acc t.children

let discrimination t =
  let avg (n, r) = r /. float_of_int n in
  let null = (0, 0.) in
  let fc = avg (first_children null t) in
  fc /. avg (branch_children null t),
  fc /. avg (all_children    null t)


(* extension selection heuristic *)

let normalise_between (new_min, tiebreaker, new_max) = function
    [] -> []
  | l -> let min, max = List.min_max ~cmp:(fun (_, a) (_, b) -> compare a b) l in
         let f =
           if snd min = snd max then fun _ -> tiebreaker
           else
             let scale = (new_max -. new_min) /. (snd max -. snd min) in
             fun y -> new_min +. (y -. snd min) *. scale in
         List.map (fun (x, y) -> (*Format.printf "Normalised: %f\n%!" (f y);*) x, f y) l

let lincomb =
  List.fold_left (fun acc (m, x) -> if m = 0. then acc else acc +. m *. Lazy.force x) 0.

(* compress float value in [0, inf[ to range (lower, upper) *)
let compress_range (lower, upper) =
  let c = 1. /. (upper -. lower) in
  fun x -> upper -. (1. /. (x +. c))

let neg_normalise range = Float.neg %> compress_range range %> Float.neg

let trans_prob len max_lbl_rel min_ftr_interrel (lbl_rel, ftr_rel) =
  lincomb
  [ !invprob, lazy (1. /. float_of_int (1 + len))
  ; !constprob, lazy 1.
  ; !bayesprob, lazy (
      match Lazy.force max_lbl_rel with
          None -> exp (-2.)
        | Some mlr ->
            let inter_prob, n_inter, n_disj = Lazy.force ftr_rel |> snd in
            let ftr_rel = (inter_prob +. (Lazy.force min_ftr_interrel |> Option.default 0.) *. float_of_int n_disj) /.
              (float_of_int (n_inter + n_disj)) in
            if !verbose then Format.printf "lbl_rel = %f, mlr = %f, ftr_rel = %f\n%!" (Lazy.force lbl_rel) mlr ftr_rel;
            let sum = Lazy.force lbl_rel -. mlr +. ftr_rel in
            let normed = neg_normalise (0., 2.) sum in
            if !verbose then Format.printf "Raw/normed logarithmic Bayes: %2.2f\t%2.2f\n%!" sum normed;
            exp normed
            |> (if !verbose then tap (Format.printf "Bayes transition probability: %f\n%!") else identity)
    )
  ]

let prev_occ_dampen (exts, hsh) score =
  if !prevdampen = 0. then score
  else
    let prev_occs = float_of_int (List.length (List.filter ((=) hsh) exts) + 1) in
     score /. (prev_occs ** !prevdampen)

let eval_state_prob = function
    true -> List.filter_map
      (fun (sl, rl) -> Lazy.force sl >>= (fun s -> Some (lazy (Some s), Lazy.force rl)))
  | false -> List.map (fun (sl, rl) -> sl, Lazy.force rl)

let foldl1 f = function
    [] -> None
  | x :: xs -> Some (List.fold_left f x xs)

(* prover loop *)

let rec do_task start state = incr uct_sim_steps;
  match state.tasks with
    [] -> if !verbose then Format.printf "Proof found!\n"; if start then [] else [lazy (Some state), 1.]
  | (_, ([], _)) :: rest -> do_task false {state with tasks = rest}
  | task :: rest -> prove_cls {state with tasks = rest} task

and prove_cls ({sub; proof; tasks} as state) (ctx, (cl, hsh)) = match cl with
    [] -> [lazy (Some state), 1.]
  | lit :: lits ->
    if List.exists (fun x -> List.exists (Subst.eq sub x) ctx.path) cl then []
    else if List.exists (Subst.eq sub lit) ctx.lemmas then
      [lazy (Some {state with proof = Lem lit :: proof; tasks = (ctx, (lits, hsh)) :: tasks}), 1.]
    else begin
      let neglit = negate_lit lit
      and nfea = lazy
        (if !upd_fea then update_features (fst sub) lit ctx.features else ctx.features) in
      if !verbose then FM.iter (fun ftr w -> Format.printf "Feature: %s %f\n%!" (Print.string_of_lit ftr) w) (Lazy.force nfea);
      (*let subst_vars sub' = float_of_int (List.length (fst sub') - List.length (fst sub)) in*)
      let reductions = List.map
        (fun p ->
          lazy ((Subst.unify sub neglit p >>=
          (fun sub1 ->
            Some ({state with sub = sub1; proof = Pat lit :: proof;
              tasks = (ctx, (lits, hsh)) :: tasks})))), 1.)
        ctx.path in
      let dbs = db_entries sub neglit in
      let bayesian = List.map (fun (contra, freqs) ->
        (contra, (lazy (lbl_relevance freqs),
                 lazy (ftr_relevance (Lazy.force nfea) freqs)))) dbs in
      let max_lbl_rel = lazy (
        List.map (fun (_, (lr, _)) -> Lazy.force lr) bayesian |> foldl1 max) in
      let min_ftr_interrel = lazy (
        List.map (fun (_, (_, fr)) -> Lazy.force fr |> fst) bayesian |> foldl1 min) in
      let extensions = List.map
        (fun ((_, rest, vars, hsh1) as contra, bayes) ->
          let len = List.length rest in
          let prob = lazy (
            trans_prob len max_lbl_rel min_ftr_interrel bayes
            |> prev_occ_dampen (ctx.extension_cps, hsh)
            |> (if !verbose then tap (Format.printf "Transition probability: %f\n%!") else identity)) in
          lazy ((Subst.unify_rename sub (snd lit) contra >>=
          (fun (sub1, cla1) ->
            incr uct_infer;
            let step = Res (lit, ctx.path, ctx.lemmas, hsh1)
            and tsk1 = {ctx with path = lit :: ctx.path; features = Lazy.force nfea; extensions = ctx.extensions + 1; extension_cps = hsh1 :: ctx.extension_cps}, (cla1, hsh1)
            and tsk2 = {ctx with lemmas = lit :: ctx.lemmas}, (lits, hsh) in
            Some ({sub = sub1; proof = step :: proof; tasks = tsk1 :: tsk2 :: tasks;
              opened = state.opened + len})))), prob
          
        ) bayesian
      |> eval_state_prob !prefilter in
      reductions @ extensions
    end

let fproduct = List.fold_left (fun x y -> x *. y) 1.
let harmean l = let n = List.length l in
  float_of_int n /. (List.fsum (List.map (fun x -> 1. /. x) l))
let geomean l = fproduct l ** (1. /. float_of_int (List.length l))
let arimean l = List.fsum l /. (float_of_int (List.length l))

let get_comb = function
    Min -> List.min
  | Product -> fproduct
  | Harmonic -> harmean
  | Geometric -> geomean
  | Arithmetic -> arimean

let lit_refutability stats lit =
  let fi = float_of_int
  and lit_hsh = Database.lit_hash lit in
  try
    incr ml_lit_tried;
    let pos,neg = List.assoc lit_hsh stats in
    if !verbose then Format.printf "Found %d: %d %d\n" lit_hsh pos neg;
    incr ml_lit_found;
    let posneg = fi (pos+neg) in
    1. -. compress_range (0., !cert) (posneg /. !certdamp) *. (fi neg /. posneg)
    |> (if !verbose then tap (Format.printf "Literal refutability: %f\n%!") else identity)
  with Not_found ->
    if !verbose then Format.printf "Not found literal %d (%s)\n" lit_hsh (Print.string_of_lit lit);
    !defscore

let subgoal_refutability = function
    (_, ([], _)) -> 1.
  | (ctxt, (lits, cl_hsh)) ->
  try
    let stats = Hashtbl.find Litdata.cl_in cl_hsh in
    if !verbose then Format.printf "Clause %d found\n" cl_hsh;
    List.map (lit_refutability stats) lits |> get_comb !mlmean
  with Not_found ->
    if !verbose then Format.printf "Not found clause %d\n" cl_hsh;
    !defscore

let rec term_size = function
    V v -> 0
  | A (f, a) -> 1 + List.fold_left (fun acc x -> acc + term_size x) 0 a

let lit_size (p, a) = term_size (A (p, a))

let geomeanf f l =
  let prod, len = List.fold_left (fun (p, n) x -> p *. f x, n+1) (1., 0) l in
  prod ** (1. /. float_of_int len)

let size_reward {sub; proof; opened; tasks} =
  let cl_means = List.map (fun ({extensions}, (cl, clhash)) ->
     let mean = geomeanf (((/.) 1.) % float_of_int % lit_size % Subst.inst_lit (fst sub)) cl in
     let attenuation = compress_range (!depth_att_max, 1.) in
     mean *. attenuation (float_of_int extensions)
     ) tasks in
  geomean cl_means

let cop_reward = function
  {sub; proof; tasks = []} -> 1.0, Some (sub, proof)
| {sub; proof; opened; tasks} as state ->
  let fi = float_of_int in
  let closed = List.length proof in
  if !verbose then Format.printf "Proof length: %d\n" closed;
  if !verbose then Format.printf "Remaining tasks: %d\n" (List.length tasks);

  let reward = lincomb
    [ !sizereward, lazy (size_reward state)
    ; !naivereward, lazy (1. -. fi closed /. fi opened)
    ; !randreward, lazy (Random.float 1.)
    ; !mlreward, lazy (List.map subgoal_refutability tasks |> get_comb !mlmean)
    ] in
  if !verbose then Format.printf "Total reward: %f\n" reward;
  reward, None

(*let rec shorter_elems f shortest l =
  let fs = f shortest in
  match List.drop_while (fun x -> f x >= fs) l with
    shorter :: rest -> shorter :: shorter_elems f shorter rest
  | [] -> []
*)

let by f a b = compare (f a) (f b)

let cut_metric = function
    On_clauses  -> List.length
  | On_literals -> List.fold_left (fun sum (_, (cls, _)) -> sum + List.length cls) 0

let cut_expansion_policy cut p sim =
  Uct.state_expansion p
    (List.min_max ~cmp:(by (fun s -> s.tasks |> cut_metric cut)) sim |> fst)
    (List.last sim)

let expansion_fun = function
    First -> Uct.single_expansion_policy
  | Cut c -> cut_expansion_policy c

let expansion_policy p pol simulation = let open Uct in
  if !verbose then List.iteri (fun i s -> Format.printf "Simulation step %d: %d tasks left\n" i (List.length s.tasks)) simulation;
  expansion_fun pol p simulation

let pi = 4. *. atan 1.
let explore i = !exploration +. !expamp *. sin (2. *. pi *. float_of_int i /. !expper)

let montecop_uct lim = let open Uct in
  let p = { successors = do_task true; reward = cop_reward} in
  let iters = ref 0 in
  { problem = p
  ; tree_policy = uct_tree_policy (explore !iters)
  ; simulation_policy = default_simulation_policy p lim
  ; expansion_policy = (incr iters; expansion_policy p !exppol)
  }
