{- |
Module      : ParallelStepRewriting
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides functions to perform rewriting on constrained terms.  This
covers, single steps, reflexive transitive steps, parallel steps and multisteps.
-}
module Rewriting.ConstrainedRewriting.ParallelStepRewriting where

import Data.Bifunctor (Bifunctor (..), bimap)
import Data.Foldable (foldrM)
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.Guard (
  Guard,
  concatGuards,
  conjGuards,
 )
import Data.LCTRS.LCTRS (
  CRSeq (..),
  LCTRS,
  StepAnn (ParStep),
 )
import Data.LCTRS.Position (
  Pos,
  epsilon,
  position,
 )
import Data.LCTRS.Rule (
  CTerm,
  Rule,
 )
import Data.LCTRS.Sort (Sort, Sorted)
import Data.LCTRS.Term (
  Term (..),
  fun,
 )
import qualified Data.LCTRS.Term as T
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Monad
import Data.SExpr
import qualified Data.Set as S
import Prettyprinter (Pretty (..))
import Rewriting.ConstrainedRewriting.ConstrainedRewriting (seqStep)
import Rewriting.ConstrainedRewriting.SingleStepRewriting (rewriteRoot)

----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
-- parallel rewriting on constrained terms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

-- returns all possible constrained terms which can be reached
-- after N parallel steps performed from the starting term
parallelRewriteStepN
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => Int
  -> Maybe Pos
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> LCTRS (FId f) v
  -> Term (FId f) v
  -> Guard (FId f) v
  -> StateM [CTerm (FId f) v]
parallelRewriteStepN steps belowEqP renameRule fresh lctrs term g = go steps [(term, g)]
 where
  go 0 reducts = return reducts
  go steps reducts = do
    reds <-
      -- uncurry (++)
      -- . bimap concat concat
      -- . unzip
      -- <$> mapM (parallelRewriteStep lctrs epsilon belowEqP) reducts
      concat
        <$> mapM (parallelRewriteStep lctrs renameRule fresh belowEqP) reducts
    go (pred steps) reds

-- returns all possible rewrite sequences
-- after N parallel steps performed from the starting term
parallelRewriteStepNSeq
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => Int
  -> Maybe Pos
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> LCTRS (FId f) v
  -> CTerm (FId f) v
  -> StateM [CRSeq (FId f) v]
parallelRewriteStepNSeq steps belowEqP renameRule fresh lctrs cterm = go steps [CRSeq (cterm, [])]
 where
  go 0 reductSeqs = return reductSeqs
  go steps reductSeqs = do
    reds <-
      concat
        <$> mapM
          (seqStep (ParStep Nothing) (parallelRewriteStep lctrs renameRule fresh belowEqP))
          reductSeqs
    go (pred steps) reds

-- returns all possible constrained terms which can be reached
-- after a parallel steps performed from the starting term
-- this parallel step is performed at position below the given positions
parallelRewriteStep
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => LCTRS (FId f) v
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> Maybe Pos
  -> CTerm (FId f) v
  -> StateM
      [CTerm (FId f) v]
parallelRewriteStep lc renameRule fresh belowEqP (t, g) = do
  let pos = fromMaybe epsilon belowEqP
  let sbtm = T.subtermAt t pos
  case sbtm of
    Nothing -> return []
    Just sbt -> do
      reds <- parallelRewriteStep' lc renameRule fresh (sbt, g)
      let res = map (\(ti, gi) -> (T.replaceAt t pos ti, gi)) reds
      return $ mapMaybe (\(t, g) -> case t of Nothing -> Nothing; Just t' -> Just (t', g)) res

-- returns all possible constrained terms and the set of parallel positions
-- which was used in the parallel step from the starting term
parallelRewriteStepRetPoss
  :: (Ord v, Ord f, Pretty v, ToSExpr f, ToSExpr v, Pretty f, Sorted v)
  => LCTRS (FId f) v
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> Maybe Pos
  -> CTerm (FId f) v
  -> StateM
      [(CTerm (FId f) v, S.Set Pos)]
parallelRewriteStepRetPoss lc renameRule fresh belowEqP (t, g) = do
  let pos = fromMaybe epsilon belowEqP
  let sbtm = T.subtermAt t pos
  case sbtm of
    Nothing -> return []
    Just sbt -> do
      reds <- parallelRewriteStepRetPoss' lc renameRule fresh (sbt, g)
      -- lift reductions into previous context
      let res = map (\((ti, gi), poss) -> ((T.replaceAt t pos ti, gi), S.map (pos <>) poss)) reds
      return $
        mapMaybe (\((t, g), poss) -> case t of Nothing -> Nothing; Just t' -> Just ((t', g), poss)) res

