import Data.List
import Control.Monad


-- Existing algorithms and data structures of lecture

-- import Data.Either.Utils

maybeToEither :: e -> Maybe a -> Either e a
maybeToEither e Nothing  = Left e
maybeToEither _ (Just x) = return x

distinct :: Eq a => [a] -> Bool
distinct xs = length (nub xs) == length xs


type Check a = Either String a
type Type = String
type Var  = String
type FSym = String
type FSymInfo = ([Type], Type)
type Vars = Var -> Check Type
type Sig = FSym -> Check FSymInfo
data Term = 
  Var Var 
  | Fun FSym [Term]
  deriving (Eq, Show)

failure :: String -> Either String a
failure = Left

assert :: Bool -> String -> Either String ()
assert p e = if p then return () else failure e

typeCheck :: Sig -> Vars -> Term -> Check Type
typeCheck sig vars (Var x) = vars x
typeCheck sig vars t@(Fun f ts) = do
  (tysIn,tyOut) <- sig f
  tysTs <- mapM (typeCheck sig vars) ts
  assert (tysTs == tysIn) (show t ++ " ill-typed") 
  return tyOut

inferType :: Sig -> Type -> Term -> Check [(Var,Type)]
inferType sig ty (Var x) = return [(x,ty)]
inferType sig ty t@(Fun f ts) = do
  (tysIn,tyOut) <- sig f
  assert (length tysIn == length ts) "lengths don't match"
  assert (tyOut == ty) "problem with expected type"
  varsL <- mapM (uncurry (inferType sig)) (zip tysIn ts)
  let vars = nub (concat varsL) 
  assert (distinct (map fst vars)) "conflicting types of variables"
  return vars
  
typeCheckEqn sigma (Var x, r) = failure "var as lhs"
typeCheckEqn sigma (l@(Fun f _), r) = do
  (_,ty) <- sigma f
  vars <- inferType sigma ty l
  tyR <- typeCheck 
     sigma 
     (\ x -> maybeToEither 
         (x ++ " is unknown variable") 
         (lookup x vars))
     r
  assert (ty == tyR) "types of lhs and rhs don't match"
  
  
applySubst sigma (Var x) = sigma x
applySubst sigma (Fun f ts) = Fun f (map (applySubst sigma) ts)

subst :: Var -> Term -> Term -> Term
subst x t = applySubst (\ y -> if y == x then t else Var y)

unify :: [(Term, Term)] -> Check [(Var, Term)]
unify u = unifyMain u []

unifyMain :: [(Term, Term)] -> [(Var,Term)] -> Check [(Var, Term)]
unifyMain [] v = return v                           
unifyMain ((Fun f ts, Fun g ss) : u) v = do 
  assert (f == g && length ts == length ss) "clash"
  unifyMain (zip ts ss ++ u) v              
unifyMain ((Fun f ts, x) : u) v =   
  unifyMain ((x, Fun f ts) : u) v   
unifyMain ((Var x, t) : u) v =
  if Var x == t then unifyMain u v  
  else do
    assert (not (x `elem` varsTerm t)) "occurs check"
    unifyMain                                       
      (map ( \ (l,r) -> (subst x t l, subst x t r)) u)
      ((x,t) : map ( \ (y, s) -> (y, subst x t s)) v)
  

data DataDefinition = Data Type [(FSym, FSymInfo)]

type SigList = [(FSym, FSymInfo)]
type Defs = SigList
type Cons = SigList
type Equations = [(Term,Term)]

data ProgInfo = ProgInfo [Type] Cons Defs Equations deriving Show

processDataDefinition :: ProgInfo -> DataDefinition -> Check ProgInfo
processDataDefinition pi@(ProgInfo tys cons defs eqs) (Data ty newCs) = do
  -- check that type is fresh
  assert (not (elem ty tys)) "type is not new"
  let newTys = ty : tys
  -- check distinctness of new constructor names
  assert (distinct (map fst newCs)) "constructors not distinct"
  -- check fresh constructor names
  assert (all (\ (c,_) -> lookup c (cons ++ defs) == Nothing) newCs) "constructors not new"
  -- check types of constructors
  assert (all (\ (_,(tysIn,tyOut)) -> tyOut == ty  
                 && all (\ ty -> elem ty newTys) tysIn) newCs) 
     "problems in types of constructors"
  -- check existence of non-recursive constructor
  assert (any (\ (_,(tysIn,_)) -> all (/= ty) tysIn) newCs)
    "no non-recursive constructor"
  return (ProgInfo newTys (newCs ++ cons) defs eqs)

processDataDefinitions :: ProgInfo -> [DataDefinition] -> Check ProgInfo
processDataDefinitions = foldM processDataDefinition

initialProgInfo = ProgInfo [] [] [] []



-- Exercise 1 - Processing Programs


data FunctionDefinition = Function FSym FSymInfo [(Term,Term)]
type FunctionalProg = ([DataDefinition],[FunctionDefinition])

equations :: ProgInfo -> Equations
equations (ProgInfo _ _ _ eqs) = eqs

varsTerm :: Term -> [Var]
varsTerm = error "varsTerm has not yet been implemented"  


linear :: Term -> Bool
linear t = undefined

checkEquation :: 
  SigList ->       -- defined symbols, including f
  SigList ->       -- constructors
  FSym ->           -- f
  FSymInfo ->      -- type of f
  (Term, Term) ->   -- equation (l,r)
  Check ()
checkEquation defs cons f (tysIn, tyOut) (l,r) = do
  -- check that l is linear
  -- check that types of l and r are identical
  -- check that l is of form f(pat_1,..,pat_n)
  -- check var-condition 
  return ()
  
  
processFunctionDefinition :: ProgInfo -> FunctionDefinition -> Check ProgInfo
processFunctionDefinition = undefined
  
processFunctionDefinitions :: ProgInfo -> [FunctionDefinition] -> Check ProgInfo
processFunctionDefinitions = foldM processFunctionDefinition

processProgram :: FunctionalProg -> Check ProgInfo
processProgram (dataDefs, funDefs) = do
  pi <- processDataDefinitions initialProgInfo dataDefs
  processFunctionDefinitions pi funDefs
    
exampleProg :: FunctionalProg
exampleProg = (
  [
   Data "Nat" [("Zero",([],"Nat")), ("Succ",(["Nat"],"Nat"))],
   Data "List" [("Nil",([],"List")), ("Cons",(["Nat","List"],"List"))]
  ],
  [
    Function "add" (["Nat","Nat"],"Nat") [
      (Fun "add" [Fun "Succ" [Var "x"], Var "y"], 
       Fun "Succ" [Fun "add" [Var "x", Var "y"]]),
      (Fun "add" [Fun "Zero" [], Var "y"], 
       Var "y")
      ],
    Function "append" (["List","List"],"List") [
      (Fun "append" [Fun "Cons" [Var "x",Var "xs"], Var "ys"], 
       Fun "Cons" [Var "x", Fun "append" [Var "xs", Var "ys"]]),
      (Fun "append" [Fun "Nil" [], Var "ys"], 
       Var "ys")
      ]
  ])


test = processProgram exampleProg