{- |
Module      : TwoParallelClosedness
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides the functionality to check a given LCTRS for
quasi-commuting parallel closedness.
-}
module Analysis.Confluence.TwoParallelClosedness (
  isTwoParallelClosed,
  isTwoParallelClosedCP,
  varsUnderPoss,
)
where

import Analysis.Confluence.Confluence
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.Guard (Guard, varsGuard)
import Data.LCTRS.LCTRS
import qualified Data.LCTRS.LCTRS as L
import Data.LCTRS.Position (Pos, position)
import Data.LCTRS.Rule (Rule)
import Data.LCTRS.Sort (Sort)
import Data.LCTRS.Term
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (VId, freshV)
import Data.Maybe (
  catMaybes,
  isJust,
 )
import Data.Monad
import Data.SExpr
import qualified Data.Set as S
import Prettyprinter
import Rewriting.ConstrainedRewriting.ParallelStepRewriting (
  parallelRewriteStepNSeq,
  parallelRewriteStepRetPoss,
 )
import Rewriting.ParallelCriticalPair (
  ParallelCriticalPair (..),
  renameFreshWithRenaming,
 )
import Rewriting.Renaming (
  Renaming,
  innerRenaming,
 )

isTwoParallelClosed
  :: (Ord v, Ord f, Pretty f, Pretty v, ToSExpr f, ToSExpr v)
  => LCTRS (FId f) (VId v)
  -> [ParallelCriticalPair (FId f) (VId v) (VId v)]
  -> StateM
      (Maybe [(ParallelCriticalPair (FId f) (VId v) (VId v), CRSeq (FId f) (Renaming (VId v) (VId v)))])
isTwoParallelClosed _ [] = return Nothing
isTwoParallelClosed lc cps = do
  closedness <-
    mapM
      (\cp -> isTwoParallelClosedCP lc 1 cp `alt` isTwoParallelClosedCP lc heuristic cp)
      cps
  if all isJust closedness
    then return . Just . zip cps $ catMaybes closedness
    else return Nothing
 where
  heuristic = 5
  alt c1 c2 = do
    v <- c1
    case v of
      Nothing -> c2
      jseq -> return jseq

isTwoParallelClosedCP
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v)
  => LCTRS (FId f) (VId v)
  -> Int
  -> ParallelCriticalPair (FId f) (VId v) (VId v)
  -> StateM (Maybe (CRSeq (FId f) (Renaming (VId v) (VId v))))
isTwoParallelClosedCP lc heuristic ParallelCriticalPair{..} = do
  -- NOTE: the order of the rewrite steps on s and t is important for printing the variable condition
  --       in the output
  redsS <- parallelRewriteStepNSeq heuristic (Just $ position [0]) ren fV lcR (cEq s t, c)
  redsT <-
    concat
      <$> mapM
        ( furtherStepsCRSeq
            RefTraStep
            -- (rewriteNSeq 1 (Just $ position [1]) vS ren fV lcR)
            -- (\t c -> parallelRewriteStepNSeq 1 (Just $ position [1]) vS fV lcR (t, c))
            (performStepRightToLeft lcR topVars ren fV)
        )
        redsS
  findTrivial redsT
 where
  ren = renameFreshWithRenaming

  fV s i = innerRenaming $ freshV s i

  poss = S.fromList $ map fst innerRules
  constraintVars = S.fromList $ varsGuard constraint
  topVars = varsUnderPoss (`S.notMember` constraintVars) top poss

  s = inner
  t = outer
  c = constraint

  lcR = L.rename innerRenaming lc

performStepRightToLeft
  :: (Ord v, Ord f, Pretty v, ToSExpr v, ToSExpr f, Pretty f)
  => LCTRS (FId f) (Renaming (VId v) (VId v))
  -> S.Set (Renaming (VId v) (VId v))
  -> (Rule (FId f) (Renaming (VId v) (VId v)) -> StateM (Rule (FId f) (Renaming (VId v) (VId v))))
  -> (Sort -> Int -> Renaming (VId v) (VId v))
  -> Term (FId f) (Renaming (VId v) (VId v))
  -> Guard (FId f) (Renaming (VId v) (VId v))
  -> StateM [CRSeq (FId f) (Renaming (VId v) (VId v))]
performStepRightToLeft lc varsTop ren fV term constraint = do
  possParSteps <- parallelRewriteStepRetPoss lc ren fV (Just $ position [1]) (term, constraint)
  let xs =
        [ CRSeq ((t, c), [(ParStep (Just poss), (term, constraint))])
        | ((t, c), poss) <- possParSteps
        , let varsC = S.fromList $ varsGuard c
        , let varsBot = varsUnderPoss (`S.notMember` varsC) t poss
        , varsBot `S.isSubsetOf` varsTop
        ]
  return xs

varsUnderPoss :: (Ord v, Ord f) => (v -> Bool) -> Term f v -> S.Set Pos -> S.Set v
varsUnderPoss isNonLogical term poss =
  S.filter isNonLogical
    $ S.foldr
      ( \mt set ->
          case mt of
            Nothing -> error "TwoParallelClosedness.hs: no term at parallel position found."
            Just t -> set <> S.fromList (vars t)
      )
      S.empty
    $ S.map (term `T.subtermAt`) poss
