Theory Multimap

theory Multimap
imports Util RBT_Map_Set_Extension
(*
Author:  Christian Sternagel <c.sternagel@gmail.com> (2011-2015)
Author:  René Thiemann <rene.thiemann@uibk.ac.at> (2011-2015)
License: LGPL (see file COPYING.LESSER)
*)

section ‹Mappings associating keys with lists of values›

theory Multimap
imports
  QTRS.Util
  "Transitive-Closure.RBT_Map_Set_Extension"
begin

section ‹Abstract Multimaps›

text ‹A multi-map is disjoint if value lists resulting from different keys do
not have any common elements.›
definition disjoint :: "('k ⇒ 'v list option) ⇒ bool" where
  "disjoint m ≡ ∀k l vs ws. k ≠ l ∧ m k = Some vs ∧ m l = Some ws ⟶ (set vs ∩ set ws) = {}"

lemma ran_map_upd_Some':
  assumes "m a = Some b" and "disjoint m" and "set c ⊆ set b"
  shows "ran (m(a ↦ c)) ⊆ insert c (ran m - {b})"
proof
  fix x assume "x ∈ ran (m(a ↦ c))"
  then obtain k where upd: "(m(a ↦ c)) k = Some x" by (auto simp: ran_def)
  show "x ∈ insert c (ran m - {b})"
  proof (cases "a = k")
    case True
    from assms and upd have x: "x = c" by (simp add: True)
    thus ?thesis by simp
  next
    case False
    hence m_upd: "(m(a ↦ c)) k = m k" by simp
    from upd[unfolded m_upd] have "m k = Some x" by simp
    hence "x ∈ ran m" by (auto simp: ran_def)
    from ‹m k = Some x› and assms(2)[unfolded disjoint_def] and False and assms(1)
      have 1: "(set b ∩ set x) = {}" by blast
    show ?thesis
    proof (cases "set b = {}")
      case True with assms(3) have 2: "b = c" by simp
      from assms(1) have "b ∈ ran m" by (auto simp: ran_def)
      hence "insert c (ran m - {b}) = ran m" unfolding 2 by auto
      thus ?thesis using ‹x ∈ ran m› by simp
    next
      case False
      with 1 have "b ≠ x" by auto
      with ‹x ∈ ran m› show ?thesis by simp
    qed
  qed
qed

lemma UNION_ran_disjoint:
  assumes "m k = Some v" and "disjoint m" and "set w ⊆ set v"
  shows "UNION (ran (m(k ↦ w))) set = UNION (insert w (ran m - {v})) set"
proof
  show "UNION (ran (m(k ↦ w))) set ⊆ UNION (insert w (ran m - {v})) set"
  proof
    fix x assume "x ∈ UNION (ran (m(k ↦ w))) set"
    then obtain vs where vs: "vs ∈ ran (m(k ↦ w))" and x: "x ∈ set vs" by auto
    from ran_map_upd_Some'[of m k v, OF assms] and vs
      have "vs ∈ insert w (ran m - {v})" by blast
    with x show "x ∈ UNION (insert w (ran m - {v})) set" by blast
  qed
next
  show "UNION (insert w (ran m - {v})) set ⊆ UNION (ran (m(k ↦ w))) set"
  proof
    fix x assume "x ∈ UNION (insert w (ran m - {v})) set"
    then obtain vs where "vs ∈ insert w (ran m - {v})" and x: "x ∈ set vs" by auto
    hence "w = vs ∨ vs ∈ ran m - {v}" by auto
    thus "x ∈ UNION (ran (m(k ↦ w))) set"
    proof
      assume w: "w = vs"
      show ?thesis unfolding w using x by (auto simp: ran_def)
    next
      assume "vs ∈ ran m - {v}"
      then obtain l where ml: "m l = Some vs" and ne: "vs ≠ v" by (auto simp: ran_def)
      show ?thesis
      proof (cases "k = l")
        case True
        from assms(1) and ml and ne show ?thesis by (simp add: True)
      next
        case False
        with ml have "vs ∈ ran (m(k ↦ w))" by (auto simp: ran_def)
        with x show ?thesis by blast
      qed
    qed
  qed
qed

text ‹Mult-maps whose elements are positioned according to their keys, somehow
correspond to injective mappings.›
definition mmap_inj :: "('v ⇒ 'k option) ⇒ ('k ⇒ 'v list option) ⇒ bool" where
  "mmap_inj key m ≡ (∀k vs. m k = Some vs ⟶ (∀v∈set vs. key v = Some k))"

