Theory Nat_Bijection

theory Nat_Bijection
imports Main
(*  Title:      HOL/Library/Nat_Bijection.thy
    Author:     Brian Huffman
    Author:     Florian Haftmann
    Author:     Stefan Richter
    Author:     Tobias Nipkow
    Author:     Alexander Krauss
*)

section ‹Bijections between natural numbers and other types›

theory Nat_Bijection
  imports Main
begin

subsection ‹Type @{typ "nat × nat"}›

text ‹Triangle numbers: 0, 1, 3, 6, 10, 15, ...›

definition triangle :: "nat ⇒ nat"
  where "triangle n = (n * Suc n) div 2"

lemma triangle_0 [simp]: "triangle 0 = 0"
  by (simp add: triangle_def)

lemma triangle_Suc [simp]: "triangle (Suc n) = triangle n + Suc n"
  by (simp add: triangle_def)

definition prod_encode :: "nat × nat ⇒ nat"
  where "prod_encode = (λ(m, n). triangle (m + n) + m)"

text ‹In this auxiliary function, @{term "triangle k + m"} is an invariant.›

fun prod_decode_aux :: "nat ⇒ nat ⇒ nat × nat"
  where "prod_decode_aux k m =
    (if m ≤ k then (m, k - m) else prod_decode_aux (Suc k) (m - Suc k))"

declare prod_decode_aux.simps [simp del]

definition prod_decode :: "nat ⇒ nat × nat"
  where "prod_decode = prod_decode_aux 0"

lemma prod_encode_prod_decode_aux: "prod_encode (prod_decode_aux k m) = triangle k + m"
  apply (induct k m rule: prod_decode_aux.induct)
  apply (subst prod_decode_aux.simps)
  apply (simp add: prod_encode_def)
  done

lemma prod_decode_inverse [simp]: "prod_encode (prod_decode n) = n"
  by (simp add: prod_decode_def prod_encode_prod_decode_aux)

lemma prod_decode_triangle_add: "prod_decode (triangle k + m) = prod_decode_aux k m"
  apply (induct k arbitrary: m)
   apply (simp add: prod_decode_def)
  apply (simp only: triangle_Suc add.assoc)
  apply (subst prod_decode_aux.simps)
  apply simp
  done

lemma prod_encode_inverse [simp]: "prod_decode (prod_encode x) = x"
  unfolding prod_encode_def
  apply (induct x)
  apply (simp add: prod_decode_triangle_add)
  apply (subst prod_decode_aux.simps)
  apply simp
  done

lemma inj_prod_encode: "inj_on prod_encode A"
  by (rule inj_on_inverseI) (rule prod_encode_inverse)

lemma inj_prod_decode: "inj_on prod_decode A"
  by (rule inj_on_inverseI) (rule prod_decode_inverse)

lemma surj_prod_encode: "surj prod_encode"
  by (rule surjI) (rule prod_decode_inverse)

lemma surj_prod_decode: "surj prod_decode"
  by (rule surjI) (rule prod_encode_inverse)

lemma bij_prod_encode: "bij prod_encode"
  by (rule bijI [OF inj_prod_encode surj_prod_encode])

lemma bij_prod_decode: "bij prod_decode"
  by (rule bijI [OF inj_prod_decode surj_prod_decode])

lemma prod_encode_eq: "prod_encode x = prod_encode y ⟷ x = y"
  by (rule inj_prod_encode [THEN inj_eq])

lemma prod_decode_eq: "prod_decode x = prod_decode y ⟷ x = y"
  by (rule inj_prod_decode [THEN inj_eq])


text ‹Ordering properties›

lemma le_prod_encode_1: "a ≤ prod_encode (a, b)"
  by (simp add: prod_encode_def)

lemma le_prod_encode_2: "b ≤ prod_encode (a, b)"
  by (induct b) (simp_all add: prod_encode_def)


subsection ‹Type @{typ "nat + nat"}›

definition sum_encode :: "nat + nat ⇒ nat"
  where "sum_encode x = (case x of Inl a ⇒ 2 * a | Inr b ⇒ Suc (2 * b))"

definition sum_decode :: "nat ⇒ nat + nat"
  where "sum_decode n = (if even n then Inl (n div 2) else Inr (n div 2))"

lemma sum_encode_inverse [simp]: "sum_decode (sum_encode x) = x"
  by (induct x) (simp_all add: sum_decode_def sum_encode_def)

lemma sum_decode_inverse [simp]: "sum_encode (sum_decode n) = n"
  by (simp add: even_two_times_div_two sum_decode_def sum_encode_def)

lemma inj_sum_encode: "inj_on sum_encode A"
  by (rule inj_on_inverseI) (rule sum_encode_inverse)

lemma inj_sum_decode: "inj_on sum_decode A"
  by (rule inj_on_inverseI) (rule sum_decode_inverse)

lemma surj_sum_encode: "surj sum_encode"
  by (rule surjI) (rule sum_decode_inverse)

lemma surj_sum_decode: "surj sum_decode"
  by (rule surjI) (rule sum_encode_inverse)

