module LPO_Encoder(
  lpoTrsEncoder, 
  runTrsEncoder,
  PrecMapList
  ) where

import SMT
import TRS
import Abstract_SMT_Encoder

import Control.Monad.Reader
import Control.Monad.State
import qualified Data.Map as M

data LpoConfig = LpoConfig {
     nrSymbols :: Int
  }
  
type PrecMap = M.Map Id SmtVar
type PrecMapList = [(Id, SmtVar)]

data LPOState = LPOState {
    precMap :: PrecMap,
    lpoMap :: M.Map (Term,Term) Formula
 }

type LPOEncoder = StateT LPOState (ReaderT LpoConfig SmtEncoder)

lift2 = lift . lift

getPrec :: Id -> LPOEncoder IAExpr
getPrec f = do 
  pMap <- gets precMap
  case M.lookup f pMap of
     Just x -> return $ IAVar x
     Nothing -> do
        x <- lift2 $ getNewSmtVariable SmtInt
        ns <- nrSymbols <$> ask
        lift2 $ assertFormula (conj [Le (IAConst 1) (IAVar x), Le (IAVar x) (IAConst ns)])
        modify (\ s -> s { precMap = M.insert f x (precMap s) }) 
        return $ IAVar x

lpoEncoder :: (Term, Term) -> LPOEncoder Formula
lpoEncoder st = do
  lMap <- gets lpoMap
  case M.lookup st lMap of
    Just result -> return result
    Nothing -> do 
      phi <- lpo2 st
      bf <- if atomic phi 
        then return phi
        else do 
          b <- lift2 $ getNewSmtVariable SmtBool
          lift2 $ assertFormula (Equiv (BoolVar b) phi)
          return (BoolVar b)
      modify (\ s -> s { lpoMap = M.insert st bf (lpoMap s) })
      return bf

lpoEq (s,t) 
  | s == t = return true
  | otherwise = lpoEncoder (s,t)

lpo2 (Var _, _) = return false
lpo2 (Fun _ ss, t@(Var _)) = disj <$> mapM (\ si -> lpoEq (si,t)) ss
lpo2 (s@(Fun f ss), t@(Fun g ts)) 
  | f /= g = do
      phi <- mapM (\ si -> lpoEq (si,t)) ss
      pf <- getPrec f
      pg <- getPrec g
      psi <- mapM (\ tj -> lpoEncoder (s,tj)) ts
      return $ disj (conj (Gt pf pg : psi) : phi)
  | s == t = return false
  | otherwise = case dropWhile ( \ (si, ti) -> si == ti) $ zip ss ts of
      (si_ti : other) -> do
        phi <- mapM (\ si -> lpoEq (si,t)) ss
        psi <- lpoEncoder si_ti
        chi <- mapM ( \ (_, tj) -> lpoEncoder (s,tj)) other
        return $ disj $ conj (psi : chi) : phi

ruleEncoder :: Rule -> LPOEncoder ()
ruleEncoder st = lpoEncoder st >>= lift2 . assertFormula

trsEncoder :: [Rule] -> LPOEncoder ()
trsEncoder = mapM_ ruleEncoder

lpoTrsEncoder :: TRS -> (PrecMapList, String)
lpoTrsEncoder = runSmtEncoder . runTrsEncoder

runTrsEncoder :: TRS -> SmtEncoder PrecMapList
runTrsEncoder (sig,r) =  M.toList . precMap . snd <$> 
  (flip runReaderT (LpoConfig $ length sig)
    $ flip runStateT (LPOState M.empty M.empty)
    $ trsEncoder r)
