shithub: MicroHs

Download patch

ref: 537d007a124feaedd0893a9e8fcf19ff10470a80
parent: b0ba3d89bb90cb2706143ab47e87bb6f06f93e4a
author: Lennart Augustsson <lennart.augustsson@epicgames.com>
date: Wed Nov 15 19:39:45 EST 2023

Implement defaulting.

--- a/README.md
+++ b/README.md
@@ -29,7 +29,6 @@
  * Type variables without a kind annotation are assumed to have kind `Type`.
  * There is no prefix negation.
  * There is no `Read` class.
- * There is no defaulting.
  * There is no deriving.
  * The `Prelude` has to be imported explicitly.
  * Polymorphic types are never inferred; use a type signature if you need it.
@@ -36,6 +35,7 @@
  * Always enabled extension:
    * ConstraintKinds
    * EmptyDataDecls
+   * ExtendedDefaultRules
    * FlexibleContexts
    * FlexibleInstance
    * ForeignFunctionInterface
--- a/ghc/Compat.hs
+++ b/ghc/Compat.hs
@@ -120,3 +120,11 @@
     rat3 f2      s  = f2 * expo s
 
     expo s = 10 ^ readInteger s
+
+partitionM :: Monad m => (a -> m Bool) -> [a] -> m ([a], [a])
+partitionM _ [] = return ([], [])
+partitionM p (x:xs) = do
+  b <- p x
+  (ts,fs) <- partitionM p xs
+  return $ if b then (x:ts, fs) else (ts, x:fs)
+
--- a/lib/Control/Monad.hs
+++ b/lib/Control/Monad.hs
@@ -74,6 +74,21 @@
 (>=>) :: forall (m :: Type -> Type) a b c . Monad m => (a -> m b) -> (b -> m c) -> (a -> m c)
 (>=>) = flip (<=<)
 
