{- |
Module      : PolynomialInterpretations
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : experimental


This module provides the main types and functions to show termination
by polynomial interpretations.
-}
module Analysis.Termination.PolynomialInterpretations (
  polynomialnterpretationsRP, -- not stable currently
)
where

import Analysis.Termination.DependencyPairs (DPProblem (..))
import Analysis.Termination.Termination (
  SNInfo (PolyInterpretation),
  getCompAndZeroOfSort,
 )
import Control.Monad (replicateM)
import Control.Monad.Trans (MonadIO (liftIO))
import Data.LCTRS.FIdentifier (FId, getFIdSort)
import Data.LCTRS.Guard (collapseGuardToTerm)
import Data.LCTRS.Rule (Rule)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort (Sorted (sort), inSort)
import Data.LCTRS.Term (Term (Fun, Var))
import Data.LCTRS.VIdentifier (VId, freshV)
import qualified Data.Map.Strict as M
import Data.Maybe (mapMaybe)
import Data.Monad (MonadFresh, StateM, freshInt, getSolver)
import Data.SExpr (ToSExpr (toSExpr))
import Data.SMT (boolSort, intSort)
import qualified Data.Set as S
import Prettyprinter (Pretty (..))
import Rewriting.SMT (satSExprsResults)
import SimpleSMT (SExpr)
import qualified SimpleSMT as SMT

type FInterpretation f v = M.Map f ([v], v)

polynomialnterpretationsRP
  :: (Ord v, ToSExpr f, Pretty f, Ord f, ToSExpr v, Pretty v)
  => DPProblem (FId f) (VId v)
  -> StateM (DPProblem (FId f) (VId v), Maybe (SNInfo (FId f) (VId v)))
polynomialnterpretationsRP dpp@DPProblem{..} = do
  solver <- getSolver
  -- NOTE: new context to define the SMT function max
  _ <- liftIO $ SMT.push solver
  -- _ <- liftIO $ print $ pretty dpp
  -- maxName <- getAndDefineMax solver
  -- let maxS sexpr = SMT.fun maxName [SMT.const "0", sexpr]
  finter <- createInterpretation $ S.fromList $ foldMap R.funs $ weakrules ++ strictrules
  let interVars = concatMap (\(vs, v) -> v : vs) (M.elems finter)
  -- let weakmonotonicity =
  --       map
  --         ( \(_, (cs, c)) ->
  --             SMT.andMany (map (\c' -> SMT.gt (toSExpr c') (SMT.const "0")) cs)
  --               `SMT.and` SMT.geq (toSExpr c) (SMT.const "0")
  --         )
  --         $ M.toList finter
  let wruleInters = map (\rule -> (rule, interpretRule finter rule)) weakrules
  let sruleInters = map (\rule -> (rule, interpretRule finter rule)) strictrules
  let relevantVars rule = S.fromList $ R.vars rule
  let encodeweak =
        SMT.andMany
          [ forallPolynomial (relevantVars rule) impl
          | (rule, ((sexprl, sexprr), c)) <- wruleInters
          , let impl = SMT.implies c (sexprl `SMT.eq` sexprr)
          ]
  let encodestrictweak =
        SMT.andMany
          [ forallPolynomial (relevantVars rule) impl
          | (rule, ((sexprl, sexprr), c)) <- sruleInters
          , let impl = SMT.implies c (sexprl `SMT.geq` sexprr)
          ]
  isStrict <- M.fromList <$> mapM (\rule -> (rule,) . freshV boolSort <$> freshInt) strictrules
  let encodestrict = SMT.orMany $ mapMaybe (strictEncoding relevantVars isStrict) sruleInters
  let allVars = S.fromList (interVars ++ M.elems isStrict) <> foldMap R.lvar (weakrules ++ strictrules)
  (mbool, assigns) <-
    satSExprsResults allVars [encodeweak, encodestrict, encodestrictweak]
  _ <- liftIO $ SMT.pop solver
  let assignM = M.fromList assigns
  case mbool of
    Just True -> do
      let inters = PolyInterpretation $ extractInters finter assignM
      let newstrict = [dp | dp <- strictrules, assignM M.! (isStrict M.! dp) == SMT.Bool False]
      return (DPProblem newstrict weakrules, Just inters)
    _ -> return (dpp, Nothing)

