{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

{- |
Module      : MatrixInterpretationsConstraints
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : experimental


This module provides the main types and functions to show termination
by matrix interpretations incorporating constraints in the encoding.
-}
module Analysis.Termination.MatrixInterpretationsConstraints (
  matrixInterpretationsCons,
  matrixInterpretationsConsDP,
  matrixInterpretationsConsRP,
  Dimension,
  EntryUpperBound,
) where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Analysis.Termination.DependencyPairs (DPProblem (..), isDPSym)
import Analysis.Termination.MatrixInterpretations (Dimension, EntryUpperBound)
import Analysis.Termination.Termination (
  SNInfo (MatrixInterpretation),
  SNResult (MaybeTerminating, Terminating),
  correctDecrease,
  interpretTerm,
 )
import Control.Applicative (empty)
import Control.Arrow (first)
import Control.Monad (MonadPlus (mzero), filterM, replicateM, (<=<))
import Control.Monad.State (MonadIO (liftIO), StateT, lift, runStateT)
import Control.Monad.State.Class (gets, modify)
import Control.Monad.State.Strict (State, evalState, get)
import Control.Monad.Trans.Maybe (MaybeT (runMaybeT))
import Control.Monad.Union
import Data.Bifunctor (bimap)
import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.FIdentifier (FId, getArity)
import Data.LCTRS.Guard (Guard, collapseGuardToTerm, varsGuard)
import Data.LCTRS.Rule (Rule)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort (Sort, Sorted (sort), sortAnnotation)
import Data.LCTRS.Term (
  Term,
  var,
  pattern TermFun,
  pattern TheoryFun,
  pattern Val,
 )
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (VId, freshV)
import qualified Data.Map.Strict as M
import Data.Matrix (Matrix)
import qualified Data.Matrix as DM
import Data.Maybe (fromJust)
import Data.Monad (MonadFresh, StateM, freshInt, getSolver)
import Data.SExpr (ToSExpr (toSExpr))
import Data.SMT (boolSort, intSort, isBVSort, realSort)
import qualified Data.Set as S
import qualified Data.Union as DUn
import Rewriting.CriticalPair (calcRuleFromSymbol)
import Rewriting.SMT (declareVariables, satSExpr, satSExprsResultsTactic)
import SimpleSMT (SExpr)
import qualified SimpleSMT as SMT
import Type.SortLCTRS (DefaultVId)

----------------------------------------------------------------------------------------------------
-- types
----------------------------------------------------------------------------------------------------

{- | Type representing entries in the matrices. This
type wraps an 'SExpr' type.
-}
newtype IntVar = IntVar SExpr
  deriving (Eq, Ord)

instance ToSExpr IntVar where
  toSExpr (IntVar sexpr) = sexpr

instance Sorted IntVar where
  sort = const intSort

instance Num IntVar where
  (+) (IntVar s1) (IntVar s2) = IntVar $ SMT.add s1 s2
  (*) (IntVar s1) (IntVar s2) = IntVar $ SMT.mul s1 s2
  abs (IntVar s1) = IntVar $ SMT.abs s1
  signum _ = error "MatrixInterpretation.hs: signum not implemented for IntVar."
  fromInteger = intVar . SMT.const . show
  negate (IntVar s1) = IntVar $ SMT.neg s1

-- | Smart constructor for 'IntVar'.
intVar :: SExpr -> IntVar
intVar = IntVar

{- | Returns a fresh 'IntVar' provided that a 'MonadFresh'
is present.
-}
freshIntVar :: (MonadFresh fresh) => fresh IntVar
freshIntVar = intVar . toSExpr . (freshV intSort :: Int -> VId DefaultVId) <$> freshInt

-- | Matrix consisting of arithmetic 'IntVar' as entry.
type IntMatrix = Matrix IntVar

{- | Plain interpretation for a function symbol. It consists
of matrices per argument and a constant matrix.
-}
type MInterpretation f = ([IntMatrix], IntMatrix)

-- | Map which represents interpretations for different function symbols.
type InterpretationMap f = M.Map f (MInterpretation f)

{- | Interpretation resulting from interpreting terms. It specifies
a constant matrix and a 'M.Map' representing the constant part for
each variable.
-}
type Interpretation v = (M.Map v IntMatrix, IntMatrix)

----------------------------------------------------------------------------------------------------
-- interpretations for matrix
----------------------------------------------------------------------------------------------------

-- | 'freshMatrix' @rows@ @cols@ returns a square Matrix of fresh integer variables of the given dimensions.
freshMatrix :: (MonadFresh m) => Dimension -> Dimension -> m IntMatrix
freshMatrix dimensionRows dimensionCols = DM.fromLists <$> replicateM dR (replicateM dC freshIntVar)
 where
  dR = fromIntegral dimensionRows
  dC = fromIntegral dimensionCols

{- | 'freshMatrixInterpretations' @dim@ @fId@ returns a matrix interpretation for a function symbol @fId@.
Note that for DP symbols a different function should be used (see 'freshMatrixInterpretationsDP'.
-}
freshMatrixInterpretations :: (MonadFresh m) => Dimension -> FId f -> m ([IntMatrix], IntMatrix)
freshMatrixInterpretations dimension fId = (,) <$> replicateM arity (freshMatrix dimension dimension) <*> freshMatrix dimension 1
 where
  arity = getArity fId

-- | 'freshMatrixInterpretationsDP' @dim@ @fId@ returns a matrix interpretation for a DP symbol @fId@ following the given dimension.
freshMatrixInterpretationsDP :: (MonadFresh m) => Dimension -> FId f -> m ([IntMatrix], IntMatrix)
freshMatrixInterpretationsDP dimension fId = (,) <$> replicateM arity (freshMatrix 1 dimension) <*> freshMatrix 1 1
 where
  arity = getArity fId

{- | 'freshMatrixInterpretationsBasedOnFId' @dim@ @fId@ returns a matrix interpretation based on @fId@.
It either returns a matrix interpretation for a standard symbol or a different one for a DP symbol.
-}
freshMatrixInterpretationsBasedOnFId
  :: (MonadFresh m) => Dimension -> FId f -> m ([IntMatrix], IntMatrix)
freshMatrixInterpretationsBasedOnFId dim fId
  | isDPSym fId = freshMatrixInterpretationsDP dim fId
  | otherwise = freshMatrixInterpretations dim fId

-- | 'hasSameDimension' checks whether two matrices have the same dimension.
hasSameDimension :: Matrix a -> Matrix a -> Bool
hasSameDimension m1 m2 = sameRows m1 m2 && sameCols m1 m2
 where
  sameRows m1 m2 = DM.nrows m1 == DM.nrows m2
  sameCols m1 m2 = DM.ncols m1 == DM.ncols m2

-- | 'strictDecrease' @m1@ @m2@ returns an 'SExpr' which encodes a strict decrease from @m1@ to @m2@.
strictDecrease :: IntMatrix -> IntMatrix -> EncodingMatrix f v SExpr
strictDecrease m1 m2
  | hasSameDimension m1 m2 =
      case (DM.toList m1, DM.toList m2) of
        (s1 : ws1, s2 : ws2) ->
          return $ SMT.andMany $ strictDecreaseEntry s1 s2 : zipWith weakDecreaseEntry ws1 ws2
        _ -> mzero
strictDecrease _ _ = mzero -- SMT.const "false"

-- | 'weakDecrease' @m1@ @m2@ returns an 'SExpr' which encodes a weak decrease from @m1@ to @m2@.
weakDecrease :: IntMatrix -> IntMatrix -> EncodingMatrix f v SExpr
weakDecrease m1 m2
  | hasSameDimension m1 m2 =
      return $ SMT.andMany $ zipWith weakDecreaseEntry (DM.toList m1) (DM.toList m2)
weakDecrease _ _ = mzero

-- | 'strictDecreaseEntry' @e1@ @e2@ returns an SExpr which encodes a strict decrease from @e1@ to @e2@.
strictDecreaseEntry :: IntVar -> IntVar -> SExpr
strictDecreaseEntry v1 v2 = toSExpr v1 `SMT.gt` toSExpr v2

-- | 'weakDecreaseEntry' @e1@ @e2@ returns an SExpr which encodes a weak decrease from @e1@ to @e2@.
weakDecreaseEntry :: IntVar -> IntVar -> SExpr
weakDecreaseEntry v1 v2 = toSExpr v1 `SMT.geq` toSExpr v2

-- | 'positiveFirstEntry' returns an SExpr that the entry at position (1,1) is greater than 0.
positiveFirstEntry :: (ToSExpr a) => Matrix a -> SExpr
positiveFirstEntry m | DM.nrows m > 0 && DM.ncols m > 0 = toSExpr (m DM.! (1, 1)) `SMT.gt` SMT.const "0"
positiveFirstEntry _ = bot

{- | 'positiveEntry' @upperBound@ @m@ returns an SExpr that every entry in the matrix is greater than or equal to 0
and smaller equal than \(upperBound^2 - 1\) if a bound is given.
-}
positiveEntries
  :: (ToSExpr a)
  => Maybe EntryUpperBound
  -> Matrix a
  -> SExpr
positiveEntries upperB m =
  case upperB of
    Nothing -> SMT.andMany $ map (\e -> toSExpr e `SMT.geq` SMT.const "0") $ DM.toList m
    Just bound ->
      SMT.andMany
        $ map
          ( \e ->
              (toSExpr e `SMT.geq` SMT.const "0")
                `SMT.and` (SMT.int ((2 ^ (fromIntegral bound :: Int)) - 1) `SMT.geq` toSExpr e)
          )
        $ DM.toList m

-- | Constant 'SExpr' representing true.
top :: SExpr
top = SMT.const "true"

-- | Constant 'SExpr' representing false.
bot :: SExpr
bot = SMT.const "false"

----------------------------------------------------------------------------------------------------
-- interpretations for rules
----------------------------------------------------------------------------------------------------

-- | State used in the encoding mondad transformer 'EncodingMatrix'.
data MatrixCache f v = MatrixCache
  { _cDim :: !Dimension
  , upperBound :: !(Maybe EntryUpperBound)
  , cValInter :: !(M.Map (Rule f v) IntMatrix)
  , cInters :: !(InterpretationMap f)
  }

{- | Encoding matrix which caches important properties specified
in 'MatrixCache' and also if a encoding fails maybe already using
the 'Maybe' monad.
-}
type EncodingMatrix f v = MaybeT (State (MatrixCache f v))

{- | 'prepareEncoding' @rules@ @decreasingness@ prepares the encoding for all rules
with respect to the given decreasingness function. It returns a tuple for which the
first component encodes that all entries in the interpretations are natural numbers
and the second component decreasingness of the rules.
-}
prepareEncoding
  :: (Ord v, Ord f, ToSExpr v, ToSExpr f)
  => [Rule (FId f) (VId v)]
  -> ((Interpretation (VId v), Interpretation (VId v), SExpr) -> EncodingMatrix (FId f) (VId v) SExpr)
  -> EncodingMatrix (FId f) (VId v) (SExpr, [SExpr])
prepareEncoding rules decreasingness = do
  inters <- gets cInters
  valInters <- gets cValInter
  uB <- gets upperBound
  let positiveFirstEntries = SMT.andMany [SMT.andMany $ map positiveFirstEntry ms | (_, (ms, _)) <- M.toList inters]
  let allPositiveEntries =
        SMT.andMany
          [ SMT.andMany [positiveEntries uB m' | m' <- ms] `SMT.and` positiveEntries uB m
          | (_, (ms, m)) <- M.toList inters
          ]
  let positiveValEntries = SMT.andMany [positiveEntries uB m | (_, m) <- M.toList valInters]
  let positivityEncoding = SMT.andMany [positiveFirstEntries, allPositiveEntries, positiveValEntries]
  encodings <- mapM (decreasingness <=< interpretRule) rules
  return (positivityEncoding, encodings)

{- | 'prepareEncodingDP' @rules@ @decreasingness@ prepares the encoding for all rules
with respect to the given decreasingness function, however applying more efficient
relaxations in the encoding valid only within the DP framework. It returns a tuple for which the
first component encodes that all entries in the interpretations are natural numbers
and the second component decreasingness of the rules.
-}
prepareEncodingDP
  :: (Ord v, Ord f, ToSExpr v, ToSExpr f)
  => [Rule (FId f) (VId v)]
  -> ((Interpretation (VId v), Interpretation (VId v), SExpr) -> EncodingMatrix (FId f) (VId v) SExpr)
  -> EncodingMatrix (FId f) (VId v) (SExpr, [SExpr])
prepareEncodingDP rules decreasingness = do
  inters <- gets cInters
  valInters <- gets cValInter
  uB <- gets upperBound
  let allPositiveEntries =
        SMT.andMany
          [ SMT.andMany [positiveEntries uB m' | m' <- ms] `SMT.and` positiveEntries uB m
          | (_, (ms, m)) <- M.toList inters
          ]
  let positiveValEntries = SMT.andMany [positiveEntries uB m | (_, m) <- M.toList valInters]
  let positivityEncoding = SMT.andMany [allPositiveEntries, positiveValEntries]
  encodingsM <- mapM (decreasingness <=< interpretRule) rules
  return (positivityEncoding, encodingsM)

{- | This function, given calculation rules and a notion of decreasingness, returns a constraint
ensuring this. This is important as we need (weak/strict) simplicity for calculation rules and
achieve this as logical variables are interpreted as constants.
-}
simpleCalcRules
  :: (ToSExpr v, ToSExpr f, Ord f, Ord v)
  => [Rule (FId f) (VId v)]
  -> ((Interpretation (VId v), Interpretation (VId v), SExpr) -> EncodingMatrix (FId f) (VId v) SExpr)
  -> EncodingMatrix (FId f) (VId v) SExpr
simpleCalcRules calcRules decreasingness = SMT.andMany <$> mapM (decreasingness <=< interpretRule) calcRules

{- | This function constructs the correct interpretation for ranges of values
with respect to a rule. If two rules have overlapping constraints then
their values have the same interpretation. However, if an open constraint
is present then all values have the same interpretation.
-}
constructCorrectValInterMap
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v)
  => Dimension
  -> [Rule (FId f) (VId v)]
  -> StateM (M.Map (Rule (FId f) (VId v)) IntMatrix)
constructCorrectValInterMap dim rules = do
  openConstraintFound <- or <$> mapM (isOpenConstraint . R.guard) rules
  if openConstraintFound
    then do
      fM <- freshMatrix dim 1
      return $ M.fromList [(rule, fM) | rule <- rules]
    else do
      overlaps <-
        mapM
          ( \rule ->
              (rule,) <$> filterM (overlappingConstraint rule) rules
          )
          rules
      freshMs <- mapM (\rule -> (rule,) <$> freshMatrix dim 1) rules
      let (union, (_, (ruleNodes, _))) = evalEquiv (M.empty, M.fromList freshMs) $ mapM (uncurry go) overlaps
      return $
        M.fromList [(rule, fM) | rule <- rules, let (_, fM) = DUn.lookup union (ruleNodes M.! rule)]
 where
  go rule rules = do
    node <- nodeOfRule rule
    mapM_ (merge (\fM _ -> (fM, ())) node <=< nodeOfRule) rules

  evalEquiv state computation = run' $ runStateT computation state

  nodeOfRule r = do
    ruleMap <- gets fst
    fMMap <- gets snd
    case M.lookup r ruleMap of
      Nothing -> do
        node <- new $ fMMap M.! r
        modify $ first (M.insert r node)
        return node
      Just n -> return n

{- | Given a dimension, an optional maximal number of bits,
rules and an encoding, this function constructs the encoding
checks it for satisfiabiliy and extracts the resulting interpretations.
-}
executeMatrixInterpretations
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v)
  => Dimension
  -> Maybe EntryUpperBound
  -> [Rule (FId f) (VId v)]
  -> ( [Rule (FId f) (VId v)]
       -> EncodingMatrix (FId f) (VId v) [SExpr]
     )
  -> StateM
      ( Maybe
          ( M.Map (FId f) ([Matrix Integer], Matrix Integer)
          , M.Map (Rule (FId f) (VId v)) (Matrix Integer)
          )
      )