lemma bij_sum_encode: "bij sum_encode"
  by (rule bijI [OF inj_sum_encode surj_sum_encode])

lemma bij_sum_decode: "bij sum_decode"
  by (rule bijI [OF inj_sum_decode surj_sum_decode])

lemma sum_encode_eq: "sum_encode x = sum_encode y ⟷ x = y"
  by (rule inj_sum_encode [THEN inj_eq])

lemma sum_decode_eq: "sum_decode x = sum_decode y ⟷ x = y"
  by (rule inj_sum_decode [THEN inj_eq])


subsection ‹Type @{typ "int"}›

definition int_encode :: "int ⇒ nat"
  where "int_encode i = sum_encode (if 0 ≤ i then Inl (nat i) else Inr (nat (- i - 1)))"

definition int_decode :: "nat ⇒ int"
  where "int_decode n = (case sum_decode n of Inl a ⇒ int a | Inr b ⇒ - int b - 1)"

lemma int_encode_inverse [simp]: "int_decode (int_encode x) = x"
  by (simp add: int_decode_def int_encode_def)

lemma int_decode_inverse [simp]: "int_encode (int_decode n) = n"
  unfolding int_decode_def int_encode_def
  using sum_decode_inverse [of n] by (cases "sum_decode n") simp_all

lemma inj_int_encode: "inj_on int_encode A"
  by (rule inj_on_inverseI) (rule int_encode_inverse)

lemma inj_int_decode: "inj_on int_decode A"
  by (rule inj_on_inverseI) (rule int_decode_inverse)

lemma surj_int_encode: "surj int_encode"
  by (rule surjI) (rule int_decode_inverse)

lemma surj_int_decode: "surj int_decode"
  by (rule surjI) (rule int_encode_inverse)

lemma bij_int_encode: "bij int_encode"
  by (rule bijI [OF inj_int_encode surj_int_encode])

lemma bij_int_decode: "bij int_decode"
  by (rule bijI [OF inj_int_decode surj_int_decode])

lemma int_encode_eq: "int_encode x = int_encode y ⟷ x = y"
  by (rule inj_int_encode [THEN inj_eq])

lemma int_decode_eq: "int_decode x = int_decode y ⟷ x = y"
  by (rule inj_int_decode [THEN inj_eq])


subsection ‹Type @{typ "nat list"}›

fun list_encode :: "nat list ⇒ nat"
  where
    "list_encode [] = 0"
  | "list_encode (x # xs) = Suc (prod_encode (x, list_encode xs))"

function list_decode :: "nat ⇒ nat list"
  where
    "list_decode 0 = []"
  | "list_decode (Suc n) = (case prod_decode n of (x, y) ⇒ x # list_decode y)"
  by pat_completeness auto

termination list_decode
  apply (relation "measure id")
   apply simp_all
  apply (drule arg_cong [where f="prod_encode"])
  apply (drule sym)
  apply (simp add: le_imp_less_Suc le_prod_encode_2)
  done

lemma list_encode_inverse [simp]: "list_decode (list_encode x) = x"
  by (induct x rule: list_encode.induct) simp_all

lemma list_decode_inverse [simp]: "list_encode (list_decode n) = n"
  apply (induct n rule: list_decode.induct)
   apply simp
  apply (simp split: prod.split)
  apply (simp add: prod_decode_eq [symmetric])
  done

lemma inj_list_encode: "inj_on list_encode A"
  by (rule inj_on_inverseI) (rule list_encode_inverse)

lemma inj_list_decode: "inj_on list_decode A"
  by (rule inj_on_inverseI) (rule list_decode_inverse)

lemma surj_list_encode: "surj list_encode"
  by (rule surjI) (rule list_decode_inverse)

lemma surj_list_decode: "surj list_decode"
  by (rule surjI) (rule list_encode_inverse)

lemma bij_list_encode: "bij list_encode"
  by (rule bijI [OF inj_list_encode surj_list_encode])

lemma bij_list_decode: "bij list_decode"
  by (rule bijI [OF inj_list_decode surj_list_decode])

lemma list_encode_eq: "list_encode x = list_encode y ⟷ x = y"
  by (rule inj_list_encode [THEN inj_eq])

lemma list_decode_eq: "list_decode x = list_decode y ⟷ x = y"
  by (rule inj_list_decode [THEN inj_eq])


subsection ‹Finite sets of naturals›

subsubsection ‹Preliminaries›

lemma finite_vimage_Suc_iff: "finite (Suc -` F) ⟷ finite F"
  apply (safe intro!: finite_vimageI inj_Suc)
  apply (rule finite_subset [where B="insert 0 (Suc ` Suc -` F)"])
   apply (rule subsetI)
   apply (case_tac x)
    apply simp
   apply simp
  apply (rule finite_insert [THEN iffD2])
  apply (erule finite_imageI)
  done

lemma vimage_Suc_insert_0: "Suc -` insert 0 A = Suc -` A"
  by auto

lemma vimage_Suc_insert_Suc: "Suc -` insert (Suc n) A = insert n (Suc -` A)"
  by auto

