{- |
Module      : LCTRS
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides the implementation of logical constrained rewrite systems.
-}
module Data.LCTRS.LCTRS (
  LCTRS (..),
  createLctrs,
  emptyLCTRS,
  getRules,
  getFuns,
  getTheoryFuns,
  getTheories,
  getTheorySorts,
  getSorts,
  getDefines,
  updateRules,
  updateFuns,
  updateTheoryFuns,
  updateTheories,
  fType,
  isLeftLinear,
  isLinear,
  isGround,
  allFunsInRules,
  allVars,
  checkLhsRules,
  rename,
  CRSeq (..),
  StepAnn (..),
  combineCRSeqs,
  definedSyms,
  definesToSExpr,
  prettyCRSeq,
  prettyCRSeqFId,
  prettyLCTRS,
  prettyLCTRSDefault,
  prettyARILCTRS,
) where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.FIdentifier (FId, getFIdSort)
import qualified Data.LCTRS.Guard as R
import Data.LCTRS.Position (Pos)
import Data.LCTRS.Rule (Rule)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort (
  Sort,
  Sorted (sort),
  inSort,
  outSort,
  prettySortAnnotation,
 )
import Data.LCTRS.Term (FunType)
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (VId)
import Data.Maybe (mapMaybe)
import Data.SExpr
import Data.SMT (Identifier, Theory)
import qualified Data.Set as S
import Data.Text (Text)
import Prettyprinter
import SimpleSMT (showsSExpr)
import qualified SimpleSMT as SMT
import Utils (mapOnMaybes)

----------------------------------------------------------------------------------------------------
-- types
----------------------------------------------------------------------------------------------------

-- data LCTRS = LCTRS {
--   rules :: [Rule]
--   -- , logic :: Maybe Logic
--   , funs :: M.Map FId (Int, Maybe [Sort])
--   , vars :: M.Map VId Sort
--   , theoryFuns :: M.Map String ([Sort], Maybe Prop)
--   , valueCheck :: String -> Bool
--   -- , fsymbols :: [FId]
--   -- , fsorts :: M.Map FId Sort
--   -- , farities :: M.Map FId Int
--   -- , ftypes :: M.Map FId FSymbolType
--   -- , vsorts :: M.Map VId Sort
--   }

-- data LCTRS f v = LCTRS
--   { rules       :: [Rule f v]
--   , funs        :: S.Set f -- SortAnnotation -- (Int, Maybe [Sort])
--   , ruleVars    :: M.Map (Rule f v) (M.Map v Sort)
--   , theoryFuns  :: M.Map f (Maybe Prop)
--   , valueCheck  :: f -> Bool
--   , valueSort   :: f -> Maybe Sort
--   , theories    :: [Theory]
--   , theorySorts :: S.Set Sort
--   , sorts       :: S.Set Sort
--   , theoryDefs  :: M.Map f (([(v, Sort)], Sort), T.Term f v)
--   }

data LCTRS f v = LCTRS
  { theories :: ![Theory]
  , isTheorySort :: Sort -> Bool
  , defines :: ![(f, (([(Identifier, Sort)], Sort), Text))]
  , theoryFuns :: !(S.Set f)
  , sorts :: !(S.Set Sort)
  , funs :: !(S.Set f)
  , rules :: ![Rule f v]
  }

data StepAnn = SinStep | RefStep | RefTraStep | ParStep !(Maybe (S.Set Pos)) | MulStep
  deriving (Eq)

-- rewrite sequence of constrained terms from right to left
newtype CRSeq f v = CRSeq (R.CTerm f v, [(StepAnn, R.CTerm f v)])
  deriving (Eq)

----------------------------------------------------------------------------------------------------
-- core functions
----------------------------------------------------------------------------------------------------

-- lctrs :: [Rule] -> [FId] -> M.Map FId Sort -> M.Map FId Int -> M.Map FId FSymbolType -> M.Map VId Sort -> LCTRS
createLctrs
  :: [Theory]
  -> (Sort -> Bool)
  -> [(f, (([(Identifier, Sort)], Sort), Text))]
  -> S.Set f
  -> S.Set Sort
  -> S.Set f
  -> [Rule f v]
  -> LCTRS f v
createLctrs = LCTRS

emptyLCTRS :: LCTRS f v
emptyLCTRS = LCTRS [] (const False) [] S.empty S.empty S.empty []

getRules :: LCTRS f v -> [Rule f v]
getRules = rules

-- getFuns :: LCTRS f v -> M.Map f (Int, Maybe [Sort])
getFuns :: LCTRS f v -> S.Set f
getFuns = funs

-- getRuleVars :: LCTRS f v -> M.Map (Rule f v) (M.Map v Sort)
-- getRuleVars = ruleVars

