{-# LANGUAGE RankNTypes #-}

{- |
Module      : ValueCriterionPol
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : experimental


This module implements a variation of the value criterion which
does not only select an argument position of a DP symbol
but a (restricted) linear combination of all of them.
-}
module Analysis.Termination.ValueCriterionPol (
  valueCriterionPol,
  valueCriterionPolRP,
)
where

import Analysis.Termination.DependencyPairs (DPProblem (..), dpSymbols)
import Analysis.Termination.Termination (
  SNInfo (VcProjectionPol),
  getAddMulOfSort,
  getCompAndZeroOfSort,
  zeroSExprOfSort,
 )
import Control.Arrow (Arrow (first, second))
import Control.Monad (forM)
import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.FIdentifier (
  FId,
  getFIdInSort,
 )
import Data.LCTRS.Guard (collapseGuardToTerm)
import Data.LCTRS.Rule (
  Rule,
  guard,
  lhs,
  lvar,
  rhs,
 )
import Data.LCTRS.Sort (sort)
import qualified Data.LCTRS.Sort as Sort
import Data.LCTRS.Term (
  Term (..),
  isLogicTerm,
 )
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (
  VId,
  freshV,
 )
import Data.List (groupBy, partition, sortOn, uncons, (\\))
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, mapMaybe)
import Data.Monad (MonadFresh, StateM, freshInt)
import Data.SExpr (ToSExpr (..), forallSExpr)
import Data.SMT (boolSort, intSort)
import qualified Data.Set as S
import Rewriting.SMT (satSExprsResults)
import SimpleSMT (SExpr)
import qualified SimpleSMT as SMT

valueCriterionPolRP
  :: (Ord v, ToSExpr f, Ord f, ToSExpr v)
  => DPProblem (FId f) (VId v)
  -> StateM (DPProblem (FId f) (VId v), Maybe (SNInfo (FId f) (VId v)))
valueCriterionPolRP dpp = do
  (mprec, newDPP) <- valueCriterionPol dpp
  return (newDPP, VcProjectionPol . (: []) <$> mprec)

valueCriterionPol
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v)
  => DPProblem (FId f) (VId v)
  -> StateM
      (Maybe ([Rule (FId f) (VId v)], [(FId f, ([(Int, SExpr)], SExpr))]), DPProblem (FId f) (VId v))
valueCriterionPol dpp@DPProblem{..} = do
  pSetVars <- constrainedSetVars dps
  -- NOTE: we demand that integers are available in the SMT solver
  (prjVars, prjConstraints) <- constrainedPrjVars dps
  let (projections, restrictions) = unzip $ map (\dp -> (\(ps, r) -> ((dp, ps), r)) $ possibleProjections prjVars dp) dps
  encoding <- SMT.andMany <$> mapM (encodePossProj pSetVars) projections
  let projectionVars = concatMap (\(v, prjs) -> v : concatMap (\(_, (ps, v)) -> v : map snd ps) prjs) $ M.elems prjVars
  let allVars = S.fromList (M.elems pSetVars ++ projectionVars)
  let oneTrue = foldr (\v sexpr -> sexpr `SMT.or` toSExpr v) (SMT.const "false") $ M.elems pSetVars
  (mbool, assignments) <-
    satSExprsResults allVars [prjConstraints, encoding, oneTrue, SMT.andMany restrictions]
  case mbool of
    Nothing -> return (Nothing, dpp)
    Just False -> return (Nothing, dpp)
    Just True ->
      let (p2, prjs) = divideSet dpp pSetVars prjVars assignments
      in  return ((strictrules \\ p2,) <$> prjs, DPProblem p2 weakrules)
 where
  dps = strictrules

constrainedSetVars
  :: (ToSExpr v, Ord v, Ord f, MonadFresh m)
  => [Rule (FId f) (VId v)]
  -> m (M.Map (Rule (FId f) (VId v)) (VId v))
constrainedSetVars dps = do
  M.fromList <$> mapM (\dp -> (dp,) . freshV boolSort <$> freshInt) dps

constructInter :: (MonadFresh m) => [(t, Sort.Sort)] -> m ([(t, VId v1)], VId v2)
constructInter [] = error "ValueCriterionPol.hs: no arguments to project."
constructInter all@((_, s) : _) = do
  varInter <- mapM (\(i, s) -> (i,) <$> freshConstantVar s) all
  constInter <- freshConstantVar s
  return (varInter, constInter)
 where
  freshConstantVar s = freshV s <$> freshInt

