import Data.List
import Control.Monad

-- import Data.Either.Utils

maybeToEither :: e -> Maybe a -> Either e a
maybeToEither e Nothing  = Left e
maybeToEither _ (Just x) = return x

distinct :: Eq a => [a] -> Bool
distinct xs = length (nub xs) == length xs


type Check a = Either String a
type Type = String
type Var  = String
type FSym = String
type FSym_Info = ([Type], Type)
type Vars = Var -> Check Type
type Sig = FSym -> Check FSym_Info
data Term = 
  Var Var 
  | Fun FSym [Term]
  deriving (Eq)
  
instance Show Term where
  show (Var x) = x
  show (Fun f ts) = f ++ showList ts ""

-- change type of failure and assert to more general one
-- for error type in matching algorithm
failure :: e -> Either e a
failure = Left

assert :: Bool -> e -> Either e ()
assert p e = if p then return () else failure e

type_check :: Sig -> Vars -> Term -> Check Type
type_check sig vars (Var x) = vars x
type_check sig vars t@(Fun f ts) = do
  (tys_in,ty_out) <- sig f
  tys_ts <- mapM (type_check sig vars) ts
  assert (tys_ts == tys_in) (show t ++ " ill-typed") 
  return ty_out

infer_type :: Sig -> Type -> Term -> Check [(Var,Type)]
infer_type sig ty (Var x) = return [(x,ty)]
infer_type sig ty t@(Fun f ts) = do
  (tys_in,ty_out) <- sig f
  assert (length tys_in == length ts) "lengths don't match"
  assert (ty_out == ty) "problem with expected type"
  vars_l <- mapM (uncurry (infer_type sig)) (zip tys_in ts)
  let vars = nub (concat vars_l) 
  assert (distinct (map fst vars)) "conflicting types of variables"
  return vars
  
type_check_eqn sigma (Var x, r) = failure "var as lhs"
type_check_eqn sigma (l @ (Fun f _), r) = do
  (_,ty) <- sigma f
  vars <- infer_type sigma ty l
  ty_r <- type_check 
     sigma 
     (\ x -> maybeToEither 
         (x ++ " is unknown variable") 
         (lookup x vars))
     r
  assert (ty == ty_r) "types of lhs and rhs don't match"


type Subst = Var -> Term
type Subst_List = [(Var,Term)]
  
  
-- Processing Programs


data Data_Definition = Data Type [(FSym, FSym_Info)]
data Function_Definition = Function FSym FSym_Info [(Term,Term)]
type Functional_Prog = ([Data_Definition],[Function_Definition])

type Sig_List = [(FSym, FSym_Info)]
type Defs = Sig_List
type Cons = Sig_List
type Equations = [(Term,Term)]

data Prog_Info = Prog_Info [Type] Cons Defs Equations deriving Show

initial_prog_info = Prog_Info [] [] [] []

equations :: Prog_Info -> Equations
equations (Prog_Info _ _ _ eqs) = eqs

process_data_definition :: Prog_Info -> Data_Definition -> Check Prog_Info
process_data_definition pi@(Prog_Info tys cons defs eqs) (Data ty new_cs) = do
  -- check that type is fresh
  assert (not (elem ty tys)) "type is not new"
  let new_tys = ty : tys
  -- check distinctness of new constructor names
  assert (distinct (map fst new_cs)) "constructors not distinct"
  -- check fresh constructor names
  assert (all (\ (c,_) -> lookup c (cons ++ defs) == Nothing) new_cs) "constructors not new"
  -- check types of constructors
  assert (all (\ (_,(tys_in,ty_out)) -> ty_out == ty  
                 && all (\ ty -> elem ty new_tys) tys_in) new_cs) 
     "problems in types of constructors"
  -- check existence of non-recursive constructor
  assert (any (\ (_,(tys_in,_)) -> all (/= ty) tys_in) new_cs)
    "no non-recursive constructor"
  return (Prog_Info new_tys (new_cs ++ cons) defs eqs)

process_data_definitions :: Prog_Info -> [Data_Definition] -> Check Prog_Info
process_data_definitions = foldM process_data_definition

vars_term :: Term -> [Var]
vars_term (Var x) = [x]
vars_term (Fun f ts) = concatMap vars_term ts

linear :: Term -> Bool
linear t = distinct (vars_term t)

sig_list_to_sig :: Sig_List -> Sig
sig_list_to_sig slist f = maybeToEither "unknown symbol" $ lookup f slist

check_equation :: 
  Sig_List ->       -- defined symbols, including f
  Sig_List ->       -- constructors
  FSym ->           -- f
  FSym_Info ->      -- type of f
  (Term, Term) ->   -- equation (l,r)
  Check ()
