{- |
Module      : MultiStepRewriting
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides functions to perform rewriting on constrained terms.  This
covers, single steps, reflexive transitive steps, parallel steps and multisteps.
-}
module Rewriting.ConstrainedRewriting.MultiStepRewriting where

import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.Guard (
  Guard,
  concatGuards,
 )
import Data.LCTRS.LCTRS (
  CRSeq (..),
  LCTRS,
  StepAnn (MulStep),
  getRules,
 )
import Data.LCTRS.Position (Pos, epsilon)
import Data.LCTRS.Rule (
  CTerm,
  Rule (..),
 )
import Data.LCTRS.Sort (Sort, Sorted)
import Data.LCTRS.Term (
  Term (..),
 )
import qualified Data.LCTRS.Term as T
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Monad
import Data.SExpr (ToSExpr)
import Prettyprinter (Pretty (..))
import Rewriting.ConstrainedRewriting.ConstrainedRewriting (seqStep)
import Rewriting.ConstrainedRewriting.SingleStepRewriting (calcRedRoot, ruleRedRoot)
import Rewriting.Rewrite (
  Reduct (rule),
 )
import qualified Rewriting.Rewrite as Rew
import Rewriting.Substitution (toMap)
import qualified Rewriting.Substitution as Sub

----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
-- multistep rewriting on constrained terms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

-- returns all possible constrained terms which can be reached
-- after N multisteps performed from the starting term
multistepN
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => Int
  -> Maybe Pos
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> LCTRS (FId f) v
  -> Term (FId f) v
  -> Guard (FId f) v
  -> StateM [CTerm (FId f) v]
multistepN steps belowEqP renameRule fresh lctrs term g = go steps [(term, g)]
 where
  go 0 reducts = return reducts
  go steps reducts = do
    reds <-
      -- uncurry (++)
      -- . bimap concat concat
      -- . unzip
      -- <$> mapM (parallelRewriteStep lctrs epsilon belowEqP) reducts
      concat
        <$> mapM (multistepPos lctrs renameRule fresh belowEqP) reducts
    go (pred steps) reds

-- returns all possible rewrite sequences which can be reached
-- after N multisteps performed from the starting term
multistepNSeq
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => Int
  -> Maybe Pos
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> LCTRS (FId f) v
  -> CTerm (FId f) v
  -> StateM [CRSeq (FId f) v]
multistepNSeq steps belowEqP renameRule fresh lctrs cterm = go steps [CRSeq (cterm, [])]
 where
  go 0 reductSeqs = return reductSeqs
  go steps reductSeqs = do
    reds <-
      concat
        <$> mapM (seqStep MulStep (multistepPos lctrs renameRule fresh belowEqP)) reductSeqs
    go (pred steps) reds

-- returns all possible constrained terms after a multistep under a given position
-- from the starting term
multistepPos
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => LCTRS (FId f) v
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> Maybe Pos
  -> CTerm (FId f) v
  -> StateM
      [CTerm (FId f) v]
multistepPos lc renameRule fresh belowEqP (t, g) = do
  let pos = fromMaybe epsilon belowEqP
  let sbtm = T.subtermAt t pos
  case sbtm of
    Nothing -> return []
    Just sbt -> do
      reds <- multistep lc renameRule fresh (sbt, g)
      let res = map (\(ti, gi) -> (T.replaceAt t pos ti, gi)) reds
      return $ mapMaybe (\(t, g) -> case t of Nothing -> Nothing; Just t' -> Just (t', g)) res

-- returns all possible constrained terms after a multistep
-- from the starting term
multistep
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => LCTRS (FId f) v
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> CTerm (FId f) v
  -> StateM
      [CTerm (FId f) v]
multistep _ _ _ ct@(Var _, _) = return [ct]
multistep lctrs renameRule fresh tc@(Fun typ f tis, c) =
  (++) <$> reducts <*> wrappedReducts
 where
  reducts = do
    ruleReducts <- ruleRedRoot renameRule (getRules lctrs) tc
    calcReducts <- calcRedRoot lctrs fresh tc
    ruleReducts' <-
      concat
        <$> mapM
          ( \((_, c'), reduct) ->
              substMultistep c' reduct
          )
          ruleReducts
    return $
      calcReducts -- calc reductions where no step in the subst is possible
        ++ map fst ruleReducts -- rule reductions without subst steps
        ++ ruleReducts' -- rule reductions with subst steps
  wrappedReducts =
    map
      ( \ctis ->
          let (tis, cs) = unzip ctis
          in  (Fun typ f tis, concatGuards cs)
      )
      . cartProd
      <$> argsReducts
  argsReducts = mapM (multistep lctrs renameRule fresh . (,c)) tis

  -- substMultistep
  --   :: Reduct (FId f) (VId v) (VId v) -> StateM [Term (FId f) (VId v)]
  substMultistep c reduct = do
    let subst = Rew.subst reduct
    varTerms <-
      mapM
        ( \(v, t) -> do
            reds <-
              multistep
                lctrs
                renameRule
                fresh
                (t, c)
            -- (t, mapGuard (Sub.apply subst) c)
            return $ map (\(t, c) -> ((v, t), c)) reds
        )
        $ M.toList
        $ toMap subst
    let combiArgsReducts = cartProd varTerms
    let taus =
          map
            ( \poss ->
                let (vs, cs) = unzip poss
                in  (Sub.fromMap $ M.fromList vs, concatGuards cs)
            )
            combiArgsReducts
    return [(Sub.apply tau (rhs $ rule reduct), constr) | (tau, constr) <- taus]

  cartProd = sequence
