Theory LS_Extras

theory LS_Extras
imports Multihole_Context Renaming Unification_More
(*
Author:  Bertram Felgenhauer <bertram.felgenhauer@uibk.ac.at> (2015-2017)
Author:  Franziska Rapp <franziska.rapp@uibk.ac.at> (2015-2017)
License: LGPL (see file COPYING.LESSER)
*)

section ‹Miscellaneous›

text ‹This theory contains several trivial extensions to existing IsaFoR theories.›
(* TODO: move things elsewhere *)

theory LS_Extras
  imports TA.Multihole_Context QTRS.Renaming QTRS.Unification_More
begin

subsection ‹Bounded duplication›

lemma refl_inv_image:
  "refl R ⟹ refl (inv_image R f)"
  by (simp add: inv_image_def refl_on_def)

lemma image_mset_Some_None_zero [simp]:
  "count (image_mset Some M) None = 0"
  by (induct M) auto

text ‹IsaFoR does not seem to define non-duplicating TRSs yet.›

definition non_duplicating :: "('f, 'v) trs ⇒ bool" where
  "non_duplicating R ≡ ∀l r. (l, r) ∈ R ⟶ vars_term_ms r ≤# vars_term_ms l"

text ‹Bounded duplication @{cite ‹Definition 4.2› FMZvO15}. Note: @{term "undefined"} is an arbitrary fixed value.›

definition bounded_duplicating :: "('f, 'v) trs ⇒ bool" where
  "bounded_duplicating R ≡ SN_rel (rstep {(Fun None [Var undefined], Var undefined)}) (rstep (map_funs_trs Some R))"

text ‹Non-duplicating TRSs are bounded duplicating @{cite ‹Lemma 4.4› FMZvO15}.›

lemma non_duplicating_imp_bounded_duplicating:
  assumes nd: "non_duplicating R"
  shows "bounded_duplicating R"
proof -
  let ?gt = "inv_image {(n, m) . n > m} (λt. count (funs_term_ms t) None)"
  let ?ss = "rstep {(Fun None [Var undefined], Var undefined)}"
  let ?ge = "inv_image {(n, m) . n ≥ m} (λt. count (funs_term_ms t) None)"
  let ?rs = "rstep (map_funs_trs Some R)"
  have "?ss ⊆ ?gt"
  proof
    fix s t
    assume "(s, t) ∈ ?ss"
    then have "count (funs_term_ms s) None > count (funs_term_ms t) None"
    proof
      fix C σ l r
      assume rl: "(l, r) ∈ {(Fun None [Var undefined], Var undefined)}"
        and s: "s = C⟨l ⋅ σ⟩" and t: "t = C⟨r ⋅ σ⟩"
      show "count (funs_term_ms s) None > count (funs_term_ms t) None" using rl
        by (auto simp: s t funs_term_ms_ctxt_apply funs_term_ms_subst_apply) 
    qed
    then show "(s, t) ∈ ?gt" by simp
  qed
  moreover have "?rs ⊆  ?ge"
  proof
    fix s t
    assume "(s, t) ∈ ?rs"
    then have "count (funs_term_ms s) None ≥ count (funs_term_ms t) None"
    proof
      fix C σ l r
      {
        fix l r
        assume "(l, r) ∈ map_funs_trs Some R"
        then have "vars_term_ms r ≤# vars_term_ms l ∧ count (funs_term_ms l) None = 0 ∧ count (funs_term_ms r) None = 0"
          by (auto simp: nd[unfolded non_duplicating_def] map_funs_trs.simps funs_term_ms_map_funs_term)
      } note * = this
      assume rl: "(l, r) ∈ map_funs_trs Some R"
        and s: "s = C⟨l ⋅ σ⟩" and t: "t = C⟨r ⋅ σ⟩"
      from rl *[of l r] show "count (funs_term_ms s) None ≥ count (funs_term_ms t) None"
        by (auto simp: s t funs_term_ms_ctxt_apply funs_term_ms_subst_apply)
           (metis image_mset_union mset_subset_eq_count subset_mset.le_iff_add sum_mset.union)
    qed
    then show "(s, t) ∈ ?ge" by simp
  qed
  moreover
  {
    have [simp]: "?ge* = ?ge"
      by (auto intro!: trans_refl_imp_rtrancl_id trans_inv_image refl_inv_image simp: refl_on_def trans_def)
    have *: "(?ge* O ?gt O ?ge*)¯ ⊆ ?gt¯"
      by auto
    moreover have "wf (?gt¯)" by auto (auto simp: converse_def wf_less)
    ultimately have "SN_rel ?gt ?ge" unfolding SN_rel_on_def SN_iff_wf
      by (intro wf_subset[OF _ *]) auto
  }
  ultimately show ?thesis
    by (auto intro: SN_rel_mono simp: bounded_duplicating_def)
qed  

subsection ‹Trivialities›

lemma pos_diff_cons [simp]: "pos_diff (i <# p) (i <# q) = pos_diff p q"
  by (auto simp: pos_diff_def)

lemma max_list_append: "max_list (xs1 @ xs2) = max (max_list xs1) (max_list xs2)"
  by (induct xs1) auto

lemma max_list_bound:
  "max_list xs ≤ z ⟷ (∀i < length xs. xs ! i ≤ z)"
  using less_Suc_eq_0_disj by (induct xs) auto

lemma max_list_bound_set:
  "max_list xs ≤ z ⟷ (∀x ∈ set xs. x ≤ z)"
  using less_Suc_eq_0_disj by (induct xs) auto

lemma max_list_mono_concat:
  assumes "length xss = length yss" and "⋀i. i < length xss ⟹ max_list (xss ! i) ≤ max_list (yss ! i)"
  shows "max_list (concat xss) ≤ max_list (concat yss)"
  using assms
proof (induct yss arbitrary: xss)
  case (Cons ys yss) thus ?case by (cases xss) (force simp: max_list_append)+
qed auto

lemma max_list_mono_concat1:
  assumes "length xss = length ys" and "⋀i. i < length xss ⟹ max_list (xss ! i) ≤ ys ! i"
  shows "max_list (concat xss) ≤ max_list ys"
  using assms max_list_mono_concat[of xss "map (λy. [y]) ys"] by auto

lemma take1:
  "take (Suc 0) xs = (if xs = [] then [] else [hd xs])"
  by (cases xs) auto

lemma sum_list_take':
  "i ≤ length (xs :: nat list) ⟹ sum_list (take i xs) ≤ sum_list xs"
  by (induct xs arbitrary: i) (simp, case_tac i, auto)

lemma nth_equalityE:
  "xs = ys ⟹ (length xs = length ys ⟹ (⋀i. i < length xs ⟹ xs ! i = ys ! i) ⟹ P) ⟹ P"
  by simp

lemma finite_vars_mctxt [simp]: "finite (vars_mctxt C)"
  by (induct C) auto

lemma partition_by_of_zip:
  "length xs = sum_list zs ⟹ length ys = sum_list zs ⟹
   partition_by (zip xs ys) zs = map (λ(x,y). zip x y) (zip (partition_by xs zs) (partition_by ys zs))"
  by (induct zs arbitrary: xs ys) (auto simp: take_zip drop_zip)

lemma distinct_count_atmost_1':
  "distinct xs = (∀a. count (mset xs) a ≤ 1)"
  unfolding distinct_count_atmost_1 using dual_order.antisym by fastforce

lemma nth_subset_concat:
  assumes "i < length xss" 
  shows "set (xss ! i) ⊆ set (concat xss)"
  by (metis assms concat_nth concat_nth_length in_set_idx nth_mem subsetI)

lemma subst_domain_subst_of:
  "subst_domain (subst_of xs) ⊆ set (map fst xs)"
proof (induct xs)
  case (Cons x xs)
  moreover have "subst_domain (subst (fst x) (snd x)) ⊆ set (map fst [x])" by simp
  ultimately show ?case
    using subst_domain_compose[of "subst_of xs" "subst (fst x) (snd x)"] by auto
qed simp

lemma subst_apply_mctxt_cong: "(∀x. x ∈ vars_mctxt C ⟶ σ x = τ x) ⟹ C ⋅mc σ = C ⋅mc τ"
  by (induct C) auto

lemma distinct_concat_unique_index:
  "distinct (concat xss) ⟹ i < length xss ⟹ x ∈ set (xss ! i) ⟹ j < length xss ⟹ x ∈ set (xss ! j) ⟹ i = j"
proof (induct xss rule: List.rev_induct)
  case (snoc xs xss) thus ?case using nth_mem[of i xss] nth_mem[of j xss]
    by (cases "i < length xss"; cases "j < length xss") (auto simp: nth_append simp del: nth_mem)
qed auto

lemma list_update_concat:
  assumes "i < length xss" "j < length (xss ! i)"
  shows "concat (xss[i := (xss ! i)[j := x]]) = concat xss[sum_list (take i (map length xss)) + j := x]"
  using assms
proof (induct xss arbitrary: i)
  case (Cons xs xss) thus ?case
    by (cases i) (auto simp: list_update_append)
