{- |
Module      : Term
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides functions for identifiers of function symbols
and variables as well as terms.
-}
module Data.LCTRS.Term (
  -- * Types and TypeClasses
  FunType (..),
  VarType (..),
  -- T.Term (..), -- NOTE: this lacks documentation at the moment!
  Term (..),
  pattern Val,
  pattern TermFun,
  pattern TheoryFun,

  -- * Core Term Functions
  var,
  val,
  termFun,
  theoryFun,
  fun,
  foldTerm,
  mapTerm,
  cEq,
  eq,
  conj,
  top,
  disj,
  grT,
  neg,
  imp,
  isVar,
  vars,
  funs,
  termFuns,
  mapVars,
  mapFuns,
  renameFresh,
  freshVariables,
  subtermAt,
  replaceAt,
  fPos,
  vPos,
  termPos,
  properSubterms,
  subterms,
  getRoot,

  -- * Predicates on Ids and Terms
  isConst,
  isValue,
  isLogicTerm,
  isLogicConstraint,
  isConstrainedEquation,
  isLinear,
  isGround,

  -- * Pretty Printing

  -- prettyTerm,
  prettyTermFId,
) where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Control.Monad (guard)
import Control.Monad.State.Strict (StateT, evalStateT, gets, modify)
import Data.LCTRS.FIdentifier (FId (..))
import Data.LCTRS.Position (Pos (pos), epsilon, position)
import Data.LCTRS.Sort
import Data.LCTRS.VIdentifier (VId, freshV)
import qualified Data.Map.Strict as M
import Data.Monad (MonadFresh (..))
import qualified Data.MultiSet as MS
import Data.SMT (boolSort)
import qualified Data.Set as S
import Prettyprinter

----------------------------------------------------------------------------------------------------
-- types
----------------------------------------------------------------------------------------------------

-- | Type for terms.
data Term f v = Var !v | Fun !FunType !f ![Term f v]
  deriving (Eq, Ord, Show)