executeMatrixInterpretations dim upperB rules encoding = do
  inters <- mapM (\f -> (f,) <$> freshMatrixInterpretationsBasedOnFId dim f) nonValSyms
  calcRules <- mapM calcRuleFromSymbol nonValThSyms
  valInters <- constructCorrectValInterMap dim $ calcRules <> rules
  case flip evalState (MatrixCache dim upperB valInters $ M.fromList inters) . runMaybeT $
    encoding calcRules of
    Nothing -> return Nothing
    Just encodings -> do
      let logVars = foldMap (\rule -> R.lvar rule <> R.extraVars rule) $ rules <> calcRules
      let allVars =
            concatMap DM.toList (M.elems valInters)
              <> concatMap (\(_, (ms, m)) -> concatMap DM.toList ms <> DM.toList m) inters
      -- (isSat, assignedValueList) <- satSExprsResultsTactic (S.fromList allVars <> logVars) SMT.check encodings
      -- case isSat of
      solver <- getSolver
      liftIO $ SMT.push solver
      liftIO $ declareVariables solver logVars
      _ <-
        liftIO $
          SMT.defineFun
            solver
            "bool2nat"
            [("b", toSExpr boolSort)]
            (toSExpr intSort)
            (SMT.app (SMT.const "ite") [SMT.const "b", SMT.const "1", SMT.const "0"])
      (isValid, assignedValueList) <-
        first (fmap not)
          <$> satSExprsResultsTactic (S.fromList allVars) SMT.check ((: []) . SMT.not $ SMT.andMany encodings)
      liftIO $ SMT.pop solver
      case isValid of
        Just True -> return $ do
          let assignedValues = M.fromList assignedValueList
          interAss <-
            M.fromList <$> mapM (\(f, msm) -> (f,) <$> assignedInterpretation assignedValues msm) inters
          interVals <- traverse (assignedMatrixValues assignedValues) valInters
          return (interAss, interVals)
        _ -> return Nothing
 where
  -- TODO this leads to an exception
  -- nonValSyms = S.toList $ getTheoryFuns lctrs <> getFuns lctrs
  -- nonValSyms = nubOrd $ foldMap R.funs $ weakrules <> strictrules
  nonValSyms = nubOrd $ foldMap extractNonValsRule rules
  nonValThSyms = nubOrd $ foldMap extractNonValThsRule rules

