module SMT(
  SmtVar(..),
  IAExpr(..),
  Formula(Gt, BoolVar, Le, Equiv), true, false, disj, conj, atomic,
  SmtType(..),
  SmtStmt(..),
  showSmt2) where

-- SMT formulas for linear integer arithmetic

newtype SmtVar = SmtVar Int deriving Eq

instance Show SmtVar where
  show (SmtVar x) = "x" ++ show x

data IAExpr = IAVar SmtVar | IAConst Int | IASum [IAExpr] | IAMult Int IAExpr 
  deriving (Eq)
  
instance Show IAExpr where
  show (IAVar x) = show x
  show (IAConst x) 
    | x < 0 = "(- " ++ show (- x) ++ ")"
    | otherwise = show x
  show (IAMult x e) = "(* " ++ show (IAConst x) ++ " " ++ show e ++ ")"
  show (IASum []) = "0"
  show (IASum es) = "(+" ++ concatMap (\ e -> " " ++ show e) es ++ ")"

data Formula = 
    Gt IAExpr IAExpr 
  | Le IAExpr IAExpr 
  | BoolVar SmtVar
  | Conj [Formula]
  | Disj [Formula]
  | Equiv Formula Formula
  deriving (Eq)
  
instance Show Formula where
  show (Gt a b) = "(> " ++ show a ++ " " ++ show b ++ ")"
  show (Le a b) = "(<= " ++ show a ++ " " ++ show b ++ ")"
  show (BoolVar b) = show b
  show (Equiv f g) = "(= " ++ show f ++ " " ++ show g ++ ")"
  show (Conj []) = "true"
  show (Conj fs) = "(and" ++ foldr (\ f s -> " " ++ show f ++ s) ")" fs
  show (Disj []) = "false"
  show (Disj fs) = "(or" ++ foldr (\ f s -> " " ++ show f ++ s) ")" fs

atomic :: Formula -> Bool
atomic (Disj []) = True
atomic (Conj []) = True
atomic (BoolVar _) = True
atomic (Gt _ _) = True
atomic (Le _ _) = True
atomic _ = False

-- smart constructors for the formula type
false = Disj []
true = Conj []

disj phis = let phis' = filter (/= false) phis
  in if true `elem` phis' then true else
    case phis' of 
      [phi] -> phi
      _ -> Disj phis'

conj phis = let phis' = filter (/= true) phis
  in if false `elem` phis' then false else
    case phis' of 
      [phi] -> phi
      _ -> Conj phis'

-- towards statements of the SMT-Lib format

data SmtType = SmtBool | SmtInt

instance Show SmtType where
  show SmtBool = "Bool"
  show SmtInt = "Int"

data SmtStmt = 
   SmtVarDecl SmtType SmtVar
 | SmtAssert Formula

instance Show SmtStmt where
  show (SmtVarDecl ty x) = "(declare-fun " ++ show x ++ " () " ++ show ty ++ ")"
  show (SmtAssert phi) = "(assert " ++ show phi ++ ")"

showSmt2 :: [SmtStmt] -> String
showSmt2 stmts = unlines $ header ++ map show stmts ++ footer 
  where header = ["(set-logic QF_LIA)"]
        footer = ["(check-sat)"]
