{- createCalcs)
Module      : ConstrainedRewriting
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable

This module provides functions to perform rewriting on constrained terms.  This
covers, single steps, reflexive transitive steps, parallel steps and multisteps.
-}
module Rewriting.ConstrainedRewriting.ConstrainedRewriting where

import Control.Monad (zipWithM)
import Data.LCTRS.FIdentifier (
  FId,
 )
import Data.LCTRS.Guard (
  Guard,
  collapseGuardToTerm,
  createGuard,
  varsGuard,
 )
import Data.LCTRS.LCTRS
import Data.LCTRS.Rule (
  CTerm,
  lhs,
  lvar,
 )
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort
import Data.LCTRS.Term (
  Term (..),
  eq,
  isConstrainedEquation,
  isValue,
 )
import qualified Data.LCTRS.Term as T
import Data.Maybe (fromJust, isJust)
import Data.Monad
import Data.SExpr
import Data.SMT (boolSort)
import qualified Data.Set as S
import Prettyprinter (Pretty (..))
import Rewriting.SMT (
  satSExprSF,
  smtResultCheck,
  validGuard,
 )
import Rewriting.Substitution (Subst, apply, match)
import qualified SimpleSMT as SMT

----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
-- auxiliary rewriting functions on constrained terms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

seqStep :: StepAnn -> (CTerm f v -> StateM [CTerm f v]) -> CRSeq f v -> StateM [CRSeq f v]
seqStep stepAnn reduce (CRSeq (endTerm, ts)) = do
  reducts <- reduce endTerm
  return $ [CRSeq (r, (stepAnn, endTerm) : ts) | r <- reducts]

-- checks that a given constrained equation is trivial
trivialConstrainedEq
  :: (Ord f, Ord v, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => CTerm (FId f) v
  -> StateM Bool
trivialConstrainedEq (equa@(T.Fun _ _ args), c)
  | isConstrainedEquation equa =
      case args of
        [s, t] -> equivalent (s, t, c)
        _ -> return False
 where
  acceptable phis x = x `elem` varsGuard phis
  equivalent (s, t, phis) =
    equivalentUnderConstraint s t phis (acceptable phis)
trivialConstrainedEq _ =
  sError
    "ConstrainedRewriting.hs: Cannot check a term which is not a trivial constrained equation if it is a trivial constrained equation."

-- taken from Ctrl, translated to Haskell and simplified/modified
-- checks that two terms are equivalent under a given constraint
equivalentUnderConstraint
  :: (Ord v, Ord f, Pretty v, Pretty f, ToSExpr f, ToSExpr v, Sorted v)
  => Term (FId f) v
  -> Term (FId f) v
  -> Guard (FId f) v
  -> (v -> Bool)
  -> StateM Bool
equivalentUnderConstraint s t phis theoryVar =
  let lst = make_equalities <$> matchem s t
  in  case lst of
        Nothing -> return False
        Just [] -> return True
        Just lst -> do
          let constr = createGuard [T.imp (collapseGuardToTerm phis) (collapseGuardToTerm $ createGuard lst)]
          smtResultCheck =<< validGuard constr
 where
  matchem s t
    | s == t = Just []
    | otherwise =
        case (s, t) of
          (Var x, Var y) ->
            if not (theoryVar x) || not (theoryVar y)
              then Nothing
              else Just [(s, t)]
          (Var x, g@(Fun _ _ _)) ->
            if not (theoryVar x) || not (isValue g)
              then Nothing
              else Just [(s, t)]
          (f@(Fun _ _ _), Var y) ->
            if not (theoryVar y) || not (isValue f)
              then Nothing
              else Just [(s, t)]
          (Fun _ f fs, Fun _ g gs) ->
            if f /= g
              then Nothing
              else concat <$> zipWithM matchem fs gs

  make_equality (x, y) =
    let sa = sortAnnotation [sort x, sort y] boolSort
    in  eq sa x y
  make_equalities = map make_equality

----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
-- auxiliary functions on constrained terms for normal form checking and CCP splitting
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

isReducibleInstance
  :: (Eq v, Eq f, Ord v, Pretty f, Pretty v, ToSExpr v, ToSExpr f, Sorted v)
  => Guard (FId f) v
  -> (Guard (FId f) v, Subst (FId f) v, S.Set v)
  -> StateM Bool
isReducibleInstance c (g, s, valueVars) = do
  let
    varphi = collapseGuardToTerm c
    psi = s `apply` collapseGuardToTerm g
    anyInstance =
      if null valueVars
        then toSExpr varphi `SMT.and` toSExpr psi
        else toSExpr varphi `SMT.and` existsSExprOfTerm valueVars psi
  -- satisfiable <- satGuard varSort (createGuard [anyInstance])
  satisfiable <- satSExprSF (S.fromList $ T.vars varphi ++ T.vars psi) anyInstance -- Guard varSort (createGuard [anyInstance])
  case satisfiable of
    Just True -> return True
    Just False -> return False
    Nothing -> error "ConstrainedRewriting.hs: SMT solver cannot solve satisfiability of formula."

matchesAnyPos
  :: (Ord v, Pretty v, Pretty f, Ord f, ToSExpr f, ToSExpr v)
  => [T.Term (FId f) v]
  -> Guard (FId f) v
  -> R.Rule (FId f) v
  -> StateM [(Guard (FId f) v, Subst (FId f) v, S.Set v)]
matchesAnyPos subterms c rule = do
  let valueVars = lvar rule
  -- ( S.fromList (T.vars $ rhs rule)
  --     <> S.fromList (varsGuard $ guard rule)
  -- )
  --   S.\\ S.fromList (T.vars $ lhs rule)
  return
    [ (R.guard rule, subst, S.unions $ S.map (S.fromList . T.vars . apply subst . T.var) valueVars)
    | subterm <- subterms
    , not (T.isVar subterm)
    , let matched = lhs rule `match` subterm
    , isJust matched
    , let subst = fromJust matched
    , all
        (check . apply subst . T.var)
        (S.fromList (T.vars $ lhs rule) `S.intersection` lvar rule) -- S.fromList (varsGuard $ guard rule))
    ]
 where
  check t =
    t `elem` map T.var (varsGuard c)
      || T.isValue t

{- | 'trivialVarEqsRule' @rule@ constructs trivial equations (x = x) for all extra variables
of the given @rule@ and returns a constraint constaining them.
-}
trivialVarEqsRule :: (Ord v, Sorted v) => R.Rule (FId f) v -> Guard (FId f) v
trivialVarEqsRule rule = trivialVarEquations (R.extraVars rule)

{- | 'trivialVarEquations' @set@  takes a set of variables and returns a constraint which consists
only of trivial equations for each variable, i.e. for x we have x = x
-}
trivialVarEquations :: (Ord v, Sorted v) => S.Set v -> Guard (FId f) v
trivialVarEquations vars =
  createGuard
    [ T.eq (sortAnnotation [] (sort v)) (T.var v) (T.var v)
    | v <- S.toList vars
    ]
