module Demo09_SMT(
  SmtVar,
  IAExpr(..),
  Formula(Gt, BoolVar, Le, Equiv), true, false, disj, conj, atomic,
  SmtType(..),
  SmtEncoder, 
  runSmtEncoder,
  getNewSmtVariable,
  assertFormula,
  smtAnswerFromHandle,
  smtAnswerParser,
  smtRequestValues) where

import Control.Monad.Writer
import Control.Monad.State
import Data.Monoid
import Data.Bifunctor

import Data.Char(isAlphaNum)
import Text.ParserCombinators.Parsec
import System.IO(hGetContents, Handle)





newtype SmtVar = SmtVar Int deriving Eq

instance Show SmtVar where
  show (SmtVar x) = "x" ++ show x

data IAExpr = IAVar SmtVar | IAConst Int 
  deriving (Eq)
  
instance Show IAExpr where
  show (IAVar x) = show x
  show (IAConst x) 
    | x < 0 = "(- " ++ show (- x) ++ ")"
    | otherwise = show x

data Formula = 
    Gt IAExpr IAExpr 
  | Le IAExpr IAExpr 
  | BoolVar SmtVar
  | Conj [Formula]
  | Disj [Formula]
  | Equiv Formula Formula
  deriving (Eq)
  
instance Show Formula where
  show (Gt a b) = "(> " ++ show a ++ " " ++ show b ++ ")"
  show (Le a b) = "(<= " ++ show a ++ " " ++ show b ++ ")"
  show (BoolVar b) = show b
  show (Equiv f g) = "(= " ++ show f ++ " " ++ show g ++ ")"
  show (Conj []) = "true"
  show (Conj fs) = "(and" ++ foldr (\ f s -> " " ++ show f ++ s) ")" fs
  show (Disj []) = "false"
  show (Disj fs) = "(or" ++ foldr (\ f s -> " " ++ show f ++ s) ")" fs

atomic (Disj []) = True
atomic (Conj []) = True
atomic (BoolVar _) = True
atomic (Gt _ _) = True
atomic (Le _ _) = True
atomic _ = False

false = Disj []
true = Conj []

disj phis = let phis' = filter (/= false) phis
  in if true `elem` phis' then true else
    case phis' of 
      [phi] -> phi
      _ -> Disj phis'

conj phis = let phis' = filter (/= true) phis
  in if false `elem` phis' then false else
    case phis' of 
      [phi] -> phi
      _ -> Conj phis'


--- The encoder 

data SmtStmt = 
   SmtVarDecl SmtType SmtVar
 | SmtAssert Formula

instance Show SmtStmt where
  show (SmtVarDecl ty x) = "(declare-fun " ++ show x ++ " () " ++ show ty ++ ")"
  show (SmtAssert phi) = "(assert " ++ show phi ++ ")"

data SmtState = SmtState {
    nextFreshVar :: Int
  }

data SmtType = SmtBool | SmtInt

instance Show SmtType where
  show SmtBool = "Bool"
  show SmtInt = "Int"

-- hide monad transformer stack in newtype, using GeneralizedNewtypeDeriving
newtype SmtEncoder a = SmtEncoder (StateT SmtState (Writer (Endo [SmtStmt])) a)
  deriving (Functor, Applicative, Monad, 
     MonadState SmtState, MonadWriter (Endo [SmtStmt]))


tellStmt :: MonadWriter (Endo [SmtStmt]) m => SmtStmt -> m ()
tellStmt x = tell $ Endo (x :)

assertFormula :: Formula -> SmtEncoder ()
assertFormula phi = tellStmt (SmtAssert phi)

runSmtEncoder :: SmtEncoder a -> (a, String)
runSmtEncoder (SmtEncoder app) = bimap fst (showSmt2 . flip appEndo [])
  $ runWriter 
  $ flip runStateT (SmtState 1)
  $ app

showSmt2 :: [SmtStmt] -> String
showSmt2 stmts = unlines $ header ++ map show stmts ++ footer 
  where header = ["(set-logic QF_LIA)"]
        footer = ["(check-sat)"]

getNewSmtVariable :: SmtType -> SmtEncoder SmtVar
getNewSmtVariable ty = do
  s <- get
  let x = nextFreshVar s
  put (s {nextFreshVar = x + 1})
  tellStmt $ SmtVarDecl ty (SmtVar x)
  return $ SmtVar x
  
  
smtAnswerParser :: Parser [(String, Integer)]
smtAnswerParser = spaces *> string "(" *> spaces *> many pair <* string ")" 
  where
    pair = (,) <$> (string "(" *> spaces *> identifier) <*> num <* string ")" <* spaces
    num = (read :: String -> Integer) <$> many1 digit <* spaces
    identChar = satisfy isAlphaNum
    identifier = many1 identChar <* spaces

smtAnswerFromHandle :: Handle -> IO [(String, Integer)]
smtAnswerFromHandle h = do
  inp <- hGetContents h
  case parse smtAnswerParser "" inp of
     Left e -> error $ show e
     Right res -> return res
  
smtRequestValues :: [SmtVar] -> String
smtRequestValues vars = "(get-value (" ++ unwords (map show vars) ++ "))"