{- | Check function given to the SMT solver with specific tactics that
can be used with the Z3 solver. We either use an non-linear arithmetic
solver or use bitblasting to solve the matrix interpretation encoding.
-}
_bvCheck :: SMT.Solver -> IO SMT.Result
_bvCheck proc = do
  res <-
    SMT.command
      proc
      ( SMT.app
          (SMT.const "check-sat-using")
          [ SMT.app
              (SMT.const "par-or")
              [ SMT.app
                  (SMT.const "then")
                  [ SMT.const "simplify"
                  , SMT.const "nlsat"
                  ]
              , SMT.app
                  (SMT.const "then")
                  [ SMT.app
                      (SMT.const "using-params")
                      [ SMT.const "nla2bv"
                      , SMT.const ":nla2bv_bv_size"
                      , SMT.const "8"
                      ]
                  , SMT.const "simplify"
                  , SMT.const "bit-blast"
                  , SMT.const "sat"
                  ]
              ]
          ]
      )
  case res of
    SMT.Atom "unsat" -> return SMT.Unsat
    SMT.Atom "unknown" -> return SMT.Unknown
    SMT.Atom "sat" -> return SMT.Sat
    _ ->
      fail $
        unlines
          [ "Unexpected result from the SMT solver:"
          , "  Expected: unsat, unknown, or sat"
          , "  Result: " ++ SMT.showsSExpr res ""
          ]