qed auto

lemma length_filter_sum_list:
  "length (filter p xs) = sum_list (map (λx. if p x then 1 else 0) xs)"
  by (induct xs) auto

lemma poss_mctxt_mono:
  "C ≤ D ⟹ poss_mctxt C ⊆ poss_mctxt D"
  by (induct C D rule: less_eq_mctxt_induct) force+

lemma poss_append_funposs:
  shows "p <#> q ∈ funposs t ⟷ p ∈ poss t ∧ q ∈ funposs (t |_ p)" (is "?L ⟷ ?R")
proof
  show "?L ⟹ ?R" using funposs_imp_poss[of "p <#> q" t] funposs_fun_conv[of "p <#> q" t]
    poss_is_Fun_funposs[of q "t |_ p"] by auto
next
  show "?R ⟹ ?L" using funposs_imp_poss[of q "t |_ p"] funposs_fun_conv[of q "t |_ p"]
    poss_is_Fun_funposs[of "p <#> q" t] by auto
qed

text ‹see @{thm nrrstep_preserves_root} which proves this without the length›
lemma nrrstep_preserves_root':
  assumes "(Fun f ss, t) ∈ nrrstep R"
  shows "∃ts. t = Fun f ts ∧ length ss = length ts"
  using assms unfolding nrrstep_def rstep_r_p_s_def Let_def by auto

