Theory Check_Monad

theory Check_Monad
imports Error_Monad
(* Title:     Check_Monad
   Author:    Christian Sternagel
   Author:    René Thiemann
*)

section ‹A Special Error Monad for Certification with Informative Error Messages›

theory Check_Monad
imports Error_Monad
begin

text ‹A check is either successful or fails with some error.›
type_synonym
  'e check = "'e + unit"

abbreviation succeed :: "'e check"
where
  "succeed ≡ return ()"

definition check :: "bool ⇒ 'e ⇒ 'e check"
where
  "check b e = (if b then succeed else error e)"

lemma isOK_check [simp]:
  "isOK (check b e) = b" by (simp add: check_def)

lemma isOK_check_catch [simp]:
  "isOK (try check b e catch f) ⟷ b ∨ isOK (f e)"
  by (auto simp add: catch_def check_def)

definition check_return :: "'a check ⇒ 'b ⇒ 'a + 'b"
where
  "check_return chk res = (chk ⪢ return res)"

lemma check_return [simp]:
  "check_return chk res = return res' ⟷ isOK chk ∧ res' = res"
  unfolding check_return_def by (cases chk) auto

lemma [code_unfold]:
  "check_return chk res = (case chk of Inr _ ⇒ Inr res | Inl e ⇒ Inl e)"
  unfolding check_return_def bind_def ..

abbreviation check_allm :: "('a ⇒ 'e check) ⇒ 'a list ⇒ 'e check"
where
  "check_allm f xs ≡ forallM f xs <+? snd"

abbreviation check_exm :: "('a ⇒ 'e check) ⇒ 'a list ⇒ ('e list ⇒ 'e) ⇒ 'e check"
where
  "check_exm f xs fld ≡ existsM f xs <+? fld"

lemma isOK_check_allm:
  "isOK (check_allm f xs) ⟷ (∀x ∈ set xs. isOK (f x))"
  by simp

abbreviation check_allm_index :: "('a ⇒ nat ⇒ 'e check) ⇒ 'a list ⇒ 'e check"
where
  "check_allm_index f xs ≡ forallM_index f xs <+? snd"

abbreviation check_all :: "('a ⇒ bool) ⇒ 'a list ⇒ 'a check"
where
  "check_all f xs ≡ check_allm (λx. if f x then succeed else error x) xs"

abbreviation check_all_index :: "('a ⇒ nat ⇒ bool) ⇒ 'a list ⇒ ('a × nat) check"
where
  "check_all_index f xs ≡ check_allm_index (λx i. if f x i then succeed else error (x, i)) xs"

lemma isOK_check_all_index [simp]:
  "isOK (check_all_index f xs) ⟷ (∀i < length xs. f (xs ! i) i)"
  by auto

text ‹The following version allows to modify the index during the check.›
definition
  check_allm_gen_index ::
    "('a ⇒ nat ⇒ nat) ⇒ ('a ⇒ nat ⇒ 'e check) ⇒ nat ⇒ 'a list ⇒ 'e check"
where
  "check_allm_gen_index g f n xs = snd (foldl (λ(i, m) x. (g x i, m ⪢ f x i)) (n, succeed) xs)"

lemma foldl_error:
  "snd (foldl (λ(i, m) x . (g x i, m ⪢ f x i)) (n, error e) xs) = error e"
  by (induct xs arbitrary: n) auto

lemma isOK_check_allm_gen_index [simp]:
  assumes "isOK (check_allm_gen_index g f n xs)"
  shows "∀x∈set xs. ∃i. isOK (f x i)"
using assms
proof (induct xs arbitrary: n)
  case (Cons x xs)
  show ?case
  proof (cases "isOK (f x n)")
    case True
    then have "∃i. isOK (f x i)" by auto
    with True Cons show ?thesis
      unfolding check_allm_gen_index_def by (force simp: isOK_iff)
  next
    case False
    then obtain e where "f x n = error e" by (cases "f x n") auto
    with foldl_error [of g f _ e] and Cons show ?thesis
      unfolding check_allm_gen_index_def by auto
  qed
qed simp

lemma check_allm_gen_index [fundef_cong]:
  fixes f :: "'a ⇒ nat ⇒ 'e check"
  assumes "⋀x n. x ∈ set xs ⟹ g x n = g' x n"
    and "⋀x n. x ∈ set xs ⟹ f x n = f' x n"
  shows "check_allm_gen_index g f n xs = check_allm_gen_index g' f' n xs"