{- | 'matrixInterpretationsCons' @dim@ @upperBound@ @rules@ tries to
orient all rules strictly from left-to-right according to
matrix interpretations. It thereby considers the given dimension
@dim@ and the upper bound @upperBound@ for the bits of the resulting
natural numbers.
-}
matrixInterpretationsCons
  :: (Ord v, Ord f, ToSExpr v, ToSExpr f)
  => Dimension
  -> Maybe EntryUpperBound
  -> [Rule (FId f) (VId v)]
  -> StateM
      ( SNResult
      , Maybe
          (M.Map (FId f) ([Matrix Integer], Matrix Integer), M.Map (Rule (FId f) (VId v)) (Matrix Integer))
      )
matrixInterpretationsCons dim upperB rules = do
  let encoding calcRules = do
        simpenc <- simpleCalcRules calcRules strictDecreasing
        (pos, dec) <- prepareEncoding rules strictDecreasing
        return $ simpenc : pos : dec
  result <- executeMatrixInterpretations dim upperB rules encoding
  case result of
    Nothing -> return (MaybeTerminating, Nothing)
    Just (ms, mv) -> return (Terminating, Just (ms, mv))

{- | 'matrixInterpretationsConsDP' does in principle the same as
'matrixInterpretationsCons', however it applies different
efficient relaxations to the encoding which is only
valid within the DP framework.
-}
matrixInterpretationsConsDP
  :: (Ord v, Ord f, ToSExpr v, ToSExpr f)
  => Dimension
  -> Maybe EntryUpperBound
  -> DPProblem (FId f) (VId v)
  -> StateM
      ( SNResult
      , Maybe
          (M.Map (FId f) ([Matrix Integer], Matrix Integer), M.Map (Rule (FId f) (VId v)) (Matrix Integer))
      )
