Theory LLVM_Step

theory LLVM_Step
imports Misc_Aux LLVM_Syntax Word
theory LLVM_Step
  imports
    Misc_Aux
    LLVM_Syntax
    "HOL-Word.Word"
    "Certification_Monads.Error_Monad"
begin

(* TODO:
- rename error to stuck?
- Instead of saving the instructions in the frame themselves only the position is now saved.
  It should work for now. But with saving the position there should be some further improvements
  possible. *)
datatype stuck = StaticError String.literal |
                  ExternalFunctionCall name |
                  ReturnValue llvm_constant |
                  Program_Termination

type_synonym 'a error = "stuck + 'a"

fun option_to_sum :: "'a option ⇒ 'b ⇒ 'b + 'a" where
  "option_to_sum (Some x) _ = Inr x"
| "option_to_sum None y = Inl y"

fun static_error :: "String.literal ⇒ 'a error" where
  "static_error m = error (StaticError m)"

fun inc_pos where
  "inc_pos (fn, bn, n) = (fn, bn, Suc n)"

fun option_to_error :: "'a option ⇒ String.literal ⇒ 'a error" where
  "option_to_error x m = option_to_sum x (StaticError m)"

lemma option_to_error: "option_to_error x s = Inr y ⟷ x = Some y"
  by (cases x) (auto)

datatype action = Instruction "instruction named" | Terminator "terminator named"

fun blocks :: "llvm_fun ⇒ basic_block list" where
  "blocks f = hd_blocks f # tl_blocks f"

fun find_fun :: "llvm_prog ⇒ name ⇒ llvm_fun error" where
  "find_fun prog n = (let g = (λf. (fun_name f, f));
                          prog' = map_of (map g (funs prog)) in
                          option_to_error (prog' n) (STR ''Cannot find function''))"

fun find_block :: "llvm_fun ⇒ name ⇒ basic_block error" where
  "find_block f n = (let g = (λb. (name b, b));
                         blocks' = map_of (map g (blocks f)) in
                         option_to_error (blocks' n) STR ''Cannot find block'')"

fun find_fun_block :: "llvm_prog ⇒ name ⇒ name ⇒ basic_block error" where
  "find_fun_block prog fn bn = do {
     f ← find_fun prog fn;
     find_block f bn}"

fun nth_option :: "'a list ⇒ nat ⇒ 'a option" where
  "nth_option (_#xs) (Suc n) = nth_option xs n" |
  "nth_option (x#_) 0 = Some x" |
  "nth_option [] _ = None"

fun find_action :: "basic_block ⇒ nat ⇒ action error" where
  "find_action b n =
    (if n < length (phis b) then static_error STR ''jumping to phi node not possible'' else
      case nth_option (instructions b) (n - length (phis b)) of
        Some i ⇒ Inr (Instruction i) |
        None ⇒ Inr (Terminator (terminator b)))"

lemma find_action_terminator:
  "find_action b p = Inr (Terminator t) ⟹ terminator b = t"
  by (auto split: option.splits if_splits)

(* TODO: find_statement should return "String.literal ⇒ String.literal + action"
  and not "action error" *)
fun find_statement :: "llvm_prog ⇒ LLVM_Syntax.pos ⇒ action error" where
  "find_statement prog (fn, bn, p) = do {
     b ← find_fun_block prog fn bn;
     find_action b p}"

(* TODO: find_phis should return "String.literal ⇒ String.literal + action"
  and not "action error" *)
fun find_phis :: "llvm_prog ⇒ name ⇒ name ⇒ phi list error" where
  "find_phis prog fn bn = map_sum id phis (find_fun_block prog fn bn)"

(* XXX: We use unbounded integer arithmetic here!
   For real LLVM semantics we would need bitvector arithmetic *)
fun binop_llvm_constant where
  "binop_llvm_constant g (IntConstant a n) (IntConstant a' n') =
     (if a = a' ∧ a > 0 then return (IntConstant a (g n n')) else
       static_error STR ''Constant have different types'')"


fun select_llvm_constant :: "llvm_constant ⇒ llvm_constant ⇒ llvm_constant ⇒ llvm_constant error" where
  "select_llvm_constant (IntConstant a n) o1 o2 = return (if n = 1 then o1 else o2)"

declare [[syntax_ambiguity_warning = false]]

