{- |
Module      : SortLCTRS
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides functions to attach sort information to an LCTRS.
-}
module Type.SortLCTRS (
  sortLCTRS,
  inferRemainingPSorts,
  IntermediateFId,
  IntermediateVId,
  IntermediateInput,
  DefaultFId (deffid),
  DefaultVId (defvid),
  defaultFId,
  defaultVId,
) where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Control.Applicative ((<|>))
import Control.Monad (zipWithM)
import Control.Monad.State (
  MonadTrans (lift),
  StateT,
  evalStateT,
  get,
  modify,
 )
import Data.LCTRS.FIdentifier (
  FId,
  fIdParse,
  getFIdSort,
  updateFIdSort,
 )
import Data.LCTRS.Guard (mapGuard)
import Data.LCTRS.LCTRS (
  LCTRS,
  createLctrs,
  getDefines,
  getFuns,
  getRules,
  getSorts,
  getTheories,
  getTheoryFuns,
  getTheorySorts,
 )
import Data.LCTRS.Rule (
  Rule,
  createRule,
  guard,
  lhs,
  rhs,
 )
import Data.LCTRS.Sort (
  Attr (..),
  Sort (AttrSort, LitSort, PSort),
  SortAnnotation,
  containsPolySort,
  inSort,
  isInstanceSortAnnOf,
  outSort,
  sortAnnotation,
 )
import Data.LCTRS.Term (
  Term (..),
  fun,
  var,
 )
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (
  VId,
  vId,
 )
import Data.List (sortBy)
import qualified Data.Map.Strict as M
import Data.SExpr (ToSExpr (toSExpr), modifyAtom)
import Data.SMT (Identifier, Property (..))
import qualified Data.Set as S
import Data.String (IsString)
import Fmt (fmt, (+|), (|+))
import Parser.Ari (getThSorts)
import qualified Parser.Ari as Ari
import Prettyprinter (Pretty (pretty))

----------------------------------------------------------------------------------------------------
-- intermediate types
----------------------------------------------------------------------------------------------------

type IntermediateFId = (Identifier, SortAnnotation)
type IntermediateVId = (Identifier, Int, Sort)
type IntermediateInput = Ari.Input IntermediateFId IntermediateVId

----------------------------------------------------------------------------------------------------
-- default types for identifierS
----------------------------------------------------------------------------------------------------

newtype DefaultFId = DefaultFId {deffid :: Identifier}
  deriving (Eq, Ord, IsString)
newtype DefaultVId = DefaultVId {defvid :: (Identifier, Int)}
  deriving (Eq, Ord)

defaultFId :: Identifier -> DefaultFId
defaultFId = DefaultFId

defaultVId :: (Identifier, Int) -> DefaultVId
defaultVId = DefaultVId

instance Pretty DefaultFId where
  pretty = pretty . deffid

instance Pretty DefaultVId where
  pretty = pretty . fst . defvid

instance ToSExpr DefaultFId where
  toSExpr = toSExpr . deffid

instance ToSExpr DefaultVId where
  -- toSExpr = SMT.const . unpack . fst . defvid
  toSExpr v =
    let (t, i) = defvid v
    in  -- in  SMT.const $ unpack $ fromIdentifier t <> "_" <> pack (show i)
        modifyAtom (\s -> fmt $ s |+ "_" +| i |+ "") (toSExpr t)

----------------------------------------------------------------------------------------------------
-- sort rules and terms in an LCTRS
----------------------------------------------------------------------------------------------------

sortLCTRS
  :: IntermediateInput
  -> (Rule IntermediateFId IntermediateVId -> Sort -> Sort)
  -> LCTRS (FId DefaultFId) (VId DefaultVId)