matrixInterpretationsConsDP dim upperB DPProblem{..} = do
  let encoding calcRules = do
        simpenc <- simpleCalcRules calcRules weakDecreasing
        (wpos, wdec) <- prepareEncodingDP weakrules weakDecreasing
        (spos, sdec) <- prepareEncodingDP strictrules strictDecreasing
        return $ simpenc : wpos : spos : wdec <> sdec
  result <- executeMatrixInterpretations dim upperB (weakrules <> strictrules) encoding
  case result of
    Nothing -> return (MaybeTerminating, Nothing)
    Just (ms, mv) -> return (Terminating, Just (ms, mv))

{- | 'matrixInterpretationsConsRP' can be used as reduction pair and
tries to orient as much dependency pairs as possible strictly.
It returns a (partial) proof and the remaining DP problem to be
further analyzed.
-}
matrixInterpretationsConsRP
  :: (Ord v, ToSExpr f, Ord f, ToSExpr v)
  => Dimension
  -> Maybe EntryUpperBound
  -> DPProblem (FId f) (VId v)
  -> StateM (DPProblem (FId f) (VId v), Maybe (SNInfo (FId f) (VId v)))
matrixInterpretationsConsRP dim upperB dpp@DPProblem{..} = do
  let encoding calcRules = do
        simpenc <- simpleCalcRules calcRules weakDecreasing
        (wpos, wdec) <- prepareEncodingDP weakrules weakDecreasing
        swdec <- mapM (weakDecreasing <=< interpretRule) strictrules
        (spos, sdec) <- prepareEncodingDP strictrules strictDecreasing
        return $ simpenc : wpos : spos : SMT.andMany (zipWith SMT.or sdec swdec) : SMT.orMany sdec : wdec
  -- return $ simpenc : wpos : spos : SMT.andMany swdec : SMT.orMany sdec : wdec
  result <- executeMatrixInterpretations dim upperB (weakrules <> strictrules) encoding
  case result of
    Nothing -> return (dpp, Nothing)
    Just result@(ms, mv) -> do
      let newDPP = remainingDPP dpp result
      return (newDPP, Just $ MatrixInterpretation dpp (ms, mv))
 where
  remainingDPP DPProblem{..} (ms, mv) = DPProblem remainingStrictRules weakrules
   where
    remainingStrictRules =
      filter
        ( \rule ->
            not $
              correctDecrease
                (interpretTerm (isLVar rule) (ms, fromJust $ M.lookup rule mv) $ R.lhs rule)
                (interpretTerm (isLVar rule) (ms, fromJust $ M.lookup rule mv) $ R.rhs rule)
                (>)
        )
        strictrules

    isLVar rule v = v `S.member` R.lvar rule

