{- |
Module      : Type Checking
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides functions to perform checks on types and the semantics of a
given LCTRS to ensure it follows the specified properties to be a valid LCTRS.
-}
module Type.TypeChecking (
  typeCheck,
) where

import Control.Arrow (second)
import Data.LCTRS.FIdentifier
import Data.LCTRS.Guard
import Data.LCTRS.LCTRS
import Data.LCTRS.Rule
import Data.LCTRS.Sort
import Data.LCTRS.Term
import Data.LCTRS.VIdentifier
import qualified Data.Map.Strict as M
import Data.SMT (Identifier, Property, boolSort)
import Prettyprinter (Pretty, pretty)
import Type.SortLCTRS (DefaultFId (deffid), DefaultVId, inferRemainingPSorts)

typeCheck
  :: M.Map Identifier SortAnnotation
  -> M.Map Identifier [(SortAnnotation, Maybe Property, SortAnnotation -> Bool)]
  -> LCTRS (FId DefaultFId) (VId DefaultVId)
  -> Bool
typeCheck fSyms thSyms lctrs = all (typecheckRule fSyms thSyms (getTheorySorts lctrs)) $ getRules lctrs

typecheckRule
  :: M.Map Identifier SortAnnotation
  -> M.Map Identifier [(SortAnnotation, Maybe Property, SortAnnotation -> Bool)]
  -> (Sort -> Bool)
  -> Rule (FId DefaultFId) (VId DefaultVId)
  -> Bool
typecheckRule fSyms thSyms isTheorySort rule
  | not (all (isTheorySort . sort) (lvar rule)) =
      error "TypeChecking.hs: failed as logical variables with non-theory sort is present."
  | otherwise =
      lS == rS
        && snd (typecheckTerm fSyms thSyms isTheorySort lS l)
        && snd (typecheckTerm fSyms thSyms isTheorySort rS r)
        && all (snd . typecheckTerm fSyms thSyms isTheorySort boolSort) (termsGuard g)
 where
  l = lhs rule
  r = rhs rule
  g = guard rule

  lS = sort l
  rS = sort r

typecheckTerm
  :: M.Map Identifier SortAnnotation
  -> M.Map Identifier [(SortAnnotation, Maybe Property, SortAnnotation -> Bool)]
  -> (Sort -> Bool)
  -> Sort
  -> Term (FId DefaultFId) (VId DefaultVId)
  -> (Sort, Bool)
typecheckTerm _ _ _ sort' term@(Var _) = checkFailing term $ compSorts sort' $ sort term
typecheckTerm _ _ isTheorySort sort' term@(Val f) =
  arityCheck f (length (getFIdInSort f)) 0 $
    let outS = sort f
    in  if isTheorySort outS
          then checkFailing term $ compSorts sort' outS
          else
            error $
              "TypeChecking.hs: value is assigned non-theory sort "
                ++ show (pretty outS)
typecheckTerm fSyms thSyms isTheorySort sort' term@(TermFun f@(FId fId _) args) =
  arityCheck f (length argSorts) (length args) $
    let
      checkedArgs = all (snd . uncurry tT) (zip argSorts args)
      id = deffid fId
      -- id = case getFId f of
      --   Nothing -> error $ "TypeChecking.hs: no id for symbol " ++ show id ++ " found"
      --   Just i -> deffid i
      defSortAnn = case M.lookup (deffid fId) fSyms of
        Nothing -> error $ "TypeChecking.hs: no sort for symbol " ++ show id ++ " found"
        Just sa -> sa
      checkedInitSig = all (uncurry (==)) (zip argSorts $ inSort defSortAnn)
      checkOutSig = rootSort == outSort defSortAnn
    in
      checkFailing term $
        second (&& (checkedArgs && checkedInitSig && checkOutSig)) $
          compSorts
            sort'
            rootSort
 where
  argSorts = getFIdInSort f
  rootSort = sort f

  tT = typecheckTerm fSyms thSyms isTheorySort
typecheckTerm _ _ _ _ (TermFun f _) = error $ "TypeChecking.hs: The symbol " <> shows (pretty f) " is not a term symbol."
typecheckTerm fSyms thSyms isTheorySort sort' term@(TheoryFun f args) =
  let termSortAnn = sortAnnotation argSorts rootSort
  in  arityCheck f (length argSorts) (length args) $
        if containsPolySort termSortAnn
          then
            error $
              "TypeChecking.hs: failed as polymorphic sort is still present for symbol "
                ++ show (pretty f)
                ++ ": "
                ++ show (prettySortAnnotation termSortAnn)
                ++ " in term "
                ++ show (pretty term)
          else
            let
              checkedArgs = all (snd . uncurry tT) (zip argSorts args)
              id = deffid $ getFIdText f
              sas = case M.lookup id thSyms of
                Nothing -> error $ "TypeChecking.hs: no sort for logic symbol " ++ show (pretty f) ++ " found"
                Just sa -> sa
              checkedSig = case inferRemainingPSorts termSortAnn sas of
                Nothing ->
                  error $
                    "TypeChecking.hs: sort of symbol "
                      ++ show (pretty f)
                      ++ " is still polymorphic "
                      ++ show (prettySortAnnotation termSortAnn)
                Just sa -> sa == termSortAnn
            in
              checkFailing term $
                second (&& (checkedArgs && checkedSig)) $
                  compSorts
                    sort'
                    rootSort
 where
  argSorts = getFIdInSort f
  rootSort = sort f

  tT = typecheckTerm fSyms thSyms isTheorySort

arityCheck :: (Pretty p) => p -> Int -> Int -> c -> c
arityCheck _ a1 a2 cnt | a1 == a2 = cnt
arityCheck f _ _ _ =
  error $
    "TypeChecking.hs: wrong number of arguments for function symbol "
      ++ shows (pretty f) "."

checkFailing :: (Pretty f, Pretty v, Pretty a) => Term f v -> (a, Bool) -> (a, Bool)
checkFailing term (s, False) =
  error $
    "TypeChecking.hs: failed at the term "
      ++ show (pretty term)
      ++ " to infer the sort "
      ++ show (pretty s)
checkFailing _ (s, True) = (s, True)

compSorts :: (Eq a) => a -> a -> (a, Bool)
compSorts s1 s2 = (s1, s1 == s2)
