shithub: MicroHs

Download patch

ref: 1ae323f927922ea500097175e2dfebaa1776cb15
parent: e8facabf8870ecb9ea3a372517b2096d09aeea08
author: Lennart Augustsson <lennart@augustsson.net>
date: Thu Apr 25 11:38:55 EDT 2024

Turn GADTs into ADTs in the parser.

--- a/src/MicroHs/Parse.hs
+++ b/src/MicroHs/Parse.hs
@@ -10,6 +10,7 @@
 import Text.ParserComb as P
 import MicroHs.Lex
 import MicroHs.Expr hiding (getSLoc)
+import qualified MicroHs.Expr as E
 import MicroHs.Ident
 --import Debug.Trace
 
@@ -308,19 +309,18 @@
 
 pDef :: P EDef
 pDef =
-      Data        <$> (pKeyword "data"    *> pLHS) <*> ((pSymbol "=" *> esepBy1 pConstr (pSymbol "|"))
-                                                        <|< pure []) <*> pDeriving
-  <|< Newtype     <$> (pKeyword "newtype" *> pLHS) <*> (pSymbol "=" *> (Constr [] [] <$> pUIdentSym <*> pField)) <*> pDeriving
-  <|< Type        <$> (pKeyword "type"    *> pLHS) <*> (pSymbol "=" *> pType)
-  <|< uncurry Fcn <$> pEqns
-  <|< Sign        <$> ((esepBy1 pLIdentSym (pSpec ',')) <* dcolon) <*> pType
-  <|< Import      <$> (pKeyword "import"  *> pImportSpec)
-  <|< ForImp      <$> (pKeyword "foreign" *> pKeyword "import" *> pKeyword "ccall" *> eoptional pString) <*> pLIdent <*> (pSymbol "::" *> pType)
-  <|< Infix       <$> ((,) <$> pAssoc <*> pPrec) <*> esepBy1 pTypeOper (pSpec ',')
-  <|< Class       <$> (pKeyword "class"    *> pContext) <*> pLHS <*> pFunDeps     <*> pWhere pClsBind
-  <|< Instance    <$> (pKeyword "instance" *> pType) <*> pWhere pClsBind
-  <|< Default     <$> (pKeyword "default"  *> pParens (esepBy pType (pSpec ',')))
-  <|< KindSign    <$> (pKeyword "type"    *> pTypeIdentSym) <*> (pSymbol "::" *> pKind)
+      uncurry Data <$> (pKeyword "data"    *> pData) <*> pDeriving
+  <|< Newtype      <$> (pKeyword "newtype" *> pLHS) <*> (pSymbol "=" *> (Constr [] [] <$> pUIdentSym <*> pField)) <*> pDeriving
+  <|< Type         <$> (pKeyword "type"    *> pLHS) <*> (pSymbol "=" *> pType)
+  <|< uncurry Fcn  <$> pEqns
+  <|< Sign         <$> ((esepBy1 pLIdentSym (pSpec ',')) <* dcolon) <*> pType
+  <|< Import       <$> (pKeyword "import"  *> pImportSpec)
+  <|< ForImp       <$> (pKeyword "foreign" *> pKeyword "import" *> pKeyword "ccall" *> eoptional pString) <*> pLIdent <*> (pSymbol "::" *> pType)
+  <|< Infix        <$> ((,) <$> pAssoc <*> pPrec) <*> esepBy1 pTypeOper (pSpec ',')
+  <|< Class        <$> (pKeyword "class"    *> pContext) <*> pLHS <*> pFunDeps     <*> pWhere pClsBind
+  <|< Instance     <$> (pKeyword "instance" *> pType) <*> pWhere pClsBind
+  <|< Default      <$> (pKeyword "default"  *> pParens (esepBy pType (pSpec ',')))
+  <|< KindSign     <$> (pKeyword "type"    *> pTypeIdentSym) <*> (pSymbol "::" *> pKind)
   where
     pAssoc = (AssocLeft <$ pKeyword "infixl") <|< (AssocRight <$ pKeyword "infixr") <|< (AssocNone <$ pKeyword "infix")
     dig (TInt _ ii) | 0 <= i && i <= 9 = Just i  where i = fromInteger ii
