{-# OPTIONS_GHC -Wno-incomplete-patterns #-}

{- |
Module      : ConstrainedReductionOrder
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable

This module provides the implementation of
the constrained reduction order.
-}
module Analysis.Termination.ConstrainedReductionOrder where

import Analysis.Termination.DependencyPairs (DPProblem (..))
import Analysis.Termination.Termination (
  Precedence (Prec),
  SNInfo (RpoPrecedence),
  SNResult (MaybeTerminating, Terminating),
  cutoffZeroSExpr,
  getCompAndZeroOfSort,
 )
import Control.Monad (zipWithM)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.FIdentifier (
  FId,
 )
import Data.LCTRS.Guard (collapseGuardToTerm)
import Data.LCTRS.Rule (
  Rule,
  guard,
  lhs,
  rhs,
 )
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort (Sorted (sort), sortAnnotation)
import Data.LCTRS.Term (
  Term (..),
  isLogicTerm,
  isVar,
  pattern TermFun,
  pattern TheoryFun,
 )
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (
  VId,
  freshV,
 )
import qualified Data.Map.Strict as M
import Data.Monad (StateM, freshI, getSolver)
import Data.SExpr (ToSExpr (toSExpr))
import Data.SMT (boolSort, intSort)
import qualified Data.Set as S
import Prettyprinter (Pretty)
import Rewriting.SMT (satSExprsResults, validSExpr)
import qualified SimpleSMT as SMT

termFunsOfRule :: Rule f v -> [f]
termFunsOfRule R.Rule{..} = T.termFuns lhs <> T.termFuns rhs

rpoDP
  :: (Ord v, Eq f, Pretty f, Ord f, ToSExpr v, Pretty v, ToSExpr f)
  => DPProblem (FId f) (VId v)
  -> StateM (SNResult, Maybe (Precedence (FId f)))
rpoDP DPProblem{..} = do
  solver <- getSolver
  -- NOTE: new context to define the SMT function max
  _ <- liftIO $ SMT.push solver
  -- maxName <- getAndDefineMax solver
  -- let maxS sexpr1 sexpr2 = SMT.fun maxName [sexpr1, sexpr2]
  -- NOTE: this assumes that integers are available in the SMT solver
  --       currently we do not restrict the theory/logic
  precM <- M.fromList <$> mapM (\f -> (f,) . freshV intSort <$> freshI) fs
  let naturalPrec =
        foldr (SMT.and . (\v -> SMT.geq (toSExpr v) (SMT.const "0"))) top precM
  wencoding <-
    SMT.andMany
      <$> mapM
        (\rule -> rpoWeakTerm precM (phi' rule) (lhs rule) (rhs rule))
        weakrules
  sencoding <-
    SMT.andMany
      <$> mapM
        (\rule -> rpoStrictTerm precM (phi' rule) (lhs rule) (rhs rule))
        strictrules
  let allVars = S.fromList (M.elems precM) <> foldMap R.lvar (weakrules ++ strictrules)
  res <- encodingToResult precM allVars $ SMT.andMany [naturalPrec, wencoding, sencoding]
  _ <- liftIO $ SMT.pop solver
  return res
 where
  fs = nubOrd $ foldMap termFunsOfRule $ strictrules ++ weakrules
  phi' rule = T.conj (R.dummyConstraints rule) (collapseGuardToTerm $ guard rule)

rpoRP
  :: (Ord v, Eq f, Pretty f, Ord f, ToSExpr v, Pretty v, ToSExpr f)
  => DPProblem (FId f) (VId v)
  -> StateM (DPProblem (FId f) (VId v), Maybe (SNInfo (FId f) (VId v)))
rpoRP dpp@DPProblem{..} = do
  solver <- getSolver
  -- NOTE: new context to define the SMT function max
  _ <- liftIO $ SMT.push solver
  -- maxName <- getAndDefineMax solver
  -- let maxS sexpr1 sexpr2 = SMT.fun maxName [sexpr1, sexpr2]
  -- NOTE: we assume a precedence in the natural to allow infinite signatures although
  --       they are currently not allowed in the ARI format
  precM <- M.fromList <$> mapM (\f -> (f,) . freshV intSort <$> freshI) fs
  let naturalPrec =
        foldr (SMT.and . (\v -> SMT.geq (toSExpr v) (SMT.const "0"))) top precM
  wencoding <-
    SMT.andMany
      <$> mapM
        (\rule -> rpoWeakTerm precM (phi' rule) (lhs rule) (rhs rule))
        (weakrules ++ strictrules)
  isStrict <- M.fromList <$> mapM (\f -> (f,) . freshV boolSort <$> freshI) strictrules
  sencoding <-
    SMT.orMany
      <$> mapM
        ( \rule -> do
            sexpr <- rpoStrictTerm precM (phi' rule) (lhs rule) (rhs rule)
            return $ sexpr `SMT.and` (sexpr `SMT.eq` toSExpr (isStrict M.! rule))
        )
        strictrules
  let allVars =
        S.fromList (M.elems precM)
          <> S.fromList (M.elems isStrict)
          <> foldMap R.lvar (weakrules ++ strictrules)
  let sexpr = SMT.andMany [naturalPrec, wencoding, sencoding]
  (mbool, assigns) <- satSExprsResults allVars [sexpr]
  liftIO $ SMT.pop solver
  let assignM = M.fromList assigns
  case mbool of
    Just True -> do
      let prec = extractPrecedence precM assigns
      let newstrict = [dp | dp <- strictrules, assignM M.! (isStrict M.! dp) == SMT.Bool False]
      return (DPProblem newstrict weakrules, Just $ RpoPrecedence prec)
    _ -> return (dpp, Nothing)
 where
  fs = nubOrd $ foldMap termFunsOfRule $ strictrules ++ weakrules
  phi' rule = T.conj (R.dummyConstraints rule) (collapseGuardToTerm $ guard rule)

-- extractPrec precM assignM =
--   RpoPrecedence $
--     concat
--       [ [ (f, g)
--         | g <- fs
--         , let j = assignM M.! (precM M.! g)
--         , g /= f
--         , getInt j < getInt i
--         ]
--       | f <- M.keys precM
--       , let i = assignM M.! (precM M.! f)
--       ]
-- getInt value =
--   case value of
--     SMT.Int i -> i
--     _ -> error "ConstrainedReductionOrder.hs: integer value assigned to wrong type."

rpoWeak
  :: (Ord v, Ord f, Pretty f, ToSExpr v, Pretty v, ToSExpr f)
  => [Rule (FId f) (VId v)]
  -> StateM (SNResult, Maybe (Precedence (FId f)))
rpoWeak rules = do
  solver <- getSolver
  -- NOTE: new context to define the SMT function max
  _ <- liftIO $ SMT.push solver
  -- maxName <- getAndDefineMax solver
  -- let maxS sexpr1 sexpr2 = SMT.fun maxName [sexpr1, sexpr2]
  precM <- M.fromList <$> mapM (\f -> (f,) . freshV intSort <$> freshI) fs
  let naturalPrec =
        foldr (SMT.and . (\v -> SMT.geq (toSExpr v) (SMT.const "0"))) top precM
  encoding <-
    SMT.andMany
      <$> mapM
        ( \rule -> do
            rpoWeakTerm
              precM
              (phi' rule)
              (lhs rule)
              (rhs rule)
        )
        rules
  let allVars = S.fromList (M.elems precM) <> foldMap R.lvar rules
  res <- encodingToResult precM allVars $ naturalPrec `SMT.and` encoding
  liftIO $ SMT.pop solver
  return res
 where
  fs = nubOrd $ foldMap termFunsOfRule rules
  phi' rule = T.conj (R.dummyConstraints rule) (collapseGuardToTerm $ guard rule)

rpoStrict
  :: (Ord v, Pretty f, ToSExpr v, Ord f, Pretty v, ToSExpr f)
  => [Rule (FId f) (VId v)]
  -> StateM (SNResult, Maybe (Precedence (FId f)))
rpoStrict rules = do
  solver <- getSolver
  -- NOTE: new context to define the SMT function max
  _ <- liftIO $ SMT.push solver
  -- maxName <- getAndDefineMax solver
  -- let maxS sexpr1 sexpr2 = SMT.fun maxName [sexpr1, sexpr2]
  precM <- M.fromList <$> mapM (\f -> (f,) . freshV intSort <$> freshI) fs
  let naturalPrec =
        foldr (SMT.and . (\v -> SMT.geq (toSExpr v) (SMT.const "0"))) top precM
  encoding <-
    SMT.andMany
      <$> mapM
        ( \rule ->
            rpoStrictTerm
              precM
              (phi' rule)
              (lhs rule)
              (rhs rule)
        )
        rules
  let allVars = S.fromList (M.elems precM) <> foldMap R.lvar rules
  res <- encodingToResult precM allVars $ naturalPrec `SMT.and` encoding
  liftIO $ SMT.pop solver
  return res
 where
  fs = nubOrd $ foldMap termFunsOfRule rules
  phi' rule = T.conj (R.dummyConstraints rule) (collapseGuardToTerm $ guard rule)

extractPrecedence :: (Ord k, Ord f) => M.Map f k -> [(k, SMT.Value)] -> Precedence f
extractPrecedence precM assigns =
  Prec $ M.elems $ M.fromListWith (++) $ map (\f -> (valOf f, [f])) fs
 where
  fs = M.keys precM
  assignM = M.fromList assigns
  valOf = extractInt . (assignM M.!) . (precM M.!)
  extractInt value =
    case value of
      SMT.Int i -> i
      _ -> error "ConstrainedReductionOrder.hs: integer value assigned to wrong type."

encodingToResult
  :: (ToSExpr v, Ord v, Ord f, Pretty v)
  => M.Map (FId f) (VId v)
  -> S.Set (VId v)
  -> SMT.SExpr
  -> StateM (SNResult, Maybe (Precedence (FId f)))
encodingToResult precM allVars sexpr = do
  (mbool, assigns) <- satSExprsResults allVars [sexpr]
  case mbool of
    Just True -> return (Terminating, Just $ extractPrecedence precM assigns)
    _ -> return (MaybeTerminating, Nothing)

rpoWeakTerm
  :: (Ord v, Eq f, Pretty f, Pretty v, ToSExpr v, Ord f, ToSExpr f)
  => M.Map (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> StateM SMT.SExpr
rpoWeakTerm precM phi s t = encode s t
 where
  encode s t
    | isLogicTerm s
        && isLogicTerm t
        && S.fromList (T.vars s) `S.isSubsetOf` S.fromList (T.vars phi)
        && S.fromList (T.vars t) `S.isSubsetOf` S.fromList (T.vars phi)
        && sort s == sort t = do
        let sa = sortAnnotation [sort s, sort t] boolSort
        let implication =
              if sort s == boolSort
                then
                  T.imp phi (T.neg $ T.conj (T.neg $ T.eq sa s t) (T.neg $ T.neg $ T.imp s t))
                else T.imp phi (T.neg $ T.conj (T.neg $ T.eq sa s t) (T.neg $ T.grT sa s t))
        let vars = S.fromList $ T.vars implication
        mbool <- validSExpr vars (toSExpr implication)
        case mbool of
          Just True -> (top `SMT.or`) <$> case3 s t
          _ -> case3 s t
  encode (TermFun f ss) (TermFun g ts)
    | f == g =
        SMT.or
          <$> (SMT.andMany <$> zipWithM (rpoWeakTerm precM phi) ss ts)
          <*> case3 s t
  encode s t = case3 s t

  case3 s t
    | s == t && isVar s = return top
    | otherwise = rpoStrictTerm precM phi s t

rpoStrictTerm
  :: (Pretty f, ToSExpr v, Ord v, Eq f, Ord f, Pretty v, ToSExpr f)
  => M.Map (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> StateM SMT.SExpr
rpoStrictTerm precM phi s t = encode s t
 where
  encode s t
    | isLogicTerm s
        && isLogicTerm t
        && S.fromList (T.vars s) `S.isSubsetOf` S.fromList (T.vars phi)
        && S.fromList (T.vars t) `S.isSubsetOf` S.fromList (T.vars phi)
        && sort s == sort t
        && sort s == intSort -- FIXME restricted to int sort
      =
        do
          let sa = sortAnnotation [sort s, sort t] boolSort
          let implication = T.imp phi (T.grT sa s t)
          let vars = S.fromList $ T.vars implication
          -- mbool <- validSExpr vars (SMT.and (toSExpr implication) (SMT.geq (toSExpr s) (SMT.const "0")))
          -- mbool <- validSExpr vars (toSExpr implication)
          case (cutoffZeroSExpr t, getCompAndZeroOfSort (sort t)) of
            (Nothing, _) -> return bot
            (_, Nothing) -> return bot
            (Just nonZeroT, Just (gt, _, _)) -> do
              mbool <-
                validSExpr vars $ SMT.implies (toSExpr phi) (gt (toSExpr s) nonZeroT)
              case mbool of
                Just True -> return top
                _ -> return bot
  -- encode s@(Fun f ss) t
  -- \| not (isLogic f) = do
  encode s@(TermFun f ss) t = do
    siGTt <- SMT.orMany <$> zipWithM (rpoWeakTerm precM phi) ss (repeat t)
    case2or3or4 <- do
      case t of
        -- (Fun g ts) | f /= g && isLogic g -> SMT.andMany <$> mapM (rpoStrictTerm maxS precM phi s) ts
        (TheoryFun g ts) | f /= g -> SMT.andMany <$> mapM (rpoStrictTerm precM phi s) ts
        -- (Fun g ts)
        --   | f /= g && not (isLogic g) -> do
        (TermFun g ts)
          | f /= g -> do
              argsPrecs <- SMT.andMany <$> mapM (rpoStrictTerm precM phi s) ts
              let localPrec = toSExpr (precM M.! f) `SMT.gt` toSExpr (precM M.! g)
              return $ localPrec `SMT.and` argsPrecs
        (Fun _ g ts) | f == g -> do
          weak <- SMT.andMany <$> zipWithM (rpoWeakTerm precM phi) ss ts
          strict <- SMT.orMany <$> zipWithM (rpoStrictTerm precM phi) ss ts
          return $ weak `SMT.and` strict
        (Var v) | v `S.member` S.fromList (T.vars phi) -> return top
        _ -> return bot
    return $ siGTt `SMT.or` case2or3or4
  -- \| otherwise = return bot
  encode _ _ = return bot

top, bot :: SMT.SExpr
top = SMT.const "true"
bot = SMT.const "false"
