open Batteries

module Make (Feature : Map.OrderedType) = struct

module Fm = Map.Make(Feature)

(*
The naming scheme is derived from the data representation

    data[l][i][f],

where l is a label, i is a training example index and f is a feature.
data[l][i][f] is 1 if the feature f co-occurs with label l in
training example i, and 0 otherwise.
*)


type 'a t =
    (* number of training examples *)
  { mutable sum_lbl_card : int
    (* number of all feature occurrences *)
  ; mutable sum_lbl_te_ftr : float
    (* number of times label occurs in the training examples (tfreq) *)
  ; lbl_card : ('a, int) Hashtbl.t
    (* number of all feature occurrences for a label *)
  ; lbl_sum_te_ftr : ('a, float) Hashtbl.t
    (* number of times feature occurs in the training examples *)
  ; ftr_sum_lbl_te : (Feature.t, float) Hashtbl.t
    (* how often did label occur together with feature (sfreq) *)
  ; lbl_ftr_sum_te : ('a, float Fm.t) Hashtbl.t
    (* feature IDF *)
  ; mutable ftr_idf : float Fm.t
  }

let empty () : 'a t =
  { sum_lbl_card   = 0
  ; sum_lbl_te_ftr = 0.
  ; lbl_card       = Hashtbl.create 1
  ; lbl_sum_te_ftr = Hashtbl.create 1
  ; ftr_sum_lbl_te = Hashtbl.create 1
  ; lbl_ftr_sum_te = Hashtbl.create 1
  ; ftr_idf        = Fm.empty
  }


(* ****************************************************************************)
(* I/O *)

let statistics {sum_lbl_card; ftr_sum_lbl_te; lbl_card} =
[ Printf.sprintf "%i training examples" sum_lbl_card
; Printf.sprintf "%i features" (Hashtbl.length ftr_sum_lbl_te)
; Printf.sprintf "%i labels" (Hashtbl.length lbl_card)
]

let load fp =
  let d = File.with_file_in fp input_value in
  List.iter (Printf.printf "Read: %s\n%!") (statistics d);
  d

let try_load fp = try load fp with 
  | Sys_error e -> empty () (* don't complain if file not found *)
  | _ -> failwith  ("Error reading " ^ fp ^ " file")

let write d fp =
  File.with_file_out fp (fun c -> output_value c d);
  Printf.printf "Wrote %s with %i training examples\n%!" fp d.sum_lbl_card


(* ****************************************************************************)
(* obtain learned data *)

let get_idf d ftr = try Fm.find ftr d.ftr_idf with Not_found -> 0.

let get_lbl_data d lbl =
  try (Hashtbl.find d.lbl_card lbl, Hashtbl.find d.lbl_sum_te_ftr lbl, Hashtbl.find d.lbl_ftr_sum_te lbl)
  with Not_found -> (0, 0., Fm.empty)


(* ****************************************************************************)
(* update learned data *)

let calc_ftr_idf d =
  let calc_idf ftr_sum_lbl_te =
    log (float_of_int d.sum_lbl_card) -.
    log ftr_sum_lbl_te in
  Hashtbl.fold (fun ftr occ -> Fm.add ftr (calc_idf occ)) d.ftr_sum_lbl_te Fm.empty

let add_training_ex d (ftrs, lbl) =
  let add_ftr ftr w = Fm.modify_def 0. ftr ((+.) w) in
  d.sum_lbl_card <- d.sum_lbl_card + 1;
  d.sum_lbl_te_ftr <- Fm.fold (const (+.)) ftrs d.sum_lbl_te_ftr;
  Hashtbl.modify_def 0 lbl ((+) 1) d.lbl_card;
  Hashtbl.modify_def 0. lbl (Fm.fold (const (+.)) ftrs) d.lbl_sum_te_ftr;
  Hashtbl.modify_def Fm.empty lbl (Fm.fold add_ftr ftrs) d.lbl_ftr_sum_te;
  Fm.iter (fun k w -> Hashtbl.modify_def 0. k ((+.) w) d.ftr_sum_lbl_te) ftrs

let add_training_exs d l =
  List.iter (add_training_ex d) l;
  Printf.printf "pf, cn and cn_pf frequencies updated\n%!";
  d.ftr_idf <- calc_ftr_idf d;
  Printf.printf "IDF information calculated\n%!"


(* ****************************************************************************)
(* Naive Bayes relevance *)

let maps_partition m1 m2 =
  Fm.fold (fun k1 v1 (l, i, r) ->
    try let v2 = Fm.find k1 m2 in (l, Fm.add k1 (v1, v2) i, Fm.remove k1 r)
    with Not_found -> (Fm.add k1 v1 l, i, r)) m1 (Fm.empty, Fm.empty, m2)

let relevance idf fl fi fr ftrs sfreq =
  let (l, i, r) = maps_partition ftrs sfreq in
  let sum f = Fm.enum %> map (fun (ftr, v) -> f (idf ftr) v) %> Enum.fsum in
  sum fl l +. sum fi i +. sum fr r

(* optimized for small m1 *)
let maps_fold_inter (f1, f12) (m1, m2) = Fm.fold (fun k v1 acc ->
  try let v2 = Fm.find k m2 in f12 k (v1, v2) acc
  with Not_found -> f1 k v1 acc) m1

let lbl_relevance d (lbl_card, _, _) =
  log (lbl_card +. 1.) -. log (float_of_int (d.sum_lbl_card + 1))

let ftr_relevance d ftrs (lbl_card, lbl_sum_te_ftr, lbl_ftr_sum_te) =
  let f1 ftr ftr_w (sum, argmin, n_inter, n_disj) = (sum, argmin, n_inter, n_disj + 1) in
  let f12 ftr (ftr_w, coocc) (sum, argmin, n_inter, n_disj) =
    let p = log coocc -. log lbl_card in
    sum +. p, min p argmin, n_inter + 1, n_disj in
  let inter_prob, argmin, n_inter, n_disj =
    maps_fold_inter (f1, f12) (ftrs, lbl_ftr_sum_te) (0., 0., 0, 0) in
  (*Format.printf "n_inter: %d\n%!" n_inter;
  Format.printf "n_disj: %d\n%!" n_disj;
  Format.printf "argmin: %f\n%!" argmin;*)
  (argmin, (inter_prob, n_inter, n_disj))
  (*(inter_prob +. (min argmin (-1.)) *. float_of_int n_disj) /.
  (float_of_int (n_inter + n_disj))*)

end
