{- |
Module      : Rule
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides the implementation for constrained rules used in LCTRSs.
-}
module Data.LCTRS.Rule (
  Rule (..),
  createRule,
  -- rule,
  isVariantOf,
  isLeftLinear,
  isLinear,
  isGround,
  lvar,
  extraVars,
  funs,
  vars,
  funsRule,
  varsRule,
  varsLhs,
  varsRhs,
  rename,
  renameFresh,
  CTerm,
  definedSym,
  dummyConstraints,
  prettyRules,
  prettyRule,
  prettyCTerm,
  prettyRulesFId,
  prettyRuleFId,
  prettyCTermFId,
) where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Control.Monad.State.Strict (StateT, evalStateT)
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.Guard
import Data.LCTRS.Sort (Sorted (sort), sortAnnotation)
import Data.LCTRS.Term (Term)
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (VId)
import qualified Data.Map.Strict as M
import Data.Monad (MonadFresh)
import qualified Data.MultiSet as MS
import Data.SMT (boolSort)
import qualified Data.Set as S
import Prettyprinter
import qualified Rewriting.Substitution as Sub

----------------------------------------------------------------------------------------------------
-- types
----------------------------------------------------------------------------------------------------

-- rule
-- type Rule f v = (R.Rule f v, Guard f v)
data Rule f v = Rule
  { lhs :: Term f v
  , rhs :: Term f v
  , guard :: Guard f v
  }
  deriving (Eq, Ord)

-- constrained term
type CTerm f v = (Term f v, Guard f v)

----------------------------------------------------------------------------------------------------
-- core functions
----------------------------------------------------------------------------------------------------

-- rule :: Term f v -> Term f v -> Maybe SMT2.SExpr -> Rule f v
-- rule t1 t2 Nothing = (R.Rule t1 t2,SMT2.SESymbol "true")
-- rule t1 t2 (Just guard) = (R.Rule t1 t2,guard)

-- rule :: Term String String -> Term String String -> Maybe (Term String String) -> Rule String String
createRule :: Term f v -> Term f v -> Guard f v -> Rule f v
-- rule t1 t2 Nothing = (R.Rule t1 t2,[])
-- rule t1 t2 (Just guard) = (R.Rule t1 t2,[guard])
-- createRule t1 t2 guard = (R.Rule t1 t2, guard)
createRule = Rule

-- guard :: Rule f v -> Guard f v
-- guard = constraint

-- rule :: Rule f v -> R.Rule f v
-- rule = fst

funs :: Rule f v -> [f]
funs r = T.funs (lhs r) ++ T.funs (rhs r) ++ funsGuard (guard r)

vars :: Rule f v -> [v]
vars r = T.vars (lhs r) ++ T.vars (rhs r) ++ varsGuard (guard r)

funsRule :: Rule f v -> [f]
funsRule r = T.funs (lhs r) ++ T.funs (rhs r)

varsRule :: Rule f v -> [v]
varsRule r = T.vars (lhs r) ++ T.vars (rhs r)

varsLhs :: Rule f v -> [v]
varsLhs r = T.vars (lhs r)

varsRhs :: Rule f v -> [v]
varsRhs r = T.vars (rhs r)

-- checks linearity for non-logical variables
-- note that for current results this is enough
isLinearWrtRule :: (Ord v) => Rule f v -> Term f v -> Bool
isLinearWrtRule rule term = all (\(_, c) -> c <= 1) $ MS.toOccurList filteredMS
 where
  lvars = lvar rule
  filteredMS = deleteLogicVars lvars . MS.fromList $ T.vars term
  deleteLogicVars set mset = foldr MS.deleteAll mset set

isLeftLinear :: (Ord v) => Rule f v -> Bool
-- isLeftLinear (r, _) = R.isLeftLinear r
-- isLeftLinear = T.isLinear . lhs
isLeftLinear rule = isLinearWrtRule rule (lhs rule)

isLinear :: (Ord v) => Rule f v -> Bool
isLinear rule = all (isLinearWrtRule rule) [lhs rule, rhs rule]

isGround :: (Ord v) => Rule f v -> Bool
isGround r = T.isGround (lhs r) && T.isGround (rhs r)

