{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

{- |
Module      : MatrixInterpretationsBV
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : experimental


This module provides the main types and functions to show termination
by matrix interpretations with bitvectors.
-}
module Analysis.Termination.MatrixInterpretationsBV (
  matrixInterpretationsBV,
  matrixInterpretationsBVDP,
  matrixInterpretationsBVRP,
  Bits,
)
where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Analysis.Termination.DependencyPairs (DPProblem (..), isDPSym)
import Analysis.Termination.Termination (
  SNInfo (MatrixInterpretation),
  SNResult (MaybeTerminating, Terminating),
  correctDecrease,
  interpretTerm,
  prettyMPoly,
  prettyMatrix,
  wrapInter,
 )
import Control.Applicative (empty)
import Control.Arrow (first)
import Control.Monad (MonadPlus (mzero), filterM, replicateM, zipWithM, (<=<))
import Control.Monad.IO.Class (liftIO)
import Control.Monad.State (StateT, evalStateT, lift, runStateT)
import Control.Monad.State.Class (MonadState (..), gets, modify)
import Control.Monad.Trans.Maybe (MaybeT (runMaybeT))
import Control.Monad.Union (MonadUnion (merge, new), run')
import Data.Char (intToDigit)
import Data.Containers.ListUtils (nubOrd)
import Data.Foldable (foldrM)
import Data.LCTRS.FIdentifier (FId, getArity)
import Data.LCTRS.Guard (Guard, collapseGuardToTerm, varsGuard)
import Data.LCTRS.Rule (Rule)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort (Sort, Sorted (sort), sortAnnotation)
import Data.LCTRS.Term (
  Term,
  var,
  pattern TermFun,
  pattern TheoryFun,
  pattern Val,
 )
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (VId, freshV)
import qualified Data.Map.Strict as M
import Data.Matrix (Matrix)
import qualified Data.Matrix as DM
import Data.Maybe (fromJust)
import Data.Monad (MonadFresh, StateM, freshInt, getSolver)
import Data.SExpr (ToSExpr (toSExpr))
import Data.SMT (boolSort, bvSort)
import qualified Data.Set as S
import qualified Data.Union as DUn
import Data.Word (Word32, Word8)
import Numeric (showIntAtBase)
import Pretty.Box (boxString, leftOf)
import Prettyprinter (Doc, Pretty (..), encloseSep, indent, line, vsep)
import Rewriting.CriticalPair (calcRuleFromSymbol)
import Rewriting.SMT (declareVariables, satSExpr, satSExprsResults)
import SimpleSMT (SExpr)
import qualified SimpleSMT as SMT
import Type.SortLCTRS (DefaultVId)
import Utils (monadicUnionsWith)

----------------------------------------------------------------------------------------------------
-- types
----------------------------------------------------------------------------------------------------

-- | Type for the dimension of matrices.
type Dimension = Word8

-- | Type for the amount of bits in the matrix entries.
type Bits = Word32

{- | Type representing entries in the matrices. This
type wraps an 'SExpr' type and save the number of bits used.
-}
data BVVar = BVVar !Bits !SExpr
  deriving (Eq, Ord)

instance ToSExpr BVVar where
  toSExpr (BVVar _ sexpr) = sexpr

instance Sorted BVVar where
  sort (BVVar bits _) = bvSort (fromIntegral bits)

infixl 6 |+|

-- | Function implementing added on 'BVar's.
(|+|) :: BVVar -> BVVar -> BVVar
(|+|) (BVVar b1 s1) (BVVar b2 s2)
  | b1 == b2 = BVVar b1 $ SMT.bvAdd s1 s2
  | otherwise = error "MatrixInterpretationsBV.hs: bitvector addition with different number of bits."

-- | Function implementing multiplication on 'BVar's.
infixl 7 |*|

(|*|) :: BVVar -> BVVar -> BVVar
(|*|) (BVVar b1 s1) (BVVar b2 s2)
  | b1 == b2 = BVVar b1 $ SMT.bvMul s1 s2
  | otherwise =
      error "MatrixInterpretationsBV.hs: bitvector multiplication with different number of bits."

-- instance Num BVVar where
--   (+) = (|+|)
--   (*) = (|*|)
--   abs _ = error "MatrixInterpretationBV.hs: abs not implemented for BVVar."
--   signum _ = error "MatrixInterpretationBV.hs: signum not implemented for BVVar."
--   fromInteger i =
--     let bv = showIntAtBase 2 intToDigit i ""
--     in  BVVar (fromIntegral $ length bv) $ SMT.const $ "#b" <> bv
--   -- error "MatrixInterpretationBV.hs: fromInteger not implemented for BVVar."
--   negate _ = error "MatrixInterpretationBV.hs: negate not implemented for BVVar."

-- | Smart constructor for 'BVVar'.
bvVar :: Bits -> SExpr -> BVVar
bvVar = BVVar

{- | Returns a fresh 'BVVar' provided that a 'MonadFresh'
is present.
-}
freshBVVar :: (MonadFresh fresh) => Bits -> fresh BVVar
freshBVVar bits =
  bvVar bits . toSExpr . (freshV (bvSort $ fromIntegral bits) :: Int -> VId DefaultVId) <$> freshInt

-- TODO improve this function

-- | Naively constructs a 'BVVar' from an integer.
bvFromInteger :: Int -> EncodingMatrix f v BVVar
bvFromInteger i = do
  bs <- gets bits
  let bitRepresentation = showIntAtBase 2 intToDigit i ""
  let paddedRepresentation =
        if length bitRepresentation > fromIntegral bs
          then
            error $ "MatrixInterpretationBV.hs: cannot represent " <> show i <> " with " <> show bs <> " bits."
          else replicate (fromIntegral bs - length bitRepresentation) '0' <> bitRepresentation
  return $ bvVar bs $ SMT.const $ "#b" <> paddedRepresentation

-- | 'BVVar' representing the integer 0.
zeroBVVar :: EncodingMatrix f v BVVar
zeroBVVar = bvFromInteger 0

-- | 'BVVar' representing the integer 1.
oneBVVar :: EncodingMatrix f v BVVar
oneBVVar = bvFromInteger 1

-- | Matrix consisting of arithmetic 'BVVar' as entry.
type BVMatrix = Matrix BVVar

{- | Plain interpretation for a function symbol. It consists
of matrices per argument and a constant matrix.
-}
type MInterpretation f = ([BVMatrix], BVMatrix)

-- | Map which represents interpretations for different function symbols.
type InterpretationMap f = M.Map f (MInterpretation f)

{- | Interpretation resulting from interpreting terms. It specifies
a constant matrix and a 'M.Map' representing the constant part for
each variable.
-}
type Interpretation v = (M.Map v BVMatrix, BVMatrix)

----------------------------------------------------------------------------------------------------
-- interpretations for matrix
----------------------------------------------------------------------------------------------------

-- | 'freshMatrix' @bits@ @rows@ @cols@ returns a square Matrix of fresh 'BVVar' variables of the given dimensions.
freshMatrix :: (MonadFresh m) => Bits -> Dimension -> Dimension -> m BVMatrix
freshMatrix bits dimensionRows dimensionCols = DM.fromLists <$> replicateM dR (replicateM dC (freshBVVar bits))
 where
  dR = fromIntegral dimensionRows
  dC = fromIntegral dimensionCols

{- | 'freshMatrixInterpretations' @bits@ @dim@ @fId@ returns a matrix interpretation for a function symbol @fId@.
Note that for DP symbols a different function should be used (see 'freshMatrixInterpretationsDP'.
-}
freshMatrixInterpretations
  :: (MonadFresh m) => Bits -> Dimension -> FId f -> m ([BVMatrix], BVMatrix)
freshMatrixInterpretations bits dimension fId = (,) <$> replicateM arity (freshMatrix bits dimension dimension) <*> freshMatrix bits dimension 1
 where
  arity = getArity fId

-- | 'freshMatrixInterpretationsDP' @bits@ @dim@ @fId@ returns a matrix interpretation for a DP symbol @fId@ following the given dimension and bits.
freshMatrixInterpretationsDP
  :: (MonadFresh m) => Bits -> Dimension -> FId f -> m ([BVMatrix], BVMatrix)
freshMatrixInterpretationsDP bits dimension fId = (,) <$> replicateM arity (freshMatrix bits 1 dimension) <*> freshMatrix bits 1 1
 where
  arity = getArity fId

{- | 'freshMatrixInterpretationsBasedOnFId' @bits@ @dim@ @fId@ returns a matrix interpretation based on @fId@.
It either returns a matrix interpretation for a standard symbol or a different one for a DP symbol.
-}
freshMatrixInterpretationsBasedOnFId
  :: (MonadFresh m) => Bits -> Dimension -> FId f -> m ([BVMatrix], BVMatrix)
freshMatrixInterpretationsBasedOnFId bits dim fId
  | isDPSym fId = freshMatrixInterpretationsDP bits dim fId
  | otherwise = freshMatrixInterpretations bits dim fId

-- | 'hasSameDimension' checks whether two matrices have the same dimension.
hasSameDimension :: Matrix a -> Matrix a -> Bool
hasSameDimension m1 m2 = sameRows m1 m2 && sameCols m1 m2
 where
  sameRows m1 m2 = DM.nrows m1 == DM.nrows m2
  sameCols m1 m2 = DM.ncols m1 == DM.ncols m2

-- | 'strictDecrease' @m1@ @m2@ returns an 'SExpr' which encodes a strict decrease from @m1@ to @m2@.
strictDecrease :: BVMatrix -> BVMatrix -> EncodingMatrix f v SExpr
strictDecrease m1 m2
  | hasSameDimension m1 m2 =
      case (DM.toList m1, DM.toList m2) of
        (s1 : ws1, s2 : ws2) ->
          return $ SMT.andMany $ strictDecreaseEntry s1 s2 : zipWith weakDecreaseEntry ws1 ws2
        _ -> mzero
strictDecrease _ _ = mzero -- SMT.const "false"

-- | 'weakDecrease' @m1@ @m2@ returns an SExpr which encodes a weak decrease from @m1@ to @m2@.
weakDecrease :: BVMatrix -> BVMatrix -> EncodingMatrix f v SExpr
weakDecrease m1 m2
  | hasSameDimension m1 m2 =
      return $ SMT.andMany $ zipWith weakDecreaseEntry (DM.toList m1) (DM.toList m2)
weakDecrease _ _ = mzero

-- | 'strictDecreaseEntry' @e1@ @e2@ returns an SExpr which encodes a strict decrease from @e1@ to @e2@.
strictDecreaseEntry :: BVVar -> BVVar -> SExpr
strictDecreaseEntry v1 v2 = toSExpr v2 `SMT.bvULt` toSExpr v1

-- | 'weakDecreaseEntry' @e1@ @e2@ returns an SExpr which encodes a weak decrease from @e1@ to @e2@.
weakDecreaseEntry :: BVVar -> BVVar -> SExpr
weakDecreaseEntry v1 v2 = toSExpr v2 `SMT.bvULeq` toSExpr v1

-- | 'positiveFirstEntry' returns an SExpr that the entry at position (1,1) is greater than 0.
positiveFirstEntry :: (ToSExpr a) => Matrix a -> EncodingMatrix f v SExpr
positiveFirstEntry m
  | DM.nrows m > 0 && DM.ncols m > 0 =
      (`SMT.bvULt` toSExpr (m DM.! (1, 1))) . toSExpr <$> zeroBVVar
positiveFirstEntry _ = return bot

-- | Constant 'SExpr' representing true.
top :: SExpr
top = SMT.const "true"

-- | Constant 'SExpr' representing false.
bot :: SExpr
bot = SMT.const "false"

----------------------------------------------------------------------------------------------------
-- interpretations for rules
----------------------------------------------------------------------------------------------------

-- | State used in the encoding mondad transformer 'EncodingMatrix'.
data MatrixCache f v = MatrixCache
  { _cDim :: !Dimension
  , bits :: !Bits
  , cValInter :: !(M.Map (Rule f v) BVMatrix)
  , cInters :: !(InterpretationMap f)
  }

{- | Encoding matrix which caches important properties specified
in 'MatrixCache' and also if a encoding fails maybe already using
the 'Maybe' monad. This transformer also consists of StateM to cache
additional constraints arising in addition and multiplication.
-}
type EncodingMatrix f v = MaybeT (StateT (MatrixCache f v) StateM)

{- | 'prepareEncoding' @rules@ @decreasingness@ prepares the encoding for all rules
with respect to the given decreasingness function. It returns a tuple for which the
first component encodes left upper entries in the interpretations are greater than 0
and the second component decreasingness of the rules.
-}
prepareEncoding
  :: (Ord v, Ord f, ToSExpr v, ToSExpr f)
  => [Rule (FId f) (VId v)]
  -> (Interpretation (VId v) -> Interpretation (VId v) -> EncodingMatrix (FId f) (VId v) SExpr)
  -> EncodingMatrix (FId f) (VId v) (SExpr, [SExpr])
prepareEncoding rules decreasingness = do
  inters <- gets cInters
  positiveFirstEntries <-
    SMT.andMany . concat <$> mapM (\(_, (ms, _)) -> mapM positiveFirstEntry ms) (M.toList inters)
  encodingsM <- mapM (uncurry decreasingness <=< interpretRule) rules
  return $ (positiveFirstEntries, encodingsM)

{- | This function, given calculation rules and a notion of decreasingness, returns a constraint
ensuring this. This is important as we need (weak/strict) simplicity for calculation rules and
achieve this as logical variables are interpreted as constants.
-}
simpleCalcRules
  :: (ToSExpr v, ToSExpr f, Ord f, Ord v)
  => [Rule (FId f) (VId v)]
  -> (Interpretation (VId v) -> Interpretation (VId v) -> EncodingMatrix (FId f) (VId v) SExpr)
  -> EncodingMatrix (FId f) (VId v) SExpr
simpleCalcRules calcRules decreasingness = SMT.andMany <$> mapM (uncurry decreasingness <=< interpretRule) calcRules

{- | This function constructs the correct interpretation for ranges of values
with respect to a rule. If two rules have overlapping constraints then
their values have the same interpretation. However, if an open constraint
is present then all values have the same interpretation.
-}
constructCorrectValInterMap
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v)
  => Bits
  -> Dimension
  -> [Rule (FId f) (VId v)]
  -> StateM (M.Map (Rule (FId f) (VId v)) BVMatrix)
constructCorrectValInterMap bits dim rules = do
  openConstraintFound <- or <$> mapM (isOpenConstraint . R.guard) rules
  if openConstraintFound
    then do
      fM <- freshMatrix bits dim 1
      return $ M.fromList [(rule, fM) | rule <- rules]
    else do
      overlaps <-
        mapM
          ( \rule ->
              (rule,) <$> filterM (overlappingConstraint rule) rules
          )
          rules
      freshMs <- mapM (\rule -> (rule,) <$> freshMatrix bits dim 1) rules
      let (union, (_, (ruleNodes, _))) = evalEquiv (M.empty, M.fromList freshMs) $ mapM (uncurry go) overlaps
      return $
        M.fromList [(rule, fM) | rule <- rules, let (_, fM) = DUn.lookup union (ruleNodes M.! rule)]
 where
  go rule rules = do
    node <- nodeOfRule rule
    mapM_ (merge (\fM _ -> (fM, ())) node <=< nodeOfRule) rules

  evalEquiv state computation = run' $ runStateT computation state

  nodeOfRule r = do
    ruleMap <- gets fst
    fMMap <- gets snd
    case M.lookup r ruleMap of
      Nothing -> do
        node <- new $ fMMap M.! r
        modify $ first (M.insert r node)
        return node
      Just n -> return n

{- | Given a dimension, a number of bits,
rules and an encoding, this function constructs the encoding
checks it for satisfiabiliy and extracts the resulting interpretations.
-}
executeMatrixInterpretations
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v)
  => Dimension
  -> Bits
  -> [Rule (FId f) (VId v)]
  -> ( [Rule (FId f) (VId v)]
       -> EncodingMatrix (FId f) (VId v) [SExpr]
     )
  -> StateM
      ( Maybe
          ( M.Map (FId f) ([Matrix Integer], Matrix Integer)
          , M.Map (Rule (FId f) (VId v)) (Matrix Integer)
          )
      )
executeMatrixInterpretations dim bits rules encoding = do
  solver <- getSolver
  inters <- mapM (\f -> (f,) <$> freshMatrixInterpretationsBasedOnFId bits dim f) nonValSyms
  calcRules <- mapM calcRuleFromSymbol nonValThSyms
  valInters <- constructCorrectValInterMap bits dim $ calcRules <> rules
  let allVars =
        concatMap DM.toList (M.elems valInters)
          <> concatMap (\(_, (ms, m)) -> concatMap DM.toList ms <> DM.toList m) inters
  liftIO $ declareVariables solver $ S.fromList allVars
  res <-
    evalStateT (runMaybeT (encoding calcRules)) (MatrixCache dim bits valInters $ M.fromList inters)
  case res of
    Nothing -> return Nothing
    Just encodings -> do
      (isSat, assignedValueList) <- satSExprsResults S.empty encodings
      case isSat of
        Just True -> return $ do
          let assignedValues = M.fromList assignedValueList
          interAss <-
            M.fromList <$> mapM (\(f, msm) -> (f,) <$> assignedInterpretation assignedValues msm) inters
          interVals <- traverse (assignedMatrixValues assignedValues) valInters
          return (interAss, interVals)
        _ -> return Nothing
 where
  -- TODO this leads to an exception
  -- nonValSyms = S.toList $ getTheoryFuns lctrs <> getFuns lctrs
  -- nonValSyms = nubOrd $ foldMap R.funs $ weakrules <> strictrules
  nonValSyms = nubOrd $ foldMap extractNonValsRule rules
  nonValThSyms = nubOrd $ foldMap extractNonValThsRule rules

{- | 'matrixInterpretationsBV' @dim@ @bits@ @rules@ tries to
orient all rules strictly from left-to-right according to
matrix interpretations. It thereby considers the given dimension
@dim@ and the number of @bits@ of the resulting natural numbers.
-}
matrixInterpretationsBV
  :: (Ord v, Pretty f, Ord f, ToSExpr v, Pretty v, ToSExpr f)
  => Dimension
  -> Bits
  -> [Rule (FId f) (VId v)]
  -> StateM
      ( SNResult
      , Maybe
          (M.Map (FId f) ([Matrix Integer], Matrix Integer), M.Map (Rule (FId f) (VId v)) (Matrix Integer))
      )
matrixInterpretationsBV dim bits rules = do
  let encoding calcRules = do
        simpenc <- simpleCalcRules calcRules strictDecreasing
        (spos, sdec) <- prepareEncoding rules strictDecreasing
        return $ simpenc : spos : sdec
  result <- executeMatrixInterpretations dim bits rules encoding
  case result of
    Nothing -> return (MaybeTerminating, Nothing)
    Just (ms, mv) -> return (Terminating, Just (ms, mv))

{- | 'matrixInterpretationsBVDP' does in principle the same as
'matrixInterpretationsBV', however it applies different
efficient relaxations to the encoding which is only
valid within the DP framework.
-}
matrixInterpretationsBVDP
  :: (Ord v, Pretty f, Ord f, ToSExpr v, Pretty v, ToSExpr f)
  => Dimension
  -> Bits
  -> DPProblem (FId f) (VId v)
  -> StateM
      ( SNResult
      , Maybe
          (M.Map (FId f) ([Matrix Integer], Matrix Integer), M.Map (Rule (FId f) (VId v)) (Matrix Integer))
      )
matrixInterpretationsBVDP dim bits DPProblem{..} = do
  let encoding calcRules = do
        simpenc <- simpleCalcRules calcRules weakDecreasing
        (_, wdec) <- prepareEncoding weakrules weakDecreasing
        (_, sdec) <- prepareEncoding strictrules strictDecreasing
        return $ simpenc : wdec <> sdec
  result <- executeMatrixInterpretations dim bits (weakrules <> strictrules) encoding
  case result of
    Nothing -> return (MaybeTerminating, Nothing)
    Just (ms, mv) -> return (Terminating, Just (ms, mv))

{- | 'matrixInterpretationsBVRP' can be used as reduction pair and
tries to orient as much dependency pairs as possible strictly.
It returns a (partial) proof and the remaining DP problem to be
further analyzed.
-}
matrixInterpretationsBVRP
  :: forall f v
   . (Ord v, ToSExpr f, Pretty f, Ord f, ToSExpr v, Pretty v)
  => Dimension
  -> Bits
  -> DPProblem (FId f) (VId v)
  -> StateM (DPProblem (FId f) (VId v), Maybe (SNInfo (FId f) (VId v)))
matrixInterpretationsBVRP dim upperB dpp@DPProblem{..} = do
  solver <- getSolver
  liftIO $ SMT.push solver
  let encoding calcRules = do
        simpenc <- simpleCalcRules calcRules weakDecreasing
        (_, wdec) <- prepareEncoding weakrules weakDecreasing
        swdec <- mapM (uncurry weakDecreasing <=< interpretRule) strictrules
        (_, sdec) <- prepareEncoding strictrules strictDecreasing
        return $ simpenc : SMT.andMany (zipWith SMT.or sdec swdec) : SMT.orMany sdec : wdec
  -- return $ simpenc : SMT.andMany swdec : SMT.orMany sdec : wdec
  result <- executeMatrixInterpretations dim upperB (weakrules <> strictrules) encoding
  case result of
    Nothing -> return (dpp, Nothing)
    Just result@(ms, mv) -> do
      liftIO $ print $ printInters result
      let newDPP = remainingDPP dpp result
      if dpp == newDPP
        then do
          liftIO $ SMT.pop solver
          error "OHHHH"
        else do
          liftIO $ SMT.pop solver
          return (newDPP, Just $ MatrixInterpretation dpp (ms, mv))
 where
  printInters (ms, mv) =
    vsep $
      [ "Matrix Interpretations:"
      , indent 2 $
          "Non-Value Symbols:"
            <> line
            <> vsep
              [ pretty resultString <> line
              | (f, (cs, c)) <- M.toList ms
              , let ics = zip [1 :: Int ..] cs
              , let vars = map (("x" <>) . pretty . fst) ics
              , let args
                      | null vars = ""
                      | otherwise = encloseSep "(" ")" ", " vars
              , let funPrefix = show . wrapInter $ pretty f <> args
              , let poly = prettyMPoly (zip vars cs) c
              , let resultString = boxString funPrefix `leftOf` poly
              ]
      , indent 2 $
          "Value Symbols:"
            <> line
            <> vsep
              [ pretty resultString <> line
              | (r, c) <- M.toList mv
              , let resultString =
                      boxString (show $ "values satisfying guard of " <> R.prettyRule r <> " as ") `leftOf` prettyMatrix c
              ]
      , indent 2 $
          "Strict Rule Orientations:"
            <> line
            <> vsep
              ( map
                  (\rule -> (<> line) $ prettyRuleInters ms (fromJust $ M.lookup rule mv) True rule)
                  (strictrules)
              )
      , indent 2 $
          "Weak Rule Orientations:"
            <> line
            <> vsep
              ( map
                  (\rule -> (<> line) $ prettyRuleInters ms (fromJust $ M.lookup rule mv) False rule)
                  (weakrules)
              )
      ]

  remainingDPP
    :: DPProblem (FId f) (VId v)
    -> (M.Map (FId f) ([Matrix Integer], Matrix Integer), M.Map (Rule (FId f) (VId v)) (Matrix Integer))
    -> DPProblem (FId f) (VId v)
  remainingDPP DPProblem{..} (ms, mv) = DPProblem remainingStrictRules weakrules
   where
    remainingStrictRules =
      filter
        ( \rule ->
            not $
              correctDecrease
                (interpretTerm (isLVar rule) (ms, fromJust $ M.lookup rule mv) $ R.lhs rule)
                (interpretTerm (isLVar rule) (ms, fromJust $ M.lookup rule mv) $ R.rhs rule)
                (>)
        )
        strictrules

    isLVar rule v = v `S.member` R.lvar rule

prettyRuleInters
  :: (Ord f, Ord v, Pretty v, Pretty f)
  => M.Map f ([Matrix Integer], Matrix Integer)
  -> Matrix Integer
  -> Bool
  -> Rule f v
  -> Doc ann
prettyRuleInters msM m isStrict rule =
  if correctDecrease lhsM rhsM firstOp
    then
      pretty $
        boxString (show . wrapInter $ R.prettyRule rule)
          `leftOf` prettyMPoly (map (first pretty) $ M.toList lhsV) lhsC
          `leftOf` (if isStrict then boxString " > " else boxString " >= ")
          `leftOf` prettyMPoly (map (first pretty) $ M.toList rhsV) rhsC
    else
      pretty $
        boxString (show . wrapInter $ R.prettyRule rule)
          `leftOf` prettyMPoly (map (first pretty) $ M.toList lhsV) lhsC
          `leftOf` boxString "?"
          `leftOf` prettyMPoly (map (first pretty) $ M.toList rhsV) rhsC
 where
  lhsM@(lhsV, lhsC) = interpretTerm isLVar (msM, m) (R.lhs rule)
  rhsM@(rhsV, rhsC) = interpretTerm isLVar (msM, m) (R.rhs rule)

  firstOp = if isStrict then (>) else (>=)

  isLVar v = v `S.member` R.lvar rule

-- | Extract all function symbols which are not values.
extractNonValsRule :: Rule a v -> [a]
extractNonValsRule rule = extractNonValsTerm (R.lhs rule) <> extractNonValsTerm (R.rhs rule)

-- | Extract all function symbols which are not values.
extractNonValsTerm :: Term a v -> [a]
extractNonValsTerm (Val _) = []
extractNonValsTerm (TermFun f args) = f : concatMap extractNonValsTerm args
extractNonValsTerm (TheoryFun f args) = f : concatMap extractNonValsTerm args
extractNonValsTerm _ = []

-- | Extract all function symbols which are theory symbols, but not values.
extractNonValThsRule :: Rule a v -> [a]
extractNonValThsRule rule = extractNonValThsTerm (R.lhs rule) <> extractNonValThsTerm (R.rhs rule)

-- | Extract all function symbols which are theory symbols, but not values.
extractNonValThsTerm :: Term a v -> [a]
extractNonValThsTerm (Val _) = []
extractNonValThsTerm (TermFun _ args) = concatMap extractNonValThsTerm args
extractNonValThsTerm (TheoryFun f args) = f : concatMap extractNonValThsTerm args
extractNonValThsTerm _ = []

-- | Extract the resulting interpretation from a given assignment.
assignedInterpretation
  :: (Traversable t1, Traversable t2, Traversable t3)
  => M.Map BVVar SMT.Value
  -> (t1 (t2 BVVar), t3 BVVar)
  -> Maybe (t1 (t2 Integer), t3 Integer)
assignedInterpretation assignedVals (ms, m) = (,) <$> mapM (assignedMatrixValues assignedVals) ms <*> assignedMatrixValues assignedVals m

-- | Extract the resulting interpretation from a given assignment for a matrix.
assignedMatrixValues :: (Traversable t) => M.Map BVVar SMT.Value -> t BVVar -> Maybe (t Integer)
assignedMatrixValues assignedVals = traverse (assignedBVVarValue assignedVals)

-- | Extract the resulting interpretation from a given assignment for a single 'IntVar'.
assignedBVVarValue :: M.Map BVVar SMT.Value -> BVVar -> Maybe Integer
assignedBVVarValue assignment intVar = extractInt <$> M.lookup intVar assignment
 where
  extractInt value =
    case value of
      SMT.Bits _width integer -> integer
      _ -> error "MatrixInterpretation.hs: bitvector value assigned to wrong type."

-- | Encodes whether two interpretations are weak decreasing.
weakDecreasing :: (Ord v) => Interpretation v -> Interpretation v -> EncodingMatrix f v SMT.SExpr
weakDecreasing (inter1, const1) (inter2, const2) = do
  weak1 <- weakDecrease const1 const2
  weak2 <- SMT.andMany <$> mapM (uncurry findAndCompare) (M.toList inter2)
  return $ weak1 `SMT.and` weak2
 where
  findAndCompare v mRight =
    case M.lookup v inter1 of
      Nothing -> mzero
      Just mLeft -> weakDecrease mLeft mRight

-- | Encodes whether two interpretations are strictly decreasing.
strictDecreasing :: (Ord v) => Interpretation v -> Interpretation v -> EncodingMatrix f v SMT.SExpr
strictDecreasing (inter1, const1) (inter2, const2) = do
  strict <- strictDecrease const1 const2
  weak <- SMT.andMany <$> mapM (uncurry findAndCompare) (M.toList inter2)
  return $ strict `SMT.and` weak
 where
  findAndCompare v mRight =
    case M.lookup v inter1 of
      Nothing -> mzero
      Just mLeft -> weakDecrease mLeft mRight

-- | Returns the interpretation of the left-hand and right-hand side of a rule.
interpretRule
  :: (Ord f, Eq v, ToSExpr v, ToSExpr f, Ord v)
  => Rule (FId f) (VId v)
  -> EncodingMatrix (FId f) (VId v) (Interpretation (VId v), Interpretation (VId v))
interpretRule rule = do
  inters <- gets cInters
  valInters <- gets cValInter
  case M.lookup rule valInters of
    Nothing -> empty
    Just valInter -> do
      (,) <$> interpretTermBV isLVar (inters, valInter) l <*> interpretTermBV isLVar (inters, valInter) r
 where
  l = R.lhs rule
  r = R.rhs rule
  isLVar v = v `S.member` R.lvar rule

{- | Returns the interpretation of the left-hand and right-hand side of a term
where the interpretation uses bit vectors.
-}
interpretTermBV
  :: (Ord f, Ord v)
  => (v -> Bool)
  -> (M.Map f ([BVMatrix], BVMatrix), BVMatrix)
  -> T.Term f v
  -> EncodingMatrix f v (M.Map v BVMatrix, BVMatrix)
interpretTermBV isLVar (_, m) (T.Var v) | isLVar v = return (M.empty, m)
interpretTermBV _ (_, m) (T.Var v) = do
  zB <- zeroBVVar
  oB <- oneBVVar
  let
    zM = DM.fromLists $ replicate d [zB]
    d = DM.nrows m
    iM = DM.fromLists [[if i == j then oB else zB | j <- [1 .. d]] | i <- [1 .. d]]
  return (M.singleton v iM, zM)
interpretTermBV _ (_, m) (T.Val _) = return (M.empty, m)
interpretTermBV isLVar (msM, m) (T.Fun _ f args) = do
  interpretedArguments <- mapM (interpretTermBV isLVar (msM, m)) args
  case M.lookup f msM of
    Just (ms, c) -> do
      (funs, coeffs) <- unzip <$> zipWithM mulCoeff ms interpretedArguments
      addedFuns <- monadicUnionsWith naiveMatrixAdd funs
      addedCoeffs <- foldrM naiveMatrixAdd c coeffs
      return (addedFuns, addedCoeffs)
    Nothing -> error "Termination.hs: Interpretation missing."
 where
  mulCoeff
    :: BVMatrix -> (M.Map v BVMatrix, BVMatrix) -> EncodingMatrix f v (M.Map v BVMatrix, BVMatrix)
  -- mulCoeff c (vm, mc) = (fmap (c *) vm, c * mc)
  mulCoeff c (vm, mc) = (,) <$> traverse (c `naiveMatrixMul`) vm <*> c `naiveMatrixMul` mc

-- addEntries = DM.elementwise (|+|)

naiveMatrixAdd :: BVMatrix -> BVMatrix -> EncodingMatrix f v BVMatrix
naiveMatrixAdd m1 m2
  | DM.ncols m1 == DM.ncols m2 && DM.nrows m1 == DM.nrows m2 =
      DM.fromList (DM.nrows m1) (DM.ncols m1) <$> zipWithM bvAddM (DM.toList m1) (DM.toList m2)
  | otherwise =
      error "MatrixInterpretationBV.hs: cannot add matrices with different number of rows and columns."

naiveMatrixMul :: BVMatrix -> BVMatrix -> EncodingMatrix f v BVMatrix
naiveMatrixMul m1 m2
  -- FIXME can foldr1 get problematic here?
  | DM.ncols m1 == DM.nrows m2 = do
      zV <- zeroBVVar
      DM.fromLists
        <$>
        -- [[foldr1 (|+|) $ zipWith (|*|) row col | col <- specialColLists m2] | row <- DM.toLists m1]
        mapM (\row -> mapM (foldrM bvAddM zV <=< zipWithM bvMulM row) $ specialColLists m2) (DM.toLists m1)
  | otherwise =
      error $
        "MatrixInterpretationBV.hs: multiplication of "
          <> showDims m1
          <> " and "
          <> showDims m2
          <> " matrices."
 where
  showDims m = show (DM.nrows m) <> " x " <> show (DM.ncols m)
  specialColLists m = [[m DM.! (rowIndex, colIndex) | rowIndex <- [1 .. DM.nrows m]] | colIndex <- [1 .. DM.ncols m]]

bvMulM :: BVVar -> BVVar -> EncodingMatrix f v BVVar
bvMulM e1 e2 = do
  solver <- getSolver
  let bvvar = e1 |*| e2
  liftIO $ SMT.assert solver $ toSExpr e1 `SMT.bvULt` toSExpr bvvar
  liftIO $ SMT.assert solver $ toSExpr e2 `SMT.bvULt` toSExpr bvvar
  return bvvar

bvAddM :: BVVar -> BVVar -> EncodingMatrix f v BVVar
bvAddM e1 e2 = do
  solver <- getSolver
  let bvvar = e1 |+| e2
  liftIO $ SMT.assert solver $ toSExpr e1 `SMT.bvULt` toSExpr bvvar
  liftIO $ SMT.assert solver $ toSExpr e2 `SMT.bvULt` toSExpr bvvar
  return bvvar

----------------------------------------------------------------------------------------------------
-- FIXME maybe move to Rule/Guard module.
----------------------------------------------------------------------------------------------------

{- | 'overlappingConstraint' @rule1@ @rule2@ checks if the constraints of the given
  rules overlap on any values.
-}
overlappingConstraint
  :: forall f v
   . (Ord v, Eq f, ToSExpr v, ToSExpr f)
  => Rule (FId f) (VId v)
  -> Rule (FId f) (VId v)
  -> StateM Bool
overlappingConstraint rule1 rule2 = do
  res <-
    satSExpr (S.fromList $ varsg1 ++ varsg2) $
      SMT.andMany
        [SMT.orMany $ map toSExpr sexpr, toSExpr $ collapseGuardToTerm g1, toSExpr $ collapseGuardToTerm g2]
  case res of
    Just False -> return False
    _ -> return True
 where
  g1 = R.guard rule1
  g2 = R.guard rule2
  varsg1 = nubOrd $ varsGuard g1
  varsg2 = nubOrd $ varsGuard g2

  sexpr =
    concat
      [ [ T.eq @f (sortAnnotation [sort v1, sort v2] boolSort) (var v1) (var v2) | v2 <- varsg2, sort v1 == sort v2
        ]
      | v1 <- varsg1
      ]

-- | 'isOpenConstraint' @constraint@ checks if all possible values of the domain appear in any satisfiable substitution of @constraint@.
isOpenConstraint
  :: forall f v. (Ord v, ToSExpr v, ToSExpr f, Eq f) => Guard (FId f) (VId v) -> StateM Bool
isOpenConstraint guard = do
  let vars = nubOrd $ varsGuard guard
  (constraint, m) <- runStateT (buildConstraints vars) M.empty
  res <-
    satSExpr (S.fromList $ vars <> M.elems m) $ constraint `SMT.and` toSExpr (collapseGuardToTerm guard)
  case res of
    Just False -> return False
    _ -> return True
 where
  buildConstraints [] = return top
  buildConstraints (v : vs) = do
    let so = sort v
    fV <- getFreshVar so
    let equTerm = T.neg $ T.eq (sortAnnotation [so, so] boolSort) (var v) (var fV)
    (toSExpr (equTerm :: Term (FId f) (VId v)) `SMT.and`) <$> buildConstraints vs

  getFreshVar :: (MonadFresh m) => Sort -> StateT (M.Map Sort (VId v)) m (VId v)
  getFreshVar so = do
    cache <- get
    case M.lookup so cache of
      Just v -> return v
      Nothing -> do
        fV <- freshV so <$> lift freshInt
        modify $ M.insert so fV
        return fV
