{- |
Module      : CriticalPair
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides an implementation for constrained critical pairs and how to
compute them.
-}
module Rewriting.CriticalPair where

import Control.Monad (
  filterM,
  replicateM,
  zipWithM,
 )
import Data.LCTRS.FIdentifier
import Data.LCTRS.Guard
import Data.LCTRS.LCTRS hiding (rename)
import Data.LCTRS.Position (Pos, append, epsilon)
import Data.LCTRS.Rule hiding (guard)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort
import Data.LCTRS.Term
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier
import Data.Monad (
  MonadFresh (freshInt),
  StateM,
 )
import Data.SExpr
import Data.SMT (boolSort)
import qualified Data.Set as S
import Prettyprinter (
  Doc,
  Pretty (pretty),
  indent,
  line,
  nest,
  vsep,
  (<+>),
 )
import Rewriting.ConstrainedRewriting.ConstrainedRewriting (
  equivalentUnderConstraint,
  trivialVarEqsRule,
 )
import Rewriting.Renaming (Renaming, innerRenaming, outerRenaming)
import Rewriting.SMT (
  satGuard,
  smtResultCheck,
 )
import qualified Rewriting.Substitution as MySub
import Rewriting.Unification (
  UnifPTuple (UnifPTuple),
  UnifProblem (UnifProblem),
  solveUnificationProblem,
 )
import Utils (parMap)

----------------------------------------------------------------------------------------------------
-- critical pair definition
----------------------------------------------------------------------------------------------------

