open Batteries
open BatFixes
open LazyList
open Option.Infix

open Arglean.Trace
open Arglean.Lean
open Arglean.Monte
open Arglean.Monte.Evaluation
open Arglean.Oracle
open Arglean.Female
open Cnf
open Common
open Database.Clausal.ClassifierDb
open Features
open Mapping
open Term
open Print
open Proof.Clausal.Proof

module Subst = Substitution.Substoff (Substitution.Substlist)

let uct_iterations = ref 0
let uct_discr_sum = ref (0., 0.)
let uct_proofs = ref 0

let comb2 f (x1, y1) (x2, y2) = f x1 x2, f y1 y2

let mc_tree uct st = let open Uct in
  let v = empty_tree uct.problem st in
  let v' = {v with embryos = LazyList.filter (Montelib.proof_ext %> Option.is_some) v.embryos} in
  if !verbose then Format.printf "Embryos: %d (of initially %d)\n%!" (LazyList.length v'.embryos) (LazyList.length v.embryos);
  v'

let mc_get_children iters = let open Uct in
  let result = LazyList.last iters |> fst in
  let children = result.children |> List.sort by_reward |> List.map (fun x -> x.state) in
  LazyList.of_list children ^@^ result.embryos

let mc_get_exts iters =
  lazy (mc_get_children iters |> LazyList.filter_map Montelib.proof_ext |> LazyList.next)

let mc_get_proof (tree, (_, prf)) =
  prf >>= (fun (sub2, prf) ->
    let n_exts = List.filter (Proof.Clausal.Step.is_res_step) prf |> List.length in
    if !verbose then Format.printf "MC proof found with %d extension steps\n%!" n_exts;
    let discr = Montelib.discrimination tree in
    if !verbose then Format.printf "Discrimination: %f/%f\n%!" (fst discr) (snd discr);
    uct_discr_sum := comb2 Float.add !uct_discr_sum discr;
    incr uct_proofs;
    Stats.infer := !Stats.infer + n_exts;
    let prf' = prf |> List.rev |> Proof.Clausal.Proof.of_steps |> snd |> snd in
    Some (sub2, prf'))

let quote s = "\"" ^ s ^ "\""
let embrace s = "{" ^ s ^ "}"

let fields_of_tree t = let open Uct in
[ "r/n", Format.sprintf "%.2f" (avg_reward t)
; "n", string_of_int t.visits
; "\\|S\\|", string_of_int (LazyList.length t.embryos)
; "ds", Montelib.show_last_proof_step t |> embrace
]

let string_of_fields f =
  String.concat "|" (List.map (fun (l, r) -> "{" ^ l ^ "|" ^ r ^ "}") f)

let save_dot prob out uct = let open Uct in let open Dot in
  let name_of_state s = "Embryo" in
  let node_of_state s = ["label", quote (name_of_state s)] in
  let node_of_tree ({state} as t) = ["shape", "record"; "label", fields_of_tree t |> string_of_fields |> embrace |> quote] in
  let dot = LazyList.last uct |> fst |> Uct.dot_tree (node_of_tree, node_of_state) in
  File.with_file_out out (fun c -> Dot.print_dot c dot);
  exit 0


let rec prove_lit sub (path, fea, lem, lim) lit =
  if !verbose then Format.printf "Lit: %s\n%!" (string_of_lit lit);
  if !verbose then Format.printf "Path: %s\n%!" (string_of_lits (LazyList.to_list path));
  if !verbose then Format.printf "Lemmas: %s\n%!" (string_of_lits lem);
  let neglit = negate_lit lit
  and nfea = lazy (if !upd_fea then update_features (fst sub) lit fea else fea) in
  let lemmas =
    if List.exists (Subst.eq sub lit) lem then (if !verbose then Format.printf "lemma\n%!"; cons (sub, lem, Lemma) nil) else nil
  and reductions = LazyList.filter_map
    (fun p -> if !verbose then Format.printf "Reduction try %s\n%!" (string_of_lit p);
      Subst.unify sub neglit p >>=
      (fun sub1 -> if !verbose then Format.printf "Reduction works\n%!";
        Some (sub1, lem, Reduction))
    ) path
  and extensions_monte = lazy (
    let prob = Montelib.montecop_uct !simdepth in
    let tree = Montelib.mc_pstate sub (path |> LazyList.to_list, fea, lem, lim) lit |> mc_tree prob in
    let uct = Uct.iteration prob tree
      |> (if !maxiters < 0 then identity else LazyList.take !maxiters)
      |> LazyList.map (fun x -> incr uct_iterations; x)
      |> LazyList.cons (tree, (0.0, None))
      |> (if !dot_out = "" then identity else tap (save_dot prob !dot_out)) in
    let prfs = LazyList.filter_map mc_get_proof uct
      |> LazyList.unique
      |> (if !n_proofs < 0 then identity else LazyList.take !n_proofs)
      |> LazyList.map (fun (sub2, prf) -> (sub2, lit :: lem, prf)) in
    let exts = mc_get_exts uct
      |> (if lim <= 0 then LazyList.filter (fun (sub1, _) -> snd sub1 = snd sub)
          else identity)
      |> LazyList.map (fun (sub1, (nfea, (cla1, hsh))) ->
          if !verbose then Format.printf "Extension works\n%!";
          incr Stats.infer;
          prove_clause sub1 (lit ^:^ path, nfea, lem, lim - 1) (cla1, hsh)
          |> LazyList.map (fun (sub2, prfs) -> (sub2, lit :: lem, Extension (hsh, prfs)))
        )
      |> LazyList.concat in
    prfs ^@^ exts)
  and extensions_trad = lazy (
    db_entries sub neglit |> relevancel (Lazy.force nfea) |> LazyList.of_list
    |> LazyList.map (fun ((_, _, vars, hsh) as contra, freqs) ->
      if !verbose then Format.printf "Extension try %s (for lit %s, lim %d)\n%!" (Hashtbl.find Database.Clausal.no_contr hsh) (string_of_lit lit) lim;
      if lim <= 0 && vars > 0 then nil
      else match Subst.unify_rename sub (snd lit) contra with
        Some (sub1, cla1) ->
          if !verbose then Format.printf "Extension works\n%!";
          incr Stats.infer;
          prove_clause sub1 (lit ^:^ path, Lazy.force nfea, lem, lim - 1) (cla1, hsh)
          |> LazyList.map (fun (sub2, prfs) -> (sub2, lit :: lem, Extension (hsh, prfs)))
      | None -> nil)
    |> LazyList.concat) in
  let extensions = Lazy.force (
    if Stats.(!infer - !depthinfer) < !tradinfs then extensions_trad
    else extensions_monte) in
  cut !cut1 lemmas (cut !cut2 reductions (cut !cut3 extensions nil))
and prove_clause sub (path, fea, lem, lim) (cl, cl_hsh) = match cl with
    lit :: lits ->
    if (List.exists (fun x -> LazyList.exists (Subst.eq sub x) path)) cl then (if !verbose then Format.printf "regularity\n%!"; nil)
    else
    begin
    prove_lit sub (path, fea, lem, lim) lit
    |> Litdata.trace_proofs lit cl_hsh
    |> LazyList.map
       (fun (sub1, lem1, prf1) -> prove_clause sub1 (path, fea, lem1, lim) (lits, cl_hsh) |> LazyList.map
       (fun (sub2, prfs) -> (sub2, (lit, prf1) :: prfs)))
    |> LazyList.concat
    end
  | [] -> cons (sub, []) nil


let start lim =
  if !verbose then Format.printf "Start %d\n%!" lim;
  Hashtbl.clear Litdata.cl_out;
  prove_lit (Subst.empty 0, 0) (nil, FM.empty, [], lim) (hash, []) |> LazyList.peek
  >>= fun (sub, _, prfs) -> Some (fst sub, prfs)


let print_stats () = let open Stats in
  Format.printf "%% UCTIters: %i UCTInf: %i UCTSimSteps: %i UCTDiscr: %f/%f\n"
    !uct_iterations !Montelib.uct_infer !Montelib.uct_sim_steps
    (fst !uct_discr_sum /. float_of_int !uct_proofs)
    (snd !uct_discr_sum /. float_of_int !uct_proofs);
  Format.printf "%% MLLitRatio: %d/%d (%f)\n"
   !Montelib.ml_lit_found !Montelib.ml_lit_tried
   (float_of_int !Montelib.ml_lit_found /. float_of_int !Montelib.ml_lit_tried);
  Format.printf "%% Inf: %i Depth: %i DInf: %i Str: %i\n"
    !infer !depth (!infer - !depthinfer) !strategy

let leancop file =
  try
    let load_db conj def = axioms2db !classifier (file_mat conj def file) in
    run_schedule load_db start |> Option.map Litdata.write_trace |> show_result;
    print_stats ()
  with e -> print_error e; print_stats ()

let _ =
  setup_signals ();

  let tosolve = ref [] in
  let speclist = Arg.align Arglean.(Monte.args @ Trace.args @ Lean.args @ Random.args @ Oracle.args @ Female.args) in
  let usage = "Usage: montecop [options] <file.p>\nAvailable options are:" in
  Arg.parse speclist (fun s -> tosolve := s :: !tosolve) usage;

  if !do_nbayes then classifier := FClassifier.load (!cdata);
  if !ldata <> "" then Litdata.add_cl_in (File.lines_of !ldata);

  if !tosolve = [] then Arg.usage speclist usage
  else List.iter leancop (List.rev !tosolve)