check_equation defs cons f (tys_in, ty_out) (l,r) = do
  -- check that l is linear
  assert (linear l) "lhs not linear"
  -- check that types of l and r are identical
  type_check_eqn (sig_list_to_sig (cons ++ defs)) (l,r)
  -- check that l is of form f(pat_1,..,pat_n)
  case l of 
    Var _ -> failure "lhs is variable"; 
    Fun g ps -> do
      assert (f == g) "wrong function symbol in lhs"
      assert (length ps == length tys_in) "wrong arity of lhs"
      mapM (uncurry (infer_type (sig_list_to_sig cons))) (zip tys_in ps)
  -- check var-condition 
  assert (all (\ x -> elem x (vars_term l)) (vars_term r)) "unknown variable in rhs"
  
  
process_function_definition :: Prog_Info -> Function_Definition -> Check Prog_Info
process_function_definition 
  (Prog_Info tys cons defs eqs) 
  (Function f f_ty new_eqs) = do
  -- check that f is fresh
  assert (lookup f (cons ++ defs) == Nothing) "symbol not fresh"
  let new_defs = (f,f_ty) : defs
  -- check conditions of equations
  mapM (check_equation new_defs cons f f_ty) new_eqs
  let new_prog = Prog_Info tys cons new_defs (eqs ++ new_eqs)
  return new_prog
  
process_function_definitions :: Prog_Info -> [Function_Definition] -> Check Prog_Info
process_function_definitions = foldM process_function_definition

process_program :: Functional_Prog -> Check Prog_Info
process_program (data_defs, fun_defs) = do
  pi <- process_data_definitions initial_prog_info data_defs
  process_function_definitions pi fun_defs
  
  
-- Matching Algorithm for Linear Terms

data Match_Error = Fun_Var FSym Var | Clash

type Match_Result a = Either Match_Error a

match :: Term -> Term -> Match_Result Subst
match l t = do 
  xt_list <- match_list [(l,t)]
  return (\ x -> case lookup x xt_list of 
     Just s -> s
     Nothing -> Var x)
  
match_list :: [(Term,Term)] -> Match_Result Subst_List
match_list [] = return []
match_list ((Fun f ls, Fun g ts) : pairs) = do
  assert (f == g) Clash
  match_list (zip ls ts ++ pairs)                
match_list ((Fun f _, Var x) : _) = failure (Fun_Var f x)
match_list ((Var x, t) : pairs) = do
  xt_list <- match_list pairs
  return ((x,t) : xt_list)

-- Unification Algorithm

apply_subst :: Subst -> Term -> Term
apply_subst sig (Var x) = sig x
apply_subst sig (Fun f ts) = Fun f (map (apply_subst sig) ts)

subst :: Var -> Term -> Term -> Term
subst x t = apply_subst (\ y -> if y == x then t else Var y)

unify :: [(Term, Term)] -> Maybe [(Var, Term)]
unify u = unify_main u []

unify_main :: [(Term, Term)] -> [(Var,Term)] -> Maybe [(Var, Term)]
unify_main [] v = Just v
unify_main ((Fun f ts, Fun g ss) : u) v = 
  if f == g && length ts == length ss 
  then unify_main (zip ts ss ++ u) v
  else Nothing
unify_main ((Fun f ts, x) : u) v = unify_main ((x, Fun f ts) : u) v
unify_main ((Var x, t) : u) v = if Var x == t 
  then unify_main u v
  else if x `elem` vars_term t then Nothing
  else unify_main 
    (map ( \ (l,r) -> (subst x t l, subst x t r)) u) 
    ((x,t) : map ( \ (y, s) -> (y, subst x t s)) v)


subst_list_to_subst :: Subst_List -> Subst
subst_list_to_subst sig x = case lookup x sig of
  Nothing -> Var x
  Just t -> t    

-- Checking Pattern Disjointness
      
create_var :: Int -> Term
create_var i = Var ("x" ++ show i)

rename_vars :: Int -> Term -> (Int, Term)
rename_vars i (Var x) = (i+1, create_var i)
rename_vars i (Fun f []) = (i, Fun f [])
rename_vars i (Fun f (t : ts)) = 
  case rename_vars i t of 
    (j, s) -> 
      case rename_vars j (Fun f ts) of 
        (k, Fun _ ss) -> (k, Fun f (s : ss))

check_pattern_disjoint_prog :: Prog_Info -> Check ()
check_pattern_disjoint_prog p = case pattern_disjoint_prog p of
  Left term -> failure ("not pattern disjoint, consider term: " ++ show term)
  _ -> return ()
  
pattern_disjoint_prog :: Prog_Info -> Either Term ()
pattern_disjoint_prog (Prog_Info _ _ _ eqs) = do
  mapM (\ (l,r) -> mapM (\ (l',r') -> 
        if (l,r) == (l',r') then return ()
        else check_disjoint_lhss l l'
      ) eqs
    ) eqs
  return ()
    