data CriticalPair f v v' = CriticalPair
  { innerRule :: Rule f (Renaming v v')
  , outerRule :: Rule f (Renaming v v')
  , top :: Term f (Renaming v v')
  , inner :: Term f (Renaming v v')
  , outer :: Term f (Renaming v v')
  , constraint :: Guard f (Renaming v v')
  , pos :: Pos
  , subst :: MySub.Subst f (Renaming v v')
  }

----------------------------------------------------------------------------------------------------
-- computing the set of critical pairs
----------------------------------------------------------------------------------------------------

computeCPs
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v, Pretty v)
  => LCTRS (FId f) (VId v)
  -> StateM [CriticalPair (FId f) (VId v) (VId v)]
computeCPs l = (++) <$> computeRuleCPsLCTRS l <*> computeCalcCPsLCTRS l

----------------------------------------------------------------------------------------------------
-- computing the set of critical pairs of standard rules
----------------------------------------------------------------------------------------------------

computeRuleCPsLCTRS
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v, Pretty v)
  => LCTRS (FId f) (VId v)
  -> StateM [CriticalPair (FId f) (VId v) (VId v)]
computeRuleCPsLCTRS lctrs = do
  possibleCPs <- sequence $ parMap 10 computeCPsTwoRules rulePairs
  filterM criticalPairProperties $
    concat possibleCPs
 where
  rules = getRules lctrs
  rulePairs = [(rule1, rule2) | rule1 <- rules, rule2 <- rules]

computeCPsTwoRules
  :: (MonadFresh m, Ord v, Ord f)
  => (Rule (FId f) (VId v), Rule (FId f) (VId v))
  -> m [CriticalPair (FId f) (VId v) (VId v)]
computeCPsTwoRules (innerRule, outerRule) =
  computeCPsRulePos epsilon (R.lhs outerRuleRenamed) (innerRuleRenamed, outerRuleRenamed)
 where
  innerRuleRenamed = R.rename innerRenaming innerRule
  outerRuleRenamed = R.rename outerRenaming outerRule

computeCPsRulePos
  :: (MonadFresh m, Ord v, Ord f)
  => Pos
  -> Term (FId f) (Renaming (VId v) (VId v))
  -> (Rule (FId f) (Renaming (VId v) (VId v)), Rule (FId f) (Renaming (VId v) (VId v)))
  -> m [CriticalPair (FId f) (VId v) (VId v)]
computeCPsRulePos _ (Var _) _ = return []
computeCPsRulePos _ (Val _) _ = return []
computeCPsRulePos pos toBeMatched@(Fun _ _ args) rules@(innerRuleRenamed, outerRuleRenamed) = do
  recursiveCPs <-
    concat
      <$> zipWithM (\i arg -> computeCPsRulePos (append pos i) arg rules) [0 ..] args
  case matchAndConstructCP innerRuleRenamed outerRuleRenamed pos toBeMatched of
    Nothing -> return recursiveCPs
    Just cp -> return $ cp : recursiveCPs

matchAndConstructCP
  :: (Ord v, Ord f)
  => Rule (FId f) (Renaming (VId v) (VId v))
  -> Rule (FId f) (Renaming (VId v) (VId v))
  -> Pos
  -> Term (FId f) (Renaming (VId v) (VId v))
  -> Maybe (CriticalPair (FId f) (VId v) (VId v))
matchAndConstructCP innerRule outerRule pos toBeMatched = do
  let unificationProblem = UnifProblem [UnifPTuple (R.lhs innerRule, toBeMatched)]
  sigma <- solveUnificationProblem unificationProblem
  let
    applySubst = MySub.apply sigma
    top = applySubst $ R.lhs outerRule
  inner <- T.replaceAt top pos (applySubst $ R.rhs innerRule)
  return $
    CriticalPair
      innerRule
      outerRule
      top
      inner
      (applySubst $ R.rhs outerRule)
      (mapGuard applySubst $ constraints `conjGuards` trivialEqs)
      pos
      sigma
 where
  trivialEqs = trivialVarEqsRule innerRule `conjGuards` trivialVarEqsRule outerRule
  constraints = R.guard outerRule `conjGuards` R.guard innerRule

----------------------------------------------------------------------------------------------------
-- computing the set of critical pairs of calculation rules
----------------------------------------------------------------------------------------------------

computeCalcCPsLCTRS
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v, Pretty v)
  => LCTRS (FId f) (VId v)
  -> StateM [CriticalPair (FId f) (VId v) (VId v)]
computeCalcCPsLCTRS lctrs = do
  possibleCPs <- sequence $ parMap 10 computeCalcCPsRule (getRules lctrs)
  filterM criticalPairProperties $ concat possibleCPs

computeCalcCPsRule
  :: (MonadFresh m, Ord v, Ord f)
  => Rule (FId f) (VId v)
  -> m [CriticalPair (FId f) (VId v) (VId v)]
computeCalcCPsRule rule = do
  let outerRuleRenamed = R.rename outerRenaming rule
  computeCalcCPsRulePos epsilon (R.lhs outerRuleRenamed) outerRuleRenamed

computeCalcCPsRulePos
  :: (MonadFresh m, Ord v, Ord f)
  => Pos
  -> Term (FId f) (Renaming (VId v) (VId v))
  -> Rule (FId f) (Renaming (VId v) (VId v))
  -> m [CriticalPair (FId f) (VId v) (VId v)]
computeCalcCPsRulePos _ (Var _) _ = return []
computeCalcCPsRulePos _ (Val _) _ = return []
computeCalcCPsRulePos pos (TermFun _ args) outerRuleRenamed =
  concat
    <$> zipWithM (\i arg -> computeCalcCPsRulePos (append pos i) arg outerRuleRenamed) [0 ..] args
computeCalcCPsRulePos pos toBeMatched@(TheoryFun f args) outerRuleRenamed = do
  recursiveCPs <-
    concat
      <$> zipWithM (\i arg -> computeCalcCPsRulePos (append pos i) arg outerRuleRenamed) [0 ..] args
  innerRule <- calcRuleFromSymbol f
  let innerRuleRenamed = R.rename innerRenaming innerRule
  case matchAndConstructCP innerRuleRenamed outerRuleRenamed pos toBeMatched of
    Nothing -> return recursiveCPs
    Just cp -> return $ cp : recursiveCPs

calcRuleFromSymbol :: (MonadFresh m) => FId f -> m (Rule (FId f) (VId v))
calcRuleFromSymbol sym = do
  let (inSorts, outSort) = (getFIdInSort sym, sort sym)
  let arity = length inSorts
  is <- zipWith (\s i -> var $ freshV s i) inSorts <$> replicateM arity freshInt
  rhs <- var . freshV outSort <$> freshInt
  let lhs = T.theoryFun sym is
  let sa = sortAnnotation [sort rhs, sort lhs] boolSort
  return $
    R.createRule
      lhs
      rhs
      (createGuard [T.eq sa rhs lhs])

criticalPairProperties
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v, Pretty v)
  => CriticalPair (FId f) (VId v) (VId v)
  -> StateM Bool
criticalPairProperties (CriticalPair{..})
  | not (respVars subst lvars) = return False
 where
  lvars = R.lvar innerRule `S.union` R.lvar outerRule
criticalPairProperties cp@(CriticalPair{..}) = do
  satM <- satGuard constraint
  sat <- smtResultCheck satM
  return $
    sat
      && ( not (isOverlay cp)
            || not (R.isVariantOf innerRule outerRule)
            || not
              ( S.fromList (T.vars (R.rhs innerRule))
                  `S.isSubsetOf` S.fromList (T.vars (R.lhs innerRule))
              )
         )

-- checks that a given critical pair is trivial
isTrivialCP
  :: (Ord f, Ord v, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => CriticalPair (FId f) v v
  -> StateM Bool
isTrivialCP (CriticalPair{..}) =
  equivalent (s, t, c)
 where
  acceptable phis x = x `elem` varsGuard phis
  equivalent (s, t, phis) = equivalentUnderConstraint s t phis (acceptable phis)

  s = inner
  t = outer
  c = constraint

----------------------------------------------------------------------------------------------------
-- auxiliary functions
----------------------------------------------------------------------------------------------------

isOverlay :: CriticalPair f v v' -> Bool
isOverlay CriticalPair{..} = pos == epsilon

respVars
  :: (Eq f, Ord v)
  => MySub.Subst (FId f) v
  -> S.Set v
  -> Bool
respVars sub = all (check . MySub.apply sub . T.var)
 where
  check t = T.isVar t || T.isValue t

----------------------------------------------------------------------------------------------------
-- pretty printers
----------------------------------------------------------------------------------------------------

prettyCriticalPair :: (Pretty f, Pretty v, Pretty v') => CriticalPair f v v' -> Doc ann
prettyCriticalPair CriticalPair{..} =
  "Critical Pair"
    <> line
    <> indent
      2
      ( vsep
          [ nest 2 $ vsep ["Inner Rule (Position:" <+> pretty pos <> ")", prettyRule innerRule]
          , nest 2 $ vsep ["Outer Rule", prettyRule outerRule]
          , nest 2 $ vsep ["Pair", prettyPair]
          ]
      )
 where
  prettyPair
    | isTopGuard constraint = pretty inner <+> "≈" <+> pretty outer
    | otherwise = pretty inner <+> "≈" <+> pretty outer <+> prettyGuard constraint

prettyCriticalPairs :: (Pretty f, Pretty v, Pretty v') => [CriticalPair f v v'] -> Doc ann
prettyCriticalPairs = vsep . map prettyCriticalPair

prettyCriticalPairFId :: (Pretty f, Pretty v, Pretty v') => CriticalPair (FId f) v v' -> Doc ann
prettyCriticalPairFId CriticalPair{..} =
  "Critical Pair"
    <> line
    <> indent
      2
      ( vsep
          [ nest 2 $ vsep ["Inner Rule", "Pos" <+> pretty pos <> ":" <+> prettyRuleFId innerRule]
          , nest 2 $ vsep ["Outer Rule", prettyRuleFId outerRule]
          , nest 2 $ vsep ["Pair", prettyPair]
          ]
      )
 where
  prettyPair
    | isTopGuard constraint = prettyTermFId inner <+> "≈" <+> prettyTermFId outer
    | otherwise = prettyTermFId inner <+> "≈" <+> prettyTermFId outer <+> prettyGuardFId constraint

prettyCriticalPairsFId :: (Pretty f, Pretty v, Pretty v') => [CriticalPair (FId f) v v'] -> Doc ann
prettyCriticalPairsFId = vsep . map prettyCriticalPairFId