{-# COMPLETE Var, TermFun, TheoryFun #-}
data FunType = TermSym | TheorySym
  deriving (Eq, Ord, Show)

data VarType = TermVar | TheoryVar
  deriving (Eq, Ord, Show)

instance Pretty FunType where
  pretty = pretty . show

-- data Fun f v = TermFun f [Term f v] | TheoryFun f [Term f v]
--   deriving (Eq, Ord, Show)

pattern Val :: f -> Term f v
pattern Val v = Fun TheorySym v []

pattern TermFun :: f -> [Term f v] -> Term f v
pattern TermFun f args = Fun TermSym f args

pattern TheoryFun :: f -> [Term f v] -> Term f v
pattern TheoryFun f args = Fun TheorySym f args

instance (Pretty f, Pretty v) => Pretty (Term f v) where
  pretty (Var v) = pretty v
  pretty (Fun _ f []) = pretty f
  pretty (Fun _ f args) =
    pretty f
      <> parens (hsep $ punctuate comma (map pretty args))

instance (Sorted f, Sorted v) => Sorted (Term f v) where
  sort (Var v) = sort v
  sort (Fun _ f _) = sort f

----------------------------------------------------------------------------------------------------
-- core term functions
----------------------------------------------------------------------------------------------------

var :: v -> Term f v
var = Var

val :: f -> Term f v
val = Val

termFun :: f -> [Term f v] -> Term f v
termFun = TermFun

theoryFun :: f -> [Term f v] -> Term f v
theoryFun = TheoryFun

fun :: FunType -> f -> [Term f v] -> Term f v
fun TermSym = termFun
fun TheorySym = theoryFun

-- | 'conj' @t1@ @t2@ constructs a conjunction term with arguments @t1@ @t2@.
conj :: (Eq v, Eq f) => Term (FId f) v -> Term (FId f) v -> Term (FId f) v
conj t1 t2
  | t1 == top = t2
  | t2 == top = t1
  | otherwise = theoryFun Conj [t1, t2]

-- | 'disj' @t1@ @t2@ constructs a disjunction term with arguments @t1@ @t2@.
disj :: Term (FId f) v -> Term (FId f) v -> Term (FId f) v
disj t1 t2 = theoryFun Disj [t1, t2]

-- | 'grT' @sortAnn@ @t1@ @t2@ constructs a "greater than" term with arguments @t1@ @t2@.
grT :: SortAnnotation -> Term (FId f) v -> Term (FId f) v -> Term (FId f) v
grT s t1 t2 = theoryFun (GrT s) [t1, t2]

-- | 'eq' @sortAnn@ @t1@ @t2@ constructs an equality term with arguments @t1@ @t2@.
eq :: SortAnnotation -> Term (FId f) v -> Term (FId f) v -> Term (FId f) v
eq s t1 t2 = theoryFun (Eq s) [t1, t2]

-- | 'top' returns a term with the semantics of the logical symbol true.
top :: Term (FId f) v
top = val Top

-- | 'neg' @t@ returns a negation term with @t@ as argument.
neg :: Term (FId f) v -> Term (FId f) v
neg t = theoryFun Neg [t]

-- | 'imp' @t1@ @t2@ returns an implication term with arguments @t1@ @t2@.
imp :: Term (FId f) v -> Term (FId f) v -> Term (FId f) v
imp t1 t2 = theoryFun Imp [t1, t2]

-- | 'imp' @t1@ @t2@ returns an implication term with arguments @t1@ @t2@.
cEq :: (Sorted v) => Term (FId f) v -> Term (FId f) v -> Term (FId f) v
cEq t1 t2 = theoryFun (CEquation sa) [t1, t2]
 where
  sa = sortAnnotation [sort t1, sort t2] $ sort t1

-- | 'isVar' @t@ checks if @t@ is a variable.
isVar :: Term f v -> Bool
isVar (Var _) = True
isVar _ = False

foldTerm :: (v -> a) -> (FunType -> f -> [a] -> a) -> Term f v -> a
foldTerm var _ (Var v) = var v
foldTerm var fun (Fun typ f ts) = fun typ f (fmap (foldTerm var fun) ts)

mapTerm :: (f -> f') -> (v -> v') -> Term f v -> Term f' v'
mapTerm fun var = foldTerm (Var . var) (\typ f -> Fun typ (fun f))

-- | 'vars' @t@ returns all variable identifiers in @t@.
vars :: Term f v -> [v]
vars term = foldTerm (:) (\_ _ -> foldr (.) id) term []

-- | 'funs' @t@ returns all function symbol identifiers in @t@.
funs :: Term f v -> [f]
funs term = foldTerm (const id) (\_ f args -> (f :) . foldr (.) id args) term []

-- | 'termFuns' @t@ returns all term function symbol identifiers in @t@.
termFuns :: Term f v -> [f]
termFuns term =
  foldTerm
    (const id)
    (\typ f args -> if typ == TermSym then (f :) . foldr (.) id args else foldr (.) id args)
    term
    []

-- | 'mapVars' @renaming@ @t@ renames all variable identifiers in @t@ using @renaming@.
mapVars :: (v -> v') -> Term f v -> Term f v'
mapVars = mapTerm id

-- | 'mapFuns' @renaming@ @t@ renames all function symbol identifiers in @t@ using @renaming@.
mapFuns :: (f -> f') -> Term f v -> Term f' v
mapFuns fun = mapTerm fun id

renameFresh :: (Ord v, MonadFresh m) => Term f (VId v) -> m (Term f (VId v))
renameFresh term = evalStateT (freshVariables term) M.empty

freshVariables
  :: (Ord v, MonadFresh m)
  => Term f (VId v)
  -> StateT (M.Map (VId v) (VId v)) m (Term f (VId v))
freshVariables (Var v) = var <$> freshVar v
 where
  freshVar :: (Ord v, MonadFresh m) => VId v -> StateT (M.Map (VId v) (VId v)) m (VId v)
  freshVar v = do
    cached <- gets (M.lookup v)
    case cached of
      Just var -> return var
      Nothing -> do
        val <- freshV (sort v) <$> freshInt
        modify $ M.insert v val
        return val
freshVariables (Fun typ f args) = fun typ f <$> mapM freshVariables args

-- | 'properSubterms' @t@ computes all proper subterms of a term @t@.
properSubterms :: Term f v -> [Term f v]
properSubterms (Var _) = []
properSubterms (Fun _ _ args) = concatMap subterms args

-- | 'subterms' @t@ computes all subterms of a term @t@.
subterms :: Term f v -> [Term f v]
subterms t = t : properSubterms t

-- | 'subtermAt' @t@ @pos@ returns the subterm at position @pos@ in @t@.
subtermAt :: Term f v -> Pos -> Maybe (Term f v)
subtermAt term = go term . pos
 where
  go t [] = Just t
  go (Fun _ _ ts) (p : ps) | p >= 0 && p < length ts = go (ts !! p) ps
  go _ _ = Nothing

-- | 'replaceAt' @t1@ @pos@ @t2@ replaces the subterm at position @pos@ in @t1@ with the term @t2@.
replaceAt :: Term f v -> Pos -> Term f v -> Maybe (Term f v)
replaceAt term p = go term (pos p)
 where
  go _ [] t' = Just t'
  go (Fun typ f ts) (i : p) t' = do
    guard (i >= 0 && i < length ts)
    case splitAt i ts of
      (_, []) -> Nothing
      (ts1, t : ts2) -> do
        t <- go t p t'
        return $ Fun typ f (ts1 ++ t : ts2)
  go _ _ _ = Nothing

-- | 'fPos' @t@ returns the set of function symbol positions in term @t@.
fPos :: Term f v -> S.Set Pos
fPos (Var _) = S.empty
fPos (Fun _ _ args) =
  mconcat $
    S.singleton epsilon
      : [S.map (position [i] <>) poss | (i, arg) <- zip [0 ..] args, let poss = fPos arg]

-- | 'vPos' @t@ returns the set of variable positions in term @t@.
vPos :: Term f v -> S.Set Pos
vPos (Var _) = S.singleton epsilon
vPos (Fun _ _ args) =
  mconcat
    [S.map (position [i] <>) poss | (i, arg) <- zip [0 ..] args, let poss = fPos arg]

-- | 'termPos' @t@ returns the set of positions in term @t@.
termPos :: Term f v -> S.Set Pos
termPos term = fPos term `S.union` vPos term

----------------------------------------------------------------------------------------------------
-- predicates on identifiers and terms
----------------------------------------------------------------------------------------------------

-- | 'isConst' @t@ checks if @t@ is a constant term.
isConst :: Term f v -> Bool
isConst (Fun _ _ []) = True
isConst _ = False

-- | 'isValue' @t@ checks if @t@ is a value.
isValue :: Term (FId f) v -> Bool
isValue (Val _) = True
isValue _ = False

-- | 'isLinear' @t@ checks if @t@ is linear.
isLinear :: (Ord v) => Term f v -> Bool
isLinear = all (\(_, c) -> c == 1) . MS.toOccurList . MS.fromList . vars

-- | 'isGround' @t@ checks if @t@ is ground.
isGround :: Term f v -> Bool
isGround = null . vars

-- | 'isLogicTerm' @t@ checks if @t@ is a logic term.
isLogicTerm :: Term (FId f) v -> Bool
isLogicTerm (Var _) = True
isLogicTerm (Fun TheorySym _ args) = all isLogicTerm args
isLogicTerm _ = False

-- | 'isLogicConstraint' @t@ checks if @t@ is a logic constraint.
isLogicConstraint :: Term (FId f) (VId v) -> Bool
isLogicConstraint (Var v) = sort v == boolSort
isLogicConstraint (Fun TheorySym f args) =
  sort f == boolSort && all isLogicTerm args
isLogicConstraint _ = False

-- terms

-- -- | 'isVariantOf' @t1@ @t2@ checks if @t1@ and @t2@ are variants.
-- isVariantOf :: (Eq f, Ord v) => Term f v -> Term f v -> Bool
-- isVariantOf = T.isVariantOf

-- | 'isConstrainedEquation' @t@ checks if @t@ is a constrained equation.
isConstrainedEquation :: Term (FId f) v -> Bool
isConstrainedEquation (Fun _ (CEquation _) _) = True
isConstrainedEquation _ = False

{- | 'getRoot' @term@ returns either the root function symbol wrapped in Left or
   the root variable wrapped in Right
-}
getRoot :: Term f v -> Either f v
getRoot (Fun _ f _) = Left f
getRoot (Var v) = Right v

----------------------------------------------------------------------------------------------------
-- pretty printing
----------------------------------------------------------------------------------------------------

-- -- | pretty printing function for terms.
-- prettyTerm :: (Pretty f, Pretty v) => T.Term f v -> Doc ann
-- prettyTerm (T.Var v) = pretty v
-- prettyTerm (T.Fun f []) = pretty f
-- prettyTerm (T.Fun f args) =
--   pretty f
--     <> parens (hsep $ punctuate comma (map prettyTerm args))

-- | pretty printing function for terms, but inlines some special function symbols.
prettyTermFId :: (Pretty f, Pretty v) => Term (FId f) v -> Doc ann
prettyTermFId (Var v) = pretty v
prettyTermFId (Fun _ (CEquation _) [a, b]) =
  prettyTermFId a <+> "≈" <+> prettyTermFId b
prettyTermFId (Fun _ (Eq _) [a, b]) =
  prettyTermFId a <+> "=" <+> prettyTermFId b
prettyTermFId (Fun _ Imp [a, b]) =
  prettyTermFId a <+> "=>" <+> prettyTermFId b
prettyTermFId (Fun _ f []) = pretty f
prettyTermFId (Fun _ f args) =
  pretty f
    <> parens (hsep $ punctuate comma (map prettyTermFId args))
