module CoreFP (
  Exp(..),
  Type(..),
  TSub,
  Env,
  primitives,
  tvars,
  tcomp,
  tsub,
  tint,
  tbool,
  tpair,
  tlist,
  fromString,
  (~>)
  ) where
import Prelude hiding (abs, exp, lex)
import Data.List
import Data.Maybe
import Parse hiding (token)
import qualified Parse

type Id = String

{- Expressions -}
data Exp = Var Id | App Exp Exp | Abs Id Exp -- lambda terms
         | Con Id                            -- constants
         | Let Id Exp Exp                    -- let bindings
         | Ite Exp Exp Exp                   -- conditional branching
  deriving Eq

{- Print Expressions -}
par s   = "(" ++ s ++ ")"
s `app` t = s ++ " " ++ t

showExp (Var x)                  =  x
showExp (Con c)                  =  c
showExp (App s (t @ (App _ _)))  =  showExp s `app` par (showExp t)
showExp (App s t)                =  showExp s `app` showExp t
showExp (Let x s t)              =
  "let " ++ x ++ " = " ++ showExp s ++ " in " ++ showExp t
showExp (Ite s t u)              =
  "if " ++ showExp s ++ " then " ++ showExp t ++ " else " ++ showExp u
showExp t                        =  par ("\\" ++ showLambdas t)

showLambdas (Abs x (t @ (Abs _ _)))  =  x ++ " " ++ showLambdas t
showLambdas (Abs x t)                =  x ++ "." ++ showExp t

instance Show Exp
  where
    show = showExp

{- Types -}
data Type = TVar Int
          | TCon Id [Type]
  deriving Eq

type Env = [(String, Type)]

s `arr` t = s ++ " -> " ++ t

showType (TVar x)                            =  "a" ++ show x
showType (TCon "Fun" [s@(TCon "Fun" _), t])  =  par (showType s) `arr` showType t
showType (TCon "Fun" [s, t])                 =  showType s `arr` showType t
showType (TCon g [])                         =  g
showType (TCon g ts)                         =  g ++ "(" ++ showArgs ts ++ ")"

showArgs []      =  ""
showArgs [t]     =  showType t
showArgs (t:ts)  =  showType t ++ "," ++ showArgs ts

instance Show Type
  where
    show = showType

tint, tbool :: Type
tint   =  TCon "Int" []
tbool  =  TCon "Bool" []

tlist :: Type -> Type
tlist t    = TCon "List" [t]

infixr 8 ~>

tpair, (~>) :: Type -> Type -> Type
tpair s t  = TCon "Pair" [s, t]

s ~> t = TCon "Fun" [s, t]

-- assumption: involved type variables are non-negative
primitives :: Env
primitives = [
  ("True", tbool),
  ("False", tbool),
  ("not", tbool ~> tbool),
  ("<", tint ~> tint ~> tbool), 
  (">", tint ~> tint ~> tbool), 
  ("==", tint ~> tint ~> tbool), 
  ("*", tint ~> tint ~> tint), 
  ("+", tint ~> tint ~> tint),
  ("/", tint ~> tint ~> tint), 
  ("-", tint ~> tint ~> tint), 
  ("0", tint),
  ("1", tint),
  ("2", tint),
  ("3", tint),
  ("Pair", a0 ~> a1 ~> tpair a0 a1),
  ("fst", tpair a0 a1 ~> a0),
  ("snd", tpair a0 a1 ~> a1),
  ("Nil", tlist a0),
  ("Cons", a0 ~> tlist a0 ~> tlist a0),
  ("head", tlist a0 ~> a0),
  ("tail", tlist a0 ~> tlist a0),
  ("Y", (a0 ~> a0) ~> a0)]
  where
    a0 = TVar 0
    a1 = TVar 1

tvars :: Type -> [Int]
tvars (TVar x)    = [x]
tvars (TCon g ts) = foldr (union . tvars) [] ts

type TSub = [(Int, Type)]

tsub :: TSub -> Type -> Type
tsub s (x@(TVar i)) = fromMaybe x $ lookup i s
tsub s (TCon g ts)  = TCon g (map (tsub s) ts)

tcomp :: TSub -> TSub -> TSub
s1 `tcomp` s2 =
  map (\(x, t) -> (x, tsub s2 t)) s1 ++ filter ((`notElem` dom1) . fst) s2
  where
    dom1 = map fst s1

{- Lexing -}
data Token = ID String | KWD String | LP | RP | DOT | LAMBDA

lexIdOrKey, lexLpar, lexRpar, lexDot, lexLam, token :: Parser Char Token

lexIdOrKey = do
  id <- many1 (noneof "\\.()\n\t\r ")
  if isKey id then return (KWD id)
              else return (ID id)
    where
      isKey id = id `elem` keywords
      keywords = ["if", "then", "else", "let", "in", "="]

lexLpar  =  char '('  >> return LP
lexRpar  =  char ')'  >> return RP
lexDot   =  char '.'  >> return DOT
lexLam   =  char '\\' >> return LAMBDA

lex p = do
  x <- p
  spaces
  return x

token = lex (lexLpar <|> lexRpar <|> lexDot <|> lexLam <|> lexIdOrKey)

tokenize :: Parser Char [Token]
tokenize = do
  spaces
  ts <- many token
  eoi
  return ts

{- Parser -}
kwd s = Parse.token(\t ->
  case t of KWD k | k == s -> Just k
            _ -> Nothing)

isPrim :: String -> Bool
isPrim x = x `elem` map fst primitives

var = Parse.token(\t ->
  case t of ID x | not (isPrim x) -> Just x
            _ -> Nothing)

con = Parse.token(\t ->
  case t of ID x | isPrim x -> Just x
            _ -> Nothing)

lpar = Parse.token(\t -> case t of {LP -> Just(); _ -> Nothing})
rpar = Parse.token(\t -> case t of {RP -> Just(); _ -> Nothing})
dot  = Parse.token(\t -> case t of {DOT -> Just(); _ -> Nothing})
lam  = Parse.token(\t -> case t of {LAMBDA -> Just(); _ -> Nothing})

ite = do { kwd "if"; e1 <- exp; kwd "then"; e2 <- exp; kwd "else"; e3 <- exp;
           return (Ite e1 e2 e3) }

letin = do { kwd "let"; x <- var; kwd "="; e1 <- exp; kwd "in"; e2 <- exp;
             return (Let x e1 e2) }

abs = do { lam; xs <- many1 var; dot; e <- exp; return (mkAbs xs e) }
  where
    mkAbs []     e = e
    mkAbs (x:xs) e = Abs x (mkAbs xs e)

exp = expression >>= exp'

expression :: Parser Token Exp
expression =
  (var >>= return . Var)  <|>
  (con >>= return . Con)  <|>
  (between lpar rpar exp) <|>
  ite                     <|>
  letin                   <|>
  abs

exp' e1 = (do {e2 <- expression; exp' (App e1 e2)}) <|> return e1

fromString :: String -> Exp
fromString s =
  case parse tokenize s of
    Nothing -> error "lexing error"
    Just ts -> case parse (do { e <- exp; eoi; return e }) ts of
                 Nothing -> error "parse error"
                 Just e  -> e