-- |
-- = Sets as Binary Search Trees
module Set (Set, empty, insert, mem, union, diff, toList) where
import Data.List (intercalate)
import BTree (BTree(..)) -- want to use type without prefix
import qualified BTree

-- | Sets are represented by binary trees, where we assume the invariant that
-- the underlying tree is a binary search tree.
newtype Set a = Set { rep :: BTree a }

-- | Obtain the list representation of a binary search tree (which is a sorted
-- list).
toList :: Set a -> [a]
toList = BTree.flatten . rep

instance (Ord a, Show a) => Show (Set a) where
  show s = "{" ++ intercalate "," (elts s) ++ "}" 
    where elts = map show . toList

-- | The empty set.
empty :: Set a
empty = Set Empty

-- | Checking for set membership.
mem :: Ord a => a -> Set a -> Bool
mem x s = x `memTree` rep s
  where
    memTree x Empty = False
    memTree x (Node y l r) =
      case compare x y of
        EQ -> True
        LT -> x `memTree` l
        GT -> x `memTree` r

-- | Inserting single elements into 'Set's.
insert :: Ord a => a -> Set a -> Set a
insert x s = Set $ insertTree x $ rep s

insertTree :: Ord a => a -> BTree a -> BTree a
insertTree x Empty        = Node x Empty Empty
insertTree x (Node y l r) =
  case compare x y of
    EQ -> Node y l r
    LT -> Node y (insertTree x l) r
    GT -> Node y l (insertTree x r)

-- | The union of two 'Set's.
union :: Ord a => Set a -> Set a -> Set a
union s t = Set $ rep s `unionTree` rep t

unionTree :: Ord a => BTree a -> BTree a -> BTree a
unionTree Empty s        = s
unionTree (Node x l r) s =
  insertTree x $ l `unionTree` r `unionTree` s

splitMaxFromTree :: BTree a -> Maybe (a, BTree a)
splitMaxFromTree Empty            = Nothing
splitMaxFromTree (Node x l Empty) = Just (x, l)
splitMaxFromTree (Node x l r)     =
  let Just (m, r') = splitMaxFromTree r
  in Just (m, Node x l r')

removeFromTree :: Ord a => a -> BTree a -> BTree a
removeFromTree x Empty        = Empty
removeFromTree x (Node y l r) = case compare x y of
  LT -> Node y (removeFromTree x l) r
  GT -> Node y l (removeFromTree x r)
  EQ -> case splitMaxFromTree l of
    Nothing      -> r
    Just (m, l') -> Node m l' r

-- | The difference between two 'Set's.
diff :: Ord a => Set a -> Set a -> Set a
diff s t = Set $ rep s `diffTree` rep t

diffTree :: Ord a => BTree a -> BTree a -> BTree a
diffTree t Empty        = t
diffTree t (Node x l r) = removeFromTree x t `diffTree` l `diffTree` r

{- Sets as lists
import qualified Data.List as List

data Set a = Set [a]

instance (Ord a, Show a) => Show (Set a) where
  show (Set xs) = "{" ++ List.intercalate "," es ++ "}"
    where es = map show (List.sort xs)

empty :: Set a
empty = Set []

insert :: Eq a => a -> Set a -> Set a
insert x (Set xs) = Set $ List.nub $ x : xs

mem :: Eq a => a -> Set a -> Bool
x `mem` Set xs = x `elem` xs

union, diff :: Eq a => Set a -> Set a -> Set a
union (Set xs) (Set ys) = Set $ List.nub $ xs ++ ys
diff (Set xs) (Set ys) = Set $ xs List.\\ ys
-}