module Arith where
import Parser
import qualified ArithLexer as Lex

data Exp = Nat Integer
         | Neg Exp
         | Add Exp Exp
         | Sub Exp Exp
         | Mul Exp Exp
         | Div Exp Exp
  deriving Show

type ExpP = Parser Lex.Token Exp

nat  = token (\t ->
  case t of Lex.Number i -> Just (Nat i)
            _            -> Nothing)

justIf :: (a -> Bool) -> a -> Maybe ()
justIf p x = if p x then Just ()
                    else Nothing

lpar  = token (justIf (== Lex.Lpar))
rpar  = token (justIf (== Lex.Rpar))
plus  = token (justIf (== Lex.Plus))
minus = token (justIf (== Lex.Minus))
star  = token (justIf (== Lex.Star))
slash = token (justIf (== Lex.Slash))

expr :: ExpP
expr = term >>= expr'
  where
    expr' s = add <|> sub <|> return s
      where
        add = plus  >> term >>= expr' . Add s
        sub = minus >> term >>= expr' . Sub s

term :: ExpP
term = factor >>= term'
  where
    term' f = mul <|> div <|> return f
      where
        mul = star  >> factor >>= term' . Mul f
        div = slash >> factor >>= term' . Div f

factor :: ExpP
factor = nat <|> par <|> neg
  where
    par = between lpar rpar expr
    neg = minus >> factor >>= return . Neg

fromString :: String -> Maybe Exp
fromString s = case parse Lex.tokenize s of
  Just ts -> parse (expr `followedBy` eoi) ts
  _       -> Nothing