+filterM :: forall (m :: Type -> Type) a . Monad m => (a -> m Bool) -> [a] -> m [a]
+filterM _ [] = return []
+filterM p (x:xs) = do
+  b <- p x
+  ts <- filterM p xs
+  return $ if b then x : ts else ts
+
+partitionM :: forall (m :: Type -> Type) a . Monad m => (a -> m Bool) -> [a] -> m ([a], [a])
+partitionM _ [] = return ([], [])
+partitionM p (x:xs) = do
+  b <- p x
+  (ts,fs) <- partitionM p xs
+  return $ if b then (x:ts, fs) else (ts, x:fs)
+  
+
 {-
 -- Same for Maybe
 instance Functor Maybe where
--- a/src/MicroHs/Desugar.hs
+++ b/src/MicroHs/Desugar.hs
@@ -65,6 +65,7 @@
       in  (qualIdent mn $ mkClassConstructor c, lams xs $ Lam f $ apps (Var f) (map Var xs)) :
           zipWith (\ i x -> (expectQualified i, Lam f $ App (Var f) (lams xs $ Var x))) (supers ++ meths) xs
     Instance _ _ _ _ -> []
+    Default _ -> []
 
 oneAlt :: Expr -> EAlts
 oneAlt e = EAlts [([], e)] []
--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -66,6 +66,7 @@
   | Infix Fixity [Ident]
   | Class [EConstraint] LHS [FunDep] [EBind]  -- XXX will probable need initial forall with FD
   | Instance [IdKind] [EConstraint] EConstraint [EBind]  -- no deriving yet
+  | Default [EType]
   --Xderiving (Show)
 
 data ImportSpec = ImportSpec Bool Ident (Maybe Ident) (Maybe (Bool, [ImportItem]))  -- first Bool indicates 'qualified', second 'hiding'
@@ -491,6 +492,7 @@
       where f AssocLeft = "l"; f AssocRight = "r"; f AssocNone = ""
     Class sup lhs fds bs -> ppWhere (text "class" <+> ctx sup <+> ppLHS lhs <+> ppFunDeps fds) bs
     Instance vs ct ty bs -> ppWhere (text "instance" <+> ppForall vs <+> ctx ct <+> ppEType ty) bs
+    Default ts -> text "default" <+> parens (hsep (punctuate (text ", ") (map ppEType ts)))
  where ctx [] = empty
        ctx ts = ppEType (ETuple ts) <+> text "=>"
 
--- a/src/MicroHs/Interactive.hs
+++ b/src/MicroHs/Interactive.hs
@@ -27,7 +27,7 @@
 
 preamble :: String
 preamble = "module " ++ interactiveName ++ "(module " ++ interactiveName ++
-           ") where\nimport Prelude\n"
+           ") where\nimport Prelude\ndefault (Integer, Double)\n"
 
 start :: I ()
 start = do
--- a/src/MicroHs/Parse.hs
+++ b/src/MicroHs/Parse.hs
@@ -119,8 +119,8 @@
 
 keywords :: [String]
 keywords =
-  ["case", "class", "data", "do", "else", "forall", "foreign", "if", "import",
-   "in", "infix", "infixl", "infixr", "instance",
+  ["case", "class", "data", "default", "do", "else", "forall", "foreign", "if",
+   "import", "in", "infix", "infixl", "infixr", "instance",
    "let", "module", "newtype", "of", "primitive", "then", "type", "where"]
 
 pSpec :: Char -> P ()
@@ -262,6 +262,7 @@
   <|< Infix       <$> ((,) <$> pAssoc <*> pPrec) <*> esepBy1 pTypeOper (pSpec ',')
   <|< Class       <$> (pKeyword "class"    *> pContext) <*> pLHS <*> pFunDeps     <*> pWhere pClsBind
   <|< Instance    <$> (pKeyword "instance" *> pForall)  <*> pContext <*> pTypeApp <*> pWhere pClsBind
+  <|< Default     <$> (pKeyword "default"  *> pParens (esepBy pType (pSpec ',')))
   where
     pAssoc = (AssocLeft <$ pKeyword "infixl") <|< (AssocRight <$ pKeyword "infixr") <|< (AssocNone <$ pKeyword "infix")
     dig (TInt _ ii) | -2 <= i && i <= 9 = Just i  where i = _integerToInt ii
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -25,7 +25,7 @@
 import MicroHs.Expr
 --Ximport Compat
 --Ximport GHC.Stack
---import Debug.Trace
+--Ximport Debug.Trace
 
 boolPrefix :: String
 boolPrefix = "Data.Bool_Type."
@@ -104,6 +104,7 @@
 type ClassTable = M.Map ClassInfo  -- maps a class identifier to its associated information
 type InstTable  = M.Map InstInfo   -- indexed by class name
 type Constraints= [(Ident, EConstraint)]
+type Defaults   = [EType]          -- Current defaults
 
 -- To make type checking fast it is essential to solve constraints fast.
 -- The naive implementation of InstInfo would be [InstDict], but
@@ -374,6 +375,7 @@
   ClassTable            -- class info, indexed by QIdent
   InstTable             -- instances
   Constraints           -- constraints that have to be solved
+  Defaults              -- current defaults
   --Xderiving (Show)
 
 data TCMode = TCExpr | TCPat | TCType
@@ -380,73 +382,81 @@
   --Xderiving (Show)
 
 typeTable :: TCState -> TypeTable
-typeTable (TC _ _ _ tt _ _ _ _ _ _ _ _) = tt
+typeTable (TC _ _ _ tt _ _ _ _ _ _ _ _ _) = tt
 
 valueTable :: TCState -> ValueTable
-valueTable (TC _ _ _ _ _ vt _ _ _ _ _ _) = vt
+valueTable (TC _ _ _ _ _ vt _ _ _ _ _ _ _) = vt
 
 synTable :: TCState -> SynTable
-synTable (TC _ _ _ _ st _ _ _ _ _ _ _) = st
+synTable (TC _ _ _ _ st _ _ _ _ _ _ _ _) = st
 
 fixTable :: TCState -> FixTable
-fixTable (TC _ _ ft _ _ _ _ _ _ _ _ _) = ft
+fixTable (TC _ _ ft _ _ _ _ _ _ _ _ _ _) = ft
 
 assocTable :: TCState -> AssocTable
-assocTable (TC _ _ _ _ _ _ ast _ _ _ _ _) = ast
+assocTable (TC _ _ _ _ _ _ ast _ _ _ _ _ _) = ast
 
 uvarSubst :: TCState -> IM.IntMap EType
-uvarSubst (TC _ _ _ _ _ _ _ sub _ _ _ _) = sub
+uvarSubst (TC _ _ _ _ _ _ _ sub _ _ _ _ _) = sub
 
 moduleName :: TCState -> IdentModule
-moduleName (TC mn _ _ _ _ _ _ _ _ _ _ _) = mn
+moduleName (TC mn _ _ _ _ _ _ _ _ _ _ _ _) = mn
 
 classTable :: TCState -> ClassTable
-classTable (TC _ _ _ _ _ _ _ _ _ ct _ _) = ct
+classTable (TC _ _ _ _ _ _ _ _ _ ct _ _ _) = ct
 
 tcMode :: TCState -> TCMode
-tcMode (TC _ _ _ _ _ _ _ _ m _ _ _) = m
+tcMode (TC _ _ _ _ _ _ _ _ m _ _ _ _) = m
 
 instTable :: TCState -> InstTable
-instTable (TC _ _ _ _ _ _ _ _ _ _ is _) = is
+instTable (TC _ _ _ _ _ _ _ _ _ _ is _ _) = is
 
 constraints :: TCState -> Constraints
-constraints (TC _ _ _ _ _ _ _ _ _ _ _ e) = e
+constraints (TC _ _ _ _ _ _ _ _ _ _ _ e _) = e
 
+defaults :: TCState -> Defaults
+defaults (TC _ _ _ _ _ _ _ _ _ _ _ _ ds) = ds
+
 putValueTable :: ValueTable -> T ()
 putValueTable venv = do
-  TC mn n fx tenv senv _ ast sub m cs is es <- get
-  put (TC mn n fx tenv senv venv ast sub m cs is es)
+  TC mn n fx tenv senv _ ast sub m cs is es ds <- get
+  put (TC mn n fx tenv senv venv ast sub m cs is es ds)
 
 putTypeTable :: TypeTable -> T ()
 putTypeTable tenv = do
-  TC mn n fx _ senv venv ast sub m cs is es <- get
-  put (TC mn n fx tenv senv venv ast sub m cs is es)
+  TC mn n fx _ senv venv ast sub m cs is es ds <- get
+  put (TC mn n fx tenv senv venv ast sub m cs is es ds)
 
 putSynTable :: SynTable -> T ()
 putSynTable senv = do
-  TC mn n fx tenv _ venv ast sub m cs is es <- get
-  put (TC mn n fx tenv senv venv ast sub m cs is es)
+  TC mn n fx tenv _ venv ast sub m cs is es ds <- get
+  put (TC mn n fx tenv senv venv ast sub m cs is es ds)
 
 putUvarSubst :: IM.IntMap EType -> T ()
 putUvarSubst sub = do
-  TC mn n fx tenv senv venv ast _ m cs is es <- get
-  put (TC mn n fx tenv senv venv ast sub m cs is es)
+  TC mn n fx tenv senv venv ast _ m cs is es ds <- get
+  put (TC mn n fx tenv senv venv ast sub m cs is es ds)
 
 putTCMode :: TCMode -> T ()
 putTCMode m = do
-  TC mn n fx tenv senv venv ast sub _ cs is es <- get
-  put (TC mn n fx tenv senv venv ast sub m cs is es)
+  TC mn n fx tenv senv venv ast sub _ cs is es ds <- get
+  put (TC mn n fx tenv senv venv ast sub m cs is es ds)
 
 putInstTable :: InstTable -> T ()
 putInstTable is = do
-  TC mn n fx tenv senv venv ast sub m cs _ es <- get
-  put (TC mn n fx tenv senv venv ast sub m cs is es)
+  TC mn n fx tenv senv venv ast sub m cs _ es ds <- get
+  put (TC mn n fx tenv senv venv ast sub m cs is es ds)
 
 putConstraints :: Constraints -> T ()
 putConstraints es = do
-  TC mn n fx tenv senv venv ast sub m cs is _ <- get
-  put (TC mn n fx tenv senv venv ast sub m cs is es)
+  TC mn n fx tenv senv venv ast sub m cs is _ ds <- get
+  put (TC mn n fx tenv senv venv ast sub m cs is es ds)
 
+putDefaults :: Defaults -> T ()
+putDefaults ds = do
+  TC mn n fx tenv senv venv ast sub m cs is es _ <- get
+  put (TC mn n fx tenv senv venv ast sub m cs is es ds)
+
 withTCMode :: forall a . TCMode -> T a -> T a
 withTCMode m ta = do
   om <- gets tcMode
@@ -458,25 +468,25 @@
 -- Use the type table as the value table, and the primKind table as the type table.
 withTypeTable :: forall a . T a -> T a
 withTypeTable ta = do
-  TC mn n fx tt st vt ast sub m cs is es <- get
-  put (TC mn n fx primKindTable st tt ast sub m cs is es)
+  TC mn n fx tt st vt ast sub m cs is es ds <- get
+  put (TC mn n fx primKindTable st tt ast sub m cs is es ds)
   a <- ta
   -- Discard kind table, it will not have changed
-  TC mnr nr fxr _kr str ttr astr subr mr csr isr esr <- get
+  TC mnr nr fxr _kr str ttr astr subr mr csr isr esr dsr <- get
   -- Keep everyting, except that the returned value table
   -- becomes the type tables, and the old type table is restored.
-  put (TC mnr nr fxr ttr str vt astr subr mr csr isr esr)
+  put (TC mnr nr fxr ttr str vt astr subr mr csr isr esr dsr)
   return a
 
 addAssocTable :: Ident -> [Ident] -> T ()
 addAssocTable i ids = do
-  TC mn n fx tt st vt ast sub m cs is es <- get
-  put $ TC mn n fx tt st vt (M.insert i ids ast) sub m cs is es
+  TC mn n fx tt st vt ast sub m cs is es ds <- get
+  put $ TC mn n fx tt st vt (M.insert i ids ast) sub m cs is es ds
 
 addClassTable :: Ident -> ClassInfo -> T ()
 addClassTable i x = do
-  TC mn n fx tt st vt ast sub m cs is es <- get
-  put $ TC mn n fx tt st vt ast sub m (M.insert i x cs) is es
+  TC mn n fx tt st vt ast sub m cs is es ds <- get
+  put $ TC mn n fx tt st vt ast sub m (M.insert i x cs) is es ds
 
 addInstTable :: [InstDictC] -> T ()
 addInstTable ics = do
@@ -504,8 +514,8 @@
 addConstraint d ctx = do
 --  traceM $ "addConstraint: " ++ msg ++ " " ++ showIdent d ++ " :: " ++ showEType ctx
   ctx' <- expandSyn ctx
-  TC mn n fx tt st vt ast sub m cs is es <- get
-  put $ TC mn n fx tt st vt ast sub m cs is ((d, ctx') : es)
+  TC mn n fx tt st vt ast sub m cs is es ds <- get
+  put $ TC mn n fx tt st vt ast sub m cs is ((d, ctx') : es) ds
 
 withDict :: forall a . Ident -> EConstraint -> T a -> T a
 withDict i c ta = do
@@ -522,7 +532,7 @@
   let
     xts = foldr (uncurry stInsertGlb) ts primTypes
     xvs = foldr (uncurry stInsertGlb) vs primValues
-  in TC mn 1 fs xts ss xvs as IM.empty TCExpr cs is []
+  in TC mn 1 fs xts ss xvs as IM.empty TCExpr cs is [] []
 
 kTypeS :: EType
 kTypeS = kType
@@ -645,8 +655,8 @@
 
 setUVar :: TRef -> EType -> T ()
 setUVar i t = do
-  TC mn n fx tenv senv venv ast sub m cs is es <- get
-  put (TC mn n fx tenv senv venv ast (IM.insert i t sub) m cs is es)
+  TC mn n fx tenv senv venv ast sub m cs is es ds <- get
+  put (TC mn n fx tenv senv venv ast (IM.insert i t sub) m cs is es ds)
 
 getUVar :: Int -> T (Maybe EType)
 getUVar i = gets (IM.lookup i . uvarSubst)
@@ -753,8 +763,8 @@
 -- Reset unification map
 tcReset :: T ()
 tcReset = do
-  TC mn u fx tenv senv venv ast _ m cs is es <- get
-  put (TC mn u fx tenv senv venv ast IM.empty m cs is es)
+  TC mn u fx tenv senv venv ast _ m cs is es ds <- get
+  put (TC mn u fx tenv senv venv ast IM.empty m cs is es ds)
 
 newUVar :: T EType
 newUVar = EUVar <$> newUniq
@@ -763,9 +773,9 @@
 
 newUniq :: T TRef
 newUniq = do
-  TC mn n fx tenv senv venv ast sub m cs is es <- get
+  TC mn n fx tenv senv venv ast sub m cs is es ds <- get
   let n' = n+1
-  put (seq n' $ TC mn n' fx tenv senv venv ast sub m cs is es)
+  put (seq n' $ TC mn n' fx tenv senv venv ast sub m cs is es ds)
   return n
 
 newIdent :: SLoc -> String -> T Ident
@@ -868,8 +878,8 @@
 
 extFix :: Ident -> Fixity -> T ()
 extFix i fx = do
-  TC mn n fenv tenv senv venv ast sub m cs is es <- get
-  put $ TC mn n (M.insert i fx fenv) tenv senv venv ast sub m cs is es
+  TC mn n fenv tenv senv venv ast sub m cs is es ds <- get
+  put $ TC mn n (M.insert i fx fenv) tenv senv venv ast sub m cs is es ds
   return ()
 
 withExtVal :: forall a . --XHasCallStack =>
@@ -906,8 +916,13 @@
   mapM_ addTypeSyn dst
   dst' <- tcExpand dst
 --  traceM (showEDefs dst')
+  setDefault dst'
   tcDefsValue dst'
 
+setDefault :: [EDef] -> T ()
+setDefault defs =
+  putDefaults $ last $ [] : [ ts | Default ts <- defs ]
+
 tcAddInfix :: EDef -> T ()
 tcAddInfix (Infix fx is) = do
   mn <- gets moduleName
@@ -1015,6 +1030,7 @@
     ForImp  ie i            t   ->                ForImp ie i   <$> tCheckTypeT kType t
     Class   ctx lhs@(_, iks) fds ms -> withVars iks $ Class     <$> tcCtx ctx <*> return lhs <*> mapM tcFD fds <*> mapM tcMethod ms
     Instance iks ctx c m        -> withVars iks $ Instance iks  <$> tcCtx ctx <*> tCheckTypeT kConstraint c <*> return m
+    Default ts                  ->                Default       <$> mapM (tCheckTypeT kType) ts
     _                           -> return d
  where
    tcCtx = mapM (tCheckTypeT kConstraint)
@@ -1238,9 +1254,8 @@
 --      traceM $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr tt
 --      traceM $ "tcDefValue: def=" ++ showEDefs [adef]
       mn <- gets moduleName
-      teqns <- tcEqns tt eqns
+      teqns <- tcEqns True tt eqns
 --      traceM ("tcDefValue: after " ++ showEDefs [adef, Fcn i teqns])
-      -- Defaulting should be done here
       checkConstraints
       return $ Fcn (qualIdent mn i) teqns
     ForImp ie i t -> do
@@ -1644,26 +1659,26 @@
 tcExprLam :: Expected -> [Eqn] -> T Expr
 tcExprLam mt qs = do
   t <- tGetExpType mt
-  ELam <$> tcEqns t qs
+  ELam <$> tcEqns False t qs
 
-tcEqns :: EType -> [Eqn] -> T [Eqn]
+tcEqns :: Bool -> EType -> [Eqn] -> T [Eqn]
 --tcEqns t eqns | trace ("tcEqns: " ++ showEBind (BFcn dummyIdent eqns) ++ " :: " ++ showEType t) False = undefined
-tcEqns (EForall iks t) eqns = withExtTyps iks $ tcEqns t eqns
-tcEqns t eqns | Just (ctx, t') <- getImplies t = do
+tcEqns top (EForall iks t) eqns = withExtTyps iks $ tcEqns top t eqns
+tcEqns top t eqns | Just (ctx, t') <- getImplies t = do
   let loc = getSLoc eqns
   d <- newIdent loc "adict"
   f <- newIdent loc "fcnD"
   withDict d ctx $ do
-    eqns' <- tcEqns t' eqns
+    eqns' <- tcEqns top t' eqns
     let eqn =
           case eqns' of
             [Eqn [] alts] -> Eqn [EVar d] alts
             _             -> Eqn [EVar d] $ EAlts [([], EVar f)] [BFcn f eqns']
     return [eqn]
-tcEqns t eqns = do
+tcEqns top t eqns = do
   let loc = getSLoc eqns
   f <- newIdent loc "fcnS"
-  (eqns', ds) <- solveLocalConstraints $ mapM (tcEqn t) eqns
+  (eqns', ds) <- solveAndDefault top $ mapM (tcEqn t) eqns
   case ds of
     [] -> return eqns'
     _  -> do
@@ -1804,7 +1819,7 @@
   case abind of
     BFcn i eqns -> do
       (_, tt) <- tLookupV i
-      teqns <- tcEqns tt eqns
+      teqns <- tcEqns False tt eqns
       return $ BFcn i teqns
     BPat p a -> do
       (ep, tp) <- withTCMode TCPat $ tInferExpr p  -- pattern variables already bound
@@ -2036,9 +2051,11 @@
 
 ---------------------------------
 
+type Solved = (Ident, Expr)
+
 -- Solve constraints generated locally in 'ta'.
 -- Keep any unsolved ones for later.
-solveLocalConstraints :: forall a . T a -> T (a, [(Ident, Expr)])
+solveLocalConstraints :: forall a . T a -> T (a, [Solved])
 solveLocalConstraints ta = do
   cs <- gets constraints           -- old constraints
   putConstraints []                -- start empty
@@ -2047,6 +2064,42 @@
   un <- gets constraints           -- get remaining unsolved
   putConstraints (un ++ cs)        -- put back unsolved and old constraints
   return (a, ds)
+
+solveAndDefault :: forall a . Bool -> T a -> T (a, [Solved])
+solveAndDefault False ta = solveLocalConstraints ta
+solveAndDefault True  ta = do
+  a <- ta
+  ds <- solveConstraints
+  cs <- gets constraints
+  vs <- getMetaTyVars (map snd cs)    -- These are the type variables that need defaulting
+--  traceM $ "solveAndDefault" ++ show vs
+  -- XXX may have to iterate this with fundeps
+  ds' <- concat <$> mapM defaultOneTyVar vs
+  return (a, ds ++ ds')
+
+constraintHasTyVar :: TRef -> (Ident, EConstraint) -> T Bool
+constraintHasTyVar tv (_, t) = elem tv <$> getMetaTyVars [t]
+
+defaultOneTyVar :: TRef -> T [Solved]
+defaultOneTyVar tv = do
+  old <- get             -- get entire old state
+  -- split constraints into those with the current tyvar and those without
+  (ourcs, othercs) <- partitionM (constraintHasTyVar tv) (constraints old)
+  let tryDefaults [] = return []
+      tryDefaults (ty:tys) = do
+        setUVar tv ty
+        putConstraints ourcs
+        ds <- solveConstraints
+        rcs <- gets constraints
+        if null rcs then do
+          -- Success, the type variable is gone
+          putConstraints othercs   -- put back the other constraints
+          return ds
+         else do
+          -- Not solved, try with the nest type
+          put old            -- restore solver state
+          tryDefaults tys    -- and try with next type
+  tryDefaults (defaults old)
 
 {-
 showInstInfo :: InstInfo -> String
--- /dev/null
+++ b/tests/Default.hs
@@ -1,0 +1,9 @@
+module Default(main) where
+import Prelude
+default (Int, Double)
+
+main :: IO ()
+main = do
+  print 1
+  print 1.2
+  print []   -- defaults to Int, a little weird
--- /dev/null
+++ b/tests/Default.ref
@@ -1,0 +1,3 @@
+1
+1.2
+[]
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -27,6 +27,7 @@
 	$(MHS) Class      && $(EVAL) > Class.out      && diff Class.ref Class.out
 	$(MHS) Eq         && $(EVAL) > Eq.out         && diff Eq.ref Eq.out
 	$(MHS) Floating   && $(EVAL) > Floating.out   && diff Floating.ref Floating.out
+	$(MHS) Default    && $(EVAL) > Default.out    && diff Default.ref Default.out
 
 errtest:
 	sh errtester.sh < errmsg.test
--