@@ -335,6 +335,41 @@
       pure fs
     dcolon = pSymbol "::" <|< pSymbol "\x2237"
 
+pData :: P (LHS, [Constr])
+pData = do
+  lhs <- pLHS
+  let pConstrs = pSymbol "=" *> esepBy1 pConstr (pSymbol "|")
+  ((,) lhs <$> pConstrs)
+   <|< pGADT lhs
+   <|< pure (lhs, [])
+
+pGADT :: LHS -> P (LHS, [Constr])
+pGADT (n, vks) = do
+  let f (IdKind i k) = IdKind (addIdentSuffix i "$") k
+      lhs = (n, map f vks)
+  pKeyword "where"
+  gs <- pBlock pGADTconstr
+  pure (lhs, map (dsGADT lhs) gs)
+
+pGADTconstr :: P (Ident, [IdKind], [EConstraint], [SType], EType)
+pGADTconstr = do
+  cn <- pUIdentSym
+  pSymbol "::"
+  es <- pForall
+  ctx <- pContext
+  args <- emany (pSTypeApp <* pSymbol "->")
+  res <- pType
+  pure (cn, es, ctx, args, res)
+
+dsGADT :: LHS -> (Ident, [IdKind], [EConstraint], [SType], EType) -> Constr
+dsGADT (tnm, vks) (cnm, es, ctx, stys, rty) =
+  case getAppM rty of
+    Just (tnm', ts) | tnm == tnm' && length vks == length ts -> Constr es' ctx' cnm (Left stys)
+      where es' = if null es then map (\ i -> IdKind i (EVar dummyIdent)) (freeTyVars (rty : map snd stys)) else es
+            ctx' = zipWith (\ (IdKind i _) t -> eq (EVar i) t) vks ts ++ ctx
+            eq t1 t2 = EApp (EApp (EVar (mkIdentSLoc (E.getSLoc t1) "~")) t1) t2
+    _ -> errorMessage (E.getSLoc rty) $ "Bad GADT result type" ++ show (rty, tnm, vks)
+
 pDeriving :: P [EConstraint]
 pDeriving = pKeyword "deriving" *> pDer <|< pure []
   where pDer =     pParens (esepBy pType (pSpec ','))
@@ -368,6 +403,8 @@
 pSAType = (,) <$> pStrict <*> pAType
 pSType :: P (Bool, EType)
 pSType  = (,) <$> pStrict <*> pType
+pSTypeApp :: P (Bool, EType)
+pSTypeApp  = (,) <$> pStrict <*> pTypeApp
 pStrict :: P Bool
 pStrict = (True <$ pSpec '!') <|< pure False
 
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -2619,7 +2619,7 @@
 solveTypeEq loc _iCls [t1, t2] | isEUVar t1 || isEUVar t2 = return $ Just (ETuple [], [], [(loc, t1, t2)])
                                | otherwise = do
   eqs <- gets typeEqTable
-  traceM ("solveTypeEq eqs=" ++ show eqs)
+  --traceM ("solveTypeEq eqs=" ++ show eqs)
   case solveEq eqs t1 t2 of
     Nothing -> return Nothing
     Just (de, tts) -> do
--- a/tests/TypeEq.hs
+++ b/tests/TypeEq.hs
@@ -23,3 +23,19 @@
 main = do
   print (foo True)
   print (eval e1)
+  print (geval ge1)
+
+data GExp a where
+  GInt :: Int -> GExp Int
+  GAdd :: GExp Int -> GExp Int -> GExp Int
+  GEqu :: GExp Int -> GExp Int -> GExp Bool
+  GIff :: GExp Bool -> GExp a -> GExp a -> GExp a
+
+geval :: GExp a -> a
+geval (GInt i) = i
+geval (GAdd e1 e2) = geval e1 + geval e2
+geval (GEqu e1 e2) = geval e1 == geval e2
+geval (GIff c e1 e2) = if geval c then geval e1 else geval e2
+
+ge1 :: GExp Int
+ge1 = GIff (GAdd (GInt 1) (GInt 2) `GEqu` GInt 3) (GInt 1) (GInt 999)
--- a/tests/TypeEq.ref
+++ b/tests/TypeEq.ref
@@ -1,2 +1,3 @@
 False
 1
+1
--