{- |
Module      : ParallelCriticalPair
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides an implementation for constrained parallel critical pairs
and how to compute them.
-}
module Rewriting.ParallelCriticalPair where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Control.Monad (filterM, guard, zipWithM)
import Control.Monad.State (MonadTrans (lift), evalStateT, gets)
import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.FIdentifier (FId, getFIdInSort)
import Data.LCTRS.Guard (Guard, conjGuards, isTopGuard, mapGuard, prettyGuard)
import Data.LCTRS.LCTRS (LCTRS (..), getTheoryFuns)
import Data.LCTRS.Position (Pos, epsilon, parallelTo, position)
import Data.LCTRS.Rule (Rule, lhs, rename, renameFresh, rhs, vars)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort (Sorted (sort))
import Data.LCTRS.Term (Term, fPos)
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (VId, freshV)
import Data.List (intersect)
import qualified Data.Map.Strict as M
import Data.Maybe (isJust, mapMaybe)
import Data.Monad (MonadFresh, StateM, freshInt, sError)
import Data.SExpr (ToSExpr)
import qualified Data.Set as S
import Prettyprinter (Doc, Pretty (pretty), indent, line, nest, vsep, (<+>))
import Rewriting.ConstrainedRewriting.ConstrainedRewriting (trivialVarEqsRule)
import Rewriting.CriticalPair (calcRuleFromSymbol)
import Rewriting.Renaming (Renaming (unpackRenaming), innerRenaming, outerRenaming)
import Rewriting.SMT (satGuard, smtResultCheck)
import Rewriting.Substitution (Subst)
import qualified Rewriting.Substitution as Sub
import Rewriting.Unification (
  UnifPTuple (UnifPTuple),
  UnifProblem (UnifProblem),
  solveUnificationProblem,
 )

----------------------------------------------------------------------------------------------------
-- types
----------------------------------------------------------------------------------------------------

