{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
-- FIXME: fix the problem below
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}

{- |
Module      : Unification
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides function to extract a valid substitution for a matching or
unification problem. This module is inspired by the term-rewriting package and
re-uses code from there.
-}
module Rewriting.Unification (
  solveMatchingProblem,
  solveUnificationProblem,
  UnifPTuple (..),
  UnifProblem (..),
) where

----------------------------------------------------------------------------------------------------
-- imports
----------------------------------------------------------------------------------------------------

import Control.Monad (guard, zipWithM)
import Control.Monad.ST
import Control.Monad.State
import Control.Monad.Union
import qualified Control.Monad.Union as UM
import Data.Array
import Data.Array.ST
import Data.Bitraversable (bimapM)
import Data.Containers.ListUtils (nubOrd)
import Data.LCTRS.Term (Term)
import qualified Data.LCTRS.Term as T
import qualified Data.Map as M
import Data.Maybe (catMaybes, fromJust, isJust, maybeToList)
import qualified Data.Union as DU
import qualified Data.Union as U
import Data.Word (Word8)
import Rewriting.Substitution (Subst, apply, composeM, fromMap)

----------------------------------------------------------------------------------------------------
-- types
----------------------------------------------------------------------------------------------------

type UnifM f v = StateT (UnifSt f v) (UnionM (Term f v))

newtype UnifSt f v = UnifSt
  { termNode :: M.Map (Term f v) Node
  }

newtype UnifPTuple f v = UnifPTuple {unifPTuple :: (Term f v, Term f v)}

newtype UnifProblem f v = UnifProblem
  { unifProblem :: [UnifPTuple f v]
  }

emptyUnifSt :: UnifSt f v
emptyUnifSt = UnifSt M.empty

evalUnifM :: UnifSt f v -> UnifM f v a -> (Union (Term f v), (a, UnifSt f v))
evalUnifM state computation = run' $ runStateT computation state

----------------------------------------------------------------------------------------------------
-- solve matching problem
----------------------------------------------------------------------------------------------------

solveMatchingProblem
  :: (Ord v, Ord f) => UnifProblem f v -> Maybe (Subst f v)
solveMatchingProblem = solveProblem

_match :: (Eq f, Ord v) => Term f v -> Term f v -> Maybe (Subst f v)
_match t u = fromMap <$> go t u M.empty
 where
  go (T.Var v) t subst = case M.lookup v subst of
    Nothing -> Just (M.insert v t subst)
    Just t' | t == t' -> Just subst
    _ -> Nothing
  go (T.Fun tf f ss) (T.Fun tg g ts) subst
    | f /= g || length ss /= length ts || tf /= tg = Nothing
    | otherwise = composeM (zipWith go ss ts) subst
  go _ _ _ = Nothing

solveProblem
  :: (Ord v, Ord f) => UnifProblem f v -> Maybe (Subst f v)
solveProblem UnifProblem{..}
  | not (and (catMaybes suceeded)) = Nothing
  | otherwise =
      return
        . fromMap -- make substitution
        $ M.fromList -- make map
          [ (var, fromJust assigned) | var <- allVars, let assigned = varToTerm $ T.var var, isJust assigned, fromJust assigned /= T.var var
          ]
 where
  termToNode s = M.lookup s (termNode state)
  varToTerm term = snd . DU.lookup union <$> termToNode term

  varsPT UnifPTuple{..} = let (f, s) = unifPTuple in T.vars f ++ T.vars s
  allVars = nubOrd $ concatMap varsPT unifProblem

  (union, (suceeded, state)) = evalUnifM emptyUnifSt $ mapM unifP unifProblem
  unifP UnifPTuple{..} = uncurry solveProblemTerm unifPTuple

solveProblemTerm
  :: (Ord f, Ord v)
  => Term f v
  -> Term f v
  -> UnifM f v (Maybe Bool)
solveProblemTerm term term1 | term == term1 = return $ Just True
solveProblemTerm v@(T.Var n) term | n `notElem` T.vars term = do
  n1 <- nodeOfTerm v
  n2 <- nodeOfTerm term
  lift $ merge (\_ _ -> (v, True)) n1 n2
solveProblemTerm (T.Fun typF f fargs) (T.Fun typG g gargs)
  | f == g && typF == typG =
      Just . and . catMaybes <$> zipWithM solveProblemTerm fargs gargs
