{- |
Module      : NonConfluence
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides the main functionality to check a given LCTRS for non-confluence
property.
-}
module Analysis.Confluence.NonConfluence where

import Analysis.Confluence.Confluence (
  CRResult (MaybeConfluent, NonConfluent),
  NCRMethod (..),
  furtherStepsCRSeq,
 )
import Control.Monad (filterM)
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.LCTRS (CRSeq (CRSeq), LCTRS, StepAnn (RefTraStep), prettyCRSeqFId)
import qualified Data.LCTRS.LCTRS as L
import Data.LCTRS.Position (position)
import Data.LCTRS.Rule (prettyCTermFId)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Term (cEq)
import Data.LCTRS.VIdentifier (VId, freshV)
import Data.List (find)
import Data.Maybe (isJust)
import Data.Monad (StateM)
import Data.SExpr (ToSExpr (..))
import Prettyprinter (Pretty, indent, vsep, (<+>))
import Rewriting.ConstrainedRewriting.ConstrainedNormalForm (isCNormalFormIn)
import Rewriting.ConstrainedRewriting.ConstrainedRewriting (
  trivialConstrainedEq,
 )
import Rewriting.ConstrainedRewriting.SingleStepRewriting (rewriteN, rewriteNSeq)
import Rewriting.CriticalPair (
  CriticalPair (..),
  isTrivialCP,
  prettyCriticalPair,
 )
import Rewriting.CriticalPairSplitting (tryToCloseOne)
import Rewriting.ParallelCriticalPair (renameFreshWithRenaming)
import Rewriting.Renaming (
  Renaming,
  innerRenaming,
 )
import Utils (findM)

checkNonConfluent
  :: (Ord v, Ord f, Pretty f, Pretty v, ToSExpr f, ToSExpr v)
  => LCTRS (FId f) (VId v)
  -> [CriticalPair (FId f) (VId v) (VId v)]
  -> StateM CRResult
checkNonConfluent lc cps = do
  cps <- filterM (fmap not . isTrivialCP) cps
  case cps of
    [] -> return MaybeConfluent
    _ -> do
      -- normalFormsCP <- zip cps <$> mapM (tryToCloseOne isNormalFormCP lc) cps
      normalFormsCP <- zip cps <$> mapM (tryToCloseOne isNormalFormCPSeq lc) cps
      case find (\(_, may) -> isJust may) normalFormsCP of
        -- Just (cp, Just cterm) -> return $ toResult TwoDifferentNFs (cp, cterm)
        Just (cp, Just crseq) -> return $ toResultCRSeq TwoDifferentNFs (cp, crseq)
        _ -> return MaybeConfluent

toResultCRSeq
  :: (Pretty f, Pretty v)
  => NCRMethod
  -> ( CriticalPair (FId f) (VId v) (VId v)
     , Either
        ( CriticalPair (FId f) (VId v) (VId v)
        , CRSeq
            (FId f)
            (Renaming (VId v) (VId v))
        )
        ( CriticalPair (FId f) (VId v) (VId v)
        , CRSeq
            (FId f)
            (Renaming (VId v) (VId v))
        )
     )
  -> CRResult
toResultCRSeq res (cp, Left (_, reason)) =
  NonConfluent
    ( res
    , show $
        vsep
          [ "*" <+> prettyCriticalPair cp
          , "reaches the non-trivial normal form"
          , indent 2 $ prettyCRSeqFId reason
          ]
    )
toResultCRSeq res (cp, Right (splitcp, reason)) =
  NonConfluent
    ( res
    , show $
        vsep
          [ "*" <+> prettyCriticalPair cp
          , "has the following instance"
          , indent 2 $ prettyCriticalPair splitcp
          , "which reaches the non-trivial normal form"
          , indent 2 $ prettyCRSeqFId reason
          ]
    )

toResult
  :: (Pretty f, Pretty v)
  => NCRMethod
  -> ( CriticalPair (FId f) (VId v) (VId v)
     , ( CriticalPair (FId f) (VId v) (VId v)
       , R.CTerm
          (FId f)
          (Renaming (VId v) (VId v))
       )
     )
  -> CRResult
toResult res (cp, (splitcp, reason)) =
  NonConfluent
    ( res
    , show $
        vsep
          [ "*" <+> prettyCriticalPair cp
          , "is split into the following CCP"
          , indent 2 $ prettyCriticalPair splitcp
          , "which reaches the non-trivial normal form"
          , indent 2 $ prettyCTermFId reason
          ]
    )

isNormalFormCP
  :: (Ord v, Ord f, Pretty f, Pretty v, ToSExpr f, ToSExpr v)
  => LCTRS (FId f) (VId v)
  -> CriticalPair (FId f) (VId v) (VId v)
  -> StateM
      ( Maybe
          ( R.CTerm
              (FId f)
              (Renaming (VId v) (VId v))
          )
      )
isNormalFormCP lc CriticalPair{..} = do
  redsS <- rewriteN heuristic (Just $ position [0]) ren fV lcR (cEq s t) c
  redsT <-
    concat <$> mapM (uncurry $ rewriteN heuristic (Just $ position [1]) ren fV lcR) redsS
  normalforms <- filterM (fmap (== Just True) . isCNormalFormIn fV lcR) redsT
  findNFs <- findM (fmap not . trivialConstrainedEq) normalforms
  case findNFs of
    Nothing -> return Nothing
    Just nfs -> return (Just nfs)
 where
  ren = renameFreshWithRenaming

  fV s i = innerRenaming $ freshV s i

  s = inner
  t = outer
  c = constraint

  heuristic = 5 -- min 10 $ length (rules lc)
  lcR = L.rename innerRenaming lc

isNormalFormCPSeq
  :: (Ord v, Ord f, Pretty f, Pretty v, ToSExpr f, ToSExpr v)
  => LCTRS (FId f) (VId v)
  -> CriticalPair (FId f) (VId v) (VId v)
  -> StateM
      ( Maybe
          ( CRSeq
              (FId f)
              (Renaming (VId v) (VId v))
          )
      )
isNormalFormCPSeq lc CriticalPair{..} = do
  redSeqsS <- rewriteNSeq heuristic (Just $ position [0]) ren fV lcR (cEq s t) c
  redSeqsT <-
    concat
      <$> mapM
        (furtherStepsCRSeq RefTraStep (rewriteNSeq heuristic (Just $ position [1]) ren fV lcR))
        redSeqsS
  -- concat
  --   <$> mapM
  --     ( furtherStepsCRSeq
  --         RefTraStep
  --         (curry $ parallelRewriteStepNSeq heuristic (Just $ position [1]) ren fV lcR)
  --     )
  --     redsS
  normalforms <-
    filterM (\(CRSeq (ct, _)) -> (== Just True) <$> isCNormalFormIn fV lcR ct) redSeqsT
  findNFs <- findM (\(CRSeq (ct, _)) -> not <$> trivialConstrainedEq ct) normalforms
  case findNFs of
    Nothing -> return Nothing
    Just nfs -> return (Just nfs)
 where
  ren = renameFreshWithRenaming

  fV s i = innerRenaming $ freshV s i

  s = inner
  t = outer
  c = constraint

  heuristic = 5 -- min 10 $ length (rules lc)
  lcR = L.rename innerRenaming lc