-- let m = view varSort l in
-- case M.lookup r m of
--   Nothing -> error "Could not determine the sorts of variables for a rule."
--   Just vM -> vM

getTheoryFuns :: LCTRS f v -> S.Set f
getTheoryFuns = theoryFuns

-- getValueCheck :: LCTRS f v -> (f -> Bool)
-- getValueCheck = valueCheck

-- getValueSort :: LCTRS f v -> (f -> Maybe Sort)
-- getValueSort = valueSort

getTheories :: LCTRS f v -> [Theory]
getTheories = theories

getTheorySorts :: LCTRS f v -> Sort -> Bool
getTheorySorts = isTheorySort

-- getDefines :: LCTRS f v -> M.Map f (([(v, Sort)], Sort), T.Term f v)
getDefines :: LCTRS f v -> [(f, (([(Identifier, Sort)], Sort), Text))]
getDefines = defines

getSorts :: LCTRS f v -> S.Set Sort
getSorts = sorts

updateRules :: LCTRS f v -> [Rule f v] -> LCTRS f v
updateRules lctrs val = lctrs{rules = val}

-- updateFuns :: LCTRS f v -> M.Map f (Int, Maybe [Sort]) -> LCTRS f v
updateFuns :: LCTRS f v -> S.Set f -> LCTRS f v
updateFuns lctrs val = lctrs{funs = val}

-- updateRuleVars :: LCTRS f v -> M.Map (Rule f v) (M.Map v Sort) -> LCTRS f v
-- updateRuleVars lctrs val = lctrs { ruleVars = val }

-- updateVarSort
--   :: (Ord v, Ord f) => LCTRS f v -> Rule f v -> M.Map v Sort -> LCTRS f v
-- updateVarSort lctrs r val = lctrs { ruleVars = (M.insert r val rV) }
--   where rV = ruleVars lctrs

updateTheoryFuns :: LCTRS f v -> S.Set f -> LCTRS f v
updateTheoryFuns lctrs val = lctrs{theoryFuns = val}

-- updateValueCheck :: LCTRS f v -> (f -> Bool) -> LCTRS f v
-- updateValueCheck lctrs val = lctrs { valueCheck = val }

-- updateValueSort :: LCTRS f v -> (f -> Maybe Sort) -> LCTRS f v
-- updateValueSort lctrs val = lctrs { valueSort = val }

updateTheories :: LCTRS f v -> [Theory] -> LCTRS f v
updateTheories lctrs val = lctrs{theories = val}

isLeftLinear :: (Ord v) => LCTRS f v -> Bool
isLeftLinear = all R.isLeftLinear . getRules

isLinear :: (Ord v) => LCTRS f v -> Bool
isLinear = all R.isLinear . getRules

isGround :: (Ord v) => LCTRS f v -> Bool
isGround = all R.isGround . getRules

-- varSort :: (Show v, Ord v, Ord f) => LCTRS f v -> Rule f v -> v -> Sort
-- varSort l r v = case M.lookup v =<< lookupL getRuleVars l r of
--   Nothing ->
--     error $ "Could not determine the sort of the variable " ++ show v ++ "."
--   Just vM -> vM

-- ruleSort :: (Ord v, Ord f) => LCTRS f v -> Rule f v -> M.Map v Sort
-- ruleSort l r = case lookupL getRuleVars l r of
--   Nothing -> error "Could not determine the sorts of variables for a rule."
--   Just vM -> vM

-- fSort :: Ord f => LCTRS f v -> f -> Maybe SortAnnotation
-- fSort l f = lookupL getFuns l f

-- fArity :: Ord f => LCTRS f v -> f -> Maybe Int
-- fArity l f = fst <$> lookupL getFuns l f

fType :: (Ord f) => LCTRS f v -> f -> FunType
fType l f
  | f `S.member` getFuns l = T.TermSym
  | otherwise = T.TheorySym

allFunsInRules :: (Ord f) => LCTRS f v -> [f]
allFunsInRules = concatMap R.funs . getRules

-- allTheoryFuns :: (Ord f) => LCTRS (FId f) v -> [FId f]
-- allTheoryFuns = S.toList . S.filter T.isLogic . funs

allVars :: (Ord v) => LCTRS f v -> [v]
allVars = concatMap R.vars . getRules

rename :: (Ord v', Ord f) => (v -> v') -> LCTRS f v -> LCTRS f v'
rename renaming lc =
  createLctrs
    (getTheories lc)
    (getTheorySorts lc)
    (getDefines lc)
    -- (renameDefineValue <$> getDefines lc)
    (getTheoryFuns lc)
    (getSorts lc)
    (getFuns lc)
    [R.rename renaming rule | rule <- getRules lc]

-- where
--  renameDefineValue ((vss, outsort), term) =
--    ((map (first renaming) vss, outsort), T.mapVars renaming term)

