shithub: MicroHs

Download patch

ref: be0d3b7805bfb1948fda52303e8ab74be5ce9000
parent: dd4eadf2c105c53baed21ae3e800721c502f98d7
author: Lennart Augustsson <lennart.augustsson@epicgames.com>
date: Mon Sep 25 09:38:07 EDT 2023

Add some extensions.

--- a/src/MicroHs/paper/BasicTypes.hs
+++ b/src/MicroHs/paper/BasicTypes.hs
@@ -20,18 +20,27 @@
 -----------------------------------
 -- Examples below
 data Term = Var Name -- x
-  | Lit Int -- 3
+  | LitI Int -- 3
+  | LitB Bool -- True
   | App Term Term -- f x
   | Lam Name Term -- \ x -> x
   | ALam Name Sigma Term -- \ x -> x
   | Let Name Term Term -- let x = f y in x+1
   | Ann Term Sigma -- (f x) :: Int
+  | If Term Term Term
+  | PLam Pat Term -- \ x -> x
 
 atomicTerm :: Term -> Bool
 atomicTerm (Var _) = True
-atomicTerm (Lit _) = True
+atomicTerm (LitI _) = True
+atomicTerm (LitB _) = True
 atomicTerm _ = False
 
+data Pat = PVar Name
+  | PWild
+  | PAnn Pat Sigma
+  | PCon Name [Pat]
+
 -----------------------------------
 -- Types --
 -----------------------------------
@@ -44,6 +53,7 @@
   | TyCon TyCon -- Type constants
   | TyVar TyVar -- Always bound by a ForAll
   | MetaTv MetaTv -- A meta type variable
+  | TyApp Type Type
 
 data TyVar
   = BoundTv String -- A type variable bound by a ForAll
@@ -65,16 +75,16 @@
 
 type Uniq = Int
 
-data TyCon = IntT | BoolT
-  deriving( Eq )
+type TyCon = String
 
 ---------------------------------
 -- Constructors
 (-->) :: Sigma -> Sigma -> Sigma
 arg --> res = Fun arg res
+
 intType, boolType :: Tau
-intType = TyCon IntT
-boolType = TyCon BoolT
+intType = TyCon "Int"
+boolType = TyCon "Bool"
 
 ---------------------------------
 -- Free and bound variables
@@ -89,6 +99,7 @@
     go (TyCon _) acc = acc
     go (Fun arg res) acc = go arg (go res acc)
     go (ForAll _ ty) acc = go ty acc -- ForAll binds TyVars only
+    go (TyApp fun arg) acc = go fun (go arg acc)
 
 freeTyVars :: [Type] -> [TyVar]
 -- Get the free TyVars from a type; no duplicates in result
@@ -102,10 +113,11 @@
       | tv `elem` bound = acc
       | tv `elem` acc = acc
       | otherwise = tv : acc
-    go bound (MetaTv _) acc = acc
-    go bound (TyCon _) acc = acc
+    go _bound (MetaTv _) acc = acc
+    go _bound (TyCon _) acc = acc
     go bound (Fun arg res) acc = go bound arg (go bound res acc)
     go bound (ForAll tvs ty) acc = go (tvs ++ bound) ty acc
+    go bound (TyApp fun arg) acc = go bound fun (go bound arg acc)
 
 tyVarBndrs :: Rho -> [TyVar]
 -- Get all the binders used in ForAlls in the type, so that
@@ -134,11 +146,12 @@
 subst_ty :: Env -> Type -> Type
 subst_ty env (Fun arg res) = Fun (subst_ty env arg) (subst_ty env res)
 subst_ty env (TyVar n) = fromMaybe (TyVar n) (lookup n env)