-- | Extract all function symbols which are not values.
extractNonValsRule :: Rule a v -> [a]
extractNonValsRule rule = extractNonValsTerm (R.lhs rule) <> extractNonValsTerm (R.rhs rule)

-- | Extract all function symbols which are not values.
extractNonValsTerm :: Term a v -> [a]
extractNonValsTerm (Val _) = []
extractNonValsTerm (TermFun f args) = f : concatMap extractNonValsTerm args
extractNonValsTerm (TheoryFun f args) = f : concatMap extractNonValsTerm args
extractNonValsTerm _ = []

-- | Extract all function symbols which are theory symbols, but not values.
extractNonValThsRule :: Rule a v -> [a]
extractNonValThsRule rule = extractNonValThsTerm (R.lhs rule) <> extractNonValThsTerm (R.rhs rule)

-- | Extract all function symbols which are theory symbols, but not values.
extractNonValThsTerm :: Term a v -> [a]
extractNonValThsTerm (Val _) = []
extractNonValThsTerm (TermFun _ args) = concatMap extractNonValThsTerm args
extractNonValThsTerm (TheoryFun f args) = f : concatMap extractNonValThsTerm args
extractNonValThsTerm _ = []

-- | Extract the resulting interpretation from a given assignment.
assignedInterpretation
  :: (Traversable t1, Traversable t2, Traversable t3)
  => M.Map IntVar SMT.Value
  -> (t1 (t2 IntVar), t3 IntVar)
  -> Maybe (t1 (t2 Integer), t3 Integer)