(* TODO: add option, ULT vs SLT *)
fun icmp_llvm_constant :: "integerPredicate ⇒ llvm_constant ⇒ llvm_constant ⇒ llvm_constant" where
  "icmp_llvm_constant EQ c1 c2 = (IntConstant 1 (if integerValue c1 = integerValue c2 then 1 else 0))" |
  "icmp_llvm_constant NE c1 c2 = (IntConstant 1 (if integerValue c1 ≠ integerValue c2 then 1 else 0))" |
  "icmp_llvm_constant SLT c1 c2 =
    (IntConstant 1 (if integerValue c1 < integerValue c2 then 1 else 0))"  |
  "icmp_llvm_constant SGT c1 c2 =
     (IntConstant 1 (if integerValue c1 > integerValue c2 then 1 else 0))" |
  "icmp_llvm_constant SGE c1 c2 =
    (IntConstant 1 (if integerValue c1 ≥ integerValue c2 then 1 else 0))" |
  "icmp_llvm_constant SLE c1 c2 =
    (IntConstant 1 (if integerValue c1 ≤ integerValue c2 then 1 else 0))"

locale small_step =
  fixes
    lc :: llvm_prog and
    ls :: llvm_state and
    cf :: frame and
    fs :: "frame list"
begin

definition c_pos where "c_pos = pos cf"

definition map_of_funs :: "name ⇀ llvm_fun" where
  "map_of_funs = map_of (map (λf. (fun_name f, f)) (funs lc))"

