module Typechecker (typecheck, TypeDict) where import Preprocessor import Data.Generics import Data.Map ((!), Map) import qualified Data.Map as Map import Control.Monad.Error import Control.Monad.State import Control.Monad.Reader import Data.Maybe (fromJust) type TypeDict = Map Exp Type type Typechecker = ReaderT Context (StateT TypeDict (Either String)) data Context = Context { fns :: Program, vars :: [(String,Type)] } typecheck :: Program -> Either String TypeDict typecheck p = do when (not $ terminates p) $ fail "Recursive reference detected" -- FIXME: better message case Map.lookup "_main" p of Just exp -> flip execStateT Map.empty $ flip runReaderT Context { fns = p, vars = []} $ mapM_ (inferType . EGlobal) (Map.keys p) >> checkType TInt (EGlobal "_main") Nothing -> Left "missing 'main' label" -- a side effect of the type checker is to store the types of top-level -- definitions and of lambdas in a dictionary, so that the Compiler -- module can use them to generate LLVM types. memoType exp inferRule = do typeDict <- get case Map.lookup exp typeDict of Just t -> return t Nothing -> do t <- inferRule learnType exp t learnType :: Exp -> Type -> Typechecker Type learnType exp t = do typeDict <- get put $ Map.insert exp t typeDict return t inferType exp@(EGlobal fname) = memoType exp $ do ctx <- ask case Map.lookup fname (fns ctx) of Just fbody -> inferType fbody Nothing -> fail $ "no such function " ++ fname inferType (ELocal var) = do ctx <- ask return $ fromJust $ lookup var (vars ctx) inferType (EInt n) = return TInt inferType (EBool b) = return TBool inferType (EApp f arg) = do funcType <- inferType f case funcType of TFun t1 t2 -> checkType t1 arg >> return t2 badType -> fail $ show f ++ " is used as a function but " ++ show badType ++ " is not a function type" inferType (ENeg e) = checkType TInt e inferType (EArith _ e1 e2) = checkType TInt e1 >> checkType TInt e2 inferType (EComp _ e1 e2) = do t <- inferType e1 checkType t e2 return TBool inferType exp@(ELambda arg@(_, argType) body) = memoType exp $ local (\ctx -> ctx { vars = arg : vars ctx} ) $ do bodyType <- inferType body return $ TFun argType bodyType inferType exp@(EClosure _ arg@(_, argType) body) = memoType exp $ local (\ctx -> ctx { vars = arg : vars ctx} ) $ do bodyType <- inferType body return $ TFun argType bodyType inferType exp@(EIf cond e1 e2) = memoType exp $ do checkType TBool cond t <- inferType e1 checkType t e2 checkType :: Type -> Exp -> Typechecker Type checkType t e = do t' <- inferType e if t' == t then return t else fail $ show e ++ " should have type " ++ show t ++ " but has type " ++ show t' -- Disallow recursive references. terminates :: Program -> Bool terminates p = null $ recursiveLoops p "_main" recursiveLoops :: Program -> String -> [[String]] recursiveLoops p fname = recLoops fname [fname] where recLoops f stack = concat $ (flip map) (calls (p ! f)) $ \g -> if g `elem` stack then [ reverse (g:stack) ] else recLoops g (g:stack) -- Count any reference to `f' as a call, since in general the analysis would -- be undecidable. calls :: Exp -> [String] calls (EGlobal f) = [f] calls exp = concat $ gmapQ (mkQ [] calls) exp