{-# LANGUAGE ScopedTypeVariables #-}

{- |
Module      : SMT
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable

This module provides the connection to the SMT solver and
allows checking guards for satisfiability and validity.
-}
module Rewriting.SMT (
  satGuard,
  validGuard,
  smtResultCheck,
  satSExpr,
  satSExprsResults,
  satSExprsResultsTactic,
  validSExpr,
  satSExprSF,
  validSExprSF,
  declareVariables,
  -- , checkValidityGuard
) where

import qualified Control.Exception as X
import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.Guard (createGuard)
import qualified Data.LCTRS.Guard as G
import Data.LCTRS.Sort (Sorted (sort))
import qualified Data.LCTRS.Term as T
import Data.Monad
import Data.SExpr
import qualified Data.Set as S
import SimpleSMT (
  MissingResponse (..),
  Result (..),
  SExpr (..),
  Value,
  assert,
  check,
  declare,
  getExpr,
  inNewScope,
  showsSExpr,
 )
import qualified SimpleSMT as SMT
import Prelude hiding (const)

----------------------------------------------------------------------------------------------------
-- checking satisfiablity/validity of guards
----------------------------------------------------------------------------------------------------

satGuard
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v, Sorted v)
  => G.Guard (FId f) v
  -> StateM (Maybe Bool)
-- satGuard g | all T.isLogicTerm g = do -- at the moment only allow logic terms
satGuard g = do
  solver <- getSolver
  let gTerm = G.collapseGuardToTerm g
  case G.tryEvalGuard g of
    Just b -> return $ Just b
    Nothing ->
      do
        let sexpr = toSExpr gTerm
        ret <-
          liftIO $
            ( inNewScope solver $ do
                mapM_ (declareVariable solver) vs
                assert solver sexpr
                res <- check solver
                case res of
                  Sat -> return $ Right $ Just True
                  Unsat -> return $ Right $ Just False
                  Unknown -> return $ Right Nothing
            )
              `X.catch` ( \(_ :: MissingResponse) ->
                            return $
                              Left "No response from SMT-solver in simple-smt. (maybe solver was asynchronously killed outside)"
                        )
        case ret of Left s -> mError s; Right r -> return r
 where
  vs = nubOrd $ G.varsGuard g

validGuard
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v, Sorted v)
  => G.Guard (FId f) v
  -> StateM (Maybe Bool)
validGuard g =
  case G.tryEvalGuard g of
    Just jr -> return $ Just jr
    Nothing -> do
      -- mError "SMT solving: could not deduce validity of guard."
      sat <- satGuard $ createGuard [T.neg $ G.collapseGuardToTerm g]
      return $ not <$> sat

smtResultCheck :: Maybe Bool -> StateM Bool
smtResultCheck Nothing =
  mError "Could not deduce satisfiability/validity of guard."
smtResultCheck (Just b) = return b

----------------------------------------------------------------------------------------------------
-- checking satisfiablity/validity of S-expressions
----------------------------------------------------------------------------------------------------

satSExpr
  :: (Ord v, ToSExpr v, Sorted v) => S.Set v -> SExpr -> StateM (Maybe Bool)
satSExpr vars sexpr = do
  solver <- getSolver
  liftIO $ inNewScope solver $ do
    declareVariables solver vars
    assert solver sexpr
    res <- check solver
    case res of
      Sat -> return $ Just True
      Unsat -> return $ Just False
      Unknown -> return Nothing

satSExprsResults
  :: (Ord v, ToSExpr v, Sorted v) => S.Set v -> [SExpr] -> StateM (Maybe Bool, [(v, Value)])
satSExprsResults vars sexprs = satSExprsResultsTactic vars check sexprs

satSExprsResultsTactic
  :: (Ord v, ToSExpr v, Sorted v)
  => S.Set v
  -> (SMT.Solver -> IO SMT.Result)
  -> [SExpr]
  -> StateM (Maybe Bool, [(v, Value)])
satSExprsResultsTactic vars check sexprs = do
  solver <- getSolver
  liftIO $ inNewScope solver $ do
    declareVariables solver vars
    mapM_ (assert solver) sexprs
    res <- check solver
    case res of
      Sat -> do
        assignments <- mapM (\v -> (v,) <$> getExpr solver (toSExpr v)) $ S.toList vars
        return (Just True, assignments)
      Unsat -> return (Just False, [])
      Unknown -> return (Nothing, [])

validSExpr
  :: (ToSExpr v, Ord v, Sorted v)
  => S.Set v
  -> SExpr
  -> StateM (Maybe Bool)
validSExpr vars sexpr = do
  sat <- satSExpr vars (SMT.not sexpr)
  return $ not <$> sat

satSExprSF
  :: (ToSExpr v, Ord v, Sorted v) => S.Set v -> SExpr -> StateM (Maybe Bool)
satSExprSF vars sexpr = do
  solver <- getSolver
  liftIO $ inNewScope solver $ do
    declareVariables solver vars
    assert solver sexpr
    res <- check solver
    case res of
      Sat -> return $ Just True
      Unsat -> return $ Just False
      Unknown -> return Nothing

validSExprSF
  :: (ToSExpr v, Ord v, Sorted v)
  => S.Set v
  -> SExpr
  -> StateM (Maybe Bool)
validSExprSF vars sexpr = do
  sat <- satSExprSF vars (SMT.not sexpr)
  return $ not <$> sat

----------------------------------------------------------------------------------------------------
-- auxiliary functions
----------------------------------------------------------------------------------------------------

-- | The 'declareVariable' function declares the variable in the context of the given SMT solver .
declareVariable :: (ToSExpr a, Sorted a) => SMT.Solver -> a -> IO SExpr
declareVariable solver var = declare solver (showsSExpr (toSExpr var) "") (toSExpr $ sort var)

{- | The 'declareVariables' function declares the given set of variables in the context
of the given SMT solver context.
-}
declareVariables :: (ToSExpr a, Sorted a) => SMT.Solver -> S.Set a -> IO ()
declareVariables solver = mapM_ (declareVariable solver)