-subst_ty env (MetaTv tv) = MetaTv tv
-subst_ty env (TyCon tc) = TyCon tc
+subst_ty _env (MetaTv tv) = MetaTv tv
+subst_ty _env (TyCon tc) = TyCon tc
 subst_ty env (ForAll ns rho) = ForAll ns (subst_ty env' rho)
   where
     env' = [(n,ty') | (n,ty') <- env, not (n `elem` ns)]
+subst_ty env (TyApp fun arg) = TyApp (subst_ty env fun) (subst_ty env arg)
 
 -----------------------------------
 -- Pretty printing class --
@@ -156,7 +169,8 @@
 -------------- Pretty-printing terms ---------------------
 instance Outputable Term where
   ppr (Var n) = pprName n
-  ppr (Lit i) = int i
+  ppr (LitI i) = int i
+  ppr (LitB i) = text $ if i then "True" else "False"
   ppr (App e1 e2) = pprApp (App e1 e2)
   ppr (Lam v e) = sep [char '\\' <> pprName v <> text ".", ppr e]
   ppr (ALam v t e) = sep [char '\\' <> parens (pprName v <> dcolon <> ppr t)
@@ -166,10 +180,18 @@
                            text "in",
                            ppr b]
   ppr (Ann e ty) = pprParendTerm e <+> dcolon <+> pprParendType ty
+  ppr (If e1 e2 e3) = parens $ text "if" <+> ppr e1 <+> text "then" <+> ppr e2 <+> text "else" <+> ppr e3
+  ppr (PLam p e) = sep [char '\\' <> ppr p <> text ".", ppr e]
 
 instance Show Term where
   show t = docToString (ppr t)
 
+instance Outputable Pat where
+  ppr (PVar n) = pprName n
+  ppr PWild = text "_"
+  ppr (PAnn p t) = parens $ ppr p <+> dcolon <+> ppr t
+  ppr (PCon c ps) = parens $ pprName c <+> hsep (map ppr ps)
+
 pprParendTerm :: Term -> Doc
 pprParendTerm e | atomicTerm e = ppr e
                 | otherwise = parens (ppr e)
@@ -199,11 +221,12 @@
 
 type Precedence = Int
 
-topPrec, arrPrec, tcPrec, atomicPrec :: Precedence
+topPrec, arrPrec, tcPrec, appPrec, atomicPrec :: Precedence
 topPrec = 0 -- Top-level precedence
 arrPrec = 1 -- Precedence of (a->b)
 tcPrec = 2 -- Precedence of (T a b)
-atomicPrec = 3 -- Precedence of t
+appPrec = 3
+atomicPrec = 4 -- Precedence of t
 
 precType :: Type -> Precedence
 precType (ForAll _ _) = topPrec
@@ -227,7 +250,7 @@
 ppr_type (TyCon tc) = ppr_tc tc
 ppr_type (TyVar n) = ppr n
 ppr_type (MetaTv tv) = ppr tv
+ppr_type (TyApp arg res) = pprType appPrec arg <+> pprType (appPrec-1) res
 
 ppr_tc :: TyCon -> Doc
-ppr_tc IntT = text "Int"
-ppr_tc BoolT = text "Bool"
+ppr_tc s = text s
--- /dev/null
+++ b/src/MicroHs/paper/Main.hs
@@ -1,0 +1,55 @@
+import BasicTypes
+import TcTerm
+import TcMonad
+
+env :: [(Name,Sigma)]
+env =
+  [("not",  boolType --> boolType)
+  ,("C",    (ForAll [tvx] (tx --> tx)) --> TyCon "T")
+  ,("pair", ForAll [tvx,tvy] (tx --> ty --> TyApp (TyApp (TyCon "Pair") tx) ty))
+  ]
+
+tvx, tvy :: TyVar
+tvx = BoundTv "x"
+tvy = BoundTv "y"
+tx, ty :: Type
+tx = TyVar tvx
+ty = TyVar tvy
+
+tc :: Term -> IO Sigma
+tc e = do
+  et <- runTc env (typecheck e) 
+  case et of
+    Left msg -> error $ docToString msg
+    Right t  -> return t
+
+pp :: Outputable a => a -> IO ()
+pp = putStrLn . docToString . ppr
+
+tcpp :: Term -> IO ()
+tcpp e = do
+  pp e
+  t <- tc e
+  pp t
+
+main :: IO ()
+main = do
+  tcpp e1
+  tcpp e2
+  tcpp e3
+  tcpp e4
+
+_tv :: String -> Type
+_tv = TyVar . BoundTv
+
+e1 :: Term
+e1 = Lam "x" $ Var "x"
+
+e2 :: Term
+e2 = Ann (Lam "x" $ Var "x") (ForAll [tvx] (tx --> tx))
+
+e3 :: Term
+e3 = Lam "b" $ If (App (Var "not") (Var "b")) e1 e2
+
+e4 :: Term
+e4 = PLam (PCon "C" [PVar "f"]) $ App (App (Var "pair") (App (Var "f") (LitI 1))) (App (Var "f") (LitB True))
--- a/src/MicroHs/paper/TcMonad.hs
+++ b/src/MicroHs/paper/TcMonad.hs
@@ -2,7 +2,7 @@
   Tc, -- The monad type constructor
   runTc, ErrMsg, lift, check,
   -- Environment manipulation
-  extendVarEnv, lookupVar,
+  extendVarEnv, extendVarEnvList, lookupVar,
   getEnvTypes, getFreeTyVars, getMetaTyVars,
   -- Types and unification
   newTyVarTy,
@@ -16,7 +16,7 @@
 import qualified Data.Map as Map
 import Text.PrettyPrint.HughesPJ
 import Data.IORef
-import Data.List( nub, (\\) )
+import Data.List( (\\) )
 ------------------------------------------
 -- The monad itself --
 ------------------------------------------
@@ -87,6 +87,12 @@
   where
     extend env = env { var_env = Map.insert var ty (var_env env) }
 
+extendVarEnvList :: [(Name, Sigma)] -> Tc a -> Tc a
+extendVarEnvList varTys (Tc m)
+  = Tc (\env -> m (extend env))
+  where
+    extend env = env { var_env = foldr (uncurry Map.insert) (var_env env) varTys }
+
 getEnv :: Tc (Map.Map Name Sigma)
 getEnv = Tc (\ env -> return (Right (var_env env)))
 
@@ -204,6 +210,9 @@
            Just ty -> do { ty' <- zonkType ty
                          ; writeTv tv ty' -- "Short out" multiple hops
                          ; return ty' } }
+zonkType (TyApp arg res) = do { arg' <- zonkType arg
+                              ; res' <- zonkType res
+                              ; return (TyApp arg' res') }
 
 ------------------------------------------
 -- Unification --
@@ -223,6 +232,9 @@
 unify (TyCon tc1) (TyCon tc2)
   | tc1 == tc2
   = return ()
+unify (TyApp arg1 res1)
+      (TyApp arg2 res2)
+  = do { unify arg1 arg2; unify res1 res2 }
 unify ty1 ty2 = failTc (text "Cannot unify types:" <+> vcat [ppr ty1, ppr ty2])
 
 -----------------------------------------
--- a/src/MicroHs/paper/TcTerm.hs
+++ b/src/MicroHs/paper/TcTerm.hs
@@ -33,8 +33,10 @@
 tcRho :: Term -> Expected Rho -> Tc ()
 -- Invariant: if the second argument is (Check rho),
 -- then rho is in weak-prenex form
-tcRho (Lit _) exp_ty
+tcRho (LitI _) exp_ty
   = instSigma intType exp_ty
+tcRho (LitB _) exp_ty
+  = instSigma boolType exp_ty
 tcRho (Var v) exp_ty
   = do { v_sigma <- lookupVar v
        ; instSigma v_sigma exp_ty }
@@ -63,6 +65,66 @@
 tcRho (Ann body ann_ty) exp_ty
   = do { checkSigma body ann_ty
        ; instSigma ann_ty exp_ty }
+tcRho (If e1 e2 e3) (Check exp_ty)  -- This?
+  = do { checkRho e1 boolType
+       ; checkSigma e2 exp_ty
+       ; checkSigma e3 exp_ty }
+tcRho (If e1 e2 e3) (Infer ref)
+  = do { checkRho e1 boolType
+       ; rho1 <- inferRho e2
+       ; rho2 <- inferRho e3
+       ; subsCheck rho1 rho2
+       ; subsCheck rho2 rho1
+       ; writeTcRef ref rho1 }
+tcRho (PLam pat body) (Infer ref)
+  = do { (binds, pat_ty) <- inferPat pat
+       ; body_ty <- extendVarEnvList binds (inferRho body)
+       ; writeTcRef ref (pat_ty --> body_ty) }
+tcRho (PLam pat body) (Check ty)
+  = do { (arg_ty, res_ty) <- unifyFun ty
+       ; binds <- checkPat pat arg_ty
+       ; extendVarEnvList binds (checkRho body res_ty) }
+
+tcPat :: Pat -> Expected Sigma -> Tc [(Name,Sigma)]
+tcPat PWild _exp_ty = return []
+tcPat (PVar v) (Infer ref) = do { ty <- newTyVarTy
+                                ; writeTcRef ref ty
+                                ; return [(v,ty)] }
+tcPat (PVar v) (Check ty) = return [(v, ty)]
+tcPat (PAnn p pat_ty) exp_ty = do { binds <- checkPat p pat_ty
+                                  ; instPatSigma pat_ty exp_ty
+                                  ; return binds }
+tcPat (PCon con ps) exp_ty
+  = do { (arg_tys, res_ty) <- instDataCon con
+       ; envs <- mapM check_arg (ps `zip` arg_tys)
+       ; instPatSigma res_ty exp_ty
+       ; return (concat envs) }
+  where
+    check_arg (p,ty) = checkPat p ty
+
+instPatSigma :: Sigma -> Expected Sigma -> Tc ()
+instPatSigma pat_ty (Infer ref) = writeTcRef ref pat_ty
+instPatSigma pat_ty (Check exp_ty) = subsCheck exp_ty pat_ty
+
+checkPat :: Pat -> Sigma -> Tc [(Name, Sigma)]
+checkPat p exp_ty = tcPat p (Check exp_ty)
+
+inferPat :: Pat -> Tc ([(Name, Sigma)], Sigma)
+inferPat pat
+  = do { ref <- newTcRef (error "inferPat: empty result")
+       ; binds <- tcPat pat (Infer ref)
+       ; ty <- readTcRef ref
+       ; return (binds, ty) }
+
+instDataCon :: Name -> Tc ([Sigma], Tau)
+instDataCon c = do
+  v_sigma <- lookupVar c
+  v_sigma' <- instantiate v_sigma
+  return (argsAndRes v_sigma')
+
+argsAndRes :: Rho -> ([Sigma], Tau)
+argsAndRes (Fun arg_ty res_ty) = (arg_ty : arg_tys, res_ty') where (arg_tys, res_ty') = argsAndRes res_ty
+argsAndRes t = ([], t)
 
 ------------------------------------------
 -- inferSigma and checkSigma
--