isVariantOf :: (Ord f, Ord v) => Rule f v -> Rule f v -> Bool
isVariantOf rule1 rule2 =
  Sub.isVariantOf
    (T.termFun Nothing $ toM l1 : toM r1 : map toM (termsGuard g1))
    (T.termFun Nothing $ toM l2 : toM r2 : map toM (termsGuard g2))
 where
  toM = T.mapFuns Just

  l1 = lhs rule1
  r1 = rhs rule1
  g1 = guard rule1
  l2 = lhs rule2
  r2 = rhs rule2
  g2 = guard rule2

-- R.isVariantOf (rule r1) (rule r2)
--   && guardVariants (guard r1) (guard r2)

lvar :: (Ord v) => Rule f v -> S.Set v
lvar Rule{..} =
  let rvars = S.fromList $ T.vars rhs
  in  let lvars = S.fromList $ T.vars lhs
      in  (rvars S.\\ lvars) `S.union` S.fromList (varsGuard guard)

extraVars :: (Ord v) => Rule f v -> S.Set v
extraVars Rule{..} =
  let rvars = S.fromList $ T.vars rhs
  in  let lvars = S.fromList $ T.vars lhs
      in  rvars S.\\ (lvars `S.union` S.fromList (varsGuard guard))

dummyConstraints :: (Ord v, Eq f) => Rule (FId f) (VId v) -> Term (FId f) (VId v)
dummyConstraints rule =
  foldr (T.conj . dummy) T.top vs
 where
  vs = extraVars rule
  dummy v = T.eq (sortAnnotation [sort v, sort v] boolSort) (T.var v) (T.var v)

rename :: (v -> v') -> Rule f v -> Rule f v'
rename renaming Rule{..} = createRule (mV lhs) (mV rhs) (modifyGuard guard (map mV))
 where
  mV = T.mapVars renaming

renameFresh :: (Ord v, MonadFresh m) => Rule f (VId v) -> m (Rule f (VId v))
renameFresh rule = evalStateT (freshVariables rule) M.empty

freshVariables
  :: (Ord v, MonadFresh m)
  => Rule f (VId v)
  -> StateT (M.Map (VId v) (VId v)) m (Rule f (VId v))
freshVariables Rule{..} =
  createRule
    <$> T.freshVariables lhs
    <*> T.freshVariables rhs
    <*> mapGuardM T.freshVariables guard

definedSym :: Rule f v -> Maybe f
definedSym rule =
  case T.getRoot l of
    (Left f) -> return f
    (Right _) -> Nothing
 where
  l = lhs rule

----------------------------------------------------------------------------------------------------
-- pretty priting
----------------------------------------------------------------------------------------------------

prettyRule :: (Pretty f, Pretty v) => Rule f v -> Doc ann
prettyRule rule
  | isTopGuard (guard rule) =
      pretty (lhs rule)
        <+> "->"
        <+> pretty (rhs rule)
  | otherwise =
      pretty (lhs rule)
        <+> "->"
        <+> pretty (rhs rule)
        <+> prettyGuard
          (guard rule)

prettyRules :: (Pretty f, Pretty v) => [Rule f v] -> Doc ann
prettyRules rules = vsep $ map prettyRule rules

prettyCTerm :: (Pretty f, Pretty v) => CTerm f v -> Doc ann
prettyCTerm (t, g) = pretty t <+> prettyGuard g

prettyRuleFId :: (Pretty f, Pretty v) => Rule (FId f) v -> Doc ann
prettyRuleFId rule
  | isTopGuard (guard rule) =
      T.prettyTermFId (lhs rule)
        <+> "->"
        <+> T.prettyTermFId (rhs rule)
  | otherwise =
      T.prettyTermFId (lhs rule)
        <+> "->"
        <+> T.prettyTermFId (rhs rule)
        <+> prettyGuardFId (guard rule)

prettyRulesFId :: (Pretty f, Pretty v) => [Rule (FId f) v] -> Doc ann
prettyRulesFId rules = vsep $ map prettyRuleFId rules

prettyCTermFId :: (Pretty f, Pretty v) => CTerm (FId f) v -> Doc ann
prettyCTermFId (t, g) = T.prettyTermFId t <+> prettyGuardFId g