lemma mmap_inj_empty[simp]: "mmap_inj key Map.empty" by (simp add: mmap_inj_def)

lemma mmap_inj_map_upd:
  assumes "⋀v. v ∈ set vs ⟹ key v = Some k"
    and "mmap_inj key m"
  shows "mmap_inj key (m(k ↦ vs))"
  unfolding mmap_inj_def
proof (intro allI impI)
  fix l ws assume 1: "(m(k ↦ vs)) l = Some ws"
  show "∀v ∈ set ws. key v = Some l"
  proof (cases "k = l")
    assume 2: "k = l"
    with 1 have "vs = ws" by simp
    from assms(1)[unfolded this] show ?thesis by (simp add: 2)
  next
    assume "k ≠ l"
    hence 2: "(m(k ↦ vs)) l = m l" by simp
    from 1 have "m l = Some ws" unfolding 2 .
    with assms(2) show ?thesis by (simp add: mmap_inj_def)
  qed
qed


section ‹Multimap implementation by Red-Black Trees›
lemma mmap_inj_rm_empty[simp]: "mmap_inj key (rm.α (rm.empty ()))"
  by (simp add: rm.correct mmap_inj_def)

text ‹All values contained in a multimap.›
"values"">definition "values" :: "('k::linorder, 'v list) rm ⇒ 'v list" where
  "values m ≡ concat (map snd (rm.to_list m))"

lemma values_rm_empty[simp]: "values (rm.empty ()) = []" by (simp add: values_def)

lemma values_ran: "set (values m) = (⋃rs∈ran (rm.α m). set rs)"
  using ran_distinct[of "rm.to_list m"]
  by (auto simp: rm.correct values_def)

definition
  insert_value ::
    "('v ⇒ 'k::linorder option) ⇒ 'v ⇒ ('k, 'v list) rm ⇒ ('k, 'v list) rm"
where
  "insert_value key v m ≡ case key v of
    None ⇒ m
  | Some k ⇒ (case rm.lookup k m of
      None ⇒ rm.update_dj k [v] m
    | Some vs ⇒ rm.update k (List.insert v vs) m)"

lemma rm_lookup_dom: "rm.lookup k m = None ⟷ k ∉ dom (rm.α m)"
  by (simp add: domIff rm.correct)

lemma insert_value_None[simp]:
  "key v = None ⟹ insert_value key v m = m"
  by (simp add: insert_value_def)

lemma rm_α_insert_value_Some[simp]:
  assumes "key v = Some k" and "rm.lookup k m = None"
  shows "rm.α (insert_value key v m) = (rm.α m)(k ↦ [v])"
  unfolding insert_value_def assms option.simps
  by (metis assms(2) rm.invar rm.update_dj_correct(1) rm_lookup_dom)

lemma rm_α_insert_value_Some'[simp]:
  assumes "key v = Some k" and "rm.lookup k m = Some vs"
  shows "rm.α (insert_value key v m) = (rm.α m)(k ↦ List.insert v vs)"
  using rm.update_correct and assms by (simp add: insert_value_def)

lemma mmap_inj_insert_value:
  assumes "mmap_inj key (rm.α m)"
  shows "mmap_inj key (rm.α (insert_value key v m))"
proof (cases "key v")
  case None with assms show ?thesis by simp
next
  case (Some k)
  note Some' = this
  show ?thesis
  proof (cases "rm.lookup k m")
    case None
    with Some' and mmap_inj_map_upd[OF _ assms]
      show ?thesis by force
  next
    case (Some vs)
    with assms have "∀v∈set (List.insert v vs). key v = Some k"
      using Some' by (simp add: mmap_inj_def rm.lookup_correct)
    with Some Some' show ?thesis by (force intro: mmap_inj_map_upd[OF _ assms])
  qed
qed

lemma Some_in_values:
  assumes "rm.lookup k m = Some vs"
  shows "set vs ⊆ set (values m)"
  using assms unfolding values_ran ran_def
  by (auto simp: rm.lookup_correct)

lemma values_insert_value_Some[simp]:
  assumes "key v = Some k"
  shows "set (values (insert_value key v m)) = insert v (set (values m))"
proof (cases "rm.lookup k m")
  case None
  show ?thesis
    unfolding values_ran
    unfolding rm_α_insert_value_Some[of key v k, OF assms None]
    using ran_map_upd[of "rm.α m" k] None
    by (simp add: rm.correct)