constrainedPrjVars
  :: (Ord f, ToSExpr v, MonadFresh m, Ord v)
  => [Rule (FId f) (VId v)]
  -> m (M.Map (FId f) (VId v, [(Int, ([(Int, VId v)], VId v))]), SMT.SExpr)
constrainedPrjVars dps = do
  -- let isTheorySort s = s `S.member` S.map Sort.sort (foldMap lvar dps)
  let allDPSyms = nubOrd $ foldMap dpSymbols dps
  prjVars <-
    forM allDPSyms $ \dpSym -> do
      let projections =
            groupBy (\(_, s1) (_, s2) -> s1 == s2) $
              -- filter (not . isBVSort . snd) $ -- FIXME not sure if BVs work
              --   filter (isTheorySort . snd) $
              filter ((== intSort) . snd) $
                sortOn snd $
                  zip [0 ..] $
                    getFIdInSort dpSym
      prjVar <- freshConstantVar intSort
      (dpSym,) . (prjVar,) . zip [0 ..] <$> mapM constructInter projections
  let prjConstraints =
        foldr
          ( \(_, (v, prjs)) sexpr ->
              -- SMT.andMany $
              sexpr
                `SMT.and` boundedBy (length prjs) (toSExpr v)
                `SMT.and` SMT.andMany
                  ( concatMap
                      (\(_, (ps, c)) -> betweenMinus1And1 (toSExpr c) : map (betweenMinus1And1 . toSExpr . snd) ps)
                      prjs
                  )
                  -- : concatMap (map (betweenMinus1And1 . toSExpr . snd) . snd) prjs
                  -- `SMT.and` SMT.orMany (map (SMT.orMany . map (notZero . snd) . snd) prjs)
          )
          (SMT.const "true")
          prjVars
  return (M.fromList prjVars, prjConstraints)
 where
  -- notZero var =
  --   let s = sort var in
  --   case getCompAndZeroOfSort s of
  --     Nothing -> trace "A" $ SMT.const "false"
  --     Just (_,_,zero) -> SMT.not $ SMT.eq zero (toSExpr var)
  freshConstantVar s = freshV s <$> freshInt
  boundedBy possibilities intsexpr =
    (SMT.const "0" `SMT.leq` intsexpr) `SMT.and` (intsexpr `SMT.lt` SMT.const (show possibilities))

  betweenMinus1And1 intsexpr =
    (SMT.app (SMT.const "-") [SMT.const "1"] `SMT.leq` intsexpr)
      `SMT.and` (intsexpr `SMT.leq` SMT.const "1")

encodePossProj
  :: (Ord v, Ord f, ToSExpr v, ToSExpr f)
  => M.Map (Rule (FId f) (VId v)) (VId v)
  -> ( Rule (FId f) (VId v)
     , [((VId v, (Int, ([(Int, VId v)], VId v))), (VId v, (Int, ([(Int, VId v)], VId v))))]
     )
  -> StateM SMT.SExpr