assignedInterpretation assignedVals (ms, m) = (,) <$> mapM (assignedMatrixValues assignedVals) ms <*> assignedMatrixValues assignedVals m

-- | Extract the resulting interpretation from a given assignment for a matrix.
assignedMatrixValues :: (Traversable t) => M.Map IntVar SMT.Value -> t IntVar -> Maybe (t Integer)
assignedMatrixValues assignedVals = traverse (assignedIntVarValue assignedVals)

-- | Extract the resulting interpretation from a given assignment for a single 'IntVar'.
assignedIntVarValue :: M.Map IntVar SMT.Value -> IntVar -> Maybe Integer
assignedIntVarValue assignment intVar = extractInt <$> M.lookup intVar assignment
 where
  extractInt value =
    case value of
      SMT.Int i -> i
      _ -> error "MatrixInterpretation.hs: integer value assigned to wrong type."

-- | Encodes whether two interpretations are weak decreasing.
weakDecreasing
  :: (Ord v) => (Interpretation v, Interpretation v, SExpr) -> EncodingMatrix f v SMT.SExpr
weakDecreasing ((inter1, const1), (inter2, const2), sexpr) = do
  weak1 <- weakDecrease const1 const2
  weak2 <- SMT.andMany <$> mapM (uncurry findAndCompare) (M.toList inter2)
  return $ SMT.implies sexpr $ weak1 `SMT.and` weak2
 where
  findAndCompare v mRight =
    case M.lookup v inter1 of
      Nothing -> mzero
      Just mLeft -> weakDecrease mLeft mRight

-- | Encodes whether two interpretations are strictly decreasing.
strictDecreasing
  :: (Ord v) => (Interpretation v, Interpretation v, SExpr) -> EncodingMatrix f v SMT.SExpr
strictDecreasing ((inter1, const1), (inter2, const2), sexpr) = do
  strict <- strictDecrease const1 const2
  weak <- SMT.andMany <$> mapM (uncurry findAndCompare) (M.toList inter2)
  return $ SMT.implies sexpr $ strict `SMT.and` weak
 where
  findAndCompare v mRight =
    case M.lookup v inter1 of
      Nothing -> mzero
      Just mLeft -> weakDecrease mLeft mRight

-- | Returns the interpretation of the left-hand and right-hand side of a rule.
interpretRule
  :: (Ord f, Eq v, ToSExpr v, ToSExpr f, Ord v)
  => Rule (FId f) (VId v)
  -> EncodingMatrix (FId f) (VId v) (Interpretation (VId v), Interpretation (VId v), SExpr)
interpretRule rule = do
  inters <- gets cInters
  valInters <- gets cValInter
  case M.lookup rule valInters of
    Nothing -> empty
    Just valInter ->
      return
        (interpretTermC isLVar (inters, valInter) l, interpretTermC isLVar (inters, valInter) r, sexpr)
 where
  l = R.lhs rule
  r = R.rhs rule
  isLVar v = v `S.member` R.lvar rule
  exVars = SMT.andMany [SMT.eq (toSExpr v) (toSExpr v) | v <- S.toList $ R.extraVars rule]
  sexpr = exVars `SMT.and` toSExpr (collapseGuardToTerm $ R.guard rule)

interpretTermC
  :: (Sorted v, ToSExpr v, Ord f, Ord v, ToSExpr f)
  => (v -> Bool)
  -> (M.Map f ([Matrix IntVar], Matrix IntVar), Matrix IntVar)
  -> Term f v
  -> (M.Map v (Matrix IntVar), Matrix IntVar)
