Theory Error_Monad

theory Error_Monad
imports Monad_Syntax Error_Syntax
(* Title:     Error_Monad
   Author:    Christian Sternagel
   Author:    René Thiemann
*)

section ‹The Sum Type as Error Monad›

theory Error_Monad
imports
  "HOL-Library.Monad_Syntax"
  Error_Syntax
begin

text ‹Make monad syntax (including do-notation) available for the sum type.›
definition bind :: "'e + 'a ⇒ ('a ⇒ 'e + 'b) ⇒ 'e + 'b"
where
  "bind m f = (case m of Inr x ⇒ f x | Inl e ⇒ Inl e)"

adhoc_overloading
  Monad_Syntax.bind bind

abbreviation (input) "return ≡ Inr"
abbreviation (input) "error ≡ Inl"
abbreviation (input) "run ≡ projr"


subsection ‹Monad Laws›

lemma return_bind [simp]:
  "(return x ⤜ f) = f x"
  by (simp add: bind_def)

lemma bind_return [simp]:
  "(m ⤜ return) = m"
  by (cases m) (simp_all add: bind_def)

lemma error_bind [simp]:
  "(error e ⤜ f) = error e"
  by (simp add: bind_def)

lemma bind_assoc [simp]:
  fixes m :: "'a + 'b"
  shows "((m ⤜ f) ⤜ g) = (m ⤜ (λx. f x ⤜ g))"
  by (cases m) (simp_all add: bind_def)

lemma bind_cong [fundef_cong]:
  fixes m1 m2 :: "'e + 'a"
    and f1 f2 :: "'a ⇒ 'e + 'b"
  assumes "m1 = m2"
    and "⋀y. m2 = Inr y ⟹ f1 y = f2 y"
  shows "(m1 ⤜ f1) = (m2 ⤜ f2)"
  using assms by (cases "m1") (auto simp: bind_def)

definition catch_error :: "'e + 'a ⇒ ('e ⇒ 'f + 'a) ⇒ 'f + 'a"
where
  catch_def: "catch_error m f = (case m of Inl e ⇒ f e | Inr x ⇒ Inr x)"

adhoc_overloading
  Error_Syntax.catch catch_error

lemma catch_splits:
  "P (try m catch f) ⟷ (∀e. m = Inl e ⟶ P (f e)) ∧ (∀x. m = Inr x ⟶ P (Inr x))"
  "P (try m catch f) ⟷ (¬ ((∃e. m = Inl e ∧ ¬ P (f e)) ∨ (∃x. m = Inr x ∧ ¬ P (Inr x))))"
  by (case_tac [!] m) (simp_all add: catch_def)

abbreviation update_error :: "'e + 'a ⇒ ('e ⇒ 'f) ⇒ 'f + 'a"
where
  "update_error m f ≡ try m catch (λx. error (f x))"

adhoc_overloading
  Error_Syntax.update_error update_error

lemma catch_return [simp]:
  "(try return x catch f) = return x" by (simp add: catch_def)

lemma catch_error [simp]:
  "(try error e catch f) = f e" by (simp add: catch_def)

lemma update_error_return [simp]:
  "(m <+? c = return x) ⟷ (m = return x)"
  by (cases m) simp_all

definition "isOK m ⟷ (case m of Inl e ⇒ False | Inr x ⇒ True)"

lemma isOK_E [elim]:
  assumes "isOK m"
  obtains x where "m = return x"
  using assms by (cases m) (simp_all add: isOK_def)

lemma isOK_I [simp, intro]:
  "m = return x ⟹ isOK m"
  by (cases m) (simp_all add: isOK_def)

lemma isOK_iff:
  "isOK m ⟷ (∃x. m = return x)"
  by blast

lemma isOK_error [simp]:
  "isOK (error x) = False"
  by blast

lemma isOK_bind [simp]:
  "isOK (m ⤜ f) ⟷ isOK m ∧ isOK (f (run m))"
  by (cases m) simp_all

lemma isOK_update_error [simp]:
  "isOK (m <+? f) ⟷ isOK m"
  by (cases m) simp_all

lemma isOK_case_prod [simp]:
  "isOK (case lr of (l, r) ⇒ P l r) = (case lr of (l, r) ⇒ isOK (P l r))"
  by (rule prod.case_distrib)

