{- |
Module      : Strategy
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable

This module provides parsers for the input strategy and
defines all processors.
-}
module Parser.Strategy where

import Analysis.Confluence.Confluence (CRResult)
import Analysis.Confluence.DevelopmentClosedness (
  isAlmostDevelopmentClosed,
  isAlmostDevelopmentClosedH,
  isDevelopmentClosed,
 )
import Analysis.Confluence.NewmansLemma (isCRbyNewmansLemma)
import Analysis.Confluence.NonConfluence (checkNonConfluent)
import Analysis.Confluence.Orthogonality (
  isOrthogonal,
  isWeaklyOrthogonal,
 )
import Analysis.Confluence.ParallelClosedness (
  isAlmostParallelClosed,
  isAlmostParallelClosedH,
  isParallelClosed,
 )
import Analysis.Confluence.StrongClosedness (isStrongClosed)
import Analysis.Confluence.Toyama81 (toyama81Applies)
import Analysis.Confluence.Transformation.RedundantRules (
  addCPJoiningRules,
  addRewrittenRhss,
  crByRedundantRules,
  removeJoinableRules,
 )
import Analysis.Termination.CheckingTermination (
  terminatingByBVMatrixInters,
  terminatingByBVMatrixIntersDP,
  terminatingByConsMatrixInters,
  terminatingByConsMatrixIntersDP,
  terminatingByDependencyGraph,
  terminatingByMatrixInters,
  terminatingByMatrixIntersDP,
  terminatingByRPO,
  terminatingByRPOwithDP,
  terminatingByReductionPair,
  terminatingBySubtermCriterion,
  terminatingByValueCriterion,
  terminatingByValueCriterionPol,
 )
import Analysis.Termination.ConstrainedReductionOrder (rpoRP)
import Analysis.Termination.MatrixInterpretations (matrixInterpretationsRP)
import Analysis.Termination.MatrixInterpretationsBV (matrixInterpretationsBVRP)
import Analysis.Termination.MatrixInterpretationsConstraints (matrixInterpretationsConsRP)
import Analysis.Termination.SubtermCriterion (subtermCriterionRP)
import Analysis.Termination.Termination (SNInfoList, SNResult)
import Analysis.Termination.ValueCriterion (valueCriterionRP)
import Analysis.Termination.ValueCriterionPol (valueCriterionPolRP)
import Control.Applicative (empty, optional)
import Data.Functor (($>))
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.LCTRS (LCTRS)
import Data.LCTRS.VIdentifier (VId)
import Data.List (sortOn)
import Data.Monad (StateM)
import Data.Ord (Down (..))
import Data.SExpr (ToSExpr)
import Data.SMT (lexeme, numeral, whitespace)
import Data.Text (Text, pack)
import Data.Void (Void)
import Prettyprinter (Pretty, encloseSep, pretty, (<+>))
import Rewriting.CriticalPair (CriticalPair)
import Rewriting.ParallelCriticalPair (computeParallelCPs)
import Text.Megaparsec (
  Parsec,
  between,
  chunk,
  eof,
  errorBundlePretty,
  runParser,
  sepBy1,
  try,
  (<|>),
 )
import Text.Megaparsec.Char (char)
import Type.SortLCTRS (DefaultFId, DefaultVId)

----------------------------------------------------------------------------------------------------
-- CONFLUENCE
----------------------------------------------------------------------------------------------------

type CRProcessor f v =
  LCTRS f v
  -> [CriticalPair f v v]
  -> StateM CRResult

processorsCrNcr
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v)
  => [ ( Text
       , -- , CRProcessor (FId f) (VId v)
         Parsec Void Text (CRProcessor (FId f) (VId v))
       , Text
       )
     ]
