{- |
Module      : ValueCriterion
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable

This module provides the implementation of
the basic value criterion projecting to one
argument.
-}
module Analysis.Termination.ValueCriterion where

import Analysis.Termination.DependencyPairs (DPProblem (..), dpSymbols)
import Analysis.Termination.Termination (
  SNInfo (VcProjection),
  cutoffZeroSExpr,
  getCompAndZeroOfSort,
 )
import Control.Monad.IO.Class (liftIO)
import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.FIdentifier (
  FId,
  getFIdInSort,
 )
import Data.LCTRS.Guard (collapseGuardToTerm)
import Data.LCTRS.Rule (
  Rule,
  guard,
  lhs,
  lvar,
  rhs,
 )
import Data.LCTRS.Sort (Sorted (sort), sortAnnotation)
import Data.LCTRS.Term (
  Term (..),
  isLogicTerm,
 )
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (
  VId,
  freshV,
 )
import Data.List (sortBy, subsequences, (\\))
import qualified Data.Map.Strict as M
import Data.Monad (MonadFresh, StateM, freshInt, getSolver)
import Data.SExpr (ToSExpr (..))
import Data.SMT (boolSort, intSort)
import qualified Data.Set as S
import qualified Data.Set as Set
import Prettyprinter (Pretty)
import Rewriting.SMT (satSExprsResults, validSExpr)
import SimpleSMT (SExpr)
import qualified SimpleSMT as S
import qualified SimpleSMT as SMT

valueCriterionRP
  :: (Ord v, ToSExpr f, Pretty f, Ord f, ToSExpr v, Pretty v)
  => DPProblem (FId f) (VId v)
  -> StateM (DPProblem (FId f) (VId v), Maybe (SNInfo (FId f) (VId v)))
valueCriterionRP dpp = do
  (mprec, newDPP) <- valueCriterion dpp
  return (newDPP, VcProjection . (: []) <$> mprec)

valueCriterion
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v, Pretty f, Pretty v)
  => DPProblem (FId f) (VId v)
  -> StateM (Maybe ([Rule (FId f) (VId v)], [(FId f, Int)]), DPProblem (FId f) (VId v))
valueCriterion dpp@DPProblem{..} = do
  solver <- getSolver
  -- NOTE: new context to define the SMT function max
  _ <- liftIO $ SMT.push solver
  -- maxName <- getAndDefineMax solver
  -- let maxS sexpr1 sexpr2 = SMT.fun maxName [sexpr1, sexpr2]
  pSetVars <- constrainedSetVars dps
  -- NOTE: we demand that integers are available in the SMT solver
  (prjVars, prjConstraints) <- constrainedPrjVars dps
  encoding <- S.andMany <$> mapM (encodePossProj pSetVars prjVars) projections
  let allVars = Set.fromList (M.elems pSetVars ++ M.elems prjVars) <> foldMap lvar dps
  let oneTrue = foldr (\v sexpr -> sexpr `S.or` toSExpr v) (S.const "false") $ M.elems pSetVars
  (mbool, assignments) <- satSExprsResults allVars [prjConstraints, encoding, oneTrue]
  -- NOTE: the implementation below tries to orient as much as possible
  --       but is currently too inefficient
  -- (mbool, assignments) <- iterPossSols pSetVars allVars [prjConstraints, encoding]
  liftIO $ SMT.pop solver
  case mbool of
    Nothing -> return (Nothing, dpp)
    Just False -> return (Nothing, dpp)
    Just True ->
      let (p2, prjs) = divideSet dpp pSetVars prjVars assignments
      in  return ((strictrules \\ p2,) <$> prjs, DPProblem p2 weakrules)
 where
  projections = map (\dp -> (dp,) $ possibleProjections dp) dps
  dps = strictrules

iterPossSols
  :: (Ord v, ToSExpr v, Pretty v)
  => M.Map k (VId v)
  -> Set.Set (VId v)
  -> [SExpr]
  -> StateM (Maybe Bool, [(VId v, SMT.Value)])