fun lookup_block :: "name ⇒ name ⇒ basic_block error" where
  "lookup_block f b = (case (case map_of_funs f of Some (llvm_fun.Function _ _ _ c cs) ⇒
                         map_of (map (λb. (basic_block.name b, b)) (c#cs)) b) of Some b ⇒ return b)"

fun operand_value :: "operand ⇒ llvm_constant error" where
  "operand_value (LocalReference n) = option_to_error (Mapping.lookup (stack cf) n) STR ''Could not find register''" |
  "operand_value (ConstantOperand i) = return i"

fun update_stack where
  "update_stack f' n o' = LLVM_Syntax.update_stack (λ_. Mapping.update n o' (stack f')) f'"

fun update_frame :: "frame ⇒ llvm_state error" where
  "update_frame f = return (update_frames (λ_. f#fs) ls)"

fun operand_binop where
  "operand_binop o1 o2 bo =
    do {
      let m = stack cf;
      o1 ← operand_value o1;
      o2 ← operand_value o2;
      bo o1 o2}"

fun operand_select where
  "operand_select c o1 o2 =
    do {
      let m = stack cf;
      c ← operand_value c;
      o1 ← operand_value o1;
      o2 ← operand_value o2;
      select_llvm_constant c o1 o2}"

fun update_frames_stack where
  "update_frames_stack n v =
    do {o' ← v;
        let f = update_stack cf n o';
        let (nf, nb, p) = pos cf;
        update_frame (update_pos (λ_. (nf, nb, p + 1)) f)}"

fun zip_parameters where
  "zip_parameters (x#xs) ((Parameter t n)#ps) = do {s ← zip_parameters xs ps; y ← operand_value x; return (Mapping.update n y s)}" |
  "zip_parameters [] [] = return Mapping.empty" |
  "zip_parameters _ _ = static_error STR ''Wrong number of arguments''"

(* TODO: rewrite enter_frame to not do pattern matching again *)
fun enter_frame :: "name ⇒ operand list ⇒ frame error" where
  "enter_frame n os =
    (case (map_of_funs n) of
      Some (Function _ n ps b _) ⇒ do {
         s ← zip_parameters os ps;
         return (Frame (n, basic_block.name b, 0) s)}
     | Some (ExternalFunction _ fn _) ⇒ error (ExternalFunctionCall fn)
     | None ⇒ static_error STR ''Undefined function'')"

fun call_function :: "name ⇒ operand list ⇒ llvm_state error" where
  "call_function fn os =
    (case (map_of_funs fn) of
      Some (Function _ fn ps b _) ⇒ do {
         f ← enter_frame fn os;
         return (update_frames ((λ_. f#cf#fs)) ls)}
     | Some (ExternalFunction _ fn _) ⇒ error (ExternalFunctionCall fn)
     | None ⇒ static_error STR ''Undefined function'')"

fun binop_instruction :: "binop_instruction ⇒ int ⇒ int ⇒ int" where
  "binop_instruction Add = (+)"
| "binop_instruction Sub = (-)"
| "binop_instruction Mul = (*)"
| "binop_instruction Xor = (bitXOR)"

fun run_instruction :: "instruction named ⇒ llvm_state error" where
  "run_instruction i =
    (case i of
      Named n i ⇒
        (let g = (λo1 o2 h. update_frames_stack n (operand_binop o1 o2 h))  in
        (case i of
          Binop binop o1 o2 ⇒ g o1 o2 (binop_llvm_constant (binop_instruction binop))  |
          Select c o1 o2 ⇒ update_frames_stack n (operand_select c o1 o2) |
          Icmp c o1 o2 ⇒ g o1 o2 (λi1 i2. return (icmp_llvm_constant c i1 i2)) |
          Call t fn ps ⇒ call_function fn ps))
      | Do i ⇒ static_error STR ''Unnamed operation not yet supported'')"
(*
fun run_instruction' :: "instruction named ⇒ llvm_state error" where
  "run_instruction' i =
    (case i of
      Named n i ⇒
        (let g = (λo1 o2 h. update_frames_stack n (operand_binop o1 o2 h))  in
        (case i of
          Binop binop o1 o2 ⇒ g o1 o2 (binop_llvm_constant (binop_instruction binop))  |
          Select c o1 o2 ⇒ update_frames_stack n (operand_select c o1 o2) |
          Icmp c o1 o2 ⇒ g o1 o2 (λi1 i2. return (icmp_llvm_constant c i1 i2)) |
          Call t n ps ⇒ call_function n ps ⤜ (λf. return (Llvm_state [f]))))
      | Do i ⇒ static_error STR ''Unnamed operation not yet supported'')"
*)

definition phi_bid where "phi_bid old_b_id ps = map_of (map prod.swap ps) old_b_id"

fun compute_phi where
  "compute_phi old_b_id xs =
     do {x ← option_to_error (phi_bid old_b_id xs) STR ''Previous block not found in phi expression'';
         operand_value x}"

fun compute_phis :: "name ⇒ phi list ⇒ _ error" where
  "compute_phis old_b_id ((a, ps)#as) =
     do {
       c ← compute_phi old_b_id ps;
       s ← compute_phis old_b_id as;
       return ((a,c)#s)}" |
  "compute_phis _ [] = return []"

(*
Searches for next block, computes the phis and jumps to the first line after the phis
*)
fun update_bid_frame :: "name ⇒ frame error" where
  "update_bid_frame new_b_id = do {
    let (func_id, old_b_id, _) = c_pos;
    φs ← (find_phis lc func_id new_b_id);
    s ← compute_phis old_b_id φs;
    let s' = foldr (λ(k,v). Mapping.update k v) s (stack cf);
    return (Frame (func_id, new_b_id, length φs) s')
}"

fun ret_from_frame :: "action ⇒ llvm_constant ⇒ frame ⇒ frame list ⇒ llvm_state error" where
  "ret_from_frame i c1 f' fs' =
    (case i of
      Instruction (Named n (Call _ _ _)) ⇒ do {let f' = update_pos inc_pos (update_stack f' n c1);
                                                return (Llvm_state (f'#fs'))} |
      _ ⇒ static_error STR ''Implement me'')"

fun condBr_to_frame where
  "condBr_to_frame (IntConstant l i) id_t id_f =
    (if l = 1 ∧ i = 1 then Inr id_t
     else if l = 1 ∧ i = 0 then Inr id_f
     else static_error STR ''condBr operand not of type i1'')"


(*
NOTE: terminate_frame is kind of a misnomer here
the frame is updated in case of breaks and only terminated in case of return
*)
definition terminate_frame :: "terminator named ⇒ llvm_state error" where
  "terminate_frame t = do {
    (case named_instruction t of
      (CondBr c n1 n2) ⇒
        do { c ← operand_value c;
             n_id ← condBr_to_frame c n1 n2;
             nf ← update_bid_frame n_id;
             update_frame nf }
     | Br n1 ⇒ update_bid_frame n1 ⤜ update_frame
     | Ret (Some o1) ⇒ do {c ← operand_value o1;
                            (f', fs') ← (case fs of [] ⇒ error (ReturnValue c) | f'# fs' ⇒ return (f', fs'));
                            i ← find_statement lc (pos f');
                            ret_from_frame i c f' fs'}
     | _ ⇒ static_error STR ''No support for empty return statement'')}"


end


fun step :: "llvm_prog ⇒ llvm_state ⇒ llvm_state error" where
  "step lf ls = (case frames ls of
      (f#fs) ⇒
        (case find_statement lf (pos f) of
          Inr (Terminator t)  ⇒ small_step.terminate_frame lf ls f fs t |
          Inr (Instruction i) ⇒ small_step.run_instruction lf ls f fs i |
          _ ⇒ static_error STR ''Can't find next instruction'') |
     [] ⇒ Inl Program_Termination)"

definition step_relation :: "llvm_prog ⇒ llvm_state rel" where
  "step_relation prog = { (bef, aft) . step prog bef = Inr aft }"

find_theorems name: closure

term "(step_relation prog)*"


declare small_step.map_of_funs_def [code]
declare small_step.lookup_block.simps [code]
declare small_step.operand_value.simps [code]
declare small_step.update_stack.simps [code]
declare small_step.update_frame.simps [code]
declare small_step.operand_binop.simps [code]
declare small_step.operand_select.simps [code]
declare small_step.update_frames_stack.simps [code]
declare small_step.zip_parameters.simps [code]
declare small_step.call_function.simps [code]
declare small_step.run_instruction.simps [code]
declare small_step.update_bid_frame.simps [code]
declare small_step.ret_from_frame.simps [code]
declare small_step.terminate_frame_def [code]
declare small_step.phi_bid_def [code]
declare small_step.enter_frame.simps [code]

declare
  small_step.binop_instruction.simps [code]
  small_step.compute_phi.simps [code]
  small_step.compute_phis.simps [code]
  small_step.condBr_to_frame.simps [code]
  small_step.c_pos_def [code]


fun step_by_step :: "nat ⇒ llvm_prog ⇒ llvm_state ⇒ llvm_state error" where
  "step_by_step (Suc i) lc ls =
      (case step lc ls of Inr ls ⇒ step_by_step i lc ls | Inl e ⇒ Inl e)" |
  "step_by_step 0 lc ls = return ls"

partial_function (option) step_by_step' :: "llvm_prog ⇒ llvm_state ⇒ (stuck + llvm_state) option" where
  "step_by_step' lc ls =
    (case step lc ls of Inr ls ⇒ step_by_step' lc ls | Inl e ⇒ Some (Inl e))"


locale successful_step =
  fixes prog cs1 cs2
  assumes step_Inr: "step prog cs1 = Inr cs2"
begin


definition f1 where "f1 = hd (frames cs1)"
definition fs1 where "fs1 = tl (frames cs1)"

lemma frames_cs1: "frames cs1 = f1 # fs1"
  using step_Inr unfolding f1_def fs1_def
  by (auto split: list.splits option.splits)


lemma cs1: "cs1 = Llvm_state (f1 # fs1)"
  using frames_cs1 by (metis llvm_state.collapse)

sublocale small_step prog cs1 f1 fs1 .

definition c_action where "c_action = projr (find_statement prog (pos f1))"

lemma c_action_def': "find_statement prog (pos f1) = Inr c_action"
  using step_Inr unfolding c_action_def c_pos_def
  by (auto simp add: frames_cs1 split: list.splits option.splits sum.splits)

end

locale successful_step' = successful_step +
  fixes f_id b_id b_pos
  assumes pos_f1: "pos f1 = (f_id, b_id, b_pos)"
begin

lemma c_pos: "c_pos = (f_id, b_id, b_pos)"
  unfolding c_pos_def pos_f1 by simp

definition c_fun where "c_fun = projr (find_fun prog f_id)"
definition c_block where "c_block = projr (find_block c_fun b_id)"


lemma c_fun_def': "find_fun prog f_id = Inr c_fun"
  using c_action_def' unfolding c_fun_def pos_f1
  by (auto intro!: option.collapse split: Option.bind_splits)

lemma c_block_def': "find_fun_block prog f_id b_id = Inr c_block"
  using c_action_def' unfolding c_fun_def c_block_def pos_f1
  by (auto intro!: option.collapse split: Option.bind_splits sum_bind_splits)

lemma c_block_def'': "find_block c_fun b_id = Inr c_block"
  using c_action_def' unfolding c_fun_def c_block_def pos_f1
  by (auto intro!: option.collapse split: Option.bind_splits sum_bind_splits)

lemma c_find_action: "find_action c_block b_pos = Inr c_action"
  using c_action_def' c_block_def' unfolding pos_f1
  by (auto simp del: find_fun_block.simps)

lemma c_action_Terminator:
  "c_action = Terminator t ⟹ step prog cs1 = terminate_frame t"
  by (auto simp add: c_action_def' frames_cs1)

end

(* TODO: move to map_of_funs definition or more fitting location *)
lemma map_of_funs_name:
  assumes "small_step.map_of_funs prog n = Some (Function tg g psg bg bsg)"
  shows "n = g"
  using assms unfolding small_step.map_of_funs_def using map_of_SomeD by fastforce

definition run_prog where
  "run_prog p s =
  (let first_block_in_main = case small_step.map_of_funs p (Name s) of Some (Function _ _ _ b _) ⇒
   basic_block.name b in
step_by_step' p (Llvm_state [Frame (Name s, first_block_in_main, 0) Mapping.empty]))"

end