{- |

Module      : RedundantRules
Description :
Copyright   : (c) Jonas Schöpf, 2025
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : experimental


This module provides the main functionality for the redundant rules technique
as well as different approximations. The approximations are taken from
[10.4230/LIPIcs.RTA.2015.257](https://doi.org/10.4230/LIPIcs.RTA.2015.257) and
lifted into the setting of LCTRSs.
-}
module Analysis.Confluence.Transformation.RedundantRules (
  -- * Redundant Rules - Basic Checks
  isRedundantIn,
  isReachable,
  crByRedundantRules,

  -- * Redundant Rules - Approximations
  addCPJoiningRules,
  addRewrittenRhss,
  removeJoinableRules,
)
where

import Analysis.Confluence.Confluence (CRResult (..), findTrivial, furtherStepsCRSeq)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.LCTRS.FIdentifier (FId (CEquation))
import Data.LCTRS.Guard (Guard, conjGuards, mapGuardM)
import Data.LCTRS.LCTRS (
  CRSeq (CRSeq),
  LCTRS (..),
  StepAnn (RefTraStep),
  prettyLCTRSDefault,
  updateRules,
 )
import Data.LCTRS.Position (position)
import Data.LCTRS.Rule (Rule (..), createRule, prettyRuleFId, renameFresh)
import Data.LCTRS.Term (Term (Fun), cEq)
import Data.LCTRS.VIdentifier (VId, freshV)
import Data.List (delete, find, (\\))
import Data.Maybe (catMaybes, isJust)
import Data.Monad (StateM)
import Data.SExpr (ToSExpr)
import Prettyprinter (Pretty)
import qualified Prettyprinter as PP
import Rewriting.ConstrainedRewriting.ConstrainedRewriting (trivialConstrainedEq, trivialVarEqsRule)
import Rewriting.ConstrainedRewriting.MultiStepRewriting (multistepN, multistepNSeq)
import Rewriting.CriticalPair (CriticalPair (..), computeCPs)
import Rewriting.Renaming (removeRenamingSafely)
import Utils (findM, shuffle)

----------------------------------------------------------------------------------------------------
-- redundant rules basics
----------------------------------------------------------------------------------------------------

{- | 'isRedundantIn' @rule@ @lctrs@ checks that the @rule@ is redundant for the
LCTRS @lctrs@.  For @rule@ \(= \ell \to r~[\varphi]\) we call 'isReachable'
(@lctrs@ \(\setminus\) @rule@) \(\ell\) \(r\) \(\varphi\).
-}
isRedundantIn
  :: (Ord v, ToSExpr f, Pretty f, Ord f, ToSExpr v, Pretty v)
  => Rule (FId f) (VId v)
  -> LCTRS (FId f) (VId v)
  -> StateM Bool
isRedundantIn rule@Rule{..} lctrs@LCTRS{..} =
  isReachable lctrs' lhs rhs guard
 where
  lctrs' = updateRules lctrs (rule `delete` rules)

{- | 'isReachable' @lctrs@ @start@ @goal@ @constraint@ checks that the @goal@ is
reachable from @start@ under the constraint @constraint@ using the LCTRS
@lctrs@. We try to rewrite the constrained equation @start@ \(\approx\) @goal@
[@constraint@] to a trivial equation using arbitrary many rewrite steps.
-}
isReachable
  :: (Ord v, ToSExpr f, Pretty f, Ord f, ToSExpr v, Pretty v)
  => LCTRS (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> Guard (FId f) (VId v)
  -> StateM Bool
isReachable lctrs start goal constraint = do
  reductions <-
    multistepN howFar (Just $ position [0]) renameFresh freshV lctrs (cEq start goal) constraint
  anythingFound <- findM trivialConstrainedEq reductions
  case anythingFound of
    Nothing -> return False
    Just _ -> return True
 where
  howFar = 2

{- | 'crByRedundantRules' @rdMethod@ @crMethod@ @lctrs@ @cps@ applies the redundant rules method
@rdMethod@ to the @lctrs@, computes new constrained critical pairs and tries to show confluence
using the confluence method @crMethod@.
-}
crByRedundantRules
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v)
  => (LCTRS (FId f) (VId v) -> StateM (LCTRS (FId f) (VId v), String))
  -> (LCTRS (FId f) (VId v) -> [CriticalPair (FId f) (VId v) (VId v)] -> StateM CRResult)
  -> LCTRS (FId f) (VId v)
  -> [CriticalPair (FId f) (VId v) (VId v)]
  -> StateM CRResult
crByRedundantRules redundantRulesMethod crMethod lctrs _ = do
  (lctrs', proof) <- redundantRulesMethod lctrs
  cps <- computeCPs lctrs'
  crResult <- crMethod lctrs cps
  case crResult of
    Confluent (m, s) -> return $ Confluent (m, proof <> s)
    NonConfluent (_, _) -> undefined
    MaybeConfluent -> return MaybeConfluent

----------------------------------------------------------------------------------------------------
-- Transformations based on redundant rules
----------------------------------------------------------------------------------------------------

{- | 'addCPJoiningRules' @lctrs@ computes all constrained critical pairs CCP of @lctrs@.
For each joinable CCP \(s \approx t~[\varphi] \to^{*} s' \approx t'~[\psi]\) where
 \(s' \approx t'~[\psi]\) is trivial we add the rules \(s \to s'~[\varphi]\)
and \(t \to t'~[\varphi]\) to the @lctrs@.
-}
addCPJoiningRules
  :: (Ord v, ToSExpr f, Pretty f, Ord f, ToSExpr v, Pretty v)
  => LCTRS (FId f) (VId v)
  -> StateM (LCTRS (FId f) (VId v), String)
addCPJoiningRules lctrs =
  do
    cps <- computeCPs lctrs
    additionalRules <- concat <$> mapM (joinableCPToRule lctrs) cps
    -- updateRules lctrs . (rules lctrs <>) <$> filterM (isRedundantIn lctrs) additionalRules
    let lctrs' = updateRules lctrs $ rules lctrs <> additionalRules
    let proof =
          "Redundant Rules:"
            PP.<> PP.line
            PP.<> PP.indent
              2
              ( "Create rules from joinable critical pairs:"
                  PP.<> PP.line
                  PP.<> PP.indent 2 (PP.vsep $ map prettyRuleFId additionalRules)
                  PP.<> PP.line
                  PP.<> "We obtain the following LCTRS:"
                  PP.<> PP.line
                  PP.<> PP.indent 2 (prettyLCTRSDefault lctrs')
              )
            PP.<> PP.line
    return (lctrs', show proof)

{- | 'joinableCPtoRule' @lctrs@ @ccp@ checks that the @ccp@ is joinable, i.e.,
if @ccp@ \(= s \approx t~[\varphi]\) such that \(s \approx t~[\varphi] \to^{*}
s' \approx t'~[\psi]\) where \(s' \approx t'~[\psi]\) is trivial. Is it is trivial
the we return [\(s \to s'~[\varphi]\), \(t \to t'~[\varphi]\)].
-}
joinableCPToRule
  :: (Ord v, Ord f, Pretty f, Pretty v, ToSExpr f, ToSExpr v)
  => LCTRS (FId f) (VId v)
  -> CriticalPair (FId f) (VId v) (VId v)
  -> StateM [Rule (FId f) (VId v)]
joinableCPToRule lctrs CriticalPair{..} = do
  i <- removeRenamingSafely inner
  o <- removeRenamingSafely outer
  c <- mapGuardM removeRenamingSafely constraint
  joiningSequence <- isLocallyConfluent' i o c
  case joiningSequence of
    Nothing -> return []
    Just (CRSeq ((term, _), _)) ->
      case term of
        (Fun _ (CEquation _) [l, r]) -> do
          return [createRule i l c, createRule o r c]
        _ -> return [] -- I don't want to fail here
 where
  howFar = 2
  isLocallyConfluent' i o c = do
    redsSN <- multistepNSeq howFar (Just $ position [0]) renameFresh freshV lctrs (cEq i o, c)
    redsTN <-
      concat
        <$> mapM
          ( furtherStepsCRSeq
              RefTraStep
              (curry $ multistepNSeq howFar (Just $ position [1]) renameFresh freshV lctrs)
          )
          redsSN
    findTrivial redsTN

{- | 'addRewrittenRhss' @lctrs@ rewrites all right-hand sides of its constrained
rewrite rules and constructs new rules if possible. For each rule \(\ell \to
r~[\varphi]\) we rewrite \(r~[\varphi] \to^{*} r'~[\varphi']\) and if \(r \neq r'\)
then we add \(\ell \to r'~[\varphi]\) to @lctrs@.
-}
addRewrittenRhss
  :: (Ord v, ToSExpr f, Pretty f, Ord f, ToSExpr v, Pretty v)
  => LCTRS (FId f) (VId v)
  -> StateM (LCTRS (FId f) (VId v), String)
addRewrittenRhss lctrs = do
  additionalRules <- catMaybes <$> mapM rewriteRhs (rules lctrs)
  let lctrs' = updateRules lctrs $ rules lctrs <> additionalRules
  let proof =
        "Redundant Rules:"
          PP.<> PP.line
          PP.<> PP.indent
            2
            ( "Create rules by rewritting right-hand sides of rules:"
                PP.<> PP.line
                PP.<> PP.indent 2 (PP.vsep $ map prettyRuleFId additionalRules)
                PP.<> PP.line
                PP.<> "We obtain the following LCTRS:"
                PP.<> PP.line
                PP.<> PP.indent 2 (prettyLCTRSDefault lctrs')
            )
          PP.<> PP.line
  return (lctrs', show proof)
 where
  howFar = 2

  rewriteRhs rule@Rule{..} = do
    -- NOTE: we add all extra variables x as x = x to the constraint
    let guardWithEC = conjGuards guard (trivialVarEqsRule rule)
    rhss <- multistepN howFar Nothing renameFresh freshV lctrs rhs guardWithEC
    -- NOTE: add all rules (need to change the type of this function)
    -- return $ (\(newRhs,_) -> createRule lhs newRhs guard) <$> filter  ((rhs /=) . fst) rhss
    -- NOTE: take first one if possible
    return $ (\(newRhs, _) -> createRule lhs newRhs guard) <$> find ((rhs /=) . fst) rhss

{- | 'removeJoinableRules' @lctrs@ removes all rules \(\ell \to r~[\varphi]\) that are
joinable, i.e., \(\ell \approx r~[\varphi] \to^{*} \ell' \approx r'~[\psi]\)
for \(\ell' \approx r'~[\psi]\) being trivial.
-}
removeJoinableRules
  :: (Ord v, ToSExpr f, Pretty f, Ord f, ToSExpr v, Pretty v)
  => LCTRS (FId f) (VId v)
  -> StateM (LCTRS (FId f) (VId v), String)
removeJoinableRules lctrs = do
  shuffledRules <- liftIO $ shuffle $ rules lctrs
  filteredRules <- go lctrs shuffledRules
  let lctrs' = updateRules lctrs filteredRules
  let proof =
        "Redundant Rules:"
          PP.<> PP.line
          PP.<> PP.indent
            2
            ( "Remove Rules where the right-hand side and left-hand side are joinable:"
                PP.<> PP.line
                PP.<> PP.indent 2 (PP.vsep $ map prettyRuleFId (rules lctrs \\ filteredRules))
                PP.<> PP.line
                PP.<> "We obtain the following LCTRS:"
                PP.<> PP.line
                PP.<> PP.indent 2 (prettyLCTRSDefault lctrs')
            )
          PP.<> PP.line
  return (lctrs', show proof)
 where
  go lctrs [] = return $ rules lctrs
  go lctrs (rule : remaining) = do
    let removedLCTRS = updateRules lctrs (delete rule (rules lctrs))
    isJoinable <- isJoinableRule removedLCTRS rule
    if isJoinable
      then
        go removedLCTRS remaining
      else go lctrs remaining

{- | 'isJoinableRule' @lctrs@ @rule@ checks that a rule \(\ell \to r~[\varphi]\) is joinable
joinable, i.e., \(\ell \approx r~[\varphi] \to^{*} \ell' \approx r'~[\psi]\)
for \(\ell' \approx r'~[\psi]\) being trivial.
-}
isJoinableRule
  :: (Ord v, Ord f, Pretty f, Pretty v, ToSExpr f, ToSExpr v)
  => LCTRS (FId f) (VId v)
  -> Rule (FId f) (VId v)
  -> StateM Bool
isJoinableRule lctrs rule@Rule{..} = do
  -- NOTE: we add all extra variables x as x = x to the constraint
  let guardWithEC = conjGuards guard (trivialVarEqsRule rule)
  redsSN <-
    multistepNSeq howFar (Just $ position [0]) renameFresh freshV lctrs (cEq lhs rhs, guardWithEC)
  redsTN <-
    concat
      <$> mapM
        ( furtherStepsCRSeq
            RefTraStep
            (curry $ multistepNSeq howFar (Just $ position [1]) renameFresh freshV lctrs)
        )
        redsSN
  isJust <$> findTrivial redsTN
 where
  howFar = 2