iterPossSols setVarM allVars sexprs = go setVars
 where
  setVars = satCombinations $ M.elems setVarM
  satCombinations = init . sortBy (\a b -> length b `compare` length a) . subsequences
  go [] = return (Nothing, [])
  go (p : ps) = do
    let psTrue = S.andMany $ map toSExpr p
    (mbool, assignments) <- satSExprsResults allVars $ psTrue : sexprs
    case mbool of
      Just True -> return (mbool, assignments)
      _ -> go ps

divideSet
  :: (Ord f, Ord v)
  => DPProblem (FId f) (VId v)
  -> M.Map (Rule (FId f) (VId v)) (VId v)
  -> M.Map (FId f) (VId v)
  -> [(VId v, SMT.Value)]
  -> ([Rule (FId f) (VId v)], Maybe [(FId f, Int)])
divideSet DPProblem{..} setMap prjMap assignments =
  (p2, Just prjs)
 where
  prjs = map (\dpS -> (dpS, assignedInt $ prjMap M.! dpS)) dpSyms
  p2 = filter (\dp -> not $ assignedBool (setMap M.! dp)) strictrules

  dpSyms = nubOrd $ foldMap dpSymbols strictrules

  assM = M.fromList assignments
  assignedInt v =
    case assM M.! v of
      SMT.Int i -> fromInteger i :: Int
      _ -> error "ValueCriterion.hs: non-integer assigned in projection."
  assignedBool v =
    case assM M.! v of
      SMT.Bool b -> b
      _ -> error "ValueCriterion.hs: no bool assigned to set splitting."

encodePossProj
  :: (Ord v, Ord f, ToSExpr v, Pretty f, Pretty v, ToSExpr f)
  => M.Map (Rule (FId f) (VId v)) (VId v)
  -> M.Map (FId f) (VId v)
  -> (Rule (FId f) (VId v), [((FId f, Int), (FId f, Int))])
  -> StateM SMT.SExpr
encodePossProj setMap prjMap (dp, prjs) =
  S.orMany <$> mapM project prjs
 where
  l = lhs dp
  r = rhs dp
  phi = collapseGuardToTerm $ guard dp
  lvars = lvar dp

  setVar = toSExpr . (setMap M.!)
  prjVar = toSExpr . (prjMap M.!)

  validGt phi pL pR
    | isLogicTerm pL
        && isLogicTerm pR
        && S.fromList (T.vars pL) `S.isSubsetOf` lvars
        && sort pL == sort pR =
        do
          let
            sa = sortAnnotation [sort pL, sort pR] boolSort
            one = T.imp phi (T.grT sa pL pR)
          -- validSExpr (nubOrd $ T.vars one) $ SMT.and (toSExpr one) (SMT.geq (toSExpr pL) (SMT.const "0"))
          -- validSExpr (nubOrd $ T.vars one) (toSExpr one)
          case (cutoffZeroSExpr pR, getCompAndZeroOfSort (sort pR)) of
            (Nothing, _) -> return Nothing
            (_, Nothing) -> return Nothing
            (Just positivePR, Just (gt, _, _)) ->
              validSExpr (S.fromList $ T.vars one) $
                SMT.implies (toSExpr phi) (gt (toSExpr pL) positivePR)
    | otherwise = return Nothing

  validGeq phi pL pR
    | isLogicTerm pL
        && isLogicTerm pR
        && S.fromList (T.vars pL) `S.isSubsetOf` lvars
        && sort pL == sort pR = do
        let
          sa = sortAnnotation [sort pL, sort pR] boolSort
          two = T.imp phi (T.eq sa pL pR) -- FIXME: what about geq here?
        validSExpr (S.fromList $ T.vars two) $ toSExpr two
  validGeq _ (Var v1) (Var v2)
    | v1 == v2 = return $ Just True
  validGeq _ _ _ = return Nothing

  project ((f, pf), (g, pg)) = do
    b1 <- validGt phi pL pR
    b2 <- validGeq phi pL pR
    case (b1, b2) of
      (Just True, Just True)
        | S.fromList (T.vars pL) `S.isSubsetOf` lvars ->
            return $ isProj `S.and` (isOne `S.xor` isNotOne)
        | otherwise -> return $ isProj `S.and` isNotOne
      (Just True, _)
        | S.fromList (T.vars pL) `S.isSubsetOf` lvars ->
            return $ isProj `S.and` isOne
        | otherwise -> return isNotProj
      (_, Just True) -> return $ isProj `S.and` isNotOne
      _ -> return isNotProj
   where
    isProj =
      SMT.eq (SMT.const $ show pf) (prjVar f)
        `S.and` SMT.eq (SMT.const $ show pg) (prjVar g)
    isNotProj =
      S.andMany
        [ SMT.not $ SMT.eq (SMT.const $ show pf) (prjVar f)
        , SMT.not $ SMT.eq (SMT.const $ show pg) (prjVar g)
        , isNotOne
        ]
    isOne = setVar dp
    isNotOne = S.not $ setVar dp
    pL = projectTerm l pf
    pR = projectTerm r pg

