shithub: MicroHs

Download patch

ref: 47a5fa14089fdccb2e1ab87f803e40f9e166be02
parent: 729bd2ed896899acc1fafbce4d3a69bf28774d78
author: Lennart Augustsson <lennart.augustsson@epicgames.com>
date: Sat Dec 2 11:04:05 EST 2023

Prepare for quantified constraint

--- a/src/MicroHs/Desugar.hs
+++ b/src/MicroHs/Desugar.hs
@@ -64,7 +64,7 @@
           xs = [ mkIdent ("$x" ++ show j) | j <- [ 1 .. length ctx + length meths ] ]
       in  (qualIdent mn $ mkClassConstructor c, lams xs $ Lam f $ apps (Var f) (map Var xs)) :
           zipWith (\ i x -> (i, Lam f $ App (Var f) (lams xs $ Var x))) (supers ++ meths) xs
-    Instance _ _ _ _ -> []
+    Instance _ _ -> []
     Default _ -> []
 
 oneAlt :: Expr -> EAlts
--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -64,7 +64,7 @@
   | ForImp String Ident EType
   | Infix Fixity [Ident]
   | Class [EConstraint] LHS [FunDep] [EBind]  -- XXX will probable need initial forall with FD
-  | Instance [IdKind] [EConstraint] EConstraint [EBind]  -- no deriving yet
+  | Instance EConstraint [EBind]  -- no deriving yet
   | Default [EType]
   deriving (Show)
 
@@ -500,7 +500,7 @@
     Infix (a, p) is -> text ("infix" ++ f a) <+> text (show p) <+> hsep (punctuate (text ", ") (map ppIdent is))
       where f AssocLeft = "l"; f AssocRight = "r"; f AssocNone = ""
     Class sup lhs fds bs -> ppWhere (text "class" <+> ppCtx sup <+> ppLHS lhs <+> ppFunDeps fds) bs
-    Instance vs ct ty bs -> ppWhere (text "instance" <+> ppForall vs <+> ppCtx ct <+> ppEType ty) bs
+    Instance ct bs -> ppWhere (text "instance" <+> ppEType ct) bs
     Default ts -> text "default" <+> parens (hsep (punctuate (text ", ") (map ppEType ts)))
 
 ppCtx :: [EConstraint] -> Doc
--- a/src/MicroHs/Parse.hs
+++ b/src/MicroHs/Parse.hs
@@ -261,7 +261,7 @@
   <|< ForImp      <$> (pKeyword "foreign" *> pKeyword "import" *> pKeyword "ccall" *> pString) <*> pLIdent <*> (pSymbol "::" *> pType)
   <|< Infix       <$> ((,) <$> pAssoc <*> pPrec) <*> esepBy1 pTypeOper (pSpec ',')
   <|< Class       <$> (pKeyword "class"    *> pContext) <*> pLHS <*> pFunDeps     <*> pWhere pClsBind
-  <|< Instance    <$> (pKeyword "instance" *> pForall)  <*> pContext <*> pTypeApp <*> pWhere pClsBind
+  <|< Instance    <$> (pKeyword "instance" *> pType) <*> pWhere pClsBind
   <|< Default     <$> (pKeyword "default"  *> pParens (esepBy pType (pSpec ',')))
   where
     pAssoc = (AssocLeft <$ pKeyword "infixl") <|< (AssocRight <$ pKeyword "infixr") <|< (AssocNone <$ pKeyword "infix")
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -268,11 +268,11 @@
 getAppCon (EApp f _) = getAppCon f
 getAppCon _ = error "getAppCon"
 
-getApp :: EType -> (Ident, [EType])
+getApp :: HasCallStack => EType -> (Ident, [EType])
 getApp = loop []
   where loop as (EVar i) = (i, as)
         loop as (EApp f a) = loop (a:as) f
-        loop _ _ = error "getApp"
+        loop _ t = impossibleShow t
 
 -- Construct a dummy TModule for the currently compiled module.
 -- It has all the relevant export tables.
@@ -565,8 +565,7 @@
 
     mkInstInfo :: InstDictC -> T (Ident, InstInfo)
     mkInstInfo (e, iks, ctx, ct, fds) = do