checkLhsRules :: LCTRS (FId f) v -> Bool
checkLhsRules lc = all (check . R.lhs) $ getRules lc
 where
  check (T.Var _) = False
  check (T.TermFun _ _) = True
  check (T.TheoryFun _ _) = False

combineCRSeqs :: CRSeq f v -> StepAnn -> CRSeq f v -> CRSeq f v
combineCRSeqs (CRSeq (endTerm, seq)) stepAnn (CRSeq (endTerm', seq')) =
  CRSeq (endTerm, seq ++ (stepAnn, endTerm') : seq')

definedSyms :: (Ord f) => LCTRS f v -> S.Set f
definedSyms lctrs =
  S.fromList . mapMaybe definedSym $ getRules lctrs
 where
  definedSym rule = case R.lhs rule of
    T.Fun _ f _ -> return f
    T.Var _ -> Nothing

----------------------------------------------------------------------------------------------------
-- SExpr
----------------------------------------------------------------------------------------------------

-- instance (ToSExpr f, ToSExpr v) => ToSExpr (LCTRS f v) where
--   toSExpr lctrs =
--     P.printf
--       "Signature = {%s}\n"
--       (intercalate ", " $ map toSExpr $ S.toList $ getFuns lctrs)
--       ++ P.printf
--         "Rules = {%s}\n"
--         (intercalate ", " $ map toSExpr $ getRules lctrs)

definesToSExpr
  :: (Pretty f, Pretty v, ToSExpr f, ToSExpr v)
  => LCTRS f v -- M.Map f (([(v, Sort)], Sort), T.Term f v)
  -> [(String, [(String, SMT.SExpr)], SMT.SExpr, SMT.SExpr)]
definesToSExpr lctrs =
  let defines = getDefines lctrs
  in  map
        ( \(f, ((vss, s), t)) ->
            ( show $ pretty f
            , -- , map (\(v, s) -> (flip showsSExpr "" $ toSExpr v, SMT.const $ show $ pretty s)) vss
              map (\(v, s) -> (flip showsSExpr "" $ SMT.const (show $ pretty v), SMT.const $ show $ pretty s)) vss
            , SMT.const $
                show $
                  pretty
                    s
            , -- , toSExpr t
              SMT.const $
                show $
                  pretty t
            )
        )
        defines

-- where
-- termToSExpr (T.Var v) = toSExpr v
-- termToSExpr (T.Fun f []) = toSExpr f
-- termToSExpr (T.Fun f args) =
--   SMT.app (toSExpr f) (map termToSExpr args)

----------------------------------------------------------------------------------------------------
-- pretty printing
----------------------------------------------------------------------------------------------------

prettyTheories :: [Theory] -> Doc ann
prettyTheories theories = hsep (punctuate comma (map (pretty . show) theories))

prettyDefines
  -- :: (Pretty f, Pretty v)
  -- => M.Map f (([(v, Sort)], Sort), T.Term f v)
  :: (Pretty f)
  => [(f, (([(Identifier, Sort)], Sort), Text))]
  -> Doc ann
prettyDefines defines =
  vsep
    ( map
        ( \(f, ((vss, s), t)) ->
            pretty f
              <> ":"
              <+> prettyArgs vss
              <+> pretty s
              <+> pretty t -- T.prettyTerm t
        )
        defines
    )
 where
  prettyArgs [] = ""
  prettyArgs args =
    parens
      (hsep $ map (\(v, s) -> parens $ pretty v <+> pretty s) args)

prettySignature :: (Pretty f) => S.Set f -> Doc ann
prettySignature signature =
  hsep $ punctuate comma (map pretty $ S.toList signature)

prettySignatureFId :: (Pretty f) => S.Set (FId f) -> Doc ann
prettySignatureFId signature =
  vsep
    ( map
        ( \f ->
            pretty f
              <> ":"
              <+> prettySortAnnotation (getFIdSort f)
        )
        $ S.toList signature
    )

prettySorts :: S.Set Sort -> Doc ann
prettySorts sorts =
  hsep $ punctuate comma (map pretty $ S.toList sorts)

prettyLCTRS :: (Pretty f, Pretty v) => LCTRS f v -> Doc ann
prettyLCTRS lctrs =
  "LCTRS"
    <> line
    <> indent
      2
      ( vsep $
          mapOnMaybes
            (nest 2 . vsep)
            [ Just ["Theories", prettyTheories (getTheories lctrs)]
            , let ds = getDefines lctrs
              in  if null ds then Nothing else Just ["Defines", prettyDefines ds]
            , let ss = getSorts lctrs
              in  if null ss then Nothing else Just ["Sorts", prettySorts ss]
            , Just ["Signature", prettySignature (getFuns lctrs)]
            , Just ["Rules", R.prettyRules (getRules lctrs)]
            ]
      )

prettyLCTRSDefault
  :: (Pretty f, Pretty v) => LCTRS (FId f) (VId v) -> Doc ann
prettyLCTRSDefault lctrs =
  "LCTRS"
    <> line
    <> indent
      2
      ( vsep $
          mapOnMaybes
            (nest 2 . vsep)
            [ Just ["Theories", prettyTheories (getTheories lctrs)]
            , let ds = getDefines lctrs
              in  if null ds then Nothing else Just ["Defines", prettyDefines ds]
            , let ss = getSorts lctrs
              in  if null ss then Nothing else Just ["Sorts", prettySorts ss]
            , Just ["Signature", prettySignatureFId (getFuns lctrs)]
            , Just ["Rules", R.prettyRulesFId (getRules lctrs)]
            ]
      )

instance Pretty StepAnn where
  pretty SinStep = "->"
  pretty RefStep = "->^="
  pretty RefTraStep = "->^*"
  pretty (ParStep Nothing) = "-||->"
  pretty (ParStep (Just poss)) = "-||->_" <> encloseSep "{" "}" ", " (map pretty $ S.toList poss)
  pretty MulStep = "-o->"

prettyCRSeq :: (Pretty f, Pretty v) => CRSeq f v -> Doc ann
prettyCRSeq (CRSeq (cterm, [])) = R.prettyCTerm cterm
prettyCRSeq (CRSeq (endTerm, seq@(_ : _))) = case reverse seq of
  [] -> R.prettyCTerm endTerm
  ((ann, startTerm) : rest) ->
    let (ann', steps) = go ann [] rest
    in  R.prettyCTerm startTerm
          <> line
          <> vsep
            (steps ++ [indent 2 $ pretty ann' <+> R.prettyCTerm endTerm])
 where
  go ann acc [] = (ann, acc)
  go ann acc ((nextAnn, t) : rest) =
    let e = indent 2 $ pretty ann <+> R.prettyCTerm t
    in  go nextAnn (acc ++ [e]) rest

prettyCRSeqFId :: (Pretty f, Pretty v) => CRSeq (FId f) v -> Doc ann
prettyCRSeqFId (CRSeq (cterm, [])) = R.prettyCTermFId cterm
prettyCRSeqFId (CRSeq (endTerm, seq@(_ : _))) = case reverse seq of
  [] -> R.prettyCTermFId endTerm
  ((ann, startTerm) : rest) ->
    let (ann', steps) = go ann [] rest
    in  R.prettyCTermFId startTerm
          <> line
          <> vsep
            (steps ++ [indent 2 $ pretty ann' <+> R.prettyCTermFId endTerm])
 where
  go ann acc [] = (ann, acc)
  go ann acc ((nextAnn, t) : rest) =
    let e = indent 2 $ pretty ann <+> R.prettyCTermFId t
    in  go nextAnn (acc ++ [e]) rest

-- | prettyARILCTRS @lctrs@ returns an ARI format LCTRS which is fully sorted
prettyARILCTRS :: (Ord v, Eq f, Pretty f, Pretty v) => LCTRS (FId f) (VId v) -> Doc ann
prettyARILCTRS lc =
  vsep $
    "(format LCTRS :smtlib 2.6)"
      : map (parens . ("theory " <>) . pretty . show) (getTheories lc)
      ++ map pDefine (getDefines lc)
      ++ map (parens . ("sort " <>) . pretty) (S.toList $ getSorts lc)
      ++ map pFun (S.toList $ getFuns lc)
      ++ map pRule (getRules lc)
 where
  pDefine (f, ((vss, s), t)) =
    parens $
      "define-fun"
        <+> pretty f
        <+> parens
          (hsep $ map (\(v, s) -> parens $ pretty v <+> pretty s) vss)
        <+> pretty s
        <+> pretty t -- pTerm t
  pFun f =
    let sa = getFIdSort f
    in  parens $
          "fun"
            <+> pretty f
            <+> pSortAnn (inSort sa, outSort sa)

  pSortAnn ([], s) = pretty s
  pSortAnn (ss, s) =
    parens $ "->" <+> hsep (map pretty ss) <+> pretty s

  pRule rule =
    parens $
      "rule"
        <+> pTerm (R.lhs rule)
        <+> pTerm (R.rhs rule)
        <+> ":guard"
        <+> pTerm (R.collapseGuardToTerm $ R.guard rule)
        <+> ":var"
        <+> pVarAnn rule

  pVarAnn rule =
    let vs = nubOrd (R.vars rule)
    in  parens (hsep $ map (\v -> parens $ pretty v <+> pretty (sort v)) vs)

  pTerm (T.Var v) = pretty v
  pTerm (T.Fun _ f []) = pretty f
  pTerm (T.Fun _ f as) = parens $ pretty f <+> hsep (map pTerm as)