solveProblemTerm _ _ = return $ Just False

nodeOfTerm :: (Ord v, Ord f) => Term f v -> UnifM f v Node
nodeOfTerm term = do
  map <- gets termNode
  case M.lookup term map of
    Nothing -> do
      node <- new term
      modify $ \st -> st{termNode = M.insert term node map}
      return node
    Just n -> return n

----------------------------------------------------------------------------------------------------
-- solve unification problem
----------------------------------------------------------------------------------------------------

solveUnificationProblem
  :: (Ord v, Ord f) => UnifProblem f v -> Maybe (Subst f v)
solveUnificationProblem = unify

type UnifyM f v a = StateT (M.Map v U.Node) (UM.UnionM (Annot f v)) a
data Annot f v = VarA v | FunA T.FunType f [U.Node] | FunP T.FunType f [Term f v]

funari :: Annot f v -> (f, Int)
funari (FunA _ f ns) = (f, length ns)
funari (FunP _ f ts) = (f, length ts)
funari (VarA _) = error "Unification.hs: no arity for variable."

solve :: (Eq f, Ord v) => [(U.Node, U.Node)] -> UnifyM f v Bool
solve [] = return True
solve ((t, u) : xs) = do
  (t, t') <- UM.lookup t
  (u, u') <- UM.lookup u
  if t == u
    then solve xs
    else case (t', u') of
      (VarA _, _) -> do
        _ <- UM.merge (\_ _ -> (u', ())) t u
        solve xs
      (_, VarA _) -> do
        _ <- UM.merge (\_ _ -> (t', ())) t u
        solve xs
      _
        | funari t' == funari u' ->
            expand t t' >>= \(FunA _ _ ts) ->
              expand u u' >>= \(FunA _ _ us) ->
                UM.merge (\t _ -> (t, ())) t u
                  >> solve (zip ts us ++ xs)
      _ -> do
        return False

expand :: (Ord v) => U.Node -> Annot f v -> UnifyM f v (Annot f v)
expand n (FunP t f ts) = do
  ann <- FunA t f <$> mapM mkNode ts
  UM.annotate n ann
  return ann
expand _ ann = return ann

mkNode :: (Ord v) => Term f v -> UnifyM f v U.Node
mkNode (T.Var v) = do
  n <- gets (M.lookup v)
  case n of
    Just n -> return n
    Nothing -> do
      n <- UM.new (VarA v)
      modify (M.insert v n)
      return n
mkNode (T.Fun t f ts) = UM.new (FunP t f ts)

unify :: (Eq f, Ord v) => UnifProblem f v -> Maybe (Subst f v)
unify UnifProblem{..} = do
  let
    act = do
      mkNodes <- mapM (bimapM mkNode mkNode . unifPTuple) unifProblem
      success <- solve mkNodes
      return (map fst mkNodes, success)
    (union, ((roots, success), vmap)) = UM.run' $ runStateT act M.empty
    succs n = case snd (U.lookup union n) of
      VarA _ -> []
      FunA _ _ ns -> ns
      FunP _ _ ts -> do v <- T.vars =<< ts; maybeToList (M.lookup v vmap)
  guard $ success && all (acyclic (U.size union) succs) roots
  let
    subst = fromMap $ fmap lookupNode vmap
    terms = fmap mkTerm (UM.label union)
    lookupNode = (terms !) . U.fromNode . fst . U.lookup union
    mkTerm (VarA v) = T.Var v
    mkTerm (FunA t f ns) = T.Fun t f (fmap lookupNode ns)
    mkTerm (FunP t f ts) = subst `apply` T.Fun t f ts
  return subst

acyclic :: Int -> (U.Node -> [U.Node]) -> U.Node -> Bool
acyclic size succs root = runST $ do
  let
    t :: ST s (STUArray s Int Word8)
    t = undefined
  color <- newArray (0, size - 1) 0 `asTypeOf` t
  let dfs n = do
        c <- readArray color (U.fromNode n)
        case c of
          0 -> do
            writeArray color (U.fromNode n) 1
            flip (foldr andM) (map dfs (succs n)) $ do
              writeArray color (U.fromNode n) 2
              return True
          1 -> return False
          2 -> return True
  dfs root

andM :: (Monad m) => m Bool -> m Bool -> m Bool
andM a b = do
  a' <- a
  if a' then b else return False
