{- |
Module      : Monad
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides the implementation and auxiliary functions for the monads
FreshM and StateM. FreshM is a state monad transformer providing fresh integers.
StateM is a monad transformer providing error handling and a reader part with
important data and builds on top of FreshM.
-}
module Data.Monad where

import Control.Exception (SomeException (..), catch)
import Control.Monad.Error.Class (MonadError)
import Control.Monad.Except (
  ExceptT,
  runExceptT,
 )
import qualified Control.Monad.Except as CME
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader.Class (MonadReader, asks)
import Control.Monad.State.Class (MonadState)
import Control.Monad.State.Strict (
  StateT,
  evalStateT,
  get,
  modify,
 )
import Control.Monad.Trans (MonadTrans (lift))
import Control.Monad.Trans.Reader (
  ReaderT,
  runReaderT,
 )
import Control.Monad.Writer.Strict (WriterT)
import SimpleSMT (
  SExpr,
  Solver,
  defineFun,
  newLogger,
  newSolver,
  setLogic,
  setOption,
 )

data Reference = Reference
  { debug :: Bool
  , solver :: Solver
  }

-- distinguish non-severe and severe error
data Error = MayError String | SevError String

data ExtSolver = Z3 | CVC5 | Yices

data Logic = QF_UF | NIA | NRA | AUFNIRA | QF_BV | ALL
  deriving (Eq, Show)

instance Semigroup Logic where
  l1 <> l2 | l1 == l2 = l1
  QF_UF <> l = l
  l <> QF_UF = l
  NIA <> NRA = AUFNIRA
  NRA <> NIA = AUFNIRA
  AUFNIRA <> NIA = AUFNIRA
  AUFNIRA <> NRA = AUFNIRA
  NIA <> AUFNIRA = AUFNIRA
  NRA <> AUFNIRA = AUFNIRA
  _ <> _ = ALL

instance Monoid Logic where
  mempty = QF_UF

----------------------------------------------------------------------------------------------------
-- Monad for Freshness
----------------------------------------------------------------------------------------------------

-- type FreshM = StateT Int IO

-- execFreshM :: Int -> FreshM a -> IO a
-- execFreshM i c = evalStateT c i

newtype FreshM a = FreshM
  {runFresh :: StateT Int IO a}
  deriving
    ( Functor
    , Applicative
    , Monad
    , MonadIO
    , MonadState Int
    )

freshI :: (MonadState s m, Enum s) => m s
freshI = do
  i <- get
  modify succ
  return i

class (Monad m) => MonadFresh m where
  freshInt :: m Int

instance MonadFresh FreshM where
  freshInt = freshI

instance (MonadFresh m) => MonadFresh (ReaderT a m) where
  freshInt = lift freshInt

instance (MonadFresh m) => MonadFresh (StateT a m) where
  freshInt = lift freshInt

instance (MonadFresh m) => MonadFresh (ExceptT a m) where
  freshInt = lift freshInt

instance (MonadFresh m, Monoid a) => MonadFresh (WriterT a m) where
  freshInt = lift freshInt

execFreshM :: Int -> FreshM a -> IO a
execFreshM i c = evalStateT (runFresh c) i

----------------------------------------------------------------------------------------------------
-- Monad Stack for analysis
----------------------------------------------------------------------------------------------------

-- type StateM = ExceptT Error (ReaderT Reference FreshM)

newtype StateM a = StateM
  { runStateM
      :: ExceptT Error (ReaderT Reference FreshM) a
  }
  deriving
    ( Functor
    , Applicative
    , Monad
    , MonadIO
    , MonadFresh
    , MonadState Int
    , MonadReader Reference
    , MonadError Error
    )

evalComp
  :: Solver
  -> [(String, [(String, SExpr)], SExpr, SExpr)]
  -> Bool
  -> StateM a
  -> FreshM (Either Error a)
evalComp solver defines b s = do
  liftIO $ mapM_ (evaluateDefine solver) defines
  -- NOTE: the logic was removed from the ARI input format
  -- liftIO $ setLogic solver $ maybe "ALL" show l
  flip runReaderT (monadReference b solver) $ runExceptT (runStateM s)

evaluateDefine :: Solver -> (String, [(String, SExpr)], SExpr, SExpr) -> IO SExpr
evaluateDefine solver (n, args, r, t) =
  defineFun solver n args r t
    `catch` ( \(SomeException s) ->
                error $
                  "The SMT solver does not accept the define-fun for "
                    <> n
                    <> ".\n\
                       \Various issues can trigger that:\n\
                       \* not well-formed\n\
                       \* uses undefined functions (possibly from other theories)\n\
                       \* ...\n"
                    <> "The SMT solver exception is the following:\n"
                    <> show s
            )

evalStateM
  :: Solver
  -> [(String, [(String, SExpr)], SExpr, SExpr)]
  -> Bool
  -> Int
  -> StateM a
  -> IO (Either Error a)
evalStateM solver defines b i =
  execFreshM i . evalComp solver defines b

-- liftIO $ mapM_ (\(n, args, r, t) -> defineFun solver n args r t) defines
-- liftIO $ setLogic solver $ maybe "ALL" show l
-- flip evalStateT i $ flip runReaderT (monadReference l b solver) $ runExceptT s

chooseSolver :: Int -> Bool -> ExtSolver -> IO Solver
chooseSolver _ useLogger Z3 = do
  logger <- if useLogger then Just <$> newLogger 0 else return Nothing
  solver <- newSolver "z3" ["-smt2", "-in"] logger
  _ <- setOption solver ":smtlib2_compliant" "true"
  -- _ <- setLogic solver $ show logic
  return solver
chooseSolver _ useLogger CVC5 = do
  logger <- if useLogger then Just <$> newLogger 0 else return Nothing
  solver <- newSolver "cvc5" ["--incremental"] logger
  -- _ <- setLogic solver $ show logic
  return solver
chooseSolver timeout useLogger Yices = do
  logger <- if useLogger then Just <$> newLogger 0 else return Nothing
  -- we need to impose a timeout as yices sometimes ignores the async exception thrown by timeout
  solver <-
    newSolver
      "yices-smt2"
      ["--incremental", "--timeout=" ++ show (timeout + 1)]
      logger
  -- _ <- setLogic solver $ show logic
  _ <- setLogic solver "ALL"
  return solver

monadReference :: Bool -> Solver -> Reference
monadReference = Reference

getSolver :: (MonadReader Reference m) => m Solver
getSolver = asks solver

throwError :: (CME.MonadError e m) => e -> m a
throwError = CME.throwError

mError :: String -> StateM a
mError = throwError . MayError

sError :: String -> StateM a
sError = throwError . SevError