text ‹see @{thm nrrsteps_preserve_root} which proves this without the length›
lemma nrrsteps_preserve_root':
  assumes "(Fun f ss, t) ∈ (nrrstep R)*"
  shows "∃ts. t = Fun f ts ∧ length ss = length ts"
  using assms by induct (auto dest: nrrstep_preserves_root')

lemma args_joinable_imp_joinable:
  assumes "length ss = length ts" "⋀i. i < length ss ⟹ (ss ! i, ts ! i) ∈ (rstep R)"
  shows "(Fun f ss, Fun f ts) ∈ (rstep R)"
proof -
  obtain u where "i < length ss ⟹ (ss ! i, u i) ∈ (rstep R)* ∧ (ts ! i, u i) ∈ (rstep R)*" for i
    using joinD[OF assms(2)] by metis
  then show ?thesis using assms(1)
    by (intro joinI[of _ "Fun f (map u [0..<length ss])"] args_rsteps_imp_rsteps) auto
qed

instance option :: (infinite) infinite
  by standard (simp add: infinite_UNIV)

lemma finite_into_infinite:
  assumes "finite A" "infinite B"
  shows "∃f. f ` A ⊆ B ∧ inj_on f A"
proof -
  from finite_imp_inj_to_nat_seg[OF assms(1)]
  obtain f :: "_ ⇒ nat" and n where "f ` A = {i. i < n}" "inj_on f A" by auto
  moreover from infinite_countable_subset[OF assms(2)]
  obtain g :: "nat ⇒ _" where "inj g" "range g ⊆ B" by auto
  ultimately show ?thesis by (auto simp: inj_on_def intro!: exI[of _ "g ∘ f"])
qed

lemma finites_into_infinites:
  fixes f :: "'a ⇒ 'b set" and g :: "'a ⇒ 'c set"
  assumes "⋀α β. f α ∩ f β ≠ {} ⟹ α = β"
  and "⋀α. finite (f α)"
  and "⋀α β. g α ∩ g β ≠ {} ⟹ α = β"
  and "⋀α. infinite (g α)"
  shows "∃h :: 'b ⇒ 'c. inj_on h (⋃α. f α) ∧ (∀α. h ` f α ⊆ g α)"
proof -
  from finite_into_infinite[OF assms(2,4)] have "∃h. h ` f α ⊆ g α ∧ inj_on h (f α)" for α by blast
  then obtain h where h: "h α ` f α ⊆ g α ∧ inj_on (h α) (f α)" for α by metis
  have [simp]: "x ∈ f α ⟹ (THE α. x ∈ f α) = α" for x α using assms(1) by auto
  then show ?thesis using assms(1,3) h
    apply (intro exI[of _ "λx. h (THE α. x ∈ f α) x"])
    apply (auto simp: inj_on_def image_def)
    by blast
qed

subsection ‹Imbalance›

definition refines :: "'a list ⇒ 'b list ⇒ bool" (infix "∝" 55) where
  "ss ∝ ts ⟷ length ts = length ss ∧ (∀i j. i < length ss ∧ j < length ts ∧ ss ! i = ss ! j ⟶ ts ! i = ts ! j)"

lemma refines_refl:
  "ss ∝ ss"
  by (auto simp: refines_def)

lemma refines_trans:
  "ss ∝ ts ⟹ ts ∝ us ⟹ ss ∝ us"
  by (auto simp: refines_def)

lemma refines_map:
  "ss ∝ map f ss"
  by (auto simp: refines_def)

lemma refines_imp_map:
  assumes "ss ∝ ts"
  obtains f where "ts = map f ss"
proof -
  note * = assms[unfolded refines_def]
  let ?P = "λs i. (i < length ss ∧ ss ! i = s)" let ?f = "λs. ts ! (SOME i. ?P s i)"
  { fix i assume "i < length ss"
    then have "ts ! i = ?f (ss ! i)" using conjunct1[OF *]
      conjunct2[OF *, THEN spec[of _ i], THEN spec[of _ "SOME j. ?P (ss ! i) j"]] someI[of "?P (ss ! i)" i]
      by auto
  }
  then show ?thesis using conjunct1[OF *] by (auto intro!: that[of "?f"] nth_equalityI)
qed

definition imbalance :: "'a list ⇒ nat" where
  "imbalance ts = card (set ts)"

lemma imbalance_def':
  "imbalance xs = card { i. i < length xs ∧ (∀j. j < length xs ∧ xs ! i = xs ! j ⟶ i ≤ j) }"
proof (induct xs rule: List.rev_induct)
  case (snoc x xs)
  have "{ i. i < length (xs @ [x]) ∧ (∀j. j < length (xs @ [x]) ∧ (xs @ [x]) ! i = (xs @ [x]) ! j ⟶ i ≤ j) } =
        { i. i < length xs ∧ (∀j. j < length xs ∧ xs ! i = xs ! j ⟶ i ≤ j) } ∪ { i. i = length xs ∧ x ∉ set xs }"
    by (simp add: set_eq_iff less_Suc_eq dnf imp_conjL nth_append)
      (metis cancel_comm_monoid_add_class.diff_cancel in_set_conv_nth nth_Cons_0)
  then show ?case using snoc by (simp add: imbalance_def card_insert_if)
qed (auto simp: imbalance_def)

lemma imbalance_mono: "set ss ⊆ set ts ⟹ imbalance ss ≤ imbalance ts"
  by (simp add: imbalance_def card_mono)

lemma refines_imbalance_mono:
  "ss ∝ ts ⟹ imbalance ss ≥ imbalance ts"
  (* apply (auto simp: refines_def imbalance_def' intro: card_mono) (* desired proof *) *)
  unfolding refines_def imbalance_def' by (intro card_mono) (simp_all add: Collect_mono)

lemma refines_imbalance_strict_mono:
  "ss ∝ ts ⟹ ¬ ts ∝ ss ⟹ imbalance ss > imbalance ts"
  unfolding refines_def imbalance_def'
proof (intro psubset_card_mono psubsetI, goal_cases)
  case 3 obtain i j where ij: "i < length ss" "j < length ss" "ts ! i = ts ! j" "ss ! i ≠ ss ! j"
    using conjunct1[OF 3(1)] 3(2) by auto
  define f where "f ≡ λi. SOME j. j < length ss ∧ ss ! j = ss ! i ∧ (∀i. ss ! j = ss ! i ⟶ j ≤ i)"
  let ?ssi = "{i. i < length ss ∧ (∀j. j < length ss ∧ ss ! i = ss ! j ⟶ i ≤ j)}"
  let ?tsi = "{i. i < length ts ∧ (∀j. j < length ts ∧ ts ! i = ts ! j ⟶ i ≤ j)}"
  note f_def = fun_cong[OF f_def[unfolded atomize_eq]]
  { fix i assume "i < length ss"
    then have "∃i'. i' < length ss ∧ ss ! i' = ss ! i ∧ (∀i. ss ! i' = ss ! i ⟶ i' ≤ i)"
    proof (induct i rule: less_induct)
      case (less i) then show ?case
        by (cases "i < length ss ∧ ss ! i = ss ! i ∧ (∀j. ss ! i = ss ! j ⟶ i ≤ j)") auto
    qed
    then have "f i < length ss" "ss ! f i = ss ! i" "⋀j. ss ! f i = ss ! j ⟹ f i ≤ j"
    using someI[of "λj. j < length ss ∧ ss ! j = ss ! i ∧ (∀i. ss ! j = ss ! i ⟶ j ≤ i)", folded f_def]
      by auto
    then have "f i < length ss" "ss ! f i = ss ! i" "f i ∈ ?ssi" "f i ∈ ?tsi ⟹ ts ! f i = ts ! i"
      using 3(1) by (auto simp: `i < length ss`)
  } note * = this[OF ij(1)] this[OF ij(2)]
  with ij(1,2,3,4) 3(1)[THEN conjunct1]
    3(1)[THEN conjunct2, rule_format, of i "f i"]
    3(1)[THEN conjunct2, rule_format, of j "f j"]
  have "f i ∈ ?ssi" "f j ∈ ?ssi" "ts ! f i = ts ! f j" "f i ≠ f j" by metis+
  moreover from this(3,4) have "f i ∉ ?tsi ∨ f j ∉ ?tsi" using *(4,8) ij(3)
    by (metis (mono_tags, lifting) dual_order.antisym mem_Collect_eq)
  ultimately show ?case by argo
qed (simp_all add: Collect_mono)

lemma refines_take:
  "ss ∝ ts ⟹ take n ss ∝ take n ts"
  unfolding refines_def by (intro conjI impI allI; elim conjE) (simp_all add: refines_def)

lemma refines_drop:
  "ss ∝ ts ⟹ drop n ss ∝ drop n ts"
  unfolding refines_def
proof ((intro conjI impI allI; elim conjE), goal_cases)
  case (2 i j) show ?case using 2(1,3-) 2(2)[rule_format, of "i + n" "j + n"]
    by (auto simp: less_diff_eq less_diff_conv ac_simps)
qed simp

subsection ‹Abstract Rewriting›

lemma join_finite:
  assumes "CR R" "finite X" "⋀x y. x ∈ X ⟹ y ∈ X ⟹ (x, y) ∈ R*"
  obtains z where "⋀x. x ∈ X ⟹ (x, z) ∈ R*"
  using assms(2,1,3)
proof (induct X arbitrary: thesis)
  case (insert x X)
  then show ?case
  proof (cases "X = {}")
    case False
    then obtain x' where "x' ∈ X" by auto
    obtain z where *: "x' ∈ X ⟹ (x', z) ∈ R*" for x' using insert by (metis insert_iff)
    from this[of x'] `x' ∈ X` insert(6)[of x x'] have "(x, z) ∈ R*"
      by (metis insert_iff converse_rtrancl_into_rtrancl conversionI' conversion_rtrancl)
    with `CR R` obtain z' where "(x, z') ∈ R*" "(z, z') ∈ R*" by (auto simp: CR_iff_conversion_imp_join)
    then show ?thesis by (auto intro!: insert(4)[of z'] dest: *)
  qed auto
qed simp

lemma balance_sequence:
  assumes "CR R"
  obtains (ts) ts where "length ss = length ts"
    "⋀i. i < length ss ⟹ (ss ! i, ts ! i) ∈ R*"
    "⋀i j. i < length ss ⟹ j < length ss ⟹ (ss ! i, ss ! j) ∈ R* ⟹ ts ! i = ts ! j"
proof -
  define f where "f s ≡ SOME u. ∀t ∈ { t ∈ set ss. (s, t) ∈ R* }. (t, u) ∈ R* " for s
  { fix i assume "i < length ss"
    have "∀t ∈ { t ∈ set ss. (ss ! i, t) ∈ R* }. (t, f (ss ! i)) ∈ R*"
      unfolding f_def using join_finite[OF `CR R`, of "{ t ∈ set ss. (ss ! i, t) ∈ R* }"]
        by (rule someI_ex) (auto, metis conversion_inv conversion_rtrancl rtrancl_trans)
    with `i < length ss` have "(ss ! i, f (ss ! i)) ∈ R*" by auto
  } note [intro] = this
  moreover {
    fix i j assume "i < length ss" "j < length ss" "(ss ! i, ss ! j) ∈ R*"
    then have "{ t ∈ set ss. (ss ! i, t) ∈ R* } = { t ∈ set ss. (ss ! j, t) ∈ R* }"
      by auto (metis conversion_inv conversion_rtrancl rtrancl_trans)+
    then have "f (ss ! i) = f (ss ! j)" by (auto simp: f_def)
  } note [intro] = this
  show ?thesis by (auto intro!: ts[of "map f ss"])
qed

lemma balance_sequences:
  assumes "CR R" and [simp]: "length ts = length ss" "length us = length ss" and
    p: "⋀i. i < length ss ⟹ (ss ! i, ts ! i) ∈ R*" "⋀i. i < length ss ⟹ (ss ! i, us ! i) ∈ R*"
  obtains (vs) vs where
    "length vs = length ss"
    "⋀i. i < length ss ⟹ (ts ! i, vs ! i) ∈ R*" "⋀i. i < length ss ⟹ (us ! i, vs ! i) ∈ R*"
    "refines ts vs" "refines us vs"
proof -
  from balance_sequence[OF `CR R`, of "ts @ us"]
  obtain vs where l: "length (ts @ us) = length vs" and
    r: "⋀i. i < length (ts @ us) ⟹ ((ts @ us) ! i, vs ! i) ∈ R*" and
    e: "⋀i j. i < length (ts @ us) ⟹ j < length (ts @ us) ⟹ ((ts @ us) ! i, (ts @ us) ! j) ∈ R* ⟹ vs ! i = vs ! j"
    by blast
  { fix i assume *: "i < length ss"
    from * have "(ts ! i, us ! i) ∈ R*" using p[of i]
      by (metis CR_imp_conversionIff_join assms(1) conversionI' conversion_rtrancl joinI_right rtrancl_trans)
    with * have "vs ! (length ss + i) = vs ! i"
      using e[of i "length ss + i"] by (auto simp: nth_append)
  } note lp[simp] = this
  show ?thesis
  proof (intro vs[of "take (length ss) vs"], goal_cases)
    case (2 i) then show ?case using l r[of i] by (auto simp: nth_append)
  next
    case (3 i) with 3 show ?case using l r[of "length ss + i"] by (auto simp: nth_append)
  next
    case 4
    then show ?case using l by (auto simp: refines_def nth_append intro!: e)
  next
    case 5
    { fix i j assume "i < length ss" "j < length ss" "(us ! i, us ! j) ∈ R*"
      then have "vs ! i = vs ! j"
        using e[of "length ss + i" "length ss + j"] by (simp add: nth_append)
    }
    then show ?case using l by (auto simp: refines_def nth_append)
  qed (auto simp: l[symmetric])
qed

lemma rtrancl_on_iff_rtrancl_restr:
  assumes "⋀x y. x ∈ A ⟹ (x, y) ∈ R ⟹ y ∈ A"
    "(x, y) ∈ R*" "x ∈ A" 
  shows "(x, y) ∈ (Restr R A)* ∧ y ∈ A"
  using assms(2,3)
proof (induct y rule: rtrancl_induct)
  case (step y z) then show ?case by (auto intro!: rtrancl_into_rtrancl[of _ y _ z] simp: assms(1))
qed auto

lemma CR_on_iff_CR_Restr:
  assumes "⋀x y. x ∈ A ⟹ (x, y) ∈ R ⟹ y ∈ A"
  shows "CR_on R A ⟷ CR (Restr R A)"
proof ((standard; standard), goal_cases)
  let ?R' = "Restr R A"
  note * = rtrancl_on_iff_rtrancl_restr[OF assms]
  have *: "(s, t) ∈ R* ⟹ s ∈ A ⟹ (s, t) ∈ ?R'* ∧ t ∈ A" for s t
    using assms rtrancl_on_iff_rtrancl_restr by metis
  {
    case (1 s t u) then show ?case
    proof (cases "s ∈ A")
      case True
      with 1(3,4) have "(s, t) ∈ R*" "t ∈ A" "(s, u) ∈ R*" "u ∈ A"
        using *[of s t] *[of s u] rtrancl_mono[of ?R' R] by auto
      then show ?thesis
        using True 1(1)[unfolded CR_on_def, rule_format, of s t u] by (auto dest!: joinD *)
    next
      case False
      then have "s = t" "s = u" using 1 by (metis IntD2 converse_rtranclE mem_Sigma_iff)+
      then show ?thesis by auto
    qed
  next
    case (2 s t u)
    obtain v where "(t, v) ∈ ?R'*" "(u, v) ∈ ?R'*" using *[OF 2(3,2)] *[OF 2(4,2)] 2(1) by auto
    then show ?case using rtrancl_mono[of ?R' R] by auto
  }
qed

subsection ‹Bijection between 'a and 'a option for infinite types›

lemma infinite_option_bijection:
  assumes "infinite (UNIV :: 'a set)" shows "∃f :: 'a ⇒ 'a option. bij f"
proof -
  from infinite_countable_subset[OF assms] obtain g :: "nat ⇒ 'a" where g: "inj g" "range g ⊆ UNIV" by blast
  let ?f = "λx. if x ∈ range g then (case inv g x of 0 ⇒ None | Suc n ⇒ Some (g n)) else Some x"
  have "?f x = ?f y ⟹ x = y" for x y using g(1) by (auto split: nat.splits if_splits simp: inj_on_def) auto
  moreover have "∃x. ?f x = y" for y
    apply (cases y)
    subgoal using g(1) by (intro exI[of _ "g 0"]) auto
    subgoal for y using g
      apply (cases "y ∈ range g")
      subgoal by (intro exI[of _ "g (Suc (inv g y))"]) auto
      subgoal by auto
    done
  done
  ultimately have "bij ?f" unfolding bij_def surj_def by (intro conjI injI) metis+
  thus ?thesis by blast
qed

definition to_option :: "'a :: infinite ⇒ 'a option" where
  "to_option = (SOME f. bij f)"

definition from_option :: "'a :: infinite option ⇒ 'a" where
  "from_option = inv to_option"

lemma bij_from_option: "bij to_option"
  unfolding to_option_def using someI_ex[OF infinite_option_bijection[OF infinite_UNIV]] .

lemma from_to_option_comp[simp]: "from_option ∘ to_option = id"
  unfolding from_option_def by (intro inv_o_cancel bij_is_inj bij_from_option)

lemma from_to_option[simp]: "from_option (to_option x) = x"
  by (simp add: pointfree_idE)

lemma to_from_option_comp[simp]: "to_option ∘ from_option = id"
  unfolding from_option_def surj_iff[symmetric] by (intro bij_is_surj bij_from_option)

lemma to_from_option[simp]: "to_option (from_option x) = x"
  by (simp add: pointfree_idE)


lemma var_subst_comp: "t ⋅ (Var ∘ f) ⋅ g = t ⋅ (g ∘ f)"
  by (simp add: comp_def subst_subst_compose[symmetric] subst_compose_def del: subst_subst_compose)

subsection ‹More polymorphic rewriting›

inductive_set rstep' :: "('f, 'v) trs ⇒ ('f, 'w) term rel" for R where
  rstep' [intro]: "(l, r) ∈ R ⟹ (C⟨l ⋅ σ⟩, C⟨r ⋅ σ⟩) ∈ rstep' R"

lemma rstep_eq_rstep': "rstep R = rstep' R"
by (auto elim: rstep'.cases)

lemma rstep'_mono:
  assumes "(s, t) ∈ rstep' R" shows "(C⟨s⟩, C⟨t⟩) ∈ rstep' R"
proof -
  from assms obtain D l r σ where "(l, r) ∈ R" "s = D⟨l ⋅ σ⟩" "t = D⟨r ⋅ σ⟩" by (auto simp add: rstep'.simps)
  then show ?thesis using rstep'[of l r R "C ∘c D" σ] by simp
qed

lemma ctxt_closed_rstep' [simp]:
  shows "ctxt.closed (rstep' R)"
  by (auto simp: ctxt.closed_def rstep'_mono elim: ctxt.closure.induct)

lemma rstep'_stable:
  assumes "(s, t) ∈ rstep' R" shows "(s ⋅ σ, t ⋅ σ) ∈ rstep' R"
proof -
  from assms obtain C l r τ where "(l, r) ∈ R" "s = C⟨l ⋅ τ⟩" "t = C⟨r ⋅ τ⟩" by (auto simp add: rstep'.simps)
  then show ?thesis using rstep'[of l r R "C ⋅c σ" "τ ∘s σ"] by simp
qed

lemma rsteps'_stable:
  "(s, t) ∈ (rstep' R)* ⟹ (s ⋅ σ, t ⋅ σ) ∈ (rstep' R)*"
  by (induct rule: rtrancl_induct) (auto dest: rstep'_stable[of _ _ _ σ])

lemma rstep'_sub_vars:
  assumes "(s, t) ∈ (rstep' R)*" "wf_trs R"
  shows "vars_term t ⊆ vars_term s"
  using assms(1)
proof (induction rule: converse_rtrancl_induct)
  case (step y z)
  obtain l r C σ where props: "(l, r) ∈ R" "y = C⟨l ⋅ σ⟩" "z = C⟨r ⋅ σ⟩" "vars_term r ⊆ vars_term l"
    using step(1) assms(2) unfolding wf_trs_def by (auto elim: rstep'.cases) 
  hence "vars_term z ⊆ vars_term y"
    using var_cond_stable[OF props(4), of σ] vars_term_ctxt_apply[of C "_ ⋅ σ"] by auto
  thus ?case using step(3) by simp
qed simp

inductive_set rstep_r_p_s' :: "('f, 'v) trs ⇒ ('f, 'v) rule ⇒ pos ⇒ ('f, 'v, 'w) gsubst ⇒ ('f, 'w) term rel"
  for R r p σ where
  rstep_r_p_s' [intro]: "r ∈ R ⟹ p = hole_pos C ⟹ (C⟨fst r ⋅ σ⟩, C⟨snd r ⋅ σ⟩) ∈ rstep_r_p_s' R r p σ"

declare rstep_r_p_s'.cases [elim]

(* TODO: change definition above *)

lemma rstep_r_p_s'I:
  "r ∈ R ⟹ p = hole_pos C ⟹ s = C⟨fst r ⋅ σ⟩ ⟹ t = C⟨snd r ⋅ σ⟩ ⟹ (s, t) ∈ rstep_r_p_s' R r p σ"
  by auto

lemma rstep_r_p_s'E:
  assumes "(s, t) ∈ rstep_r_p_s' R r p σ"
  obtains C where "r ∈ R" "p = hole_pos C" "s = C⟨fst r ⋅ σ⟩" "t = C⟨snd r ⋅ σ⟩"
  using rstep_r_p_s'.cases assms by metis

lemma rstep_r_p_s_eq_rstep_r_p_s': "rstep_r_p_s R r p σ = rstep_r_p_s' R r p σ"
  by (auto simp: rstep_r_p_s_def rstep_r_p_s'.simps) (metis hole_pos_ctxt_of_pos_term)

lemma rstep'_iff_rstep_r_p_s':
  "(s, t) ∈ rstep' R ⟷ (∃l r p σ. (s, t) ∈ rstep_r_p_s' R (l, r) p σ)"
  by (auto simp: rstep'.simps rstep_r_p_s'.simps)

lemma rstep_r_p_s'_deterministic:
  assumes "wf_trs R" "(s, t) ∈ rstep_r_p_s' R r p σ" "(s, t') ∈ rstep_r_p_s' R r p τ"
  shows "t = t'"
proof -
  obtain C where 1: "r ∈ R" "s = C⟨fst r ⋅ σ⟩" "t = C⟨snd r ⋅ σ⟩" "p = hole_pos C"
    using assms(2) by (auto simp: rstep_r_p_s'.simps)
  obtain D where 2: "s = D⟨fst r ⋅ τ⟩" "t' = D⟨snd r ⋅ τ⟩" "p = hole_pos D"
    using assms(3) by (auto simp: rstep_r_p_s'.simps)
  obtain lhs rhs where 3: "r = (lhs, rhs)" by force
  show ?thesis using 1(1) 2(1,3) unfolding 1(2,3,4) 2(2) 3
  proof (induct C arbitrary: D)
    case Hole have [simp]: "D = □" using Hole by (cases D) auto
    thus ?case using `wf_trs R` Hole by (auto simp: wf_trs_def term_subst_eq_conv)
  next
    case (More f ss1 C' ss2) thus ?case by (cases D) auto
  qed
qed

lemma rstep_r_p_s'_preserves_funas_terms:
  assumes "wf_trs R" "(s, t) ∈ rstep_r_p_s' R r p σ" "funas_trs R ⊆ F" "funas_term s ⊆ F"
  shows "funas_term t ⊆ F"
proof -
  obtain C where 1: "r ∈ R" "s = C⟨fst r ⋅ σ⟩" "t = C⟨snd r ⋅ σ⟩" "p = hole_pos C"
    using assms(2) by (auto simp: rstep_r_p_s'.simps)
  then have "funas_ctxt C ⊆ F" using assms(4) by auto
  obtain lhs rhs where 3: "r = (lhs, rhs)" by force
  show ?thesis using 1(1) using assms(1,3,4)
    by (force simp: 1(2,3,4) 3 funas_term_subst funas_trs_def funas_rule_def wf_trs_def)
qed

lemma rstep'_preserves_funas_terms:
  "funas_trs R ⊆ F ⟹ funas_term s ⊆ F ⟹ (s, t) ∈ rstep' R ⟹ wf_trs R ⟹ funas_term t ⊆ F"
  unfolding rstep'_iff_rstep_r_p_s' using rstep_r_p_s'_preserves_funas_terms by blast

lemma rstep_r_p_s'_stable:
  "(s, t) ∈ rstep_r_p_s' R r p σ ⟹ (s ⋅ τ, t ⋅ τ) ∈ rstep_r_p_s' R r p (σ ∘s τ)"
  by (auto elim!: rstep_r_p_s'E intro!: rstep_r_p_s'I simp: subst_subst simp del: subst_subst_compose)

lemma rstep_r_p_s'_mono:
  "(s, t) ∈ rstep_r_p_s' R r p σ ⟹ (C⟨s⟩, C⟨t⟩) ∈ rstep_r_p_s' R r (hole_pos C <#> p) σ"
proof (elim rstep_r_p_s'E)
  fix D assume "r ∈ R" "p = hole_pos D" "s = D⟨fst r ⋅ σ⟩" "t = D⟨snd r ⋅ σ⟩"
  then show "(C⟨s⟩, C⟨t⟩) ∈ rstep_r_p_s' R r (hole_pos C <#> p) σ"
    using rstep_r_p_s'I[of r R "hole_pos C <#> p" "C ∘c D" "C⟨s⟩" σ "C⟨t⟩"] by simp
qed

lemma rstep_r_p_s'_argE:
  assumes "(s, t) ∈ rstep_r_p_s' R (l, r) (i <# p) σ"
  obtains f ss ti where "s = Fun f ss" "i < length ss" "t = Fun f (ss[i := ti])" "(ss ! i, ti) ∈ rstep_r_p_s' R (l, r) p σ"
proof -
  thm upd_conv_take_nth_drop
  from assms obtain C where *: "hole_pos C = i <# p" "s = C⟨l ⋅ σ⟩" "t = C⟨r ⋅ σ⟩" "(l, r) ∈ R" by (auto elim: rstep_r_p_s'E)
  then obtain f ls D rs where [simp]: "C = More f ls D rs" "i = length ls" by (cases C) auto
  let ?ss = "ls @ D⟨l ⋅ σ⟩ # rs" and ?ts = "ls @ D⟨r ⋅ σ⟩ # rs" and ?ti = "D⟨r ⋅ σ⟩"
  have "s = Fun f ?ss" "i < length ?ss" "t = Fun f (list_update ?ss i ?ti)" using * by simp_all
  moreover have "(?ss ! i, ?ti) ∈ rstep_r_p_s' R (l, r) p σ" using * by (auto intro: rstep_r_p_s'I)
  ultimately show ?thesis ..
qed

lemma rstep_r_p_s'_argI:
  assumes "i < length ss" "(ss ! i, ti) ∈ rstep_r_p_s' R r p σ"
  shows "(Fun f ss, Fun f (ss[i := ti])) ∈ rstep_r_p_s' R r (i <# p) σ"
  using assms(1) rstep_r_p_s'_mono[OF assms(2), of "More f (take i ss) Hole (drop (Suc i) ss)"]
  by (auto simp: id_take_nth_drop[symmetric] upd_conv_take_nth_drop min_def)

lemma wf_trs_implies_funposs:
  assumes "wf_trs R" "(s, t) ∈ rstep_r_p_s' R (l, r) p σ"
  shows "p ∈ funposs s"
proof -
  obtain C where *: "(l, r) ∈ R" "s = C⟨l ⋅ σ⟩" "p = hole_pos C" using assms by auto
  then show ?thesis
    using assms hole_pos_in_filled_funposs[of "l ⋅ σ" C Var] by (force simp: wf_trs_def')
qed

lemma NF_Var':
  assumes "wf_trs R"
  shows "(Var x, t) ∉ rstep' R"
  unfolding rstep'_iff_rstep_r_p_s' by (auto dest: wf_trs_implies_funposs[OF assms])

lemma fill_holes_rsteps:
  assumes "num_holes C = length ss" "num_holes C = length ts"
    "⋀i. i < length ss ⟹ (ss ! i, ts ! i) ∈ (rstep' ℛ)*"
  shows "(fill_holes C ss, fill_holes C ts) ∈ (rstep' ℛ)*"
  using assms
proof (induct C ss ts rule: fill_holes_induct2)
  case (MFun f Cs xs ys)
  show ?case using MFun(1,2,4)
    by (auto intro!: MFun(3) args_steps_imp_steps  simp: partition_by_nth_nth)
qed auto

subsection ‹@{term partition_by} stuff›

fun partition_by_idcs :: "nat list ⇒ nat ⇒ (nat × nat)" where
  "partition_by_idcs (y # ys) i =
    (if i < y then (0, i) else let (j, k) = partition_by_idcs ys (i - y) in (Suc j, k))"

definition partition_by_idx1 (infix "@1" 105) where
  "ys @1 i = fst (partition_by_idcs ys i)"
definition partition_by_idx2 (infix "@2" 105) where
  "ys @2 i = snd (partition_by_idcs ys i)"

lemma partition_by_predicate:
  assumes "i < length xs"
  shows "partition_by (filter P xs) (map (λx. if P x then Suc 0 else 0) xs) ! i = (if P (xs ! i) then [xs ! i] else [])"
  using assms by (induct xs arbitrary: i) (auto simp: less_Suc_eq_0_disj)

lemma nth_by_partition_by:
  "length xs = sum_list ys ⟹ i < sum_list ys ⟹ xs ! i = partition_by xs ys ! ys @1 i ! ys @2 i"
  apply (induct ys arbitrary: i xs)
  apply (auto simp: partition_by_idx1_def partition_by_idx2_def split: prod.splits)
  by (metis add_less_cancel_left diff_add_inverse fst_conv length_drop less_or_eq_imp_le linordered_semidom_class.add_diff_inverse nth_drop snd_conv)

lemma nth_concat_by_shape:
  assumes "ys = map length xss" "i < sum_list ys"
  shows "concat xss ! i = xss ! ys @1 i ! ys @2 i"
  using nth_by_partition_by[of "concat xss" ys i] partition_by_concat_id[of xss ys] assms
  by (auto simp: length_concat)

lemma list_update_by_partition_by:
  "length xs = sum_list ys ⟹ i < sum_list ys ⟹
   xs[i := x] = concat (partition_by xs ys[ys @1 i := (partition_by xs ys ! ys @1 i)[ys @2 i := x]])"
proof (induct ys arbitrary: i xs)
  case (Cons y ys) show ?case
    using list_update_append[of "take y xs" "drop y xs" i x] Cons(1)[of "drop y xs" "i - y"] Cons(2,3)
    by (auto simp: partition_by_idx1_def partition_by_idx2_def less_diff_conv2 min.absorb2 split: prod.splits)
qed auto

lemma list_update_concat_by_shape:
  assumes "ys = map length xss" "i < sum_list ys"
  shows "concat xss[i := x] = concat (xss[ys @1 i := (xss ! ys @1 i)[ys @2 i := x]])"
  using list_update_by_partition_by[of "concat xss" ys i] partition_by_concat_id[of xss ys] assms
  by (auto simp: length_concat)

lemma partition_by_idx1_bound:
  "i < sum_list ys ⟹ ys @1 i < length ys"
  apply (induct ys arbitrary: i)
  apply (auto simp: partition_by_idx1_def split: prod.splits)
  by (metis add_diff_inverse_nat add_less_imp_less_left fst_conv)

lemma partition_by_idx2_bound:
  "i < sum_list ys ⟹ ys @2 i <  ys ! ys @1 i"
  apply (induct ys arbitrary: i)
  apply (auto simp: partition_by_idx1_def partition_by_idx2_def split: prod.splits)
  by (metis add_diff_inverse_nat fst_conv nat_add_left_cancel_less snd_conv)

subsection ‹From multihole contexts to terms and back›

fun mctxt_term_conv :: "('f, 'v) mctxt ⇒ ('f, 'v option) term" where
  "mctxt_term_conv MHole = Var None"
| "mctxt_term_conv (MVar v) = Var (Some v)"
| "mctxt_term_conv (MFun f Cs) = Fun f (map mctxt_term_conv Cs)"

fun term_mctxt_conv :: "('f, 'v option) term ⇒ ('f, 'v) mctxt" where
  "term_mctxt_conv (Var None) = MHole"
| "term_mctxt_conv (Var (Some v)) = MVar v"
| "term_mctxt_conv (Fun f ts) = MFun f (map term_mctxt_conv ts)"

lemma mctxt_term_conv_inv [simp]:
  "mctxt_term_conv (term_mctxt_conv t) = t"
  by (induct t rule: term_mctxt_conv.induct) (auto simp: map_idI)

lemma term_mctxt_conv_inv [simp]:
  "term_mctxt_conv (mctxt_term_conv t) = t"
  by (induct t rule: mctxt_term_conv.induct) (auto simp: map_idI)
  
lemma mctxt_term_conv_bij:
  "bij mctxt_term_conv"
  by (auto intro!: o_bij[of term_mctxt_conv mctxt_term_conv])

lemma term_mctxt_conv_bij:
  "bij term_mctxt_conv"
  by (auto intro!: o_bij[of mctxt_term_conv term_mctxt_conv])

lemma mctxt_term_conv_mctxt_of_term[simp]:
  "mctxt_term_conv (mctxt_of_term t) = t ⋅ (Var ∘ Some)"
  by (induct t) auto

lemma term_mctxt_conv_mctxt_of_term_conv:
  "term_mctxt_conv (t ⋅ (Var ∘ Some)) = mctxt_of_term t"
  by (induct t) auto

lemma weak_match_mctxt_term_conv_mono:
  "C ≤ D ⟹ weak_match (mctxt_term_conv D) (mctxt_term_conv C)"
  by (induct C D rule: less_eq_mctxt_induct) auto

definition term_of_mctxt_subst where "term_of_mctxt_subst = case_option (term_of_mctxt MHole) Var"

lemma term_of_mctxt_to_mctxt_term_conv:
  "term_of_mctxt C = mctxt_term_conv C ⋅ term_of_mctxt_subst"
  by (induct C) (auto simp: term_of_mctxt_subst_def)

lemma poss_mctxt_term_conv[simp]:
  "poss (mctxt_term_conv C) = all_poss_mctxt C"
  by (induct C) auto

lemma funas_term_mctxt_term_conv[simp]:
  "funas_term (mctxt_term_conv C) = funas_mctxt C"
  by (induct C) auto

lemma all_poss_mctxt_term_mctxt_conv[simp]:
  "all_poss_mctxt (term_mctxt_conv t) = poss t"
  by (induct t rule: term_mctxt_conv.induct) auto

lemma funas_mctxt_term_mctxt_conv[simp]:
  "funas_mctxt (term_mctxt_conv t) = funas_term t"
  by (induct t rule: term_mctxt_conv.induct) auto

lemma subm_at_term_mctxt_conv:
  "p ∈ poss t ⟹ subm_at (term_mctxt_conv t) p = term_mctxt_conv (subt_at t p)"
  by (induct t p rule: subt_at.induct) auto

lemma subt_at_mctxt_term_conv:
  "p ∈ all_poss_mctxt C ⟹ subt_at (mctxt_term_conv C) p = mctxt_term_conv (subm_at C p)"
  by (induct C p rule: subm_at.induct) auto

lemma subm_at_subt_at_conv:
  "p ∈ all_poss_mctxt C ⟹ subm_at C p = term_mctxt_conv (subt_at (mctxt_term_conv C) p)"
  by (induct C p rule: subm_at.induct) auto

lemma mctxt_term_conv_fill_holes_mctxt:
  assumes "num_holes C = length Cs"
  shows "mctxt_term_conv (fill_holes_mctxt C Cs) = fill_holes (map_vars_mctxt Some C) (map mctxt_term_conv Cs)"
  using assms by (induct C Cs rule: fill_holes_induct) (auto simp: comp_def)

lemma mctxt_term_conv_map_vars_mctxt_subst:
  shows "mctxt_term_conv (map_vars_mctxt f C) = mctxt_term_conv C ⋅ (Var ∘ map_option f)"
  by (induct C) auto

(* an mctxt version of fill_holes_mctxt_fill_holes *)
lemma fill_fill_holes_mctxt:
  assumes "length Cs' = num_holes L'" "length Cs = num_holes (fill_holes_mctxt L' Cs')"
  shows "fill_holes_mctxt (fill_holes_mctxt L' Cs') Cs = fill_holes_mctxt L'
     (map (λ(D, Es). fill_holes_mctxt D Es) (zip Cs' (partition_holes Cs Cs')))" (is "?L = ?R")
proof -
  note fill_holes_mctxt_fill_holes
  have "fill_holes (fill_holes_mctxt (map_vars_mctxt Some L') (map (map_vars_mctxt Some) Cs'))
    (map mctxt_term_conv Cs) = fill_holes (map_vars_mctxt Some L')
    (map (λx. mctxt_term_conv (fill_holes_mctxt (Cs' ! x) (partition_holes Cs Cs' ! x))) [0..<num_holes L'])"
    using assms by (subst fill_holes_mctxt_fill_holes)
      (auto simp: comp_def mctxt_term_conv_fill_holes_mctxt length_partition_by_nth num_holes_fill_holes_mctxt intro!: arg_cong[of _ _ "fill_holes _"])
  then have "term_mctxt_conv (mctxt_term_conv ?L) = term_mctxt_conv (mctxt_term_conv ?R)"
    using assms by (intro arg_cong[of _ _ "term_mctxt_conv"]) (auto simp: zip_nth_conv comp_def
      mctxt_term_conv_fill_holes_mctxt map_vars_mctxt_fill_holes_mctxt)
  then show ?thesis by simp
qed

inductive_set mrstep :: "('f, 'w) trs ⇒ ('f, 'v) mctxt rel" for R where
  mrstep [intro]: "(mctxt_term_conv C, mctxt_term_conv D) ∈ rstep' R ⟹ (C, D) ∈ mrstep R"

lemma mrstepD: "(C, D) ∈ mrstep R ⟹ (mctxt_term_conv C, mctxt_term_conv D) ∈ rstep' R"
  by (rule mrstep.induct)

lemma mrstepI_inf: (* caveat: note the implied 'v :: infinite *)
  assumes "(mctxt_term_conv C ⋅ (Var ∘ from_option), mctxt_term_conv D ⋅ (Var ∘ from_option)) ∈ rstep R"
  shows "(C, D) ∈ mrstep R"
  using rstep'_stable[OF assms[unfolded rstep_eq_rstep'], of "Var ∘ to_option"]
  by (intro mrstep.intros) (auto simp del: subst_subst_compose simp: subst_subst subst_compose_def)

lemma mrstepD_inf: (* caveat: note the implied 'v :: infinite *)
  "(C, D) ∈ mrstep R ⟹ (mctxt_term_conv C ⋅ (Var ∘ from_option), mctxt_term_conv D ⋅ (Var ∘ from_option)) ∈ rstep R"
  by (metis rstep'_stable mrstepD rstep_eq_rstep')

lemma NF_MVar_MHole: 
  assumes "wf_trs R" and "C = MVar x ∨ C = MHole" 
  shows "(C, D) ∉ mrstep R"
  using NF_Var'[OF assms(1)] assms(2) by (force dest: mrstepD)

definition funposs_mctxt :: "('f, 'v) mctxt ⇒ pos set" where
  "funposs_mctxt C = funposs (mctxt_term_conv C)"

lemma funposs_mctxt_subset_poss_mctxt:
  "funposs_mctxt C ⊆ poss_mctxt C"
  by (induct C) (force simp: funposs_mctxt_def)+

lemma funposs_mctxt_subset_all_poss_mctxt:
  "funposs_mctxt C ⊆ all_poss_mctxt C"
  by (induct C) (force simp: funposs_mctxt_def)+

lemma funposs_mctxt_mono:
  "C ≤ D ⟹ p ∈ funposs_mctxt C ⟹ p ∈ funposs_mctxt D"
  unfolding less_eq_mctxt_prime funposs_mctxt_def
  by (induct C D arbitrary: p rule: less_eq_mctxt'.induct) auto

lemma funposs_mctxt_compat:
  "C ≤ D ⟹ p ∈ poss_mctxt C ⟹ p ∈ funposs_mctxt D ⟹ p ∈ funposs_mctxt C"
  unfolding less_eq_mctxt_prime funposs_mctxt_def
  by (induct C D arbitrary: p rule: less_eq_mctxt'.induct) auto

lemma funposs_mctxt_mctxt_of_term[simp]:
  "funposs_mctxt (mctxt_of_term t) = funposs t"
  by (induct t) (auto simp: funposs_mctxt_def)

lemma proper_prefix_hole_poss_imp_funposs:
  assumes "p ∈ hole_poss C" "q < p"
  shows "q ∈ funposs_mctxt C"
  using assms
  apply (induct C arbitrary: p q; case_tac p; case_tac q)
  apply (auto simp: funposs_mctxt_def)
  apply (metis (no_types, hide_lams) lessThan_iff less_simps(4) nth_mem pos.exhaust)+
  done

subsection ‹finiteness of prefixes›

lemma finite_set_Cons:
  assumes A: "finite A" and B: "finite B"
  shows "finite (set_Cons A B)"
proof -
  have "set_Cons A B = case_prod (op #) ` (A × B)" by (auto simp: set_Cons_def)
  then show ?thesis
    by (simp add: finite_imageI[OF finite_cartesian_product[OF A B],of "case_prod (op #)"])
qed

lemma listset_finite:
  assumes "∀A ∈ set As. finite A"
  shows "finite (listset As)"
  using assms
  by (induct As) (auto simp: finite_set_Cons)

lemma elem_listset:
  "xs ∈ listset As = (length xs = length As ∧ (∀i < length As. xs ! i ∈ As ! i))"
proof (induct As arbitrary: xs)
  case (Cons A As xs) thus ?case
    by (cases xs) (auto simp: set_Cons_def nth_Cons nat.splits)
qed auto

lemma finite_pre_mctxt:
  fixes C :: "('f, 'v) mctxt"
  shows "finite { N. N ≤ C }"
proof (induct C)
  case MHole
  have *: "{ N. N ≤ MHole } = { MHole }" by (auto simp: less_eq_mctxt_def)
  show ?case by (simp add: *)
next
  case (MVar x)
  have *: "{ N. N ≤ MVar x } = { MHole, MVar x }"
    by (auto simp: less_eq_mctxt_def split_ifs elim: mctxt_neq_mholeE)
  show ?case by (simp add: *)
next
  case (MFun f Cs)
  have *: "{ N. N ≤ MFun f Cs } = { MHole } ∪ MFun f ` { Ds. Ds ∈ listset (map (λC. { N. N ≤ C }) Cs)}"
  unfolding elem_listset
    by (auto simp: image_def less_eq_mctxt_def split_ifs list_eq_iff_nth_eq elim!: mctxt_neq_mholeE) auto
  show ?case
    by (auto simp: * MFun listset_finite[of "map (λC. { N. N ≤ C }) Cs"])
qed

text ‹well-founded-ness of <›

fun mctxt_syms :: "('f, 'v) mctxt ⇒ nat" where
  "mctxt_syms MHole = 0"
| "mctxt_syms (MVar v) = 1"
| "mctxt_syms (MFun f Cs) = 1 + sum_list (map mctxt_syms Cs)" 
  
lemma mctxt_syms_mono:
  "C ≤ D ⟹ mctxt_syms C ≤ mctxt_syms D"
by (induct D arbitrary: C; elim less_eq_mctxtE2)
  (auto simp: map_upt_len_conv[of mctxt_syms,symmetric] intro: sum_list_mono)

lemma sum_list_strict_mono_aux:
  fixes xs ys :: "nat list"
  shows "length xs = length ys ⟹ (⋀i. i < length ys ⟹ xs ! i ≤ ys ! i) ⟹ sum_list xs < sum_list ys ∨ xs = ys"
proof (induct xs arbitrary: ys)
  case (Cons x xs zs) note * = this
  then show ?case
  proof (cases zs)
    case (Cons y ys)
    have hd: "x ≤ y" and tl: "⋀i. i < length ys ⟹ xs ! i ≤ ys ! i"
      using *(3) by (auto simp: Cons nth_Cons nat.splits)
    show ?thesis using hd *(1)[OF _ tl] *(2) by (auto simp: nth_Cons Cons)
  qed auto
qed auto

lemma mctxt_syms_strict_mono[simp]:
  "C < D ⟹ mctxt_syms C < mctxt_syms D"
proof -
  assume "C < D"
  also have "C ≤ D ⟹ C = D ∨ mctxt_syms C < mctxt_syms D"
  proof ((induct D arbitrary: C; elim less_eq_mctxtE2), goal_cases)
    case (5 f Ds C Cs)
    have "i < length Ds ⟹ mctxt_syms (Cs ! i) ≤ mctxt_syms (Ds ! i)" for i
      using 5(1)[OF _ 5(4), of i] by (auto simp: 5(3))
    then show "C = MFun f Ds ∨ mctxt_syms C < mctxt_syms (MFun f Ds)"
      using 5(1)[OF _ 5(4)] sum_list_strict_mono_aux[of "map mctxt_syms Cs" "map mctxt_syms Ds"]
      by (auto simp: list_eq_iff_nth_eq 5(2,3))
  qed auto
  ultimately show ?thesis by (auto simp: less_mctxt_def)
qed

lemma wf_less_mctxt [simp]:
  "wf { (C :: ('f, 'v) mctxt, D). C < D }"
  by (rule wf_subset[of "inv_image { (a, b). a < b } mctxt_syms"]) (auto simp: wf_less)

lemma map_zip2 [simp]:
  "n = length xs ⟹ zip xs (replicate n y) = map (λx. (x,y)) xs"
  by (induct xs arbitrary: n) auto

lemma map_zip1 [simp]:
  "n = length ys ⟹ zip (replicate n x) ys = map (λy. (x,y)) ys"
  by (induct ys arbitrary: n) auto

subsection ‹Hole positions, left to right›

fun hole_poss' :: "('f, 'v) mctxt ⇒ pos list" where
  "hole_poss' (MVar x) = []"
| "hole_poss' MHole = [ε]"
| "hole_poss' (MFun f cs) = concat (map (λi . map (op <# i) (hole_poss' (cs ! i))) [0..<length cs])"

lemma set_hole_poss': "set (hole_poss' C) = hole_poss C"
  by (induct C) auto

lemma length_hole_poss'[simp]: "length (hole_poss' C) = num_holes C"
  by (induct C) (auto simp: length_concat intro!: arg_cong[of _ _ sum_list] nth_equalityI)

lemma hole_poss'_map_vars_mctxt[simp]:
 "hole_poss' (map_vars_mctxt f C) = hole_poss' C"
  by (induct C rule: hole_poss'.induct) (auto intro: arg_cong[of _ _ concat])

lemma subt_at_fill_holes:
  assumes "length ts = num_holes C" and "i < num_holes C"
  shows "subt_at (fill_holes C ts) (hole_poss' C ! i) = ts ! i"
  using assms(1)[symmetric] assms(2)
proof (induct C ts arbitrary: i rule: fill_holes_induct)
  case (MFun f Cs ts)
  have "i < length (concat (map (λi. map (op <# i) (hole_poss' (Cs ! i))) [0..<length Cs]))"
    using MFun arg_cong[OF map_nth, of "map num_holes" Cs] by (auto simp: length_concat comp_def)
  thus ?case
    by (auto simp: nth_map[symmetric] map_concat intro!: arg_cong[of _ _ "λx. concat x ! _"])
       (insert MFun, auto intro!: nth_equalityI)
qed auto

lemma subm_at_fill_holes_mctxt:
  assumes "length Ds = num_holes C" and "i < num_holes C"
  shows "subm_at (fill_holes_mctxt C Ds) (hole_poss' C ! i) = Ds ! i"
  using assms(1)[symmetric] assms(2)
proof (induct C Ds arbitrary: i rule: fill_holes_induct)
  case (MFun f Cs Ds)
  have "i < length (concat (map (λi. map (op <# i) (hole_poss' (Cs ! i))) [0..<length Cs]))"
    using MFun arg_cong[OF map_nth, of "map num_holes" Cs] by (auto simp: length_concat comp_def)
  thus ?case
    by (auto simp: nth_map[symmetric] map_concat intro!: arg_cong[of _ _ "λx. concat x ! _"])
       (insert MFun, auto intro!: nth_equalityI)
qed auto

lemma ctxt_of_pos_term_fill_holes:
  assumes "num_holes C = length ts" "i < num_holes C"
  shows "ctxt_of_pos_term (hole_poss' C ! i) (fill_holes C (ts[i := t])) =
    ctxt_of_pos_term (hole_poss' C ! i) (fill_holes C ts)"
  using assms
proof (induct C ts arbitrary: i rule: fill_holes_induct)
  case (MFun f Cs ts)
  then show ?case using partition_by_idx1_bound[of i "map num_holes Cs"] partition_by_idx2_bound[of i "map num_holes Cs"]
  unfolding fill_holes.simps hole_poss'.simps num_holes.simps
    apply (subst (1 2) nth_concat_by_shape, (simp add: comp_def map_upt_len_conv; fail)+)
    apply (subst list_update_concat_by_shape, (simp add: comp_def map_upt_len_conv; fail)+)
    apply (subst (1 2) nth_map, simp)
    apply (subst (1 2) nth_map, simp)
    apply (subst partition_by_concat_id)
    apply (auto intro!: nth_equalityI simp: nth_list_update)
    done
qed auto

lemma hole_poss_in_poss_fill_holes:
  assumes "num_holes C = length ts" "i < num_holes C"
  shows "hole_poss' C ! i ∈ poss (fill_holes C ts)"
proof -
  have "hole_poss C ⊆ poss (fill_holes C ts)"
    using all_poss_mctxt_mono[OF fill_holes_suffix[of C ts]] assms by (simp add: all_poss_mctxt_conv)
  then show ?thesis using assms(2) set_hole_poss'[of C] by auto
qed

lemma replace_at_fill_holes:
  assumes "num_holes C = length ts" "i < num_holes C"
  shows "replace_at (fill_holes C ts) (hole_poss' C ! i) ti = fill_holes C (ts[i := ti])"
proof -
  show ?thesis using assms ctxt_supt_id[OF hole_poss_in_poss_fill_holes, of C "ts[i := ti]" i]
    by (simp add: ctxt_of_pos_term_fill_holes subt_at_fill_holes)
qed

lemma fill_holes_rstep_r_p_s':
  assumes "num_holes C = length ss" "i < num_holes C" "(ss ! i, ti) ∈ rstep_r_p_s' ℛ (l, r) p σ"
  shows "(fill_holes C ss, fill_holes C (ss[i := ti])) ∈ rstep_r_p_s' ℛ (l, r) (hole_poss' C ! i <#> p) σ"
  using rstep_r_p_s'_mono[OF assms(3), of "ctxt_of_pos_term (hole_poss' C ! i) (fill_holes C ss)"] assms(1,2)
    hole_poss_in_poss_fill_holes[OF assms(1,2)]
  by (simp add: replace_at_fill_holes)

lemma poss_mctxt_append_poss_mctxt:
  "(p <#> q) ∈ poss_mctxt C ⟷ p ∈ all_poss_mctxt C ∧ q ∈ poss_mctxt (subm_at C p)"
  by (induct p arbitrary: C; case_tac C) auto

lemma hole_poss_fill_holes_mctxt:
  assumes "num_holes C = length Ds"
  shows "hole_poss (fill_holes_mctxt C Ds) = {hole_poss' C ! i <#> q |i q. i < length Ds ∧ q ∈ hole_poss (Ds ! i)}"
    (is "?L = ?R")
  using assms
proof -
  have "p ∈ ?L ⟹ p ∈ ?R" for p using assms
  proof (induct C Ds arbitrary: p rule: fill_holes_induct)
    case (MFun f Cs Ds)
    obtain i p' where
      *: "i < length Cs" "p = i <# p'"
      "p' ∈ hole_poss (fill_holes_mctxt (Cs ! i) (partition_holes Ds Cs ! i))"
      using MFun(1,3) by (auto)
    with MFun(2)[of i p'] obtain j q where
      "j <  length (partition_holes Ds Cs ! i)" "q ∈ hole_poss (partition_holes Ds Cs ! i ! j)"
      "p' = hole_poss' (Cs ! i) ! j <#> q" by auto
    then show ?case
      using MFun(1) * partition_by_nth_nth(1)[of "map num_holes Cs" "hole_poss' (MFun f Cs)" i j,
          simplified, unfolded length_concat]
        partition_by_concat_id[of "map (λi. map (op <# i) (hole_poss' (Cs ! i))) [0..<length Cs]"
          "map num_holes Cs", simplified]
      by (auto intro!: exI[of _ "partition_by_idx (sum_list (map num_holes Cs)) (map num_holes Cs) i j"] exI[of _ q]
        simp: partition_by_nth_nth same_append_eq'[of "PCons _ _", simplified] comp_def map_upt_len_conv)
  qed auto
  moreover have "p ∈ ?R ⟹ p ∈ ?L" for p using assms
  proof (induct C Ds arbitrary: p rule: fill_holes_induct)
    case (MFun f Cs Ds)
    obtain i q where *: "i < length Ds" "q ∈ hole_poss (Ds ! i)"
      "p = concat (map (λi. map (op <# i) (hole_poss' (Cs ! i))) [0..<length Cs]) ! i <#> q"
      using MFun(1,3) by auto
    let ?i = "map num_holes Cs @1 i" and ?j = "map num_holes Cs @2 i"
    have
      "?i < length Cs" "?j < num_holes (Cs ! ?i)" "q ∈ hole_poss (partition_holes Ds Cs ! ?i ! ?j)"
      "p = ?i <# hole_poss' (Cs ! ?i) ! ?j <#> q"
      using partition_by_idx1_bound[of _ "map num_holes Cs"] * MFun(1)
        partition_by_idx2_bound[of _ "map num_holes Cs"] nth_by_partition_by[of Ds "map num_holes Cs"]
        nth_by_partition_by[of "hole_poss' (MFun f Cs)" "map num_holes Cs" i]
        partition_by_concat_id[of "map (λi. map (op <# i) (hole_poss' (Cs ! i))) [0..<length Cs]"
          "map num_holes Cs"]
      by (auto simp: length_concat comp_def map_upt_len_conv )
    then show ?case using MFun(1)
      by (auto intro!: MFun(2)[of ?i "hole_poss' (Cs ! ?i) ! ?j <#> q"] exI[of _ ?j] exI[of _ "q"])
  qed auto
  ultimately show ?thesis by blast
qed

lemma pos_diff_hole_possI:
  "q ∈ hole_poss C ⟹ p ≤ q ⟹ pos_diff q p ∈ hole_poss (subm_at C p)"
  by (induct C p arbitrary: q rule: subm_at.induct) auto

lemma unfill_holes_conv:
  assumes "C ≤ mctxt_of_term t"
  shows "unfill_holes C t = map (subt_at t) (hole_poss' C)"
  using assms
proof (induct C t rule: unfill_holes.induct)
  case (3 f Cs g ts) show ?case using 3(2)
    by (auto elim!: less_eq_mctxtE2 simp: map_concat comp_def 3(1) intro!: arg_cong[of _ _ concat])
qed (auto elim: less_eq_mctxtE2)

lemma unfill_holes_by_prefix:
  assumes "C ≤ D" and "D ≤ mctxt_of_term t"
  shows "unfill_holes D t = concat (map (λp. unfill_holes (subm_at D p) (subt_at t p)) (hole_poss' C))"
  using assms
proof (induct C arbitrary: D t)
  case (MVar x) thus ?case by (cases t) (auto elim!: less_eq_mctxtE1)
next
  case (MFun f Cs)
  obtain Ds where D[simp]: "D = MFun f Ds" "length Ds = length Cs" using MFun(2) by (auto elim: less_eq_mctxtE1)
  obtain ts where t[simp]: "t = Fun f ts" "length ts = length Cs" using MFun(3) by (cases t) (auto elim: less_eq_mctxtE1)
  have "i < length Cs ⟹ unfill_holes (Ds ! i) (ts ! i) =
    concat (map (λp. unfill_holes (subm_at (Ds ! i) p) (subt_at (ts ! i) p)) (hole_poss' (Cs ! i)))" for i
  proof (intro MFun(1))
    assume "i < length Cs" thus "Cs ! i ≤ Ds ! i" using MFun(2) by (auto elim: less_eq_mctxtE1)
  next
    assume "i < length Cs" thus "Ds ! i ≤ mctxt_of_term (ts ! i)" using MFun(3)
    by (auto elim: less_eq_mctxt_MFunE1)
  qed auto
  then have *: "map (λi. unfill_holes (Ds ! i) (ts ! i)) [0..<length Cs] =
        map (concat ∘ map (λp. unfill_holes (subm_at (MFun f Ds) p) (Fun f ts |_ p)) ∘
             (λi. map (op <# i) (hole_poss' (Cs ! i)))) [0..<length Cs]"
    by (intro nth_equalityI) (auto simp: o_def)
  then show ?case by (auto simp add: map_concat map_map[symmetric] simp del: map_map)
qed auto

lemma unfill_holes_mctxt_conv:
  assumes "C ≤ D"
  shows "unfill_holes_mctxt C D = map (subm_at D) (hole_poss' C)"
  using assms
proof (induct C D rule: unfill_holes_mctxt.induct)
  case (3 f Cs g Ds) show ?case using 3(2)
    by (auto elim!: less_eq_mctxtE2 simp: map_concat comp_def 3(1) intro!: arg_cong[of _ _ concat])
qed (auto elim: less_eq_mctxtE2)

lemma map_vars_Some_le_mctxt_of_term_mctxt_term_conv:
  "map_vars_mctxt Some C ≤ mctxt_of_term (mctxt_term_conv C)"
  by (induct C) (auto intro: less_eq_mctxtI1)

lemma unfill_holes_map_vars_mctxt_Some_mctxt_term_conv_conv:
  "C ≤ E ⟹ unfill_holes (map_vars_mctxt Some C) (mctxt_term_conv E) = map mctxt_term_conv (unfill_holes_mctxt C E)"
  by (induct C E rule: less_eq_mctxt_induct) (auto simp: map_concat intro!: arg_cong[of _ _ concat])

lemma unfill_holes_mctxt_by_prefix':
  assumes "num_holes C = length Ds" "fill_holes_mctxt C Ds ≤ E"
  shows "unfill_holes_mctxt (fill_holes_mctxt C Ds) E = concat (map (λ(D, E). unfill_holes_mctxt D E) (zip Ds (unfill_holes_mctxt C E)))"
proof -
  have "C ≤ E" using assms(2) fill_holes_mctxt_suffix[OF assms(1)[symmetric]] by auto
  have [simp]: "i < length Ds ⟹ Ds ! i ≤ unfill_holes_mctxt C E ! i" for i using assms
    by (subst (asm) fill_unfill_holes_mctxt[OF `C ≤ E`, symmetric])
      (auto simp: less_eq_fill_holes_iff `C ≤ E`)
  show ?thesis using `C ≤ E` assms
    arg_cong[OF unfill_holes_by_prefix'[of "map_vars_mctxt Some C" "map (map_vars_mctxt Some) Ds" "mctxt_term_conv E"], of "map term_mctxt_conv"]
    order.trans[OF map_vars_mctxt_mono[OF assms(2), of Some] map_vars_Some_le_mctxt_of_term_mctxt_term_conv[of E]]
    unfolding map_vars_mctxt_fill_holes_mctxt[symmetric, OF assms(1)]
    by (auto simp: zip_map1 zip_map2 comp_def unfill_holes_map_vars_mctxt_Some_mctxt_term_conv_conv
      prod.case_distrib map_concat in_set_conv_nth intro!: arg_cong[of _ _ concat])
qed

lemma hole_poss_fill_holes_mctxt_conv:
  assumes "i < num_holes C" "length Cs = num_holes C"
  shows "hole_poss' C ! i ∈ hole_poss (fill_holes_mctxt C Cs) ⟷ Cs ! i = MHole" (is "?L ⟷ ?R")
proof
  assume ?L then show ?R
    using assms arg_cong[OF unfill_holes_mctxt_conv[of C "fill_holes_mctxt C Cs"], of "λCs. Cs ! i"]
    by (auto simp: unfill_fill_holes_mctxt)
next
  assume ?R then show ?L using assms by (force simp: hole_poss_fill_holes_mctxt)
qed

end (* LS_Extras *)