next
  case (Some vs)
  show ?thesis
    unfolding values_ran
    unfolding rm_α_insert_value_Some'[of key v k, OF assms Some]
    using subset_UNION_ran[of "rm.α m" k,
      OF _ set_subset_insertI] Some
    using Some_in_values[OF Some, simplified values_ran]
    by (auto simp: rm.correct)
qed

fun
  insert_values ::
    "('v ⇒ 'k::linorder option) ⇒ 'v list ⇒ ('k, 'v list) rm ⇒ ('k, 'v list) rm"
where
  "insert_values _ [] m = m"
| "insert_values key (v#vs) m = insert_value key v (insert_values key vs m)"

lemma mmap_inj_insert_values:
  "mmap_inj key (rm.α m) ⟹ mmap_inj key (rm.α (insert_values key vs m))"
  by (induct vs arbitrary: m) (simp_all add: mmap_inj_insert_value)

lemma values_insert_values[simp]:
  "set (values (insert_values key vs m)) = set [v←vs. key v ≠ None] ∪ set (values m)"
proof (induct vs)
  case Nil show ?case by simp
next
  case (Cons v vs) thus ?case by (cases "key v") simp_all
qed

(*could be imporved by using remove1*)
definition
  delete_value ::
    "('v ⇒ 'k::linorder option) ⇒ 'v ⇒ ('k, 'v list) rm ⇒ ('k, 'v list) rm"
where
  "delete_value key v m ≡ case key v of
    None ⇒ m
  | Some k ⇒ (case rm.lookup k m of
      None ⇒ m
    | Some vs ⇒ rm.update k (removeAll v vs) m)"

lemma delete_value_None[simp]:
  "key v = None ⟹ delete_value key v m = m"
  by (simp add: delete_value_def)

lemma delete_value_None'[simp]:
  "key v = Some k ⟹ rm.lookup k m = None ⟹ delete_value key v m = m"
  by (simp add: delete_value_def)

lemma rm_α_delete_value_Some[simp]:
  assumes "key v = Some k" and "rm.lookup k m = Some vs"
  shows "rm.α (delete_value key v m) = (rm.α m)(k ↦ removeAll v vs)"
  using rm.update_correct and assms by (simp add: delete_value_def)

lemma mmap_inj_delete_value:
  assumes "mmap_inj key (rm.α m)"
  shows "mmap_inj key (rm.α (delete_value key v m))"
proof (cases "key v")
  case None
  from assms show ?thesis unfolding delete_value_None[of key v, OF None] .