lemma isOK_case_option [simp]:
  "isOK (case x of None ⇒ P | Some v ⇒ Q v) = (case x of None ⇒ isOK P | Some v ⇒ isOK (Q v))"
  by (cases x) (auto)        

lemma isOK_Let [simp]:
  "isOK (Let s f) = isOK (f s)"
  by (simp add: Let_def)

lemma run_bind [simp]:
  "isOK m ⟹ run (m ⤜ f) = run (f (run m))"
  by auto

lemma run_catch [simp]:
  "isOK m ⟹ run (try m catch f) = run m"
  by auto

fun foldM :: "('a ⇒ 'b ⇒ 'e + 'a) ⇒ 'a ⇒ 'b list ⇒ 'e + 'a"
where 
  "foldM f d [] = return d" |
  "foldM f d (x # xs) = do { y ← f d x; foldM f y xs }"

fun forallM_index_aux :: "('a ⇒ nat ⇒ 'e + unit) ⇒ nat ⇒ 'a list ⇒ (('a × nat) × 'e) + unit"
where
  "forallM_index_aux P i [] = return ()" |
  "forallM_index_aux P i (x # xs) = do {
    P x i <+? Pair (x, i);
    forallM_index_aux P (Suc i) xs
  }"

lemma isOK_forallM_index_aux [simp]:
  "isOK (forallM_index_aux P n xs) = (∀i < length xs. isOK (P (xs ! i) (i + n)))"
proof (induct xs arbitrary: n)
  case (Cons x xs)
  have "(∀i < length (x # xs). isOK (P ((x # xs) ! i) (i + n))) ⟷
    (isOK (P x n) ∧ (∀i < length xs. isOK (P (xs ! i) (i + Suc n))))"
    by (auto, case_tac i) (simp_all)
  then show ?case
    unfolding Cons [of "Suc n", symmetric] by simp
qed auto

definition forallM_index :: "('a ⇒ nat ⇒ 'e + unit) ⇒ 'a list ⇒ (('a × nat) × 'e) + unit"
where
  "forallM_index P xs = forallM_index_aux P 0 xs"

lemma isOK_forallM_index [simp]:
  "isOK (forallM_index P xs) ⟷ (∀i < length xs. isOK (P (xs ! i) i))"
  unfolding forallM_index_def isOK_forallM_index_aux by simp

lemma forallM_index [fundef_cong]:
  fixes c :: "'a ⇒ nat ⇒ 'e + unit"
  assumes "⋀x i. x ∈ set xs ⟹ c x i = d x i"
  shows "forallM_index c xs = forallM_index d xs"
proof -
  { fix n
    have "forallM_index_aux c n xs = forallM_index_aux d n xs"
      using assms by (induct xs arbitrary: n) simp_all }
  then show ?thesis by (simp add: forallM_index_def)
qed

hide_const forallM_index_aux

text ‹
  Check whether @{term f} succeeds for all elements of a given list. In case it doesn't,
  return the first offending element together with the produced error.
›
fun forallM :: "('a ⇒ 'e + unit) ⇒ 'a list ⇒ ('a * 'e) + unit"
where
  "forallM f [] = return ()" |
  "forallM f (x # xs) = f x <+? Pair x ⪢ forallM f xs"

lemma isOK_forallM [simp]:
  "isOK (forallM f xs) ⟷ (∀x ∈ set xs. isOK (f x))"
  by (induct xs) (simp_all)

text ‹
  Check whether @{term f} succeeds for at least one element of a given list.
  In case it doesn't, return the list of produced errors.
›
fun existsM :: "('a ⇒ 'e + unit) ⇒ 'a list ⇒ 'e list + unit"
where
  "existsM f [] = error []" |
  "existsM f (x # xs) = (try f x catch (λe. existsM f xs <+? Cons e))"

lemma isOK_existsM [simp]:
  "isOK (existsM f xs) ⟷ (∃x∈set xs. isOK (f x))"
proof (induct xs)
  case (Cons x xs)
  show ?case
  proof (cases "f x")
    case (Inl e)
    with Cons show ?thesis by simp
  qed (auto simp add: catch_def)
