{-# LANGUAGE FlexibleInstances #-}

{- |
Module      : Termination
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable

This module provides functions for data types and
pretty printers for termination.
-}
module Analysis.Termination.Termination where

import qualified Analysis.Confluence.Confluence as C
import Analysis.Termination.DependencyGraph (DPGraph (..), SCC (..))
import Analysis.Termination.DependencyPairs (DPProblem (..))
import Control.Arrow (Arrow (first))
import Data.Bifunctor (Bifunctor (bimap))
import Data.LCTRS.FIdentifier (FId, getArity)
import Data.LCTRS.Guard (Guard, isTopGuard, prettyGuard)
import Data.LCTRS.Rule (Rule, prettyRule)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort (Attr (..), Sort (..), Sorted (sort))
import qualified Data.LCTRS.Term as T
import qualified Data.Map.Strict as M
import Data.Matrix (Matrix)
import qualified Data.Matrix as DM
import Data.Maybe (fromJust)
import Data.SExpr (ToSExpr (toSExpr))
import Data.SMT (boolSort, intSort, isBVSort, realSort)
import Data.Set (member)
import Data.String (IsString)
import Fmt (fmt, (+|), (|+))
import Pretty.Box (Box, box, boxString, leftOf)
import Prettyprinter (
  Doc,
  Pretty (pretty),
  encloseSep,
  indent,
  line,
  vsep,
  (<+>),
 )
import SimpleSMT (SExpr, showsSExpr)
import qualified SimpleSMT as SMT

----------------------------------------------------------------------------------------------------
-- comparison and zero elements of theory sorts
----------------------------------------------------------------------------------------------------

getCompAndZeroOfSort
  :: Sort
  -> Maybe
      ( SExpr -> SExpr -> SExpr
      , SExpr -> SExpr -> SExpr
      , SExpr
      )
getCompAndZeroOfSort s | s == boolSort = Nothing -- return (SMT.gt, SMT.geq, SMT.const "false")
getCompAndZeroOfSort s | s == intSort = (SMT.gt,SMT.geq,) <$> zeroSExprOfSort s
getCompAndZeroOfSort s | s == realSort = (SMT.gt,SMT.geq,) <$> zeroSExprOfSort s
getCompAndZeroOfSort s
  | isBVSort s =
      (\s1 s2 -> SMT.fun "bvugt" [s1, s2],\s1 s2 -> SMT.fun "bvuge" [s1, s2],)
        <$> zeroSExprOfSort s
getCompAndZeroOfSort _ = Nothing

getAddMulOfSort
  :: Sort
  -> Maybe
      ( SExpr -> SExpr -> SExpr
      , SExpr -> SExpr -> SExpr
      )
getAddMulOfSort s | s == boolSort = Nothing
getAddMulOfSort s | s == intSort = return (SMT.add, SMT.mul)
getAddMulOfSort s | s == realSort = return (SMT.add, SMT.mul)
getAddMulOfSort s
  | isBVSort s =
      return
        (\s1 s2 -> SMT.fun "bvadd" [s1, s2], \s1 s2 -> SMT.fun "bvmul" [s1, s2])
getAddMulOfSort _ = Nothing

zeroSExprOfSort :: Sort -> Maybe SExpr
zeroSExprOfSort s | s == intSort = return $ SMT.const "0"
zeroSExprOfSort s | s == realSort = return $ SMT.const "0.0"
zeroSExprOfSort s | isBVSort s = correctZeroBVConst s
zeroSExprOfSort _ = Nothing

correctZeroBVConst :: Sort -> Maybe SExpr
correctZeroBVConst (AttrSort "BitVec" (AttrInt i)) = return $ SMT.const $ '#' : 'b' : zeroes
 where
  zeroes = replicate i '0'
correctZeroBVConst _ = Nothing

getOneOfSort
  :: Sort
  -> Maybe SExpr
getOneOfSort s | s == boolSort = Nothing
getOneOfSort s | s == intSort = return $ SMT.const "1"
getOneOfSort s | s == realSort = return $ SMT.const "1.0"
getOneOfSort s
  | isBVSort s = correctOneBVConst s
getOneOfSort _ = Nothing

correctOneBVConst :: Sort -> Maybe SExpr
correctOneBVConst (AttrSort "BitVec" (AttrInt i)) = return $ SMT.const $ '#' : 'b' : zeroes ++ ['1']
 where
  zeroes = replicate (i - 1) '0'
correctOneBVConst _ = Nothing

----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

-- precedence from least to greatest
-- the inner set means equality for those symbols
newtype Precedence f = Prec [[f]]

emptyPrecedence :: Precedence f
emptyPrecedence = Prec []

data SNResult = Terminating | NonTerminating | MaybeTerminating deriving (Eq)

newtype SNInfoList f v = SNInfoList [SNInfo f v]

instance (Pretty f, Pretty v, Ord f, Ord v) => Pretty (SNInfoList (FId f) v) where
  pretty (SNInfoList sninfos) = vsep $ map (("*" <+>) . pretty) sninfos

data SNInfo f v
  = RpoPrecedence (Precedence f)
  | VcProjection [([Rule f v], [(f, Int)])]
  | VcProjectionPol [([Rule f v], [(f, ([(Int, SExpr)], SExpr))])]
  | ScProjection (DPProblem f v) [([Rule f v], [(f, Int)])]
  | NoDependencyPairs
  | DpGraph (DPProblem f v) (DPGraph f v, [SCC f v])
  | PolyInterpretation (M.Map f ([Int], Int)) -- [String]
  | MatrixInterpretation
      (DPProblem f v)
      ( M.Map f ([Matrix Integer], Matrix Integer)
      , -- , M.Map (Guard f v) (Matrix Integer)
        M.Map (Rule f v) (Matrix Integer) -- [String]
      )
  | ReductionPair [(DPProblem f v, SNInfo f v)]
  | None

listSet :: [Doc ann] -> Doc ann
listSet = encloseSep "{" "}" ", "

instance (Ord f, Ord v, Pretty f, Pretty v) => Pretty (SNInfo (FId f) v) where
  pretty (RpoPrecedence (Prec [])) = "RPO with empty precedence."
  pretty (RpoPrecedence (Prec prec)) =
    vsep
      [ "RPO with precedence:"
      , indent 2 $
          encloseSep "" "" " > " $
            map (encloseSep "{" "}" " " . map pretty) $
              reverse prec
      ]
  pretty (VcProjection []) = "Value criterion with empty projection."
  pretty (VcProjection proj) =
    vsep
      [ "Value criterion after" <+> pretty (length proj) <+> "iteration(s) with the projections:"
      , ( indent 2 $
            vsep $
              zipWith
                ( \i (rules, projs) ->
                    vsep
                      [ "Iteration"
                          <+> pretty i
                          <> ":"
                      , indent 2 "projection(s):"
                      , indent 4 $
                          encloseSep
                            ""
                            ""
                            ", "
                            (map (\(f, param) -> "v(" <> pretty f <> ")" <+> "=" <+> pretty (succ param)) projs)
                      , indent 2 "removing the rule(s):"
                      , indent 4 $ listSet (map prettyRule rules)
                      ]
                )
                [1 :: Int ..]
                proj
        )
      ]
  pretty (VcProjectionPol []) = "Value criterion with empty projection."
  pretty (VcProjectionPol proj) =
    let
      -- pProj (i, 1) = "x" <> pretty i
      pProj (i, a)
        | a == SMT.const "0" = ""
        | a == SMT.const "1" = "x" <> pretty (succ i)
        | otherwise = pretty (showsSExpr a "") <+> "*" <+> "x" <> pretty (succ i)
      pArgs 0 = ""
      pArgs n = encloseSep "(" ")" ", " ["x" <> pretty i | i <- [1 .. n]]
      pProjections = encloseSep "" "" " + " . map pProj . filter (\(_, a) -> a /= SMT.const "0")
      pConst c
        | c == SMT.const "0" = ""
        | otherwise = " +" <+> pretty (showsSExpr c "")
    in
      vsep
        [ "Value criterion after"
            <+> pretty (length proj)
            <+> "iteration(s) with the linear poly projections:"
        , ( indent 2 $
              vsep $
                zipWith
                  ( \i (rules, projs) ->
                      vsep
                        [ "Iteration"
                            <+> pretty i
                            <> ":"
                        , indent 2 "projection(s):"
                        , indent 4 $
                            vsep
                              ( [ "v(" <> pretty f <> pArgs (getArity f) <> ")" <+> "=" <+> pProjections params <> pConst const
                                | (f, (params, const)) <- projs
                                ]
                              )
                        , indent 2 "removing the rule(s):"
                        , indent 4 $ listSet (map prettyRule rules)
                        ]
                  )
                  [1 :: Int ..]
                  proj
          )
        ]
  pretty (ScProjection dpp []) =
    vsep
      [ "Subterm criterion on the DP problem"
      , indent 2 $ pretty dpp
      , "with empty projection."
      ]
  pretty (ScProjection dpp proj) =
    vsep
      [ "Subterm criterion on the DP problem"
      , indent 2 $ pretty dpp
      , "after" <+> pretty (length proj) <+> "iteration(s) with the projections:"
      , ( indent 2 $
            vsep $
              zipWith
                ( \i (rules, projs) ->
                    vsep
                      [ "Iteration"
                          <+> pretty i
                          <> ":"
                      , indent 2 "projection(s):"
                      , indent 4 $
                          encloseSep
                            ""
                            ""
                            ", "
                            (map (\(f, param) -> "v(" <> pretty f <> ")" <+> "=" <+> pretty (succ param)) projs)
                      , indent 2 "removing the rule(s):"
                      , indent 4 $ listSet (map prettyRule rules)
                      ]
                )
                [1 :: Int ..]
                proj
        )
      ]
  pretty NoDependencyPairs = "No Dependency Pairs."
  pretty (DpGraph dpp (dpgraph, sccs)) =
    vsep
      [ "DPGraph approximation on the DP problem"
      , indent 2 $ pretty dpp
      , "resulting in the DP graph"
      , indent 2 $ pretty dpgraph
      , "with" <+> pretty (length sccs) <+> "SCC(s)"
      , indent 2 $ vsep $ map pretty sccs
      ]
  pretty (PolyInterpretation m) =
    vsep
      [ "Polynomial Interpretations:"
      , indent 2 $
          vsep
            [ pretty f <> args <> ":" <+> poly <+> "+" <+> pretty c
            | (f, (cs, c)) <- M.toList m
            , let ics = zip [1 :: Int ..] cs
            , let vars = map (("x" <>) . pretty . fst) ics
            , let args = encloseSep "(" ")" ", " vars
            , let poly = encloseSep "" "" " + " $ zipWith (\c' v -> pretty c' <+> "*" <+> v) cs vars
            ]
      ]
  pretty (MatrixInterpretation dpproblem (ms, ruleVals)) =
    vsep
      [ "Matrix Interpretations:"
      , indent 2 $
          "Non-Value Symbols:"
            <> line
            <> vsep
              [ pretty resultString <> line
              | (f, (cs, c)) <- M.toList ms
              , let ics = zip [1 :: Int ..] cs
              , let vars = map (("x" <>) . pretty . fst) ics
              , let args
                      | null vars = ""
                      | otherwise = encloseSep "(" ")" ", " vars
              , let funPrefix = show . wrapInter $ pretty f <> args
              , let poly = prettyMPoly (zip vars cs) c
              , let resultString = boxString funPrefix `leftOf` poly
              ]
      , indent 2 $
          "Value Symbols:"
            <> line
            <> vsep
              [ pretty resultString <> line
              | (r, c) <- M.toList ruleVals
              , let resultString =
                      boxString (show $ "values satisfying guard of " <> prettyRule r <> " as ") `leftOf` prettyMatrix c
              ]
      , indent 2 $
          "Strict Rule Orientations:"
            <> line
            <> vsep
              ( map
                  (\rule -> (<> line) $ prettyRuleInters ms (fromJust $ M.lookup rule ruleVals) True rule)
                  (strictrules dpproblem)
              )
      , indent 2 $
          "Weak Rule Orientations:"
            <> line
            <> vsep
              ( map
                  (\rule -> (<> line) $ prettyRuleInters ms (fromJust $ M.lookup rule ruleVals) False rule)
                  (weakrules dpproblem)
              )
      ]
  pretty (ReductionPair sninfos) =
    vsep
      [ "Reduction Pair with dependency pairs:"
      , indent 2 $
          vsep $
            map
              ( \(dpp, sninfo) ->
                  indent 2 $
                    vsep [pretty dpp, pretty sninfo]
              )
              sninfos
      ]
  pretty None = "No termination info given."

prettyGuardVerbose :: (Pretty f, Pretty v) => Guard f v -> Doc ann
prettyGuardVerbose g
  | isTopGuard g = "true"
  | otherwise = prettyGuard g

snresultToAnswer :: (SNResult, SNInfoList f v) -> C.Answer
snresultToAnswer (Terminating, _) = C.Yes
snresultToAnswer (NonTerminating, _) = C.No
snresultToAnswer (MaybeTerminating, _) = C.Maybe

-- printSnInfo ::(SNResult, SNInfo f) -> String

cutoffZeroSExpr :: (Sorted a, ToSExpr a) => a -> Maybe SExpr
cutoffZeroSExpr expr = do
  (gt, _, zero) <- getCompAndZeroOfSort (sort expr)
  return $ SMT.ite (gt sexpr zero) sexpr zero
 where
  sexpr = toSExpr expr

prettyMPoly :: (Foldable t, Pretty a, Show b) => t (b, Matrix a) -> Matrix a -> Box
prettyMPoly coeffVars constant =
  foldr
    ( \(v, c) p ->
        let
          matrixBox = prettyMatrix c
          op = (" * " ++ show v ++ " + ")
        in
          matrixBox `leftOf` (boxString op `leftOf` p)
    )
    (prettyMatrix constant)
    coeffVars

prettyMatrix :: (Pretty a) => Matrix a -> Box
prettyMatrix matrix =
  let formatedEntries =
        [unwords [show $ box widest 1 [entry] | entry <- row] | row <- rows]
  in  box width (DM.nrows matrix) $ map surround formatedEntries
 where
  rows = map (map (show . pretty)) $ DM.toLists matrix
  widest = maximum $ concatMap (map length) rows

  width = DM.ncols matrix * widest + DM.ncols matrix - 1 + 4 * DM.nrows matrix

  surround s = fmt $ "│ " +| s |+ " │"

prettyRuleInters
  :: (Ord f, Ord v, Pretty v, Pretty f)
  => M.Map f ([Matrix Integer], Matrix Integer)
  -> Matrix Integer
  -> Bool
  -> Rule f v
  -> Doc ann
prettyRuleInters msM m isStrict rule =
  if correctDecrease lhsM rhsM firstOp
    then
      pretty $
        boxString (show . wrapInter $ R.prettyRule rule)
          `leftOf` prettyMPoly (map (first pretty) $ M.toList lhsV) lhsC
          `leftOf` (if isStrict then boxString " > " else boxString " >= ")
          `leftOf` prettyMPoly (map (first pretty) $ M.toList rhsV) rhsC
    else error "Termination.hs: Matrix Interpretation is malformed!"
 where
  lhsM@(lhsV, lhsC) = interpretTerm isLVar (msM, m) (R.lhs rule)
  rhsM@(rhsV, rhsC) = interpretTerm isLVar (msM, m) (R.rhs rule)

  firstOp = if isStrict then (>) else (>=)

  isLVar v = v `member` R.lvar rule

interpretTerm
  :: (Ord f, Ord v, Num a)
  => (v -> Bool)
  -> (M.Map f ([Matrix a], Matrix a), Matrix a)
  -> T.Term f v
  -> (M.Map v (Matrix a), Matrix a)
interpretTerm isLVar (_, m) (T.Var v) | isLVar v = (M.empty, m)
interpretTerm _ (_, m) (T.Var v) =
  let
    zM = DM.fromLists $ replicate d [0]
    d = DM.nrows m
    iM = DM.fromLists [[if i == j then 1 else 0 | j <- [1 .. d]] | i <- [1 .. d]]
  in
    (M.singleton v iM, zM)
interpretTerm _ (_, m) (T.Val _) = (M.empty, m)
interpretTerm isLVar (msM, m) (T.Fun _ f args) =
  case M.lookup f msM of
    Just (ms, c) ->
      let multipliedWithCoeffs = zipWith mulCoeff ms interpretedArguments
      in  bimap (M.unionsWith (+)) (foldr (+) c) $ unzip multipliedWithCoeffs
    Nothing -> error "Termination.hs: Interpretation missing."
 where
  interpretedArguments = map (interpretTerm isLVar (msM, m)) args
  mulCoeff c (vm, mc) = (fmap (c *) vm, c * mc)

correctDecrease
  :: (Ord v, Ord a, Ord a)
  => (M.Map v (Matrix a), Matrix a)
  -> (M.Map v (Matrix a), Matrix a)
  -> (a -> a -> Bool)
  -> Bool
correctDecrease (ms1, m1) (ms2, m2) firstOp = performCheck m1 m2 firstOp && all (uncurry go) (M.toList ms2)
 where
  go v m = case M.lookup v ms1 of
    Nothing -> False
    Just mL -> performCheck mL m (>=)

performCheck :: (Ord a) => Matrix a -> Matrix a -> (a -> a -> Bool) -> Bool
performCheck m1 m2 firstOp = go (DM.toList m1) (DM.toList m2)
 where
  go [] [] = True
  go _ [] = False
  go [] _ = False
  go (e1 : m1) (e2 : m2) = firstOp e1 e2 && and (zipWith (>=) m1 m2) && length m1 == length m2

wrapInter :: (Semigroup a, IsString a) => a -> a
wrapInter s = "⟦" <> s <> "⟧: "