processorsCrNcr =
  [ ("o", return isOrthogonal, "Orthogonality")
  , ("wo", return isWeaklyOrthogonal, "Weak Orthogonality")
  , ("sc", return isStrongClosed, "Strongly Closedness")
  , ("pc", return isParallelClosed, "Parallel Closedness")
  ,
    ( "apc"
    , return isAlmostParallelClosed
    , "Almost Parallel Closedness (both sequences parallel steps)"
    )
  , ("apch", return isAlmostParallelClosedH, "Almost Parallel Closedness (heuristic)")
  , ("dc", return isDevelopmentClosed, "Development Closedness")
  ,
    ( "adc"
    , return isAlmostDevelopmentClosed
    , "Almost Development Closedness (both sequences multi-steps)"
    )
  , ("adch", return isAlmostDevelopmentClosedH, "Almost Development Closedness (heuristic)")
  ,
    ( "rrjs"
    , redundantRulesAddCPJoiningRules
    , "Redundant Rules by adding joinable CPs as rules (usage rrjs[METHOD] where METHOD is one of: o, wo, sc, apc, apch, adc, adch, kbc, pcp)"
    )
  ,
    ( "rrrhs"
    , redundantRulesAddRewrittenRhss
    , "Redundant Rules by rewriting right-hand side of rules (usage: see rrrhs)"
    )
  ,
    ( "rrdel"
    , redundantRulesRemoveJoinableRules
    , "Redundant Rules by removing joinable rules (usage: see rrjs)"
    )
  , ("dnfs", return checkNonConfluent, "Non Confluence")
  , ("kbc", return isCRbyNewmansLemma, "Knuth-Bendix Criterion")
  , ("pcp", return toyama, "Toyama 81 based on PCPs")
  ]
 where
  toyama lctrs cps = toyama81Applies lctrs cps =<< computeParallelCPs lctrs

  lexchunk = lexeme . chunk

  pCRList = lexeme $ do
    (lexchunk "o" $> isOrthogonal)
      <|> (lexchunk "wo" $> isWeaklyOrthogonal)
      <|> (lexchunk "sc" $> isStrongClosed)
      -- <|> (lexchunk "pc" $> isParallelClosed)
      <|> (lexchunk "apc" $> isAlmostParallelClosed)
      <|> (lexchunk "apch" $> isAlmostParallelClosedH)
      -- <|> (lexchunk "dc" $> isDevelopmentClosed)
      <|> (lexchunk "adc" $> isAlmostDevelopmentClosed)
      <|> (lexchunk "adch" $> isAlmostDevelopmentClosedH)
      -- <|> (lexchunk "dnfs" $> checkNonConfluent)
      <|> (lexchunk "kbc" $> isCRbyNewmansLemma)
      <|> (lexchunk "pcp" $> toyama)

  redundantRulesAddCPJoiningRules = do
    crByRedundantRules addCPJoiningRules <$> between "[" "]" pCRList

  redundantRulesAddRewrittenRhss =
    crByRedundantRules addRewrittenRhss <$> between "[" "]" pCRList

  redundantRulesRemoveJoinableRules =
    crByRedundantRules removeJoinableRules <$> between "[" "]" pCRList

-- | 'parseCRStrategy' parses a given confluence strategy
parseCRStrategy
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v)
  => String
  -> [CRProcessor (FId f) (VId v)]
parseCRStrategy input =
  case runParser (whitespace *> go <* eof) "" (pack input) of
    Left e -> error $ errorBundlePretty e
    Right r -> r
 where
  -- NOTE: longest should be parsed first
  sortedProcessors = sortOn (Down . (\(a, _, _) -> a)) processorsCrNcr

  stratPairs = map (\(k, c, _) -> (k, c)) sortedProcessors

  stratParsers = map (\(text, comp) -> lexeme (chunk text) >> lexeme comp) stratPairs

  combinedParser = foldr ((<|>) . try) empty stratParsers

  go = sepBy1 combinedParser (lexeme $ char ',')

crStrategyExplanation :: String
crStrategyExplanation =
  show $
    "Specify a comma-separated list of methods (if not set, all are used):"
      <+> encloseSep
        "{"
        "}"
        ""
        ["'" <> pretty method <> "'" <+> "=" <+> pretty explanation | (method, explanation) <- methods]
 where
  methods =
    map
      (\(k, _, e) -> (k, e))
      ( processorsCrNcr :: [(Text, Parsec Void Text (CRProcessor (FId DefaultFId) (VId DefaultVId)), Text)]
      )

-- | the number of all implemented confluence methods
numberOfAllCRMethods :: Int
numberOfAllCRMethods =
  length
    ( processorsCrNcr :: [(Text, Parsec Void Text (CRProcessor (FId DefaultFId) (VId DefaultVId)), Text)]
    )

----------------------------------------------------------------------------------------------------
-- TERMINATION
----------------------------------------------------------------------------------------------------

type SNProcessor f v =
  LCTRS f v
  -- -> DPProblem f v
  -> StateM (SNResult, SNInfoList f v)

processorsSN
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v)
  => [ ( Text
       , Parsec Void Text (SNProcessor (FId f) (VId v))
       , Text
       )
     ]