check_disjoint_lhss :: Term -> Term -> Either Term ()
check_disjoint_lhss s t = let
  (i, s') = rename_vars 0 s 
  (_, t') = rename_vars i t
  in case unify [(s', t')] of
    Nothing -> return ()
    Just sig -> failure (apply_subst (subst_list_to_subst sig) s')


-- Exercise 9.1 -- Checking Pattern Completeness


-- the integer encodes the next free variable,
-- i.e., in (t,n,L) the term t only contains variables
-- with numbers strictly below n
type Pattern_Problem = [(Term,Int,[Term])]


p_init :: Prog_Info -> Pattern_Problem
p_init = error "TODO"


check_pattern_complete_prog :: Prog_Info -> Check ()
check_pattern_complete_prog p = case pattern_complete_prog p of
  Left term -> failure ("not pattern complete, consider term: " ++ show term)
  _ -> return ()


pattern_complete_prog :: Prog_Info -> Either Term ()
pattern_complete_prog = error "TODO"


check_pattern_complete_prog :: Prog_Info -> Check ()
check_pattern_complete_prog p = case pattern_complete_prog p of
  Left term -> failure ("not pattern complete, consider term: " ++ show term)
  _ -> return ()
  
  

-- Exercise 9.4

-- for simplicity we assume that '#' does not occur in function names of the 
-- original program, so appending '#' to the end of a symbol will always create 
-- a fresh symbol 

create_sharp_symbol f = f ++ "#"

type DPs = [(Term,Term)]

dps_of_prog :: Prog_Info -> DPs
dps_of_prog = error "TODO"

    
-- check whether s |> t
strict_subterm s t = error "TODO"

weak_subterm s t = s == t || strict_subterm s t

-- check subterm criterion for symbol f# and argument i
-- if succesful returns all pairs that can be removed
check_subterm_criterion_sym_arg :: DPs -> FSym -> Int -> Maybe DPs
check_subterm_criterion_sym_arg p f_shp i = error "TODO" 


-- one possible application of the subterm criterion,
-- delivers the list of removable DPs
check_subterm_criterion :: Prog_Info -> DPs -> Maybe DPs
check_subterm_criterion = error "TODO"
       
-- returns remaining DPs after iterated subterm criterion application
iterated_subterm_criterion :: Prog_Info -> DPs -> DPs
iterated_subterm_criterion prog p = 
  case check_subterm_criterion prog p of
    Nothing -> p
    Just p_subt -> iterated_subterm_criterion prog [dp | dp <- p, not (elem dp p_subt)]

    
-- tests
  
example_prog :: Functional_Prog
example_prog = (
  [
   Data "Nat" [("Zero",([],"Nat")), ("Succ",(["Nat"],"Nat"))]
  ],
  [
    Function "plus" (["Nat","Nat"],"Nat") [
      (Fun "plus" [Fun "Succ" [Var "x"], Var "y"], 
       Fun "Succ" [Fun "plus" [Var "y", Var "x"]]),
      (Fun "plus" [Fun "Zero" [], Var "y"], 
       Var "y")
      ],
    Function "minus" (["Nat","Nat"],"Nat") [
      (Fun "minus" [Fun "Succ" [Var "x"], Fun "Succ" [Var "y"]], 
       Fun "minus" [Var "x", Var "y"]),
      (Fun "minus" [Var "x", Fun "Zero" []], Var "x"),
      (Fun "minus" [Fun "Zero" [], Fun "Succ" [Var "y"]], Fun "Zero" [])
      ],
    Function "div" (["Nat","Nat"],"Nat") [
      (Fun "div" [Fun "Succ" [Var "x"], Fun "Succ" [Var "y"]], 
       Fun "Succ" [Fun "div" [Fun "minus" [Var "x", Var "y"], Fun "Succ" [Var "y"]]]),
      (Fun "div" [Var "x", Fun "Zero" []], Fun "Zero" []),
      (Fun "div" [Fun "Zero" [], Fun "Succ" [Var "y"]], Fun "Zero" [])
      ],
    Function "ack" (["Nat","Nat"],"Nat") [
      (Fun "ack" [Fun "Succ" [Var "x"], Fun "Succ" [Var "y"]], 
       Fun "ack" [Var "x", Fun "ack" [Fun "Succ" [Var "x"], Var "y"]]),
      (Fun "ack" [Fun "Zero" [], Var "y"], Fun "Succ" [Fun "Zero" []]),
      (Fun "ack" [Fun "Succ" [Var "x"], Fun "Zero" []], Fun "ack" [Var "x", Fun "Succ" [Fun "Zero" []]])
      ]    
  ])


-- test returns all remaining DPs of program after subterm criterion,
-- should be precisely the DPs mentioned on slide 4/66
test = do 
  pi <- process_program example_prog
  check_pattern_disjoint_prog pi
  check_pattern_complete_prog pi
  let dps = dps_of_prog pi
  let p_remain = iterated_subterm_criterion pi dps
  return p_remain