strictEncoding
  :: (Ord v, ToSExpr v, Pretty v, Ord f)
  => (Rule (FId f) (VId v) -> S.Set (VId v))
  -> M.Map (Rule (FId f) (VId v)) (VId v)
  -> (Rule (FId f) (VId v), ((SExpr, SExpr), SExpr))
  -> Maybe SExpr
strictEncoding relevantVars isStrict (rule, ((sexprl, sexprr), c)) = do
  (gt, _, zero) <- getCompAndZeroOfSort (sort $ R.rhs rule)
  let
    cutoffZero = SMT.ite (gt sexprr zero) sexprr zero
    impl = SMT.implies c (sexprl `gt` cutoffZero)
    sexpr = forallPolynomial (relevantVars rule) impl
    isStrictV = toSExpr $ isStrict M.! rule
  return $ SMT.eq sexpr isStrictV `SMT.and` isStrictV

extractInters
  :: (Ord f, Ord v)
  => FInterpretation (FId f) (VId v)
  -> M.Map (VId v) SMT.Value
  -> M.Map (FId f) ([Int], Int)
extractInters inter assignM =
  M.fromList
    [ (f, (map (\v -> getInt $ assignM M.! v) vs, getInt $ assignM M.! v))
    | (f, (vs, v)) <- M.toList inter
    ]
 where
  getInt value = case value of
    SMT.Int i -> fromInteger i
    _ -> error "PolynomialInterpretation.hs: integer constant in polynomial was assigned to non integer."

createInterpretation
  :: (Ord f, MonadFresh m) => S.Set (FId f) -> m (FInterpretation (FId f) (VId v))
createInterpretation set = M.fromList <$> go (S.toList set)
 where
  go [] = return []
  go (f : fs) = do
    inter <- createInterpretationSym f
    ((f, inter) :) <$> go fs

  createInterpretationSym f = do
    let arity = length $ inSort $ getFIdSort f
    (,)
      <$> replicateM arity (freshV intSort <$> freshInt)
      <*> (freshV intSort <$> freshInt)

interpretRule
  :: (Ord f, Eq v, ToSExpr v, Pretty f, Pretty v, ToSExpr f)
  => FInterpretation (FId f) (VId v)
  -> Rule (FId f) (VId v)
  -> ((SExpr, SExpr), SExpr)
interpretRule inter rule = ((interpretTerm inter l, interpretTerm inter r), toSExpr $ collapseGuardToTerm c)
 where
  l = R.lhs rule
  r = R.rhs rule
  c = R.guard rule

interpretTerm
  :: (Ord f, Eq v, ToSExpr v, Pretty v)
  => FInterpretation (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> SMT.SExpr
interpretTerm _ (Var v) = toSExpr v
interpretTerm inter (Fun _ f ss) =
  foldr (SMT.add . uncurry SMT.mul) (toSExpr coeff) $
    zip (map toSExpr vcoeffs) inters
 where
  inters = map (interpretTerm inter) ss
  (vcoeffs, coeff) = inter M.! f

-- getAndDefineMax :: (MonadFresh m, MonadIO m) => SMT.Solver -> m String
-- getAndDefineMax solver = do
--   i <- freshInt
--   let name = "?max_" ++ show i
--   _ <-
--     liftIO $
--       SMT.defineFun solver name [("x", SMT.tInt), ("y", SMT.tInt)] SMT.tInt $
--         SMT.ite (SMT.gt sx sy) sx sy
--   return name
--  where
--   sx = SMT.Atom "x"
--   sy = SMT.Atom "y"

forallPolynomial :: (Ord v, ToSExpr v, Pretty v) => S.Set (VId v) -> SExpr -> SExpr
forallPolynomial vs expr
  | S.null vs = expr
  | otherwise =
      SMT.List
        [SMT.Atom "forall", SMT.List (S.toList $ S.map (\v -> SMT.List [toSExpr v, SMT.tInt]) vs), expr]