interpretTermC isLVar (_, m) (T.Var v) | isLVar v = interpretValue (DM.nrows m) v -- (M.empty, m)
interpretTermC _ (_, m) (T.Var v) =
  let
    zM = DM.fromLists $ replicate d [0]
    d = DM.nrows m
    iM = DM.fromLists [[if i == j then 1 else 0 | j <- [1 .. d]] | i <- [1 .. d]]
  in
    (M.singleton v iM, zM)
interpretTermC _ (_, m) (T.Val _) = (M.empty, m)
interpretTermC isLVar (msM, m) (T.Fun _ f args) =
  case M.lookup f msM of
    Just (ms, c) ->
      let multipliedWithCoeffs = zipWith mulCoeff ms interpretedArguments
      in  bimap (M.unionsWith (+)) (foldr (+) c) $ unzip multipliedWithCoeffs
    Nothing -> error "Termination.hs: Interpretation missing."
 where
  interpretedArguments = map (interpretTermC isLVar (msM, m)) args
  mulCoeff c (vm, mc) = (fmap (c *) vm, c * mc)

interpretValue :: (Sorted v, ToSExpr v) => Int -> v -> (M.Map v (Matrix IntVar), Matrix IntVar)
interpretValue dim v = (M.empty, DM.fromList dim 1 $ interShape $ intVar go)
 where
  interShape = replicate dim
  go
    | sort v == boolSort = SMT.app (SMT.const "bool2nat") [toSExpr v]
    | sort v == intSort = SMT.abs $ toSExpr v
    | sort v == realSort = SMT.abs $ SMT.toInt $ toSExpr v
    | isBVSort (sort v) = SMT.app (SMT.const "bv2nat") [toSExpr v]
    | otherwise =
        error $
          "MatrixInterpretationCons.hs: interpretation for value " <> show (toSExpr v) <> " not implemented."

----------------------------------------------------------------------------------------------------
-- FIXME maybe move to Rule/Guard module.
----------------------------------------------------------------------------------------------------

{- | 'overlappingConstraint' @rule1@ @rule2@ checks if the constraints of the given
  rules overlap on any values.
-}
overlappingConstraint
  :: forall f v
   . (Ord v, Eq f, ToSExpr v, ToSExpr f)
  => Rule (FId f) (VId v)
  -> Rule (FId f) (VId v)
  -> StateM Bool
overlappingConstraint rule1 rule2 = do
  res <-
    satSExpr (S.fromList $ varsg1 ++ varsg2) $
      SMT.andMany
        [SMT.orMany $ map toSExpr sexpr, toSExpr $ collapseGuardToTerm g1, toSExpr $ collapseGuardToTerm g2]
  case res of
    Just False -> return False
    _ -> return True
 where
  g1 = R.guard rule1
  g2 = R.guard rule2
  varsg1 = nubOrd $ varsGuard g1
  varsg2 = nubOrd $ varsGuard g2

  sexpr =
    concat
      [ [ T.eq @f (sortAnnotation [sort v1, sort v2] boolSort) (var v1) (var v2) | v2 <- varsg2, sort v1 == sort v2
        ]
      | v1 <- varsg1
      ]

-- | 'isOpenConstraint' @constraint@ checks if all possible values of the domain appear in any satisfiable substitution of @constraint@.
isOpenConstraint
  :: forall f v. (Ord v, ToSExpr v, ToSExpr f, Eq f) => Guard (FId f) (VId v) -> StateM Bool
isOpenConstraint guard = do
  let vars = nubOrd $ varsGuard guard
  (constraint, m) <- runStateT (buildConstraints vars) M.empty
  res <-
    satSExpr (S.fromList $ vars <> M.elems m) $ constraint `SMT.and` toSExpr (collapseGuardToTerm guard)
  case res of
    Just False -> return False
    _ -> return True
 where
  buildConstraints [] = return top
  buildConstraints (v : vs) = do
    let so = sort v
    fV <- getFreshVar so
    let equTerm = T.neg $ T.eq (sortAnnotation [so, so] boolSort) (var v) (var fV)
    (toSExpr (equTerm :: Term (FId f) (VId v)) `SMT.and`) <$> buildConstraints vs

  getFreshVar :: (MonadFresh m) => Sort -> StateT (M.Map Sort (VId v)) m (VId v)
  getFreshVar so = do
    cache <- get
    case M.lookup so cache of
      Just v -> return v
      Nothing -> do
        fV <- freshV so <$> lift freshInt
        modify $ M.insert so fV
        return fV