encodePossProj setMap (dp, []) = return . SMT.not $ toSExpr (setMap M.! dp)
encodePossProj setMap (dp, prjs) =
  SMT.orMany <$> mapM project prjs
 where
  l = lhs dp
  r = rhs dp
  phi = collapseGuardToTerm $ guard dp
  lvars = lvar dp

  setVar = toSExpr . (setMap M.!)

  validGt phi pL pfc pR pgc = do
    (aTerm, _) <- fst <$> uncons pL
    let prjSort = Sort.sort aTerm
    (gt, geq, zero) <- getCompAndZeroOfSort prjSort
    prdL <- calculateProjection pL pfc
    prdR <- calculateProjection pR pgc
    let oriented = SMT.implies (toSExpr phi) (gt prdL prdR `SMT.and` geq prdL zero) :: SMT.SExpr
    -- positiveConst <- fmap (uncurry SMT.and) . (,) <$> b0and1 prjSort pfc <*> b0and1 prjSort pgc
    return $ forallSExpr lvars oriented
  -- return $ forallSExpr lvars oriented `SMT.and` positiveConst

  validGeq phi pL pfc pR pgc = do
    -- (aTerm,_) <- fst <$> uncons pL
    -- let prjSort = Sort.sort aTerm
    term <- SMT.geq <$> calculateProjection pL pfc <*> calculateProjection pR pgc
    let oriented = SMT.implies (toSExpr phi) term
    -- positiveConst <- fmap (uncurry SMT.and) . (,) <$> b0and1 prjSort pfc <*> b0and1 prjSort pgc
    return $ forallSExpr lvars oriented
  -- return $ forallSExpr lvars oriented `SMT.and` positiveConst

  -- b0and1 s c = do
  --   (_, _, zero) <- getCompAndZeroOfSort s
  --   one <- getOneOfSort s
  --   return $ SMT.geq (toSExpr c) zero `SMT.and` SMT.not (SMT.gt (toSExpr c) one)

  calculateProjection [] _ = Nothing
  calculateProjection prjds@(_ : _) _con = do
    prjSort <- Sort.sort . fst . fst <$> uncons prjds
    case (zeroSExprOfSort prjSort, getAddMulOfSort prjSort) of
      (Just zero, Just (add, mul)) ->
        let mulTuple term vd = mul (toSExpr term) (toSExpr vd)
        in  return $ foldr (add . uncurry mulTuple) zero prjds -- FIXME do not use constant for now
      (Nothing, _) -> Nothing
      (_, Nothing) -> Nothing

  project ((fV, (fI, (pf, pfc))), (gV, (gI, (pg, pgc))))
    | sameSort pL pR && not (null pL) && not (null pR) =
        case (b1, b2) of
          (Just sexprGT, Just sexprGEQ)
            | S.fromList (concatMap (T.vars . fst) pL) `S.isSubsetOf` lvars ->
                return $
                  isProj `SMT.and` ((isOne `SMT.and` sexprGT) `SMT.xor` (isNotOne `SMT.and` sexprGEQ))
            | otherwise -> return $ isProj `SMT.and` isNotOne `SMT.and` sexprGEQ
          (Just sexprGT, _)
            | S.fromList (concatMap (T.vars . fst) pL) `S.isSubsetOf` lvars ->
                return $
                  isProj `SMT.and` isOne `SMT.and` sexprGT
            | otherwise -> return isNotProj
          (_, Just sexprGEQ) -> return $ isProj `SMT.and` isNotOne `SMT.and` sexprGEQ
          (Nothing, _) -> return isNotProj
    | otherwise = return isNotProj
   where
    b1 = validGt phi pL pfc pR pgc
    b2 = validGeq phi pL pfc pR pgc

    isProj =
      SMT.eq (toSExpr fV) (SMT.const $ show fI)
        `SMT.and` SMT.eq (toSExpr gV) (SMT.const $ show gI)
        `SMT.and` SMT.andMany
          [ SMT.eq (toSExpr v) zero | (_, v) <- eraseL <> eraseR, let (_, _, zero) = fromJust $ getCompAndZeroOfSort (Sort.sort v)
          ]
    isNotProj =
      SMT.andMany
        [ SMT.not $ SMT.eq (toSExpr fV) (SMT.const $ show fI)
        , SMT.not $ SMT.eq (toSExpr gV) (SMT.const $ show gI)
        , isNotOne
        ]
    isOne = setVar dp
    isNotOne = SMT.not $ setVar dp

    (pL, eraseL) = partitionValidTerms $ projectTermExpr l pf
    (pR, eraseR) = partitionValidTerms $ projectTermExpr r pg

    partitionValidTerms = partition (\(t, _) -> S.fromList (T.vars t) `S.isSubsetOf` lvars)

    projectTermExpr term = map (first (projectTerm term))

    sameSort [] _ = False
    sameSort _ [] = False
    sameSort ((t1, _) : _) ((t2, _) : _) = Sort.sort t1 == Sort.sort t2

divideSet
  :: (Ord f, Ord v)
  => DPProblem (FId f) (VId v)
  -> M.Map (Rule (FId f) (VId v)) (VId v)
  -> M.Map (FId f) (VId v, [(Int, ([(Int, VId v)], VId v))])
  -> [(VId v, SMT.Value)]
  -> ([Rule (FId f) (VId v)], Maybe [(FId f, ([(Int, SMT.SExpr)], SMT.SExpr))])
