open Batteries

type reward_t = float
type probability = float

type ('s, 'r) problem = {
  successors : 's -> ('s option Lazy.t * probability) list
; reward : 's -> reward_t * 'r
}

type 's tree = {
  state : 's
; visits : int
; rewards : reward_t
; embryos : 's LazyList.t
; children : 's tree list
}

let rec dot_tree (ft, fs) ({children; embryos} as tree) = let open Dot in
  let cs = List.map (fun c -> [], dot_tree (ft, fs) c) children
  (*and es = List.map (fun e -> ["style", "dotted"], Node (fs e, [])) (LazyList.to_list embryos)*)
  in Node (ft tree, cs (*@ es*))

type ('s, 'r) uct = {
  problem : ('s, 'r) problem
; tree_policy : 's tree -> 's tree -> float
; simulation_policy : 's -> 's list
; expansion_policy : 's list -> 's tree * (reward_t * 'r)
}

let weighted_shuffle l = l |>
  List.map (fun (x, w) -> x, Random.float w) |>
  List.sort (fun (_, w1) (_, w2) -> compare w2 w1) |>
  List.map fst |>
  LazyList.of_list |>
  LazyList.filter_map Lazy.force

let empty_tree p s = {
  state = s
; visits = 0
; rewards = 0.
; embryos = weighted_shuffle (p.successors s)
; children = []
}

let avg_reward tree = tree.rewards /. float_of_int tree.visits

let c_p = sqrt 2.
let uct_tree_policy exploration parent child =
  let fi = float_of_int in
  avg_reward child +.
  exploration *. c_p *. sqrt (log (fi parent.visits) /. fi child.visits)

let child_cmp f c1 c2 = compare (f c2) (f c1)

(* cdf [("a", 0.1); ("b", 0.9)] =
     ([("b", (0.1, 0.9, 1.)); ("a", (0., 0.1, 0.1))], 1.)
*)
let cdf l = List.fold_left (fun (acc, sum) (x, w) ->
  let sum' = sum +. w in ((x, (sum, w, sum')) :: acc, sum')) ([], 0.0) l

let cdf_sample xs =
  let (xs', lim) = cdf xs in
  let r = Random.float lim in
  xs', fun (x, (min, _, max)) -> min <= r && r <= max

let weighted_draw xs =
  let xs', crit = cdf_sample xs in
  let (inr, ofr) = List.partition crit xs' in
  List.hd inr, List.map (fun (x, (_, w, _)) -> x, w) (List.tl inr @ ofr)

let weighted_sample xs =
  let xs', crit = cdf_sample xs in List.find crit xs'

let uniform_sample xs =
  List.nth xs (Random.int (List.length xs))


let greedy_sample xs = List.hd (List.sort (fun (_, x) (_, y) -> compare y x) xs)


let rec default_simulation_policy p depth s =
  if depth <= 0 then [s]
  else
    match LazyList.get (weighted_shuffle (p.successors s)) with
      None -> [s]
    | Some (x, xs) -> s :: default_simulation_policy p (depth-1) x

let visit v (reward, result) =
  {v with visits = v.visits + 1; rewards = v.rewards +. reward}, (reward, result)

let state_expansion p s0 sn = visit (empty_tree p s0) (p.reward sn)

let single_expansion_policy p simulation =
  state_expansion p (List.hd simulation) (List.last simulation)

let rec tree_policy uct v =
  match LazyList.get v.embryos with
    None -> begin
    match List.sort (child_cmp (uct.tree_policy v)) v.children with
      [] -> v, uct.problem.reward v.state
    | best :: rest ->
        let best', reward = tree_policy uct best in
        let children' =
          if best'.children == [] && Option.is_none (LazyList.peek best'.embryos)
          then rest
          else best' :: rest in
        visit {v with children = children'} reward
    end
  | Some (chosen, rest) ->
    let simulation = uct.simulation_policy chosen in
    let expansion, reward = uct.expansion_policy simulation in
    visit {v with children = expansion :: v.children; embryos = rest} reward

let single_iteration uct v =
  if v.children == [] && Option.is_none (LazyList.peek v.embryos)
  then raise LazyList.No_more_elements
  else let v', rewres = tree_policy uct v in (v', rewres), v'

let iteration uct v = LazyList.from_loop v (single_iteration uct)

let search uct s = iteration uct (empty_tree uct.problem s)
 
let by_reward c1 c2 = compare (avg_reward c2) (avg_reward c1)

let by_visits c1 c2 = compare c2.visits c1.visits

let rec best_state cmp tree = match List.sort cmp tree.children with
    best :: _ -> best_state cmp best
  | [] -> tree.state
