{- |
Module      : Unification
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


Module providing functions for substitutions.
This module is inspired by the term-rewriting package.
-}
module Rewriting.Substitution where

import Control.Monad ((>=>))
import Data.LCTRS.Term (Term (..), pattern Val)
import qualified Data.LCTRS.Term as T
import qualified Data.Map as M
import Data.Maybe (isJust)
import Prettyprinter (Doc, Pretty, encloseSep, pretty, (<+>))

newtype Subst f v = Subst {subst :: M.Map v (Term f v)}

-- | pretty print a substitution
prettySubstitution :: (Pretty f, Pretty v) => Subst f v -> Doc ann
prettySubstitution (subst -> s) = encloseSep "{" "}" "; " $ map (\(v, t) -> pretty v <+> "|->" <+> pretty t) $ M.toList s

-- | toMap extract the underlying Map from the substitution
toMap :: Subst f v -> M.Map v (Term f v)
toMap = subst

-- | fromMap wraps a respective Map into a substitution
fromMap :: M.Map v (Term f v) -> Subst f v
fromMap = Subst

-- | Match two terms. If matching succeeds, return the resulting subtitution.
match :: (Eq f, Ord v, Eq v) => Term f v -> Term f v -> Maybe (Subst f v)
match t u = Subst <$> go t u M.empty
 where
  go (Var v) t subst = case M.lookup v subst of
    Nothing -> Just (M.insert v t subst)
    Just t' | t == t' -> Just subst
    _ -> Nothing
  go (Val v) (Val v') subst
    | v == v' = Just subst
    | otherwise = Nothing
  go (Fun typf f ts) (Fun typf' f' ts') subst
    | f /= f' || typf /= typf' || length ts /= length ts' = Nothing
    | otherwise = composeM (zipWith go ts ts') subst
  go _ _ _ = Nothing

composeM :: (Monad m) => [a -> m a] -> a -> m a
composeM = foldr (>=>) return

-- | Check whether the first term is an instance of the second term.
isInstanceOf :: (Eq f, Ord v, Ord v) => Term f v -> Term f v -> Bool
isInstanceOf t u = isJust (match u t)

-- | Check whether two terms are variants of each other.
isVariantOf :: (Eq f, Ord v) => Term f v -> Term f v -> Bool
isVariantOf t u = isInstanceOf t u && isInstanceOf u t

{- | Apply a substitution, assuming that it's the identity on variables not
mentionend in the substitution.
-}
apply :: (Ord v) => Subst f v -> Term f v -> Term f v
apply subst = T.foldTerm var T.fun
 where
  var v = M.findWithDefault (T.var v) v (toMap subst)

-- go sub (Var v) = sub v
-- go _ v@(Val _) = v
-- go sub (Fun typ f terms) = T.fun typ f $ map (go sub) terms

-- var' vss f v = case lookup v vss of
--   Nothing -> f v
--   Just _ -> T.var v