proof -
  { fix n m
    have "foldl (λ(i, m) x. (g x i, m ⪢ f x i)) (n, m) xs =
      foldl (λ(i, m) x. (g' x i, m ⪢ f' x i)) (n, m) xs"
      using assms by (induct xs arbitrary: n m) auto }
  then show ?thesis unfolding check_allm_gen_index_def by simp
qed

definition check_subseteq :: "'a list ⇒ 'a list ⇒ 'a check"
where
  "check_subseteq xs ys = check_all (λx. x ∈ set ys) xs"

lemma isOK_check_subseteq [simp]:
  "isOK (check_subseteq xs ys) ⟷ set xs ⊆ set ys"
  by (auto simp: check_subseteq_def)

definition check_same_set :: "'a list ⇒ 'a list ⇒ 'a check"
where
  "check_same_set xs ys = (check_subseteq xs ys ⪢ check_subseteq ys xs)"

lemma isOK_check_same_set [simp]:
  "isOK (check_same_set xs ys) ⟷ set xs = set ys"
  unfolding check_same_set_def by auto

definition check_disjoint :: "'a list ⇒ 'a list ⇒ 'a check"
where
  "check_disjoint xs ys = check_all (λx. x ∉ set ys) xs"

lemma isOK_check_disjoint [simp]:
  "isOK (check_disjoint xs ys) ⟷ set xs ∩ set ys = {}"
  unfolding check_disjoint_def by (auto)

definition check_all_combinations :: "('a ⇒ 'a ⇒ 'b check) ⇒ 'a list ⇒ 'b check"
where
  "check_all_combinations c xs = check_allm (λx. check_allm (c x) xs) xs"

lemma isOK_check_all_combinations [simp]:
  "isOK (check_all_combinations c xs) ⟷ (∀x ∈ set xs. ∀y ∈ set xs. isOK (c x y))"
  unfolding check_all_combinations_def by simp

fun check_pairwise :: "('a ⇒ 'a ⇒ 'b check) ⇒ 'a list ⇒ 'b check"
where
  "check_pairwise c [] = succeed" |
  "check_pairwise c (x # xs) = (check_allm (c x) xs ⪢ check_pairwise c xs)"

lemma pairwise_aux:
  "(∀j<length (x # xs). ∀i<j. P ((x # xs) ! i) ((x # xs) ! j))
     = ((∀j<length xs. P x (xs ! j)) ∧ (∀j<length xs. ∀i<j. P (xs ! i) (xs ! j)))"
  (is "?C = (?A ∧ ?B)")
proof (intro iffI conjI)
  assume *: "?A ∧ ?B"
  show "?C"
  proof (intro allI impI)
    fix i j
    assume "j < length (x # xs)" and "i < j"
    then show "P ((x # xs) ! i) ((x # xs) ! j)"
    proof (induct j)
      case (Suc j)
      then show ?case
        using * by (induct i) simp_all
    qed simp
  qed
qed force+

lemma isOK_check_pairwise [simp]:
  "isOK (check_pairwise c xs) ⟷ (∀j<length xs. ∀i<j. isOK (c (xs ! i) (xs ! j)))"
proof (induct xs)
  case (Cons x xs)
  have "isOK (check_allm (c x) xs) = (∀j<length xs. isOK (c x (xs ! j)))"
    using all_set_conv_all_nth [of xs "λy. isOK (c x y)"] by simp
  then have "isOK (check_pairwise c (x # xs)) =
    ((∀j<length xs. isOK (c x (xs ! j))) ∧ (∀j<length xs. ∀i<j. isOK (c (xs ! i) (xs ! j))))"
    by (simp add: Cons)
  then show ?case using pairwise_aux [of x xs "λx y. isOK (c x y)"] by simp
qed auto

abbreviation check_exists :: "('a ⇒ bool) ⇒ 'a list ⇒ ('a list) check"
where
  "check_exists f xs ≡ check_exm (λx. if f x then succeed else error [x]) xs concat"

lemma isOK_choice [simp]:
  "isOK (choice []) ⟷ False"
  "isOK (choice (x # xs)) ⟷ isOK x ∨ isOK (choice xs)"
  by (auto simp: choice.simps isOK_def split: sum.splits)

fun or_ok :: "'a check ⇒ 'a check ⇒ 'a check" where
  "or_ok (Inl a) b = b" |
  "or_ok (Inr a) b = Inr a" 

lemma or_is_or: "isOK (or_ok a b) = isOK a ∨ isOK b" using or_ok.elims by blast


end