next
  case (Some k)
  note Some' = this
  show ?thesis
  proof (cases "rm.lookup k m")
    case None
    show ?thesis
      unfolding delete_value_None'[of key v k, OF Some' None]
      by fact
  next
    case (Some vs)
    with assms have 1: "∀v∈set (removeAll v vs). key v = Some k"
      by (simp add: mmap_inj_def rm.lookup_correct)
    thus ?thesis
      unfolding rm_α_delete_value_Some[of key v k, OF Some' Some]
      using mmap_inj_map_upd[OF _ assms]
      by force
  qed
qed

lemma UNION_insert_removeAll:
  assumes "∀ws∈ran m. ws ≠ vs ⟶ v ∉ set ws"
    and "set vs ⊆ UNION (ran m) set"
  shows "UNION (insert (removeAll v vs) (ran m - {vs})) set = UNION (ran m) set - {v}"
    (is "?A = ?B")
proof
  show "?B ⊆ ?A"
  proof
    fix x assume "x ∈ ?B"
    hence "x ∈ UNION (ran m) set" and "x ≠ v" by auto
    thus "x ∈ ?A" by auto
  qed
next
  show "?A ⊆ ?B"
  proof
    fix x assume "x ∈ ?A"
    then obtain ws where "ws ∈ insert (removeAll v vs) (ran m - {vs})"
      and x: "x ∈ set ws" by blast
    hence "ws = removeAll v vs ∨ ws ∈ ran m - {vs}" by auto
    thus "x ∈ ?B"
    proof
      assume 1: "ws = removeAll v vs"
      with x have "x ≠ v" by auto
      with x[unfolded 1] show ?thesis using assms(2) by auto
    next
      assume "ws ∈ ran m - {vs}"
      with assms have "ws ∈ ran m" and "v ∉ set ws" by auto
      with x have "v ≠ x" by auto
      thus ?thesis using ‹ws ∈ ran m› and x by auto
    qed
  qed
qed

lemma Some_ran:
  "m k = Some vs ⟹ set vs ⊆ UNION (ran m) set"
  unfolding ran_def by auto

lemma mmap_inj_ran:
  assumes "key v = Some k" and "mmap_inj key m" and "m k = Some vs"
  shows "∀ws∈ran m. ws ≠ vs ⟶ v ∉ set ws"
  using assms unfolding mmap_inj_def ran_def by auto

lemma mmap_inj_imp_disjoint:
  assumes "mmap_inj key m"
  shows "disjoint m"
proof (rule ccontr)
  assume "¬ disjoint m"
  then obtain k l vs ws where kl: "k ≠ l" and mk: "m k = Some vs"
    and ml: "m l = Some ws" and ne: "set vs ∩ set ws ≠ {}"
    unfolding disjoint_def by blast
  from ne obtain v where "v ∈ set vs" and "v ∈ set ws" by auto
  from ‹v ∈ set vs› and mk and assms have "key v = Some k"
    by (auto simp: mmap_inj_def)
  moreover from ‹v ∈ set ws› and ml and assms have "key v = Some l"
    by (auto simp: mmap_inj_def)
  ultimately show False using ‹k ≠ l› by simp
qed

lemma values_delete_value_Some[simp]:
  assumes "key v = Some k" and "rm.lookup k m = Some vs" and "mmap_inj key (rm.α m)"
  shows "set (values (delete_value key v m)) = set (values m) - {v}"
  unfolding values_ran
  unfolding rm_α_delete_value_Some[of key v k, OF assms(1-2)]
  using UNION_ran_disjoint[of "rm.α m" k,
    OF _ mmap_inj_imp_disjoint[OF assms(3)] set_removeAll_subset]
    assms(2)
    UNION_insert_removeAll[OF _ Some_ran[of "rm.α m"]]
    mmap_inj_ran[of key v k, OF assms(1) assms(3), of vs] 
  by (simp add: rm.correct)

fun
  delete_values ::
    "('v ⇒ 'k::linorder option) ⇒ 'v list ⇒ ('k, 'v list) rm ⇒ ('k , 'v list) rm"
where
  "delete_values _ [] m = m"
| "delete_values key (v#vs) m = delete_value key v (delete_values key vs m)"

lemma not_in_values_Some:
  assumes key: "key v = Some k"
    and notin: "rm.lookup k m = None"
    and mmap_inj: "mmap_inj key (rm.α m)"
  shows "v ∉ set (values m)"
proof (rule ccontr)
  presume "v ∈ set (values m)"
  from this[unfolded values_ran]
    obtain vs where "vs ∈ ran (rm.α m)" and "v ∈ set vs" by blast
  then obtain l where "rm.α m l = Some vs" by (auto simp: ran_def)
  with mmap_inj and ‹v ∈ set vs› have "key v = Some l" by (auto simp: mmap_inj_def)
  with key have "k = l" by simp
  with notin[unfolded this]
    and ‹rm.α m l = Some vs› show False by (simp add: rm.correct)
qed simp

lemma not_in_values_None:
  assumes key: "key v = None"
    and mmap_inj: "mmap_inj key (rm.α m)"
  shows "v ∉ set (values m)"
proof (rule ccontr)
  presume "v ∈ set (values m)"
  from this[unfolded values_ran]
    obtain vs where "vs ∈ ran (rm.α m)" and "v ∈ set vs" by blast
  then obtain l where "rm.α m l = Some vs" by (auto simp: ran_def)
  with mmap_inj and ‹v ∈ set vs› have "key v = Some l" by (auto simp: mmap_inj_def)
  with key show False by simp
qed simp

lemma mmap_inj_delete_values:
  "mmap_inj key (rm.α m) ⟹ mmap_inj key (rm.α (delete_values key vs m))"
  by (induct vs arbitrary: m) (auto simp: mmap_inj_delete_value)

lemma values_delete_values[simp]:
  assumes "mmap_inj key (rm.α m)"
  shows "set (values (delete_values key vs m)) = set (values m) - set vs"
using assms
proof (induct vs arbitrary: m)
  case Nil show ?case by simp
next
  case (Cons v vs)
  show ?case
  proof (cases "key v")
    case None
    from not_in_values_None[of key v, OF None Cons(2)]
    show ?thesis
      using Cons by (simp add: delete_value_None[of key v, OF None])
  next
    case (Some k)
    note Some' = this
    show ?thesis
    proof (cases "rm.lookup k (delete_values key vs m)")
      case None
      from not_in_values_Some[of key v, OF Some None
        mmap_inj_delete_values[OF Cons(2)]]
        show ?thesis
        using Cons
        using delete_value_None'[of key v, OF Some' None]
        by auto
    next
      case (Some ws)
      show ?thesis
        using Cons
        using values_delete_value_Some[of key v k,
          OF Some' Some mmap_inj_delete_values[OF Cons(2)]]
        by auto
    qed
  qed
qed

definition
  aux ::
    "('v ⇒ 'k::linorder option) ⇒ ('k, 'v list) rm ⇒ 'v ⇒ ('k, 'v list) rm
    ⇒ ('k, 'v list) rm"
where
  "aux key m v m' ≡ case key v of
    Some k ⇒ (case rm.lookup k m of
      Some ws ⇒ (if v ∈ set ws then insert_value key v m' else m')
    | None ⇒ m')
  | None ⇒ m'"

(*do not compute key twice in generated code*)
lemma [code]: "aux key m v m' = (case key v of
    Some k ⇒ (case rm.lookup k m of
      Some ws ⇒ (if v ∈ set ws
        then (case rm.lookup k m' of
          Some vs ⇒ rm.update k (List.insert v vs) m'
        | None ⇒ rm.update_dj k [v] m')
        else m')
      | None ⇒ m')
    | None ⇒ m')"
  unfolding aux_def insert_value_def
  by (simp split: option.splits)

lemma not_in_values:
  assumes inj: "mmap_inj key (rm.α m)"
    and "key x = Some k"
    and "rm.α m k = Some vs"
    and "x ∉ set vs"
  shows "x ∉ set (values m)"
proof (rule ccontr)
  presume "x ∈ set (values m)"
  then obtain ws l where 1: "x ∈ set ws" and "ws ∈ ran (rm.α m)"
    and 2: "rm.α m l = Some ws"
    unfolding values_ran ran_def by auto
  with inj[unfolded mmap_inj_def] 1 2 have "key x = Some l" by simp
  with assms have "k = l" by simp
  thus False using assms 1 2 by simp
qed simp

lemma set_aux:
  assumes inj: "mmap_inj key (rm.α m)"
  shows "set (values (aux key m v m')) = (set (values m) ∩ {v}) ∪ set (values m')"
proof (cases "key v")
  case None thus ?thesis by (auto simp: aux_def not_in_values_None[OF None inj])
next
  case (Some k)
  note key = this
  show ?thesis
  proof (cases "rm.lookup k m")
    case None thus ?thesis by (simp add: aux_def key not_in_values_Some[OF key None inj])
  next
    case (Some vs)
    show ?thesis
    proof (cases "v ∈ set vs")
      case True thus ?thesis using key Some Some_in_values[OF Some]
        by (simp add: aux_def) blast
    next
      case False
      from not_in_values[OF inj key] False Some
      have "v ∉ set (values m)" by (simp add: rm.correct)
      thus ?thesis unfolding aux_def key Some option.cases if_not_P[OF False] by simp
    qed
  qed
qed

definition
  intersect_values ::
    "('v ⇒ 'k::linorder option) ⇒ 'v list ⇒ ('k, 'v list) rm ⇒ ('k, 'v list) rm"
where
  "intersect_values key vs m ≡ foldr (aux key m) vs (rm.empty ())"

lemma foldr_invariant:
  assumes "⋀x. x ∈ set xs ⟹ Q x"
    and "P s"
    and "⋀x s. ⟦Q x; P s⟧ ⟹ P (f x s)"
  shows" P (foldr f xs s)"
  using assms by (induct xs) simp_all

lemma mmap_inj_intersect_values:
  "mmap_inj key (rm.α m) ⟹ mmap_inj key (rm.α (intersect_values key vs m))"
  unfolding intersect_values_def
  by (frule foldr_invariant)
     (auto simp: aux_def split: option.split intro: mmap_inj_insert_value[of key])

lemma values_intersect_values_subset:
  "set (values (intersect_values key vs m)) ⊆ set (values m) ∩ set vs"
  unfolding intersect_values_def
  by (rule foldr_invariant[of _ "λv. v ∈ set vs"])
     (insert Some_in_values, auto simp: aux_def split: option.split)

lemma values_intersect_values[simp]:
  assumes "mmap_inj key (rm.α m)"
  shows "set (values (intersect_values key vs m)) = set (values m) ∩ set vs"
  using assms by (induct vs arbitrary: m) (auto simp: intersect_values_def set_aux)

hide_const aux

end