data ParallelCriticalPair f v v' = ParallelCriticalPair
  { innerRules :: [(Pos, Rule f (Renaming v v'))]
  , outerRule :: Rule f (Renaming v v')
  , top :: Term f (Renaming v v')
  , inner :: Term f (Renaming v v')
  , outer :: Term f (Renaming v v')
  , constraint :: Guard f (Renaming v v')
  , subst :: Subst f (Renaming v v')
  }

parallelCriticalPair
  :: [(Pos, Rule f (Renaming v v'))]
  -> Rule f (Renaming v v')
  -> Term f (Renaming v v')
  -> Term f (Renaming v v')
  -> Term f (Renaming v v')
  -> Guard f (Renaming v v')
  -> Subst f (Renaming v v')
  -> ParallelCriticalPair f v v'
parallelCriticalPair = ParallelCriticalPair

data ParRedex f v = ParRedex
  { subTerm :: Term f v
  , matchingRule :: Rule f v
  , globalPos :: Pos
  }

newtype PossParallelSteps f v = PossParallelSteps
  { possParallelSteps :: [ParRedex f v]
  }

----------------------------------------------------------------------------------------------------
-- main functionality
----------------------------------------------------------------------------------------------------

computeParallelCPs
  :: (Ord v, Ord f, Pretty v, ToSExpr f, Pretty f)
  => LCTRS (FId f) (VId v) -- [Rule (FId f) (VId v)]
  -> StateM [ParallelCriticalPair (FId f) (VId v) (VId v)]
computeParallelCPs lctrs@LCTRS{..} = do
  calcRules <- createCalcs lctrs
  let allRules = rules ++ calcRules
  concat <$> mapM (computePCPsRule allRules) rules

-- create calculation rules for a given LCTRS
-- TODO remove this and do same as for CPs
createCalcs
  :: LCTRS (FId f) (VId v)
  -> StateM [Rule (FId f) (VId v)]
createCalcs lctrs = mapM calcRuleFromSymbol $ filter (not . null . getFIdInSort) syms
 where
  syms = S.toList $ getTheoryFuns lctrs

computePCPsRule
  :: (Ord v, Ord f, Pretty v, ToSExpr f, Pretty f)
  => [Rule (FId f) (VId v)]
  -> Rule (FId f) (VId v)
  -> StateM [ParallelCriticalPair (FId f) (VId v) (VId v)]
computePCPsRule rules rule = do
  rules' <- mapM (fmap (rename innerRenaming) . renameFresh) rules
  poss <- collectPossibleParallelSteps rules' (lhs rule')
  filterM validPCP $ mapMaybe (extractPCP rule') poss
 where
  rule' = rename outerRenaming rule

extractPCP
  :: (Ord v, Ord f, Pretty f, Pretty v)
  => Rule (FId f) (Renaming (VId v) (VId v))
  -> ([ParRedex (FId f) (Renaming (VId v) (VId v))], Term (FId f) (Renaming (VId v) (VId v)))
  -> Maybe (ParallelCriticalPair (FId f) (VId v) (VId v))
extractPCP rule (parReds, term) = do
  subst <- solveUnificationProblem unifp
  guard (substitutionCheck subst)
  let
    inner =
      Sub.apply subst $
        foldr
          ( \ParRedex{..} term -> case T.replaceAt term globalPos (rhs matchingRule) of
              Nothing -> error "ParallelCriticalPair.hs: subterm replacement not possible."
              Just t -> t
          )
          term
          parReds
    outer = Sub.apply subst $ rhs rule
    constraint =
      mapGuard (Sub.apply subst) $
        foldr
          ( \ParRedex{..} g ->
              conjGuards
                (trivialVarEqsRule matchingRule)
                (conjGuards (R.guard matchingRule) g)
          )
          (conjGuards (R.guard rule) (trivialVarEqsRule rule))
          parReds
  return $
    parallelCriticalPair
      (zip poss rules)
      rule
      term
      inner
      outer
      constraint
      subst
 where
  rules = map matchingRule parReds
  poss = map globalPos parReds
  unifp = parRedToUnifP parReds

  substitutionCheck subst =
    all
      (check . Sub.apply subst . T.var)
      $ R.lvar rule
        `S.union` foldr (S.union . R.lvar) S.empty rules
  check t = T.isVar t || T.isValue t

validPCP
  :: (Eq v, Eq f, Ord v, Ord f, Pretty v, ToSExpr f, Pretty f)
  => ParallelCriticalPair (FId f) (VId v) (VId v)
  -> StateM Bool
validPCP ParallelCriticalPair{..} = do
  b3 <- smtResultCheck =<< satGuard constraint
  b4 <- four
  return $ zero && one && two && b3 && b4
 where
  poss = S.fromList $ map fst innerRules
  -- parallel set has only function symbol positions in outer rule
  zero = poss `S.isSubsetOf` fPos (lhs outerRule)
  -- no rules share any variables
  one =
    let rules = outerRule : map snd innerRules
    in  and [null $ vars r1 `intersect` vars r2 | r1 <- rules, r2 <- rules, r1 /= r2]
  -- we do not have to check this as the substitution is an mgu by construction
  two = True
  -- check that it is not a variant or vars occur only on the rhs
  four
    | poss == S.singleton epsilon =
        case innerRules of
          [(_, rule)] ->
            return $
              not (S.fromList (T.vars $ rhs rule) `S.isSubsetOf` S.fromList (T.vars $ lhs rule))
                || not (rule `R.isVariantOf` outerRule)
          _ -> sError "ParallelCriticalPair.hs: more than one rule in singleton list. How?"
    | otherwise = return True

renameFreshWithRenaming
  :: (Ord v, MonadFresh m)
  => Rule f (Renaming (VId v) (VId v))
  -> m (Rule f (Renaming (VId v) (VId v)))
renameFreshWithRenaming rule = do
  map <-
    M.fromList
      <$> evalStateT
        (mapM (\v -> (v,) <$> fresh v) allVars)
        M.empty
  let specialRenaming v = case M.lookup v map of
        Nothing -> error "ParallelCriticalPair.hs: variable for special renaming in PCP computation not found."
        Just v' -> v'
  return $ rename specialRenaming rule
 where
  allVars = nubOrd $ vars rule

  fresh v =
    case unpackRenaming v of
      (Left v) -> innerRenaming <$> freshVar v
      (Right v) -> outerRenaming <$> freshVar v

  freshVar v = do
    cached <- gets (M.lookup v)
    case cached of
      Just var -> return var
      Nothing -> freshV (sort v) <$> lift freshInt

collectPossibleParallelSteps
  :: (Ord v, Ord f, MonadFresh m, Pretty f, Pretty v)
  => [Rule f (Renaming (VId v) (VId v))]
  -> Term f (Renaming (VId v) (VId v))
  -> m [([ParRedex f (Renaming (VId v) (VId v))], Term f (Renaming (VId v) (VId v)))]
collectPossibleParallelSteps _ (T.Var _) = return []
collectPossibleParallelSteps rules term@(T.Fun _ _ args) = do
  rules' <- mapM renameFreshWithRenaming rules
  let
    matched =
      [ rule
      | rule <- rules'
      , let uniproblem = UnifProblem [UnifPTuple (lhs rule, term)]
      , isJust (solveUnificationProblem uniproblem)
      ]
    matchedParReds = [([ParRedex term matchedRule epsilon], term) | matchedRule <- matched]
  (matchedParReds ++) <$> possCombs rules'
 where
  possCombs rs = combineRecArgs <$> processedArgs rs
  processedArgs rs =
    zipWithM
      ( \i arg -> do
          compArgs <- collectPossibleParallelSteps rs arg
          let pos = position [i] -- current position
          return
            [ ([adaptPositions pos parRedex | parRedex <- matches], term)
            | (matches, _) <- compArgs
            ]
      )
      [0 ..]
      args

  adaptPositions i ParRedex{..} = ParRedex subTerm matchingRule $ i <> globalPos
  combineRecArgs [] = []
  combineRecArgs (arg : furtherArgs) =
    let combinations = combineRecArgs furtherArgs
    in  zipWith combineParRedex combinations arg -- steps below this and other positions
          ++ arg -- only steps below this position
          ++ combinations -- no step in this position

anyParallelRedexes :: [ParRedex f v] -> [ParRedex f v] -> Bool
anyParallelRedexes parRs parRs' = and [globalPos parR `parallelTo` globalPos parR' | parR <- parRs, parR' <- parRs']

-- combination is only possible if at the same term and with all positions in parallel
combineParRedex
  :: (Eq v, Eq f)
  => ([ParRedex f v], Term f v)
  -> ([ParRedex f v], Term f v)
  -> ([ParRedex f v], Term f v)
combineParRedex (parRs, term) (parRs', term')
  | anyParallelRedexes parRs parRs' && term == term' = (parRs ++ parRs', term)
  | otherwise = error "ParallelCriticalPair.hs: Combination of parallel redexes is not possible."

----------------------------------------------------------------------------------------------------
-- unification helper functions
----------------------------------------------------------------------------------------------------

parRedToUnifP :: [ParRedex f v] -> UnifProblem f v
parRedToUnifP = UnifProblem . map parRedToTuple

parRedToTuple :: ParRedex f v -> UnifPTuple f v
parRedToTuple ParRedex{..} = UnifPTuple (lhs matchingRule, subTerm)

----------------------------------------------------------------------------------------------------
-- pretty instance
----------------------------------------------------------------------------------------------------

prettyInnerRule :: (Pretty f, Pretty v) => (Pos, Rule f v) -> Doc ann
prettyInnerRule (pos, rule) = "Pos" <+> pretty pos <> ":" <+> R.prettyRule rule

prettyParallelCriticalPair :: (Pretty f, Pretty v) => ParallelCriticalPair f v v -> Doc ann
prettyParallelCriticalPair ParallelCriticalPair{..} =
  "Parallel Critical Pair"
    <> line
    <> indent
      2
      ( vsep
          [ nest 2 $ vsep ["Top", R.prettyCTerm (top, constraint)]
          , nest 2 $ vsep ["Inner Rule(s)", vsep $ map prettyInnerRule innerRules]
          , nest 2 $ vsep ["Outer Rule", R.prettyRule outerRule]
          , nest 2 $ vsep ["Pair", prettyPair]
          ]
      )
 where
  prettyPair
    | isTopGuard constraint = pretty inner <+> "≈" <+> pretty outer
    | otherwise = pretty inner <+> "≈" <+> pretty outer <+> prettyGuard constraint

prettyParallelCriticalPairs :: (Pretty f, Pretty v) => [ParallelCriticalPair f v v] -> Doc ann
prettyParallelCriticalPairs = vsep . map prettyParallelCriticalPair
