{- |
Module      : SingleStepRewriting
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides functions to perform rewriting on constrained terms.  This
covers, single steps, reflexive transitive steps, parallel steps and multisteps.
-}
module Rewriting.ConstrainedRewriting.SingleStepRewriting where

import Control.Monad ((<=<))
import Data.LCTRS.FIdentifier (
  FId,
 )
import Data.LCTRS.Guard (
  Guard,
  collapseGuardToTerm,
  conjGuards,
  fromTerm,
  mapGuard,
  varsGuard,
 )
import Data.LCTRS.LCTRS (
  CRSeq (..),
  LCTRS,
  StepAnn (SinStep),
  getRules,
  getTheoryFuns,
 )
import Data.LCTRS.Position (
  Pos,
  append,
  below,
  epsilon,
 )
import Data.LCTRS.Rule (
  CTerm,
  Rule (..),
  guard,
  lhs,
  lvar,
  rhs,
 )
import Data.LCTRS.Sort (Sort, Sorted (sort), sortAnnotation)
import Data.LCTRS.Term (
  Term (..),
  fun,
 )
import qualified Data.LCTRS.Term as T
import Data.Maybe (fromMaybe)
import Data.Monad (MonadFresh, StateM, freshInt)
import Data.SExpr (ToSExpr (..), existsSExprOfTerm)
import Data.SMT (boolSort)
import qualified Data.Set as S
import Prettyprinter (Pretty (..))
import Rewriting.ConstrainedRewriting.ConstrainedRewriting (seqStep, trivialVarEquations)
import Rewriting.Rewrite (
  Reduct (..),
  fullRewrite,
  rootRewrite,
 )
import qualified Rewriting.Rewrite as Rew
import Rewriting.SMT (
  satGuard,
  smtResultCheck,
  validSExprSF,
 )
import qualified Rewriting.Substitution as Sub
import qualified SimpleSMT as SMT
import Utils (dropNthElem)

----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
-- standard rewriting on constrained terms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

-- returns all possible constrained terms reached within N steps
-- from the starting term
rewriteN
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => Int
  -> Maybe Pos
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> LCTRS (FId f) v
  -> Term (FId f) v
  -> Guard (FId f) v
  -> StateM [CTerm (FId f) v]
rewriteN steps belowEqP renameRule fresh lctrs term g = go steps [(term, g)]
 where
  red1 = ruleRed renameRule belowEqP (getRules lctrs)
  red2 (t, c) = calcRed lctrs fresh belowEqP c (epsilon, t)

  go 0 reducts = return reducts
  go steps reducts = do
    reds1 <- concat <$> mapM red1 reducts
    reds2 <- concat <$> mapM red2 reducts
    (++) reducts <$> go (pred steps) (reds1 ++ reds2)

-- returns all possible rewrite sequences from the starting term
-- within N steps
rewriteNSeq
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => Int
  -> Maybe Pos
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> LCTRS (FId f) v
  -> Term (FId f) v
  -> Guard (FId f) v
  -> StateM [CRSeq (FId f) v]
rewriteNSeq steps belowEqP renameRule fresh lctrs term g = go steps [CRSeq ((term, g), [])]
 where
  red1 = ruleRed renameRule belowEqP (getRules lctrs)
  red2 (t, c) = calcRed lctrs fresh belowEqP c (epsilon, t)

  go 0 reductSeqs = return reductSeqs
  go steps reductSeqs = do
    reds1 <- concat <$> mapM (seqStep SinStep red1) reductSeqs
    reds2 <- concat <$> mapM (seqStep SinStep red2) reductSeqs
    recReds <- go (pred steps) (reds1 ++ reds2)
    return $ reductSeqs ++ recReds

-- returns all possible constrained terms which can be reached after a reduction
-- using a standard rule
ruleRed
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> Maybe Pos
  -> [Rule (FId f) v]
  -> CTerm (FId f) v
  -> StateM [CTerm (FId f) v]