-- returns all possible rewrite sequences after a parallel step
-- at positions in the given set of parallel positions from the starting term
parallelRewriteStepAtPoss
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => LCTRS (FId f) v
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> S.Set Pos
  -> CTerm (FId f) v
  -- -> StateM
  --     [CTerm (FId f) v]
  -> StateM [CRSeq (FId f) v]
parallelRewriteStepAtPoss lc renameRule fresh possSet ct@(t, g) = do
  possTargetsPairs <- mapM computeTargetsAtPos $ S.toList possSet
  let possTargetsCombs = correctCombinations possTargetsPairs
  targetTerms <- mapM (computeParallelStepWithTargets ct) possTargetsCombs
  return $ [CRSeq (target, [(ParStep (Just possSet), ct)]) | target <- targetTerms]
 where
  correctCombinations :: [[(Pos, CTerm (FId f) v)]] -> [[(Pos, CTerm (FId f) v)]]
  correctCombinations = sequence -- cartesian product of this list

  -- computeTargetsAtPos :: Pos -> StateM [(Pos, CTerm (FId f) v)]
  computeTargetsAtPos p = do
    subTerm <- case t `T.subtermAt` p of
      Nothing -> sError "ConstrainedRewriting.hs: No subterm at position to perform parallel rewrite step."
      Just sT -> return sT
    targets <- rewriteRoot lc renameRule fresh subTerm g
    return [(p, target) | target <- targets]

  computeParallelStepWithTargets :: CTerm f v -> [(Pos, CTerm f v)] -> StateM (CTerm f v)
  computeParallelStepWithTargets cterm possTargets =
    foldrM
      ( \(p, (ti, gi)) (term, co) ->
          case T.replaceAt term p ti of
            Nothing -> sError "ConstrainedRewriting.hs: No step possible at position in parallel rewrite step."
            Just term' -> return (term', conjGuards gi co)
      )
      cterm
      possTargets

-- returns all possible constrained terms which can be reached
-- after a parallel steps performed from the starting term
parallelRewriteStep'
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => LCTRS (FId f) v
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> CTerm (FId f) v
  -> StateM
      [CTerm (FId f) v]
parallelRewriteStep' lc renameRule fresh (t, g) = do
  top <- rewriteRoot lc renameRule fresh t g
  belows <- case t of
    (Var _) -> pure []
    (Fun _ _ args) -> do
      ss <- mapM (\ti -> parallelRewriteStep' lc renameRule fresh (ti, g)) args
      return $ wrapIntoContext t ss
  return $ (t, g) : top ++ belows
 where
  combine [] = [[]]
  combine (si' : sis') =
    let rs = combine sis'
    in  concatMap (\si'' -> map (si'' :) rs) si'

  wrapIntoContext (Var _) _ = []
  wrapIntoContext (Fun typ f _) sss =
    let ss' = combine sss
    in  map (bimap (fun typ f) concatGuards . unzip) ss'

-- returns all possible constrained terms including the set of parallel positions
-- which was used in the parallel step from the starting term
parallelRewriteStepRetPoss'
  :: (Ord v, Ord f, Pretty v, ToSExpr f, ToSExpr v, Pretty f, Sorted v)
  => LCTRS (FId f) v
  -> (Rule (FId f) v -> StateM (Rule (FId f) v))
  -> (Sort -> Int -> v)
  -> CTerm (FId f) v
  -> StateM
      [(CTerm (FId f) v, S.Set Pos)]
parallelRewriteStepRetPoss' lc renameRule fresh (t, g) = do
  top <- map (,S.singleton epsilon) <$> rewriteRoot lc renameRule fresh t g
  belows <- case t of
    (Var _) -> pure []
    (Fun _ _ args) -> do
      ss <-
        mapM
          ( \(i, ti) ->
              map (second (S.map (position [i] <>)))
                <$> parallelRewriteStepRetPoss' lc renameRule fresh (ti, g)
          )
          $ zip [0 ..] args
      return $ wrapIntoContext t ss
  return $ ((t, g), S.empty) : top ++ belows
 where
  combine [] = [[]]
  combine (si' : sis') =
    let rs = combine sis'
    in  concatMap (\si'' -> map (si'' :) rs) si'

  wrapIntoContext (Var _) _ = []
  wrapIntoContext (Fun typ f _) sss =
    let ss' = combine sss
    in  -- in  map ((\(cterms, posses) -> (bimap (fun f) concatGuards (unzip cterms), mconcat posses)) . unzip) ss'
        [ ((fun typ f terms, concatGuards constraints), mconcat posses)
        | combination <- ss'
        , let (cterms, posses) = unzip combination
        , let (terms, constraints) = unzip cterms
        ]