constrainedSetVars
  :: (ToSExpr v, Ord v, Ord f, Pretty v, MonadFresh m)
  => [Rule (FId f) (VId v)]
  -> m (M.Map (Rule (FId f) (VId v)) (VId v))
constrainedSetVars dps = do
  pSetVars <- mapM (\dp -> (dp,) . freshV boolSort <$> freshInt) dps
  return $ M.fromList pSetVars

constrainedPrjVars
  :: (Ord f, ToSExpr v, Pretty v, MonadFresh m)
  => [Rule (FId f) (VId v)]
  -> m (M.Map (FId f) (VId v), SMT.SExpr)
constrainedPrjVars dps = do
  let allDPSyms = nubOrd $ foldMap dpSymbols dps
  prjVars <-
    mapM (\dpSym -> (dpSym,) . freshV intSort <$> freshInt) allDPSyms
  let prjConstraints =
        foldr
          ( \(f, v) sexpr ->
              let ar = length $ getFIdInSort f
              in  SMT.andMany [sexpr, SMT.leq (SMT.const "0") (toSExpr v), SMT.lt (toSExpr v) (SMT.const $ show ar)]
          )
          (S.const "true")
          prjVars
  return (M.fromList prjVars, prjConstraints)

possibleProjections :: (Ord v, Eq f) => Rule (FId f) (VId v) -> [((FId f, Int), (FId f, Int))]
possibleProjections rule = projections l r
 where
  l = lhs rule
  r = rhs rule

  projections s t =
    if T.getRoot s == T.getRoot t
      then filter validProj [(lP, lP) | lP <- lPs]
      else filter validProj [(lP, rP) | lP <- lPs, rP <- rPs]
   where
    lPs = allProjections s
    rPs = allProjections t

  validProj ((_, prjf), (_, prjg)) =
    sort prjl == sort prjr
      && (not (isLogicTerm prjl) || isLogicTerm prjr)
      && S.fromList (T.vars prjr) `S.isSubsetOf` S.fromList (T.vars prjl)
      -- NOTE: we restrict projections here to logic terms; maybe in the reflexive case
      --       we could also project to others, especially in the reduction pair
      && isLogicTerm prjl
      && isLogicTerm prjr
      && sort prjl == intSort -- FIXME fix int for now
   where
    prjl = projectTerm l prjf
    prjr = projectTerm r prjg

  allProjections (Var _) = []
  allProjections (Fun _ f ss) = map (f,) [0 .. length ss - 1]

projectTerm :: Term f v -> Int -> Term f v
projectTerm (Var _) _ = error "ValueCriterion.hs: projection not possible for variable."
projectTerm (Fun _ _ ss) i = case safeAccess ss i of
  Nothing -> error "ValueCriterion.hs: projection out of bounds for function symbol."
  Just arg -> arg
 where
  safeAccess _ n | n < 0 = Nothing
  safeAccess [] _ = Nothing
  safeAccess (s : _) 0 = Just s
  safeAccess (_ : ss) n = safeAccess ss (pred n)
