theory RBT_ext
imports
"HOL-Library.RBT_Impl"
Containers_Auxiliary
List_Fusion
begin
section {* More on red-black trees *}
subsection {* More lemmas *}
context linorder begin
lemma is_rbt_fold_rbt_insert_impl:
"is_rbt t ⟹ is_rbt (RBT_Impl.fold rbt_insert t' t)"
by(simp add: rbt_insert_def is_rbt_fold_rbt_insertwk)
lemma rbt_sorted_fold_insert: "rbt_sorted t ⟹ rbt_sorted (RBT_Impl.fold rbt_insert t' t)"
by(induct t' arbitrary: t)(simp_all add: rbt_insert_rbt_sorted)
lemma rbt_lookup_rbt_insert': "rbt_sorted t ⟹ rbt_lookup (rbt_insert k v t) = rbt_lookup t(k ↦ v)"
by(simp add: rbt_insert_def rbt_lookup_rbt_insertwk fun_eq_iff split: option.split)
lemma rbt_lookup_fold_rbt_insert_impl:
"rbt_sorted t2 ⟹
rbt_lookup (RBT_Impl.fold rbt_insert t1 t2) = rbt_lookup t2 ++ map_of (rev (RBT_Impl.entries t1))"
proof(induction t1 arbitrary: t2)
case Empty thus ?case by simp
next
case (Branch c l x k r)
show ?case using Branch.prems
by(simp add: map_add_def Branch.IH rbt_insert_rbt_sorted rbt_sorted_fold_insert rbt_lookup_rbt_insert' fun_eq_iff split: option.split)
qed
end
subsection {* Build the cross product of two RBTs *}
context fixes f :: "'a ⇒ 'b ⇒ 'c ⇒ 'd ⇒ 'e" begin
definition alist_product :: "('a × 'b) list ⇒ ('c × 'd) list ⇒ (('a × 'c) × 'e) list"
where "alist_product xs ys = concat (map (λ(a, b). map (λ(c, d). ((a, c), f a b c d)) ys) xs)"
lemma alist_product_simps [simp]:
"alist_product [] ys = []"
"alist_product xs [] = []"
"alist_product ((a, b) # xs) ys = map (λ(c, d). ((a, c), f a b c d)) ys @ alist_product xs ys"
by(simp_all add: alist_product_def)
lemma append_alist_product_conv_fold:
"zs @ alist_product xs ys = rev (fold (λ(a, b). fold (λ(c, d) rest. ((a, c), f a b c d) # rest) ys) xs (rev zs))"
proof(induction xs arbitrary: zs)
case Nil thus ?case by simp
next
case (Cons x xs)
obtain a b where x: "x = (a, b)" by(cases x)
have "⋀zs. fold (λ(c, d). (#) ((a, c), f a b c d)) ys zs =
rev (map (λ(c, d). ((a, c), f a b c d)) ys) @ zs"
by(induct ys) auto
with Cons.IH[of "zs @ map (λ(c, d). ((a, c), f a b c d)) ys"] x
show ?case by simp
qed
lemma alist_product_code [code]:
"alist_product xs ys =
rev (fold (λ(a, b). fold (λ(c, d) rest. ((a, c), f a b c d) # rest) ys) xs [])"
using append_alist_product_conv_fold[of "[]" xs ys]
by simp
lemma set_alist_product:
"set (alist_product xs ys) =
(λ((a, b), (c, d)). ((a, c), f a b c d)) ` (set xs × set ys)"
by(auto 4 3 simp add: alist_product_def intro: rev_image_eqI rev_bexI)
lemma distinct_alist_product:
"⟦ distinct (map fst xs); distinct (map fst ys) ⟧
⟹ distinct (map fst (alist_product xs ys))"
proof(induct xs)
case Nil thus ?case by simp
next
case (Cons x xs)
obtain a b where x: "x = (a, b)" by(cases x)
have "distinct (map (fst ∘ (λ(c, d). ((a, c), f a b c d))) ys)"
using `distinct (map fst ys)` by(induct ys)(auto intro: rev_image_eqI)
thus ?case using Cons x by(auto simp add: set_alist_product intro: rev_image_eqI)
qed
lemma map_of_alist_product:
"map_of (alist_product xs ys) (a, c) =
(case map_of xs a of None ⇒ None
| Some b ⇒ map_option (f a b c) (map_of ys c))"
proof(induction xs)
case Nil thus ?case by simp
next
case (Cons x xs)
obtain a b where x: "x = (a, b)" by (cases x)
let ?map = "map (λ(c, d). ((a, c), f a b c d)) ys"
have "map_of ?map (a, c) = map_option (f a b c) (map_of ys c)"
by(induct ys) auto
moreover {
fix a' assume "a' ≠ a"
hence "map_of ?map (a', c) = None"
by(induct ys) auto }
ultimately show ?case using x Cons.IH
by(auto simp add: map_add_def split: option.split)
qed
definition rbt_product :: "('a, 'b) rbt ⇒ ('c, 'd) rbt ⇒ ('a × 'c, 'e) rbt"
where
"rbt_product rbt1 rbt2 = rbtreeify (alist_product (RBT_Impl.entries rbt1) (RBT_Impl.entries rbt2))"
lemma rbt_product_code [code]:
"rbt_product rbt1 rbt2 =
rbtreeify (rev (RBT_Impl.fold (λa b. RBT_Impl.fold (λc d rest. ((a, c), f a b c d) # rest) rbt2) rbt1 []))"
unfolding rbt_product_def alist_product_code RBT_Impl.fold_def ..
end
context
fixes leq_a :: "'a ⇒ 'a ⇒ bool" (infix "⊑⇩a" 50)
and less_a :: "'a ⇒ 'a ⇒ bool" (infix "⊏⇩a" 50)
and leq_b :: "'b ⇒ 'b ⇒ bool" (infix "⊑⇩b" 50)
and less_b :: "'b ⇒ 'b ⇒ bool" (infix "⊏⇩b" 50)
assumes lin_a: "class.linorder leq_a less_a"
and lin_b: "class.linorder leq_b less_b"
begin
abbreviation (input) less_eq_prod' :: "('a × 'b) ⇒ ('a × 'b) ⇒ bool" (infix "⊑" 50)
where "less_eq_prod' ≡ less_eq_prod leq_a less_a leq_b"
abbreviation (input) less_prod' :: "('a × 'b) ⇒ ('a × 'b) ⇒ bool" (infix "⊏" 50)
where "less_prod' ≡ less_prod leq_a less_a less_b"
lemmas linorder_prod = linorder_prod[OF lin_a lin_b]
lemma sorted_alist_product:
assumes xs: "linorder.sorted leq_a (map fst xs)" "distinct (map fst xs)"
and ys: "linorder.sorted (⊑⇩b) (map fst ys)"
shows "linorder.sorted (⊑) (map fst (alist_product f xs ys))"
proof -
interpret a: linorder "(⊑⇩a)" "(⊏⇩a)" by(fact lin_a)
note [simp] =
linorder.sorted.simps(1)[OF linorder_prod] linorder.sorted.simps(2)[OF linorder_prod]
linorder.sorted_append[OF linorder_prod]
linorder.sorted.simps(2)[OF lin_b]
show ?thesis using xs
proof(induction xs)
case Nil show ?case by simp
next
case (Cons x xs)
obtain a b where x: "x = (a, b)" by(cases x)
have "linorder.sorted (⊑) (map fst (map (λ(c, d). ((a, c), f a b c d)) ys))"
using ys by(induct ys) auto
thus ?case using x Cons
by(fastforce simp add: set_alist_product a.not_less dest: bspec a.antisym intro: rev_image_eqI)
qed
qed
lemma is_rbt_rbt_product:
"⟦ ord.is_rbt (⊏⇩a) rbt1; ord.is_rbt (⊏⇩b) rbt2 ⟧
⟹ ord.is_rbt (⊏) (rbt_product f rbt1 rbt2)"
unfolding rbt_product_def
by(blast intro: linorder.is_rbt_rbtreeify[OF linorder_prod] sorted_alist_product linorder.rbt_sorted_entries[OF lin_a] ord.is_rbt_rbt_sorted linorder.distinct_entries[OF lin_a] linorder.rbt_sorted_entries[OF lin_b] distinct_alist_product linorder.distinct_entries[OF lin_b])
lemma rbt_lookup_rbt_product:
"⟦ ord.is_rbt (⊏⇩a) rbt1; ord.is_rbt (⊏⇩b) rbt2 ⟧
⟹ ord.rbt_lookup (⊏) (rbt_product f rbt1 rbt2) (a, c) =
(case ord.rbt_lookup (⊏⇩a) rbt1 a of None ⇒ None
| Some b ⇒ map_option (f a b c) (ord.rbt_lookup (⊏⇩b) rbt2 c))"
by(simp add: rbt_product_def linorder.rbt_lookup_rbtreeify[OF linorder_prod] linorder.is_rbt_rbtreeify[OF linorder_prod] sorted_alist_product linorder.rbt_sorted_entries[OF lin_a] ord.is_rbt_rbt_sorted linorder.distinct_entries[OF lin_a] linorder.rbt_sorted_entries[OF lin_b] distinct_alist_product linorder.distinct_entries[OF lin_b] map_of_alist_product linorder.map_of_entries[OF lin_a] linorder.map_of_entries[OF lin_b] cong: option.case_cong)
end
hide_const less_eq_prod' less_prod'
subsection {* Build an RBT where keys are paired with themselves *}
primrec RBT_Impl_diag :: "('a, 'b) rbt ⇒ ('a × 'a, 'b) rbt"
where
"RBT_Impl_diag rbt.Empty = rbt.Empty"
| "RBT_Impl_diag (rbt.Branch c l k v r) = rbt.Branch c (RBT_Impl_diag l) (k, k) v (RBT_Impl_diag r)"
lemma entries_RBT_Impl_diag:
"RBT_Impl.entries (RBT_Impl_diag t) = map (λ(k, v). ((k, k), v)) (RBT_Impl.entries t)"
by(induct t) simp_all
lemma keys_RBT_Impl_diag:
"RBT_Impl.keys (RBT_Impl_diag t) = map (λk. (k, k)) (RBT_Impl.keys t)"
by(simp add: RBT_Impl.keys_def entries_RBT_Impl_diag split_beta)
lemma rbt_sorted_RBT_Impl_diag:
"ord.rbt_sorted lt t ⟹ ord.rbt_sorted (less_prod leq lt lt) (RBT_Impl_diag t)"
by(induct t)(auto simp add: ord.rbt_sorted.simps ord.rbt_less_prop ord.rbt_greater_prop keys_RBT_Impl_diag)
lemma bheight_RBT_Impl_diag:
"bheight (RBT_Impl_diag t) = bheight t"
by(induct t) simp_all
lemma inv_RBT_Impl_diag:
assumes "inv1 t" "inv2 t"
shows "inv1 (RBT_Impl_diag t)" "inv2 (RBT_Impl_diag t)"
and "color_of t = color.B ⟹ color_of (RBT_Impl_diag t) = color.B"
using assms by(induct t)(auto simp add: bheight_RBT_Impl_diag)
lemma is_rbt_RBT_Impl_diag:
"ord.is_rbt lt t ⟹ ord.is_rbt (less_prod leq lt lt) (RBT_Impl_diag t)"
by(simp add: ord.is_rbt_def rbt_sorted_RBT_Impl_diag inv_RBT_Impl_diag)
lemma (in linorder) rbt_lookup_RBT_Impl_diag:
"ord.rbt_lookup (less_prod (≤) (<) (<)) (RBT_Impl_diag t) =
(λ(k, k'). if k = k' then ord.rbt_lookup (<) t k else None)"
by(induct t)(auto simp add: ord.rbt_lookup.simps fun_eq_iff)
subsection {* Folding and quantifiers over RBTs *}
definition RBT_Impl_fold1 :: "('a ⇒ 'a ⇒ 'a) ⇒ ('a, unit) RBT_Impl.rbt ⇒ 'a"
where "RBT_Impl_fold1 f rbt = fold f (tl (RBT_Impl.keys rbt)) (hd (RBT_Impl.keys rbt))"
lemma RBT_Impl_fold1_simps [simp, code]:
"RBT_Impl_fold1 f rbt.Empty = undefined"
"RBT_Impl_fold1 f (Branch c rbt.Empty k v r) = RBT_Impl.fold (λk v. f k) r k"
"RBT_Impl_fold1 f (Branch c (Branch c' l' k' v' r') k v r) =
RBT_Impl.fold (λk v. f k) r (f k (RBT_Impl_fold1 f (Branch c' l' k' v' r')))"
by(simp_all add: RBT_Impl_fold1_def RBT_Impl.keys_def fold_map RBT_Impl.fold_def split_def o_def tl_append hd_def split: list.split)
definition RBT_Impl_rbt_all :: "('a ⇒ 'b ⇒ bool) ⇒ ('a, 'b) rbt ⇒ bool"
where [code del]: "RBT_Impl_rbt_all P rbt = (∀(k, v) ∈ set (RBT_Impl.entries rbt). P k v)"
lemma RBT_Impl_rbt_all_simps [simp, code]:
"RBT_Impl_rbt_all P rbt.Empty ⟷ True"
"RBT_Impl_rbt_all P (Branch c l k v r) ⟷ P k v ∧ RBT_Impl_rbt_all P l ∧ RBT_Impl_rbt_all P r"
by(auto simp add: RBT_Impl_rbt_all_def)
definition RBT_Impl_rbt_ex :: "('a ⇒ 'b ⇒ bool) ⇒ ('a, 'b) rbt ⇒ bool"
where [code del]: "RBT_Impl_rbt_ex P rbt = (∃(k, v) ∈ set (RBT_Impl.entries rbt). P k v)"
lemma RBT_Impl_rbt_ex_simps [simp, code]:
"RBT_Impl_rbt_ex P rbt.Empty ⟷ False"
"RBT_Impl_rbt_ex P (Branch c l k v r) ⟷ P k v ∨ RBT_Impl_rbt_ex P l ∨ RBT_Impl_rbt_ex P r"
by(auto simp add: RBT_Impl_rbt_ex_def)
subsection {* List fusion for RBTs *}
type_synonym ('a, 'b, 'c) rbt_generator_state = "('c × ('a, 'b) RBT_Impl.rbt) list × ('a, 'b) RBT_Impl.rbt"
fun rbt_has_next :: "('a, 'b, 'c) rbt_generator_state ⇒ bool"
where
"rbt_has_next ([], rbt.Empty) = False"
| "rbt_has_next _ = True"
fun rbt_keys_next :: "('a, 'b, 'a) rbt_generator_state ⇒ 'a × ('a, 'b, 'a) rbt_generator_state"
where
"rbt_keys_next ((k, t) # kts, rbt.Empty) = (k, kts, t)"
| "rbt_keys_next (kts, rbt.Branch c l k v r) = rbt_keys_next ((k, r) # kts, l)"
lemma rbt_generator_induct [case_names empty split shuffle]:
assumes "P ([], rbt.Empty)"
and "⋀k t kts. P (kts, t) ⟹ P ((k, t) # kts, rbt.Empty)"
and "⋀kts c l k v r. P ((f k v, r) # kts, l) ⟹ P (kts, Branch c l k v r)"
shows "P ktst"
using assms
apply(induction_schema)
apply pat_completeness
apply(relation "measure (λ(kts, t). size_list (λ(k, t). size_rbt (λ_. 1) (λ_. 1) t) kts + size_rbt (λ_. 1) (λ_. 1) t)")
apply simp_all
done
lemma terminates_rbt_keys_generator:
"terminates (rbt_has_next, rbt_keys_next)"
proof
fix ktst :: "('a × ('a, 'b) rbt) list × ('a, 'b) rbt"
show "ktst ∈ terminates_on (rbt_has_next, rbt_keys_next)"
by(induct ktst taking: "λk _. k" rule: rbt_generator_induct)(auto 4 3 intro: terminates_on.intros elim: terminates_on.cases)
qed
lift_definition rbt_keys_generator :: "('a, ('a, 'b, 'a) rbt_generator_state) generator"
is "(rbt_has_next, rbt_keys_next)"
by(rule terminates_rbt_keys_generator)
definition rbt_init :: "('a, 'b) rbt ⇒ ('a, 'b, 'c) rbt_generator_state"
where "rbt_init = Pair []"
lemma has_next_rbt_keys_generator [simp]:
"list.has_next rbt_keys_generator = rbt_has_next"
by transfer simp
lemma next_rbt_keys_generator [simp]:
"list.next rbt_keys_generator = rbt_keys_next"
by transfer simp
lemma unfoldr_rbt_keys_generator_aux:
"list.unfoldr rbt_keys_generator (kts, t) =
RBT_Impl.keys t @ concat (map (λ(k, t). k # RBT_Impl.keys t) kts)"
proof(induct "(kts, t)" arbitrary: kts t taking: "λk _. k" rule: rbt_generator_induct)
case empty thus ?case
by(simp add: list.unfoldr.simps)
next
case split thus ?case
by(subst list.unfoldr.simps) simp
next
case shuffle thus ?case
by(subst list.unfoldr.simps)(subst (asm) list.unfoldr.simps, simp)
qed
corollary unfoldr_rbt_keys_generator:
"list.unfoldr rbt_keys_generator (rbt_init t) = RBT_Impl.keys t"
by(simp add: unfoldr_rbt_keys_generator_aux rbt_init_def)
fun rbt_entries_next ::
"('a, 'b, 'a × 'b) rbt_generator_state ⇒ ('a × 'b) × ('a, 'b, 'a × 'b) rbt_generator_state"
where
"rbt_entries_next ((kv, t) # kts, rbt.Empty) = (kv, kts, t)"
| "rbt_entries_next (kts, rbt.Branch c l k v r) = rbt_entries_next (((k, v), r) # kts, l)"
lemma terminates_rbt_entries_generator:
"terminates (rbt_has_next, rbt_entries_next)"
proof(rule terminatesI)
fix ktst :: "('a, 'b, 'a × 'b) rbt_generator_state"
show "ktst ∈ terminates_on (rbt_has_next, rbt_entries_next)"
by(induct ktst taking: Pair rule: rbt_generator_induct)(auto 4 3 intro: terminates_on.intros elim: terminates_on.cases)
qed
lift_definition rbt_entries_generator :: "('a × 'b, ('a, 'b, 'a × 'b) rbt_generator_state) generator"
is "(rbt_has_next, rbt_entries_next)"
by(rule terminates_rbt_entries_generator)
lemma has_next_rbt_entries_generator [simp]:
"list.has_next rbt_entries_generator = rbt_has_next"
by transfer simp
lemma next_rbt_entries_generator [simp]:
"list.next rbt_entries_generator = rbt_entries_next"
by transfer simp
lemma unfoldr_rbt_entries_generator_aux:
"list.unfoldr rbt_entries_generator (kts, t) =
RBT_Impl.entries t @ concat (map (λ(k, t). k # RBT_Impl.entries t) kts)"
proof(induct "(kts, t)" arbitrary: kts t taking: Pair rule: rbt_generator_induct)
case empty thus ?case
by(simp add: list.unfoldr.simps)
next
case split thus ?case
by(subst list.unfoldr.simps) simp
next
case shuffle thus ?case
by(subst list.unfoldr.simps)(subst (asm) list.unfoldr.simps, simp)
qed
corollary unfoldr_rbt_entries_generator:
"list.unfoldr rbt_entries_generator (rbt_init t) = RBT_Impl.entries t"
by(simp add: unfoldr_rbt_entries_generator_aux rbt_init_def)
end