{- |
Module      : Subterm Criterion
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides an implementation of the
subterm criterion.
-}
module Analysis.Termination.SubtermCriterion (subtermCriterion, subtermCriterionRP) where

import Analysis.Termination.DependencyPairs (DPProblem (..), dpSymbols)
import Analysis.Termination.Termination (
  SNInfo (ScProjection),
 )
import Control.Monad.IO.Class (liftIO)
import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.FIdentifier (FId, getFIdInSort)
import Data.LCTRS.Rule (Rule (..), lvar)
import Data.LCTRS.Term (Term (..))
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (VId, freshV)
import Data.List (sortBy, subsequences, (\\))
import qualified Data.Map.Strict as M
import Data.Monad (MonadFresh, StateM, freshInt, getSolver)
import Data.SExpr (ToSExpr (..))
import Data.SMT (boolSort, intSort)
import qualified Data.Set as Set
import Rewriting.SMT (satSExprsResults)
import SimpleSMT (SExpr)
import qualified SimpleSMT as SMT

type Projection f = (f, Int)

subtermCriterionRP
  :: (ToSExpr v, Ord v, Ord f, ToSExpr f)
  => DPProblem (FId f) (VId v)
  -> StateM (DPProblem (FId f) (VId v), Maybe (SNInfo (FId f) (VId v)))
subtermCriterionRP dpp = do
  (projections, newDPP) <- subtermCriterion dpp
  return (newDPP, ScProjection dpp . (: []) <$> projections)

subtermCriterion
  :: (Ord v, ToSExpr v, Ord f, ToSExpr f)
  => DPProblem (FId f) (VId v)
  -> StateM (Maybe ([Rule (FId f) (VId v)], [Projection (FId f)]), DPProblem (FId f) (VId v))
subtermCriterion dpp@DPProblem{..} = do
  solver <- getSolver
  _ <- liftIO $ SMT.push solver
  pSetVars <- constrainedSetVars dps
  -- NOTE: we demand that integers are available in the SMT solver
  (prjVars, prjVarsConstraints) <- constrainedPrjVars dps
  let encoding = SMT.andMany $ map (encodePossProj pSetVars prjVars) dps
  let allVars = Set.fromList (M.elems pSetVars ++ M.elems prjVars) <> foldMap lvar dps
  let oneTrue = foldr (\v sexpr -> sexpr `SMT.or` toSExpr v) (SMT.const "false") $ M.elems pSetVars
  (mbool, assignments) <- satSExprsResults allVars [prjVarsConstraints, encoding, oneTrue]
  -- NOTE: the implementation below tries to orient as much as possible
  --       but is currently too inefficient
  -- (mbool, assignments) <- iterPossSols pSetVars allVars [prjConstraints, encoding]
  _ <- liftIO $ SMT.pop solver
  case mbool of
    Nothing -> return (Nothing, dpp)
    Just False -> return (Nothing, dpp)
    Just True ->
      let (strictOriented, prjs, remainingDPP) = divideSet dpp pSetVars prjVars assignments
      in  return (Just (strictOriented, prjs), remainingDPP)
 where
  dps = strictrules

_iterPossSols
  :: (ToSExpr v, Ord v)
  => M.Map k (VId v)
  -> Set.Set (VId v)
  -> [SExpr]
  -> StateM (Maybe Bool, [(VId v, SMT.Value)])
_iterPossSols setVarM allVars sexprs = go setVars
 where
  setVars = satCombinations $ M.elems setVarM
  satCombinations = init . sortBy (\a b -> length b `compare` length a) . subsequences
  go [] = return (Nothing, [])
  go (p : ps) = do
    let psTrue = SMT.andMany $ map toSExpr p
    (mbool, assignments) <- satSExprsResults allVars $ psTrue : sexprs
    case mbool of
      Just True -> return (mbool, assignments)
      _ -> go ps

divideSet
  :: (Ord f, Ord v)
  => DPProblem (FId f) (VId v)
  -> M.Map (Rule (FId f) (VId v)) (VId v)
  -> M.Map (FId f) (VId v)
  -> [(VId v, SMT.Value)]
  -> ([Rule (FId f) (VId v)], [(FId f, Int)], DPProblem (FId f) (VId v))
divideSet DPProblem{..} setMap prjMap assignments =
  (strictlyOriented, prjs, remDPP)
 where
  prjs = map (\dpS -> (dpS, assignedInt $ prjMap M.! dpS)) dpSyms
  strictlyOriented = filter (\dp -> assignedBool (setMap M.! dp)) strictrules
  remDPP = DPProblem (strictrules \\ strictlyOriented) weakrules

  dpSyms = nubOrd $ foldMap dpSymbols strictrules

  assM = M.fromList assignments
  assignedInt v =
    case assM M.! v of
      SMT.Int i -> fromInteger i :: Int
      _ -> error "ValueCriterion.hs: non-integer assigned in projection."
  assignedBool v =
    case assM M.! v of
      SMT.Bool b -> b
      _ -> error "ValueCriterion.hs: no bool assigned to set splitting."

encodePossProj
  :: (ToSExpr v, Ord f, Ord v, ToSExpr f)
  => M.Map (Rule (FId f) (VId v)) (VId v)
  -> M.Map (FId f) (VId v)
  -> Rule (FId f) (VId v)
  -- -> Term (FId f) (VId v)
  -- -> Term (FId f) (VId v)
  -> SMT.SExpr
encodePossProj setMap prjMap dp = go (lhs dp) (rhs dp)
 where
  go (Fun _ f ss) (Fun _ g ts) =
    SMT.orMany
      [ SMT.andMany [setProj f i, setProj g j, (t `isSubtermOf` s) `SMT.or` (t `isProperSubtermOf` s)]
      | (s, i) <- zip ss [0 ..]
      , (t, j) <- zip ts [0 ..]
      ]
  go _ _ = error "SubtermCriterion.hs: cannot project non function symbol position."

  isSubtermOf t1 t2
    | t1 `Set.member` Set.fromList (T.subterms t2) = toSExpr (setMap M.! dp) `SMT.eq` SMT.const "false"
    | otherwise = SMT.const "false"
  isProperSubtermOf t1 t2
    | t1 `Set.member` Set.fromList (T.properSubterms t2) =
        toSExpr (setMap M.! dp) `SMT.eq` SMT.const "true"
    | otherwise = SMT.const "false"
  setProj sym prj = toSExpr (prjMap M.! sym) `SMT.eq` SMT.int prj

constrainedSetVars
  :: (Ord v, Ord f, MonadFresh m)
  => [Rule (FId f) (VId v)]
  -> m (M.Map (Rule (FId f) (VId v)) (VId v))
constrainedSetVars dps = do
  pSetVars <- mapM (\dp -> (dp,) . freshV boolSort <$> freshInt) dps
  return $ M.fromList pSetVars

constrainedPrjVars
  :: (MonadFresh m, Ord f, ToSExpr v)
  => [Rule (FId f) (VId v)]
  -> m (M.Map (FId f) (VId v), SMT.SExpr)
constrainedPrjVars dps = do
  let allDPSyms = nubOrd $ foldMap dpSymbols dps
  prjVars <-
    mapM (\dpSym -> (dpSym,) . freshV intSort <$> freshInt) allDPSyms
  let prjConstraints =
        SMT.andMany $
          [ SMT.andMany [SMT.leq (SMT.const "0") (toSExpr v), SMT.lt (toSExpr v) (SMT.const $ show ar)]
          | (f, v) <- prjVars
          , let ar = length $ getFIdInSort f
          ]
  return $ (M.fromList prjVars, prjConstraints)