ruleRed renameRule belowEqP rs (t, gT) = rewrite rs
 where
  rewrite [] = return []
  rewrite (r' : rs) = do
    rule@Rule{..} <- renameRule r'
    b <- smtResultCheck =<< satGuard gT
    if b
      then do
        let lvs = lvar rule
        let evs =
              (S.fromList (T.vars rhs) <> S.fromList (varsGuard guard))
                S.\\ S.fromList (T.vars lhs)
        let reds = fullRewrite [rule] t
        (++) <$> validReds lvs evs guard reds <*> rewrite rs
      else rewrite rs

  validReds _ _ _ [] = return []
  validReds lvs evs g (red : reds) = do
    let s = Rew.subst red
    let newGuard =
          if null evs
            then gT
            else conjGuards gT $ conjGuards (trivialVarEquations evs) $ mapGuard (Sub.apply s) g
    let validSub = all (check newGuard . Sub.apply s . T.var) lvs
    let validPos = Rew.pos red `below` fromMaybe epsilon belowEqP
    if validSub && validPos
      then do
        let substitutedGuard = Sub.apply s $ collapseGuardToTerm g
        validity <-
          smtResultCheck
            =<< validSExprSF
              (S.fromList $ varsGuard gT ++ T.vars substitutedGuard)
              ( SMT.implies
                  (toSExpr $ collapseGuardToTerm gT)
                  ( if null evs
                      then toSExpr substitutedGuard
                      else existsSExprOfTerm evs substitutedGuard
                  )
              )
        if validity
          then ((result red, newGuard) :) <$> validReds lvs evs g reds
          else validReds lvs evs g reds
      else validReds lvs evs g reds

  check g t =
    t `elem` map T.var (varsGuard g)
      || T.isValue t

-- returns all possible constrained terms which can be reached after a reduction
-- using a calculation rule
calcRed
  :: (Ord v, Ord f, MonadFresh m)
  => LCTRS (FId f) v
  -> (Sort -> Int -> v)
  -> Maybe Pos
  -> Guard (FId f) v
  -> (Pos, Term (FId f) v)
  -> m [CTerm (FId f) v]
calcRed _ _ _ _ (_, Var _) = return []
calcRed lctrs fresh belowEqP gT (pos, t@(Fun typ f ts))
  | all check ts
      && pos `below` fromMaybe epsilon belowEqP
      && f `S.member` getTheoryFuns lctrs =
      do
        let s = sort f
        v <- T.var . fresh s <$> freshInt
        let g = fromTerm $ T.eq (sortAnnotation [s, s] boolSort) v t
        return [(v, conjGuards gT g)]
  | otherwise = do
      reduced <-
        mapM (calcRed lctrs fresh belowEqP gT) $
          zipWith (\i t -> (pos `append` i, t)) [0 ..] ts
      return
        [ (fun typ f (pre ++ t : suf), g)
        | (i, reds) <- zip [0 ..] reduced
        , let (pre, suf) = dropNthElem i ts
        , (t, g) <- reds
        ]
 where
  check t =
    t `elem` map T.var (varsGuard gT)
      || T.isValue t

-- returns all possible constrained terms which can be reached after a reduction
-- at the root using arbitrary rules
rewriteRoot
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => LCTRS (FId f) v
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> Term (FId f) v
  -> Guard (FId f) v
  -> StateM [CTerm (FId f) v]
rewriteRoot lctrs renameRule fresh term g = go [(term, g)]
 where
  red1 = return . map fst <=< ruleRedRoot renameRule (getRules lctrs)
  red2 = calcRedRoot lctrs fresh

  go reducts = do
    reds1 <- concat <$> mapM red1 reducts
    reds2 <- concat <$> mapM red2 reducts
    return $ reds1 ++ reds2

-- returns all possible constrained terms which can be reached after a reduction
-- at the root using a standard rule
ruleRedRoot
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> [Rule (FId f) v]
  -> CTerm (FId f) v
  -> StateM [(CTerm (FId f) v, Reduct (FId f) v)]
ruleRedRoot renameRule rs (t, gT) = rewrite rs
 where
  rewrite [] = return []
  rewrite (r' : rs) = do
    rule@Rule{..} <- renameRule r'
    b <- smtResultCheck =<< satGuard gT
    if b
      then do
        let lvs = lvar rule
        let evs =
              (S.fromList (T.vars rhs) <> S.fromList (varsGuard guard))
                S.\\ S.fromList (T.vars lhs)
        let reds = rootRewrite [rule] t
        validReducts <- validRed lvs evs guard reds
        (validReducts ++) <$> rewrite rs
      else rewrite rs

  validRed _ _ _ [] = return []
  validRed lvs evs g (red : reds) = do
    let s = Rew.subst red
    let newGuard =
          if null evs
            then gT
            else conjGuards gT $ conjGuards (trivialVarEquations evs) $ mapGuard (Sub.apply s) g
    let validSub = all (check newGuard . Sub.apply s . T.var) lvs
    let ruleGuard = Sub.apply s $ collapseGuardToTerm g
    if validSub
      then do
        validity <-
          smtResultCheck
            =<< validSExprSF
              (S.fromList $ varsGuard gT ++ T.vars ruleGuard)
              ( SMT.implies
                  (toSExpr $ collapseGuardToTerm gT)
                  ( if null evs
                      then toSExpr ruleGuard
                      else existsSExprOfTerm evs ruleGuard
                  )
              )
        if validity
          then (((result red, newGuard), red) :) <$> validRed lvs evs g reds
          else validRed lvs evs g reds
      else validRed lvs evs g reds

  check g t =
    t `elem` map T.var (varsGuard g)
      || T.isValue t

-- returns all possible constrained terms which can be reached after a reduction
-- at the root using a calculation rule
calcRedRoot
  :: (Ord v, Ord f, MonadFresh m)
  => LCTRS (FId f) v
  -> (Sort -> Int -> v)
  -> CTerm (FId f) v
  -> m [CTerm (FId f) v]
calcRedRoot _ _ (Var _, _) = return []
calcRedRoot lctrs fresh (t@(Fun _ f ts), gT)
  | all check ts
      && f `S.member` getTheoryFuns lctrs =
      do
        let s = sort f
        v <- T.var . fresh s <$> freshInt
        let g = fromTerm $ T.eq (sortAnnotation [s, s] boolSort) v t
        return [(v, conjGuards gT g)]
  | otherwise = return []
 where
  check t =
    t `elem` map T.var (varsGuard gT)
      || T.isValue t

isCalculationPossible
  :: (Ord v, Ord f, MonadFresh m)
  => (Sort -> Int -> v)
  -> LCTRS (FId f) v
  -> (Term (FId f) v, Guard (FId f) v)
  -> m Bool
isCalculationPossible freshVar lc (term, c) =
  not . null <$> calcRed lc freshVar (Just epsilon) c (epsilon, term)

-- do
-- reds <- calcRed lc freshVar (Just epsilon) c (epsilon, term)
-- return $ not $ null reds