processorsSN =
  [ ("dpg", return terminatingByDependencyGraph, "DP graph with no SCCs")
  , ("crodp", return terminatingByRPOwithDP, "Constrained Reduction Order on DP problem")
  , ("cro", return terminatingByRPO, "Constrained Reduction Order")
  , ("vc", return terminatingByValueCriterion, "Value Criterion")
  ,
    ( "vclin"
    , return terminatingByValueCriterionPol
    , "Value Criterion with linear combination projection"
    )
  , ("sc", return terminatingBySubtermCriterion, "Subterm Criterion")
  ,
    ( "mis"
    , plainMIs terminatingByMatrixInters
    , "Matrix Interpretations given a dimension, e.g. \'mis 2\' and optional upper bound for entries, e.g. \'misp 2 4\' for naturals with less or equal than 4 bits."
    )
  ,
    ( "misdp"
    , plainMIs terminatingByMatrixIntersDP
    , "Matrix Interpretations on DP problem given a dimension, e.g. \'misdp 2\' and optional upper bound for entries, e.g. \'misdp 2 4\' for naturals with less or equal than 4 bits."
    )
  ,
    ( "misbv"
    , plainBVMIs terminatingByBVMatrixInters
    , "Matrix Interpretations given a dimension and number of bits used in bit vector arithmetic; e.g. \'misbv 2 4\' for dimension 2 with 4 bits."
    )
  ,
    ( "misbvdp"
    , plainBVMIs terminatingByBVMatrixIntersDP
    , "Matrix Interpretations on DP problem given a dimension and number of bits used in bit vector arithmetic; e.g. \'misbv 2 4\' for dimension 2 with 4 bits."
    )
  ,
    ( "miscons"
    , plainMIs terminatingByConsMatrixInters
    , "Matrix Interpretations given a dimension; e.g. \'miscons 2\' for dimension 2."
    )
  ,
    ( "misconsdp"
    , plainMIs terminatingByConsMatrixIntersDP
    , "Matrix Interpretations on DP problem given a dimension; e.g. \'misconsdp 2\' for dimension 2."
    )
  ,
    ( "rpair"
    , redPairs
    , "Reduction Pair (using some of \'cro\', \'vc\', \'sc\', \'vclin\', \'mis DIM ?BITS\', \'misbv DIM BITS\', \'miscons DIM\'); e.g.: \'rpair[cro,vclin, misbv 3 15]\'"
    )
  ]
 where
  lexchunk = lexeme . chunk

  plainMIs f = f <$> fmap fromIntegral numeral <*> optional (fmap fromIntegral numeral)
  plainBVMIs f = f <$> fmap fromIntegral numeral <*> fmap fromIntegral numeral

  redPairs = do
    terminatingByReductionPair <$> between "[" "]" pRPList

  pRPList = lexeme $ do
    sepBy1
      ( (lexchunk "vclin" $> valueCriterionPolRP)
          <|> (lexchunk "vc" $> valueCriterionRP)
          <|> (lexchunk "cro" $> rpoRP)
          <|> (lexchunk "sc" $> subtermCriterionRP)
          <|> (lexchunk "misbv" >> plainBVMIs matrixInterpretationsBVRP)
          <|> (lexchunk "miscons" >> plainMIs matrixInterpretationsConsRP)
          <|> (lexchunk "mis" >> plainMIs matrixInterpretationsRP)
      )
      (lexeme $ char ',')

-- | 'parseSNStrategy' parses a given termination strategy
parseSNStrategy
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v)
  => String
  -> [SNProcessor (FId f) (VId v)]
parseSNStrategy input =
  case runParser (whitespace *> go <* eof) "" (pack input) of
    Left e -> error $ errorBundlePretty e
    Right r -> r
 where
  -- NOTE: longest should be parsed first
  sortedProcessors = sortOn (Down . (\(a, _, _) -> a)) processorsSN

  stratPairs = map (\(k, c, _) -> (k, c)) sortedProcessors

  stratParsers = map (\(text, comp) -> lexeme (chunk text) >> lexeme comp) stratPairs

  combinedParser = foldr ((<|>) . try) empty stratParsers

  go = sepBy1 combinedParser (lexeme $ char ',')

snStrategyExplanation :: String
snStrategyExplanation =
  show $
    "Specify a comma-separated list of methods (if not set, all are used):"
      <+> encloseSep
        "{"
        "}"
        ""
        ["'" <> pretty method <> "'" <+> "=" <+> pretty explanation | (method, explanation) <- methods]
 where
  methods =
    map
      (\(k, _, e) -> (k, e))
      (processorsSN :: [(Text, Parsec Void Text (SNProcessor (FId DefaultFId) (VId DefaultVId)), Text)])

-- | the number of all implemented termination methods
numberOfAllSNMethods :: Int
numberOfAllSNMethods =
  length
    ( processorsSN
        :: [ ( Text
             , Parsec Void Text (SNProcessor (FId DefaultFId) (VId DefaultVId))
             , Text
             )
           ]
    )