divideSet DPProblem{..} setMap prjMap assignments =
  (p2, Just prjs)
 where
  prjs = flip map dpSyms $ \dpS ->
    let (pV, ps) = prjMap M.! dpS
    in  let prjIndex = assignedInt pV
        in  case lookup prjIndex ps of
              Nothing -> error "ValueCriterionPol.hs: projection index invalid."
              Just (actualPs, c) -> (dpS, (map (second assignedVal) actualPs, assignedVal c))
  p2 = filter (\dp -> not $ assignedBool (setMap M.! dp)) strictrules

  dpSyms = nubOrd $ foldMap dpSymbols strictrules

  assM = M.fromList assignments
  assignedVal v =
    case assM M.! v of
      SMT.Int i -> SMT.const $ show i
      SMT.Real r -> SMT.const $ show r
      SMT.Bits _ i -> SMT.const $ show i
      SMT.Other sexpr -> sexpr
      _ -> error "ValueCriterion.hs: weird SExpr assigned in projection (please report this bug)."
  assignedInt v =
    case assM M.! v of
      SMT.Int i -> fromInteger i :: Int
      _ -> error "ValueCriterion.hs: no integer assigned to set splitting."
  assignedBool v =
    case assM M.! v of
      SMT.Bool b -> b
      _ -> error "ValueCriterion.hs: no bool assigned to set splitting."

possibleProjections
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v)
  => M.Map (FId f) (VId v, [(Int, ([(Int, VId v)], VId v))])
  -> Rule (FId f) (VId v)
  -> ([((VId v, (Int, ([(Int, VId v)], VId v))), (VId v, (Int, ([(Int, VId v)], VId v))))], SMT.SExpr)
possibleProjections prjVars rule = second SMT.andMany $ unzip $ projections l r
 where
  l = lhs rule
  r = rhs rule

  projections s t =
    case (pS, pT) of
      (Just (pVL, lPs), Just (pVR, rPs)) ->
        if T.getRoot s == T.getRoot t
          then mapMaybe validProjs [((pVL, lP), (pVR, lP)) | lP <- lPs]
          else mapMaybe validProjs [((pVL, lP), (pVR, rP)) | lP <- lPs, rP <- rPs]
      _ -> []
   where
    pS = allProjections s
    pT = allProjections t

  validProjs ((vF, (iF, (prjf, prjfc))), (vG, (iG, (prjg, prjgc))))
    | null prjtdTerms
        || not
          ( S.fromList (concatMap (T.vars . snd) prjtdTR)
              `S.isSubsetOf` S.fromList (concatMap (T.vars . snd) prjtdTL)
          ) =
        Nothing
    | otherwise =
        let newprojection = ((vF, (iF, (map fst prjtdTL, prjfc))), (vG, (iG, (map fst prjtdTR, prjgc))))
        in  Just (newprojection, SMT.andMany (map restrict restL) `SMT.and` SMT.andMany (map restrict restR))
   where
    prjtdTermsL = map (\(ind, v) -> ((ind, v), projectTerm l ind)) prjf
    prjtdTermsR = map (\(ind, v) -> ((ind, v), projectTerm r ind)) prjg

    restrict ((_, v), t) =
      let s = sort t
      in  case zeroSExprOfSort s of
            Just zero -> SMT.eq (toSExpr v) zero
            Nothing -> error "ValueCriterionPol.hs: cannot find zero value of sort."

    (prjtdTL, restL) = partition (isLogicTerm . snd) prjtdTermsL
    (prjtdTR, restR) = partition (isLogicTerm . snd) prjtdTermsR
    prjtdTerms = prjtdTL <> prjtdTR

  -- validTerms [] = True
  -- validTerms (t : ts) = isLogicTerm t && all (\tn -> Sort.sort t == Sort.sort tn && isLogicTerm tn) ts

  allProjections (Var _) = Nothing
  allProjections (Fun _ f _) = Just $ prjVars M.! f -- map (f,) [0 .. length ss - 1]

projectTerm :: Term f v -> Int -> Term f v
projectTerm (Var _) _ = error "ValueCriterion.hs: projection not possible for variable."
projectTerm (Fun _ _ ss) i = case safeAccess ss i of
  Nothing -> error "ValueCriterion.hs: projection out of bounds for function symbol."
  Just arg -> arg
 where
  safeAccess _ n | n < 0 = Nothing
  safeAccess [] _ = Nothing
  safeAccess (s : _) 0 = Just s
  safeAccess (_ : ss) n = safeAccess ss (pred n)