sortLCTRS input sortToInfSort =
  let
    defines' = map processDefine (getDefines lctrs)
    theoryFuns' = S.map extractSort (getTheoryFuns lctrs)
    funs' = S.map extractSort (getFuns lctrs)
    rules' = map (sortRule (getThSorts input) toFId toVId) (getRules lctrs)
  in
    createLctrs
      theories
      (getTheorySorts lctrs)
      defines'
      theoryFuns'
      (getSorts lctrs)
      funs'
      rules'
 where
  theories = getTheories lctrs

  processDefine ((id, _), ((vss, s), t)) =
    let
      sa = sortAnnotation (map snd vss) s
      f' = fIdParse theories (defaultFId id) sa
    in
      -- vMap = M.fromList vss
      -- infervId v@(id, i, _) = case M.lookup v vMap of
      --   Just s -> vId (defaultVId (id, i)) s
      --   Nothing ->
      --     error $
      --       "SortLCTRS.hs: Cannot infer sort of variable " ++ unpack id ++ " in define-fun of " ++ unpack id

      -- inferredT = sortDefineTerm infervId t
      ( f'
      , -- ( (map (\((id, i, _), s) -> (vId (defaultVId (id, i)) s, s)) vss, s)

        ( (vss, s)
        , t
        )
      )

  -- NOTE: we do not infer sorts for function symbols as they are not needed in defines
  -- sortDefineTerm infer (Var v) = var $ infer v
  -- sortDefineTerm infer (Fun (id, sa) args) = fun (fId (defaultFId id) sa Logic) $ map (sortDefineTerm infer) args

  extractSort (id, sa) = fIdParse theories (defaultFId id) sa

  lctrs = Ari.getLCTRS input
  vC = Ari.getValueCheck input
  vS = Ari.getValueSort input

  sortAnnToInfSortAnn rule sa = sortToInfSort rule <$> sa

  lookupF _ f@(id, _) | vC f =
    case vS id of
      Nothing -> error $ "No sort information for value " ++ show id ++ " available."
      Just s -> sortAnnotation [] s
  lookupF rule f@(id, s) =
    let
      inferredSortAnn = sortAnnToInfSortAnn rule s
      thSyms = Ari.getThSorts input
    in
      if isLogic f && containsPolySort inferredSortAnn
        then case M.lookup id thSyms of
          Nothing -> error $ "SortLCTRS.hs: sort annotation for symbol " ++ show id ++ " not found"
          Just sas ->
            case inferRemainingPSorts inferredSortAnn sas of
              Nothing -> error $ "SortLCTRS.hs: remaining polymorphic sorts for symbol " ++ show id ++ " not typeable"
              Just sa -> sa
        else inferredSortAnn

  isLogic iid@(id, _) = id `M.member` Ari.getThSorts input || vC iid

  toFId rule iid@(id, _) = fIdParse theories (defaultFId id) sa
   where
    sa = lookupF rule iid

  toVId rule (n, i, s) = vId (defaultVId (n, i)) (sortToInfSort rule s)

sortRule
  :: M.Map Identifier [(SortAnnotation, Maybe Property, SortAnnotation -> Bool)]
  -> ( Rule IntermediateFId IntermediateVId
       -> IntermediateFId
       -> FId DefaultFId
     )
  -> ( Rule IntermediateFId IntermediateVId
       -> IntermediateVId
       -> VId DefaultVId
     )
  -> Rule IntermediateFId IntermediateVId
  -> Rule (FId DefaultFId) (VId DefaultVId)
sortRule theoryProperties tFId tVId rule =
  createRule
    (sT $ lhs rule)
    (sT $ rhs rule)
    (mapGuard sT $ guard rule)
 where
  sT = sortTerm theoryProperties (tFId rule) (tVId rule)

sortTerm
  :: M.Map Identifier [(SortAnnotation, Maybe Property, SortAnnotation -> Bool)]
  -> (IntermediateFId -> FId DefaultFId)
  -> (IntermediateVId -> VId DefaultVId)
  -> Term IntermediateFId IntermediateVId
  -> Term (FId DefaultFId) (VId DefaultVId)
sortTerm _ _ tVId (Var v) = var $ tVId v
sortTerm theoryProperties tFId tVId (Fun typ f@(id, parsedSortAnnotation) args) =
  -- NOTE: we have to take {left, right}-associativity, chainability
  --       , pairwise into account here
  -- see also https://smt-lib.org/Theories/Core.smt2
  case id `M.lookup` theoryProperties of
    Nothing -> fun typ newFId args'
    Just originalSorts ->
      case findOriginalMatchingSortAnnotation parsedSortAnnotation originalSorts of
        Nothing -> fun typ newFId args'
        Just (originalSort, property) ->
          let
            instantiatedSortAnn = instantiateOrigWithInferSortAnn inferredSortAnnoation originalSort
            constructor = fun typ $ updateFIdSort newFId instantiatedSortAnn
            originalArity = length (inSort originalSort)
          in
            applyProperty property constructor originalArity args'
 where
  newFId = tFId f
  inferredSortAnnoation = getFIdSort newFId

  sT = sortTerm theoryProperties tFId tVId
  args' = map sT args

instantiateOrigWithInferSortAnn :: SortAnnotation -> SortAnnotation -> SortAnnotation
instantiateOrigWithInferSortAnn inferred original =
  sortAnnotation croppedInSorts (outSort inferred)
 where
  croppedInSorts = zipWith const (inSort inferred) (inSort original)

findOriginalMatchingSortAnnotation
  :: SortAnnotation
  -> [(SortAnnotation, Maybe Property, SortAnnotation -> Bool)]
  -> Maybe (SortAnnotation, Property)
findOriginalMatchingSortAnnotation _ [] = Nothing
findOriginalMatchingSortAnnotation finalSA origSAs
  | any (\(sa, _, _) -> isInstanceSortAnnOf finalSA sa) origSAs = Nothing
findOriginalMatchingSortAnnotation finalSA origSAs =
  -- NOTE: try to find correct sort annotation
  case sortBy longestArity $
    filter (\(sa, _, _) -> isPrefixInstanceOf finalSA sa) origSAs of
    [] -> Nothing
    ((_, Nothing, _) : _) -> Nothing
    ((sa, Just prop, _) : _) -> Just (sa, prop)
 where
  isPrefixInstanceOf finalSA origSA =
    isInstanceSortAnnOf
      (sortAnnotation (take (length $ inSort origSA) $ inSort finalSA) (outSort finalSA))
      origSA

  longestArity (sa1, _, _) (sa2, _, _) =
    compare (length $ inSort sa2) (length $ inSort sa1)

applyProperty
  :: (Eq v, Eq f)
  => Property
  -> ([Term (FId f) v] -> Term (FId f) v)
  -> Int
  -> [Term (FId f) v]
  -> Term (FId f) v
applyProperty LAssoc constr originalArity args =
  let (initialArguments, rest) = splitAt originalArity args
  in  foldl (\acc e -> constr [acc, e]) (constr initialArguments) rest
applyProperty RAssoc constr originalArity args =
  let (rest, initialArguments) = splitAt (length args - originalArity) args
  in  foldr (\e acc -> constr [e, acc]) (constr initialArguments) rest
applyProperty Chain _ _ [] =
  error "SortLCTRS.hs: no chainable construction possible (empty arguments)."
applyProperty Chain _ _ [_] =
  error "SortLCTRS.hs: no chainable construction possible (single argument)."
applyProperty Chain constr _ (a : args) = foldr1 T.conj $ go a args
 where
  go _ [] = []
  go a (a' : args) = constr [a, a'] : go a' args
applyProperty Pairw _ _ [] =
  error "SortLCTRS.hs: no pairwise construction possible (empty arguments)."
applyProperty Pairw _ _ [_] =
  error "SortLCTRS.hs: no pairwise construction possible (single argument)."
applyProperty Pairw constr _ (a : args) = foldr1 T.conj . concat $ go a args
 where
  go _ [] = []
  go a (a' : args) = [constr [a, a''] | a'' <- a' : args] : go a' args

----------------------------------------------------------------------------------------------------
-- infer the remaining sorts if there are still polymorphic sorts present after inference
----------------------------------------------------------------------------------------------------

inferRemainingPSorts
  :: SortAnnotation -> [(SortAnnotation, Maybe Property, SortAnnotation -> Bool)] -> Maybe SortAnnotation
inferRemainingPSorts sa sas = foldr ((<|>) . infer sa) Nothing correctLengths
 where
  correctLengths =
    let prop = length . inSort
    in  [sa' | (sa', _, check) <- sas, prop sa' == prop sa, check sa]

  infer :: SortAnnotation -> SortAnnotation -> Maybe SortAnnotation
  infer sa sa' = flip evalStateT M.empty $ do
    let (inSs, outS) = (inSort sa, outSort sa)
    let (inSs', outS') = (inSort sa', outSort sa')
    oS <- inferSort outS outS'
    iS <- zipWithM inferSort inSs inSs'
    return $ sortAnnotation iS oS

  inferSort :: Sort -> Sort -> StateT (M.Map Sort Sort) Maybe Sort
  inferSort (LitSort s) (LitSort s')
    | s == s' = lift $ Just (LitSort s)
    | otherwise = lift Nothing
  inferSort s@(PSort _) s'@(LitSort _) = sortOfPSort s s'
  inferSort s@(LitSort _) s'@(PSort _) = sortOfPSort s' s
  inferSort s@(PSort _) s'@(AttrSort _ (AttrInt _)) = sortOfPSort s s'
  inferSort s@(AttrSort _ (AttrInt _)) s'@(PSort _) = sortOfPSort s' s
  inferSort s@(AttrSort n (AttrPol _)) s'@(AttrSort n' (AttrInt _)) | n == n' = sortOfPSort s s'
  inferSort s@(AttrSort n (AttrInt _)) s'@(AttrSort n' (AttrPol _)) | n == n' = sortOfPSort s' s
  inferSort _ _ = lift Nothing

  sortOfPSort :: Sort -> Sort -> StateT (M.Map Sort Sort) Maybe Sort
  sortOfPSort ps s = do
    m <- get
    case M.lookup ps m of
      Nothing -> do
        modify $ M.insert ps s
        lift $ Just s
      Just s' -> if s == s' then lift $ Just s' else lift Nothing
