{- |
Module      : Type Inference
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides functions to perform type inference on LCTRSs and infer
remaining sorts.
-}
module Type.TypeInference (
  deriveTypes,
) where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Control.Monad.State.Strict (
  StateT,
  get,
  gets,
  modify,
  runStateT,
 )
import Control.Monad.Trans (lift)
import Control.Monad.Union
import Data.LCTRS.FIdentifier
import Data.LCTRS.Guard
import Data.LCTRS.LCTRS
import Data.LCTRS.Rule
import Data.LCTRS.Sort
import Data.LCTRS.Term
import Data.LCTRS.VIdentifier
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, fromMaybe)
import Data.SMT (Identifier, boolSort)
import qualified Data.Union as DU
import qualified Parser.Ari as Ari
import Type.SortLCTRS (
  DefaultFId,
  DefaultVId,
  IntermediateFId,
  IntermediateInput,
  IntermediateVId,
  sortLCTRS,
 )

----------------------------------------------------------------------------------------------------
-- types
----------------------------------------------------------------------------------------------------

type EquivM = StateT TIState (UnionM Sort)

data TIState = TIState
  { sortNode :: M.Map Sort Node
  , vSort :: M.Map IntermediateVId Sort
  , valSort :: Identifier -> Maybe Sort
  , valCheck :: IntermediateFId -> Bool
  }

evalEquivM :: TIState -> EquivM a -> (Union Sort, (a, TIState))
evalEquivM state computation = run' $ runStateT computation state

----------------------------------------------------------------------------------------------------
-- core functionality
----------------------------------------------------------------------------------------------------

deriveTypes
  :: IntermediateInput
  -> LCTRS (FId DefaultFId) (VId DefaultVId)
deriveTypes input = do
  let sortToInfSort = typeInference input
  sortLCTRS input sortToInfSort

typeInference
  :: IntermediateInput
  -> (Rule IntermediateFId IntermediateVId -> Sort -> Sort)
typeInference inp rule =
  fromJust
    . M.lookup rule
    . M.fromList
    $ do
      rule <- getRules $ Ari.getLCTRS inp
      let (union, (_, state)) = eval $ do
            modify $ \st -> st{vSort = varAnn rule}
            equationsRule (lhs rule) (rhs rule) (guard rule)
      let sortToNode s = case M.lookup s (sortNode state) of
            Nothing -> error $ "No node for sort " ++ show s ++ "."
            Just n -> n
      let sortToInfSort = snd . DU.lookup union . sortToNode
      return (rule, sortToInfSort)
 where
  eval =
    evalEquivM $
      TIState
        M.empty
        M.empty
        (Ari.getValueSort inp)
        (Ari.getValueCheck inp)

  varAnn rule =
    fromMaybe M.empty $
      M.lookup rule $
        Ari.varSort inp

equationsRule
  :: Term IntermediateFId IntermediateVId
  -> Term IntermediateFId IntermediateVId
  -> Guard IntermediateFId IntermediateVId
  -> EquivM ()
equationsRule t1 t2 g = do
  root1 <- rootNodeOfTerm t1
  root2 <- rootNodeOfTerm t2
  _ <- mergeN root1 root2
  _ <- equationsTerm root1 t1
  _ <- equationsTerm root2 t2
  boolNode <- nodeOfSort boolSort
  mapM_ (equationsTerm boolNode) $ termsGuard g
  return ()

equationsTerm :: Node -> Term IntermediateFId IntermediateVId -> EquivM (Maybe ())
equationsTerm prevN (Var v) = do
  rootN <- nodeOfVId v
  mergeN prevN rootN
equationsTerm prevN (Fun _ f as) = do
  (inNs, outN) <- nodesOfFId f
  -- zipWithM_ equationsTerm as inNs
  mapM_ (uncurry equationsTerm) $ zip inNs as
  mergeN prevN outN

----------------------------------------------------------------------------------------------------
-- union find functions
----------------------------------------------------------------------------------------------------

nodeOfSort :: Sort -> EquivM Node
nodeOfSort s = do
  sNodeMap <- gets sortNode
  case M.lookup s sNodeMap of
    Nothing -> do
      node <- new s
      modify $ \st -> st{sortNode = M.insert s node sNodeMap}
      return node
    Just n -> return n

nodeOfVId :: IntermediateVId -> EquivM Node
nodeOfVId v@(_, _, s') = do
  vS <- gets vSort
  case M.lookup v vS of
    Just s -> do
      nS <- nodeOfSort s
      nS' <- nodeOfSort s'
      _ <- mergeN nS nS'
      return nS
    Nothing -> do
      modify $ \st -> st{vSort = M.insert v s' vS}
      nodeOfSort s'

nodesOfFId :: IntermediateFId -> EquivM ([Node], Node)
nodesOfFId (id, sortAnn) = do
  st <- get
  case valSort st id of
    Just sort
      | valCheck st (id, sortAnnotation [] sort) ->
          ([],) <$> nodeOfSort sort
    _ -> (,) <$> mapM nodeOfSort (inSort sortAnn) <*> nodeOfSort (outSort sortAnn)

rootNodeOfTerm :: Term IntermediateFId IntermediateVId -> EquivM Node
rootNodeOfTerm (Var v) = nodeOfVId v
rootNodeOfTerm (Fun _ f _) = snd <$> nodesOfFId f

mergeN :: Node -> Node -> EquivM (Maybe ())
mergeN n1 n2 = lift $ merge checkSort n1 n2
 where
  checkSort :: Sort -> Sort -> (Sort, ())
  checkSort s s' | s == s' = (s, ())
  checkSort (AttrSort n (AttrPol _)) s'@(AttrSort n' _) | n == n' = (s', ())
  checkSort s@(AttrSort n _) (AttrSort n' (AttrPol _)) | n == n' = (s, ())
  checkSort (PSort _) s' = (s', ())
  checkSort s (PSort _) = (s, ())
  checkSort s s' = error $ "type inference: failed. " ++ show s ++ "    " ++ show s'