qed simp

lemma is_OK_if_return [simp]:
  "isOK (if b then return x else m) ⟷ b ∨ isOK m"
  "isOK (if b then m else return x) ⟷ ¬ b ∨ isOK m"
  by simp_all

lemma isOK_if_error [simp]:
  "isOK (if b then error e else m) ⟷ ¬ b ∧ isOK m"
  "isOK (if b then m else error e) ⟷ b ∧ isOK m"
  by simp_all

lemma isOK_if:
  "isOK (if b then x else y) ⟷ b ∧ isOK x ∨ ¬ b ∧ isOK y"
  by simp

fun sequence :: "('e + 'a) list ⇒ 'e + 'a list"
where
  "sequence [] = Inr []" |
  "sequence (m # ms) = do {
    x ← m;
    xs ← sequence ms;
    return (x # xs)
  }"


subsection ‹Monadic Map for Error Monad›

fun mapM :: "('a ⇒ 'e + 'b) ⇒ 'a list ⇒ 'e + 'b list"
where
  "mapM f [] = return []" |
  "mapM f (x#xs) = do {
    y ← f x;
    ys ← mapM f xs;
    Inr (y # ys)
  }"

lemma mapM_error:
  "(∃e. mapM f xs = error e) ⟷ (∃x∈set xs. ∃e. f x = error e)"
proof (induct xs)
  case (Cons x xs)
  then show ?case
    by (cases "f x", simp_all, cases "mapM f xs", simp_all)
qed simp

lemma mapM_return:
  assumes "mapM f xs = return ys"
  shows "ys = map (run ∘ f) xs ∧ (∀x∈set xs. ∀e. f x ≠ error e)"
using assms
proof (induct xs arbitrary: ys)
  case (Cons x xs ys)   
  then show ?case
    by (cases "f x", simp, cases "mapM f xs", simp_all)
qed simp

lemma mapM_return_idx:
  assumes *: "mapM f xs = Inr ys" and "i < length xs" 
  shows "∃y. f (xs ! i) = Inr y ∧ ys ! i = y"
proof -
  note ** = mapM_return [OF *, unfolded set_conv_nth]
  with assms have "⋀e. f (xs ! i) ≠ Inl e" by auto
  then obtain y where "f (xs ! i) = Inr y" by (cases "f (xs ! i)") auto
  then have "f (xs ! i) = Inr y ∧ ys ! i = y" unfolding ** [THEN conjunct1] using assms by auto
  then show ?thesis ..
qed

lemma mapM_cong [fundef_cong]:
  assumes "xs = ys" and "⋀x. x ∈ set ys ⟹ f x = g x"
  shows "mapM f xs = mapM g ys"
  unfolding assms(1) using assms(2) by (induct ys) auto

lemma bindE [elim]:
  assumes "(p ⤜ f) = return x"
  obtains y where "p = return y" and "f y = return x"
  using assms by (cases p) simp_all

lemma then_return_eq [simp]:
  "(p ⪢ q) = return f ⟷ isOK p ∧ q = return f"
  by (cases p) simp_all

fun choice :: "('e + 'a) list ⇒ 'e list + 'a"
where
  "choice [] = error []" |
  "choice (x # xs) = (try x catch (λe. choice xs <+? Cons e))"

declare choice.simps [simp del]

lemma isOK_mapM:
  assumes "isOK (mapM f xs)"
  shows "(∀x. x ∈ set xs ⟶ isOK (f x)) ∧ run (mapM f xs) = map (λx. run (f x)) xs"
  using assms mapM_return[of f xs] by (force simp: isOK_def split: sum.splits)+

fun firstM
  where
    "firstM f [] = error []"
  | "firstM f (x # xs) = (try f x ⪢ return x catch (λe. firstM f xs <+? Cons e))"

lemma firstM:
  "isOK (firstM f xs) ⟷ (∃x∈set xs. isOK (f x))"
  by (induct xs) (auto simp: catch_def split: sum.splits)

lemma firstM_return:
  assumes "firstM f xs = return y"
  shows "isOK (f y) ∧ y ∈ set xs"
  using assms by (induct xs) (auto simp: catch_def split: sum.splits)


end