-      ct' <- expandSyn ct
-      case (iks, ctx, getApp ct') of
+      case (iks, ctx, getApp ct) of
         ([], [], (c, [EVar i])) -> return $ (c, InstInfo (M.singleton i e) [] fds)
         (_,  _,  (c, ts      )) -> return $ (c, InstInfo M.empty [(e, ctx', ts')] fds)
           where ctx' = map (subst s) ctx
@@ -613,7 +612,8 @@
 
 addInstDict :: HasCallStack => Ident -> EConstraint -> T ()
 addInstDict i c = do
-  ics <- expandDict (EVar i) c
+  c' <- expandSyn c
+  ics <- expandDict (EVar i) c'
   addInstTable ics
 
 addEqDict :: Ident -> EType -> EType -> T ()
@@ -898,7 +898,7 @@
 newIdent :: SLoc -> String -> T Ident
 newIdent loc s = do
   u <- newUniq
-  return $ mkIdentSLoc loc $ s ++ "$" ++ show u
+  return $ mkIdentSLoc loc $ s ++ uniqIdentSep ++ show u
 
 tLookup :: HasCallStack =>
            String -> Ident -> T (Expr, EType)
@@ -1069,7 +1069,6 @@
         ESign t k        -> withVks vks k     $ \ vvks kr -> return $ Type    (i, vvks) (ESign t kr)
         _                -> withVks vks kType $ \ vvks _  -> return $ Type    (i, vvks) at
     Class ctx (i, vks) fds ms-> withVks vks kConstraint $ \ vvks _ -> return $ Class ctx (i, vvks) fds ms
-    Instance vks ctx t d -> withVks vks kConstraint $ \ vvks _ -> return $ Instance vvks ctx t d
     _                    -> return adef
 
 -- Check&rename the given kinds, apply reconstruction at the end
@@ -1143,7 +1142,7 @@
     Sign         i          t   ->                Sign    i     <$> tCheckTypeT kType t
     ForImp  ie i            t   ->                ForImp ie i   <$> tCheckTypeT kType t
     Class   ctx lhs@(_, iks) fds ms -> withVars iks $ Class     <$> tcCtx ctx <*> return lhs <*> mapM tcFD fds <*> mapM tcMethod ms
-    Instance iks ctx c m        -> withVars iks $ Instance iks  <$> tcCtx ctx <*> tCheckTypeT kConstraint c <*> return m
+    Instance ct m               ->                Instance      <$> tCheckTypeT kConstraint ct <*> return m
     Default ts                  ->                Default       <$> mapM (tCheckTypeT kType) ts
     _                           -> return d
  where
@@ -1269,12 +1268,22 @@
 tupleConstraints [c] = c
 tupleConstraints cs  = tApps (tupleConstr noSLoc (length cs)) cs
 
+splitInst :: EConstraint -> ([IdKind], [EConstraint], EConstraint)
+splitInst (EForall iks t) =
+  case splitInst t of
+    (iks', ctx, ct) -> (iks ++ iks', ctx, ct)
+splitInst act | Just (ctx, ct) <- getImplies act =
+  case splitInst ct of
+    (iks, ctxs, ct') -> (iks, ctx : ctxs, ct')
+splitInst ct = ([], [], ct)
+
 expandInst :: EDef -> T [EDef]
-expandInst dinst@(Instance vks ctx cc bs) = do
-  let loc = getSLoc cc
+expandInst dinst@(Instance act bs) = do
+  (vks, ctx, cc) <- splitInst <$> expandSyn act
+  let loc = getSLoc act
       qiCls = getAppCon cc
   iInst <- newIdent loc "inst"
-  let sign = Sign iInst (eForall vks $ addConstraints ctx cc)
+  let sign = Sign iInst act
 --  (e, _) <- tLookupV iCls
   ct <- gets classTable
 --  let qiCls = getAppCon e
@@ -1295,9 +1304,11 @@
   return [dinst, sign, bind]
 expandInst d = return [d]
 
+{-
 eForall :: [IdKind] -> EType -> EType
 eForall [] t = t
 eForall vs t = EForall vs t
+-}
 
 ---------------------
 
@@ -1551,7 +1562,7 @@
             LRat r -> do
               mex <- getExpected mt
               case mex of
-                Just v | v == mkIdent nameDouble  -> tcLit  mt loc' (LDouble (fromRational r))
+                Just v | v == mkIdent nameDouble -> tcLit  mt loc' (LDouble (fromRational r))
                 _ -> do
                   (f, ft) <- tInferExpr (EVar (mkIdentSLoc loc' "fromRational"))  -- XXX should have this qualified somehow
                   (_at, rt) <- unArrow loc ft
@@ -1560,6 +1571,7 @@
             -- Not LInteger, LRat
             _ -> tcLit mt loc' l
     ECase a arms -> do
+      -- XXX should look more like EIf
       (ea, ta) <- tInferExpr a
       tt <- tGetExpType mt
       earms <- mapM (tcArm tt ta) arms
@@ -1566,6 +1578,7 @@
       return (ECase ea earms)
     ELet bs a -> tcBinds bs $ \ ebs -> do { ea <- tcExpr mt a; return (ELet ebs ea) }
     ETuple es -> do
+      -- XXX checking if mt is a tuple would give better inference
       let
         n = length es
       (ees, tes) <- fmap unzip (mapM tInferExpr es)
@@ -1584,12 +1597,16 @@
           case as of
             SBind p a -> do
               let
-                sbind = maybe (mkIdentSLoc loc ">>=") (\ mn -> qualIdent mn (mkIdentSLoc loc ">>=")) mmn
+                -- XXX this wrong, it should be >>= from Monad
+                ibind = mkIdentSLoc loc ">>="
+                sbind = maybe ibind (\ mn -> qualIdent mn ibind) mmn
+                x = eVarI loc "$b"
               tcExpr mt (EApp (EApp (EVar sbind) a)
-                              (eLam [eVarI loc "$x"] (ECase (eVarI loc "$x") [(p, EAlts [([], EDo mmn ss)] [])])))
+                              (eLam [x] (ECase x [(p, EAlts [([], EDo mmn ss)] [])])))
             SThen a -> do
               let
-                sthen = maybe (mkIdentSLoc loc ">>") (\ mn -> qualIdent mn (mkIdentSLoc loc ">>") ) mmn
+                ithen = mkIdentSLoc loc ">>"
+                sthen = maybe ithen (\ mn -> qualIdent mn ithen) mmn
               tcExpr mt (EApp (EApp (EVar sthen) a) (EDo mmn ss))
                 
             SLet bs ->
@@ -2267,12 +2284,18 @@
 --  * name components of a tupled constraint
 --  * name superclasses of a constraint
 expandDict :: HasCallStack => Expr -> EConstraint -> T [InstDictC]
-expandDict edict acn = do
-  cn <- expandSyn acn
+expandDict edict ct = expandDict' [] [] edict =<< expandSyn ct
+
+expandDict' :: [IdKind] -> [EConstraint] -> Expr -> EConstraint -> T [InstDictC]
+expandDict' avks actx edict acc = do
   let
-    (iCls, args) = getApp cn
+    (bvks, bctx, cc) = splitInst acc
+    (iCls, args) = getApp cc
+    vks = avks ++ bvks
+    ctx = actx ++ bctx
   case getTupleConstr iCls of
-    Just _ -> concat <$> mapM (\ (i, a) -> expandDict (mkTupleSel i (length args) `EApp` edict) a) (zip [0..] args)
+    Just _ -> do
+      concat <$> mapM (\ (i, a) -> expandDict' vks ctx (mkTupleSel i (length args) `EApp` edict) a) (zip [0..] args)
     Nothing -> do
       ct <- gets classTable
       case M.lookup iCls ct of
@@ -2280,14 +2303,14 @@
           -- if iCls is a variable it's not in the class table, otherwise it's an error
           when (isConIdent iCls) $
             impossible
-          return [(edict, [], [], cn, [])]
+          return [(edict, vks, ctx, cc, [])]
         Just (iks, sups, _, _, fds) -> do
           let 
             vs = map idKindIdent iks
             sub = zip vs args
             sups' = map (subst sub) sups
-          insts <- concat <$> mapM (\ (i, sup) -> expandDict (EVar (mkSuperSel iCls i) `EApp` edict) sup) (zip [1 ..] sups')
-          return $ (edict, [], [], cn, fds) : insts
+          insts <- concat <$> mapM (\ (i, sup) -> expandDict' vks ctx (EVar (mkSuperSel iCls i) `EApp` edict) sup) (zip [1 ..] sups')
+          return $ (edict, vks, ctx, cc, fds) : insts
 
 mkSuperSel :: HasCallStack =>
               Ident -> Int -> Ident
@@ -2575,7 +2598,9 @@
 getBestMatches :: [(Int, (Expr, [EConstraint], [Improve]))] -> [(Expr, [EConstraint], [Improve])]
 getBestMatches [] = []
 getBestMatches ams =
-  let (args, insts) = partition (\ (_, (EVar i, _, _)) -> (adictPrefix ++ uniqIdentSep) `isPrefixOf` unIdent i) ams
+  let (args, insts) = partition (\ (_, (ei, _, _)) -> (adictPrefix ++ uniqIdentSep) `isPrefixOf` unIdent (unvar ei)) ams
+      unvar (EVar i) = i
+      unvar e = impossibleShow e
       pick ms =
         let b = minimum (map fst ms)         -- minimum substitution size
         in  [ ec | (s, ec) <- ms, s == b ]   -- pick out the smallest
--