lemma div2_even_ext_nat:
  fixes x y :: nat
  assumes "x div 2 = y div 2"
    and "even x ⟷ even y"
  shows "x = y"
proof -
  from ‹even x ⟷ even y› have "x mod 2 = y mod 2"
    by (simp only: even_iff_mod_2_eq_zero) auto
  with assms have "x div 2 * 2 + x mod 2 = y div 2 * 2 + y mod 2"
    by simp
  then show ?thesis
    by simp
qed


subsubsection ‹From sets to naturals›

definition set_encode :: "nat set ⇒ nat"
  where "set_encode = sum ((^) 2)"

lemma set_encode_empty [simp]: "set_encode {} = 0"
  by (simp add: set_encode_def)

lemma set_encode_inf: "¬ finite A ⟹ set_encode A = 0"
  by (simp add: set_encode_def)

lemma set_encode_insert [simp]: "finite A ⟹ n ∉ A ⟹ set_encode (insert n A) = 2^n + set_encode A"
  by (simp add: set_encode_def)

lemma even_set_encode_iff: "finite A ⟹ even (set_encode A) ⟷ 0 ∉ A"
  by (induct set: finite) (auto simp: set_encode_def)

lemma set_encode_vimage_Suc: "set_encode (Suc -` A) = set_encode A div 2"
  apply (cases "finite A")
   apply (erule finite_induct)
    apply simp
   apply (case_tac x)
    apply (simp add: even_set_encode_iff vimage_Suc_insert_0)
   apply (simp add: finite_vimageI add.commute vimage_Suc_insert_Suc)
  apply (simp add: set_encode_def finite_vimage_Suc_iff)
  done

lemmas set_encode_div_2 = set_encode_vimage_Suc [symmetric]


subsubsection ‹From naturals to sets›

definition set_decode :: "nat ⇒ nat set"
  where "set_decode x = {n. odd (x div 2 ^ n)}"

lemma set_decode_0 [simp]: "0 ∈ set_decode x ⟷ odd x"
  by (simp add: set_decode_def)

lemma set_decode_Suc [simp]: "Suc n ∈ set_decode x ⟷ n ∈ set_decode (x div 2)"
  by (simp add: set_decode_def div_mult2_eq)

lemma set_decode_zero [simp]: "set_decode 0 = {}"
  by (simp add: set_decode_def)

lemma set_decode_div_2: "set_decode (x div 2) = Suc -` set_decode x"
  by auto

lemma set_decode_plus_power_2:
  "n ∉ set_decode z ⟹ set_decode (2 ^ n + z) = insert n (set_decode z)"
proof (induct n arbitrary: z)
  case 0
  show ?case
  proof (rule set_eqI)
    show "q ∈ set_decode (2 ^ 0 + z) ⟷ q ∈ insert 0 (set_decode z)" for q
      by (induct q) (use 0 in simp_all)
  qed
next
  case (Suc n)
  show ?case
  proof (rule set_eqI)
    show "q ∈ set_decode (2 ^ Suc n + z) ⟷ q ∈ insert (Suc n) (set_decode z)" for q
      by (induct q) (use Suc in simp_all)
  qed
qed

lemma finite_set_decode [simp]: "finite (set_decode n)"
  apply (induct n rule: nat_less_induct)
  apply (case_tac "n = 0")
   apply simp
  apply (drule_tac x="n div 2" in spec)
  apply simp
  apply (simp add: set_decode_div_2)
  apply (simp add: finite_vimage_Suc_iff)
  done


subsubsection ‹Proof of isomorphism›

lemma set_decode_inverse [simp]: "set_encode (set_decode n) = n"
  apply (induct n rule: nat_less_induct)
  apply (case_tac "n = 0")
   apply simp
  apply (drule_tac x="n div 2" in spec)
  apply simp
  apply (simp add: set_decode_div_2 set_encode_vimage_Suc)
  apply (erule div2_even_ext_nat)
  apply (simp add: even_set_encode_iff)
  done

lemma set_encode_inverse [simp]: "finite A ⟹ set_decode (set_encode A) = A"
  apply (erule finite_induct)
   apply simp_all
  apply (simp add: set_decode_plus_power_2)
  done

lemma inj_on_set_encode: "inj_on set_encode (Collect finite)"
  by (rule inj_on_inverseI [where g = "set_decode"]) simp

lemma set_encode_eq: "finite A ⟹ finite B ⟹ set_encode A = set_encode B ⟷ A = B"
  by (rule iffI) (simp_all add: inj_onD [OF inj_on_set_encode])

lemma subset_decode_imp_le:
  assumes "set_decode m ⊆ set_decode n"
  shows "m ≤ n"
proof -
  have "n = m + set_encode (set_decode n - set_decode m)"
  proof -
    obtain A B where
      "m = set_encode A" "finite A"
      "n = set_encode B" "finite B"
      by (metis finite_set_decode set_decode_inverse)
  with assms show ?thesis
    by auto (simp add: set_encode_def add.commute sum.subset_diff)
  qed
  then show ?thesis
    by (metis le_add1)
qed

end