shithub: MicroHs

Download patch

ref: d24fd07ecfe72110ce09125dc3d474304af960ea
parent: 3bc7e25513e7b7057ab21fb88d4ba562b6b1f06a
author: Lennart Augustsson <lennart@augustsson.net>
date: Fri Dec 13 14:36:11 EST 2024

Better handle ing pattern synonyms.

--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -133,6 +133,7 @@
   | EForall Bool [IdKind] EType  -- True indicates explicit forall in the code
   -- only while type checking
   | EUVar Int
+  | EQVar Expr EType             -- already resolved identifier
   -- only after type checking
   | ECon Con
 --DEBUG  deriving (Show)
@@ -163,7 +164,7 @@
 data Con
   = ConData ConTyInfo Ident [FieldName]
   | ConNew Ident [FieldName]
-  | ConSyn Ident Int
+  | ConSyn Ident Int (Expr, EType)
 --DEBUG  deriving(Show)
 
 data Listish
@@ -179,23 +180,23 @@
             Con -> Ident
 conIdent (ConData _ i _) = i
 conIdent (ConNew i _) = i
-conIdent (ConSyn i _) = i
+conIdent (ConSyn i _ _) = i
 
 conArity :: Con -> Int
 conArity (ConData cs i _) = fromMaybe (error "conArity") $ lookup i cs
 conArity (ConNew _ _) = 1
-conArity (ConSyn _ n) = n
+conArity (ConSyn _ n _) = n
 
 conFields :: Con -> [FieldName]
 conFields (ConData _ _ fs) = fs
 conFields (ConNew _ fs) = fs
-conFields (ConSyn _ _) = []
+conFields (ConSyn _ _ _) = []
 
 instance Eq Con where
-  (==) (ConData _ i _) (ConData _ j _) = i == j
-  (==) (ConNew    i _) (ConNew    j _) = i == j
-  (==) (ConSyn    i _) (ConSyn    j _) = i == j
-  (==) _               _               = False
+  (==) (ConData _ i _)   (ConData _ j _)   = i == j
+  (==) (ConNew    i _)   (ConNew    j _)   = i == j
+  (==) (ConSyn    i _ _) (ConSyn    j _ _) = i == j
+  (==) _                 _                 = False
 
 data Lit
   = LInt Int
@@ -380,6 +381,7 @@
   getSLoc (ELazy _ e) = getSLoc e
   getSLoc (EOr es) = getSLoc es
   getSLoc (EUVar _) = error "getSLoc EUVar"
+  getSLoc (EQVar e _) = getSLoc e
   getSLoc (ECon c) = getSLoc c
   getSLoc (EForall _ [] e) = getSLoc e
   getSLoc (EForall _ iks _) = getSLoc iks
@@ -394,7 +396,7 @@
 instance HasLoc Con where
   getSLoc (ConData _ i _) = getSLoc i
   getSLoc (ConNew i _) = getSLoc i
-  getSLoc (ConSyn i _) = getSLoc i
+  getSLoc (ConSyn i _ _) = getSLoc i
 
 instance HasLoc Listish where
   getSLoc (LList es) = getSLoc es
@@ -537,6 +539,7 @@
     ELazy _ p -> allVarsExpr' p
     EOr ps -> composeMap allVarsExpr' ps
     EUVar _ -> id
+    EQVar e _ -> allVarsExpr' e
     ECon c -> (conIdent c :)
     EForall _ iks e -> (map (\ (IdKind i _) -> i) iks ++) . allVarsExpr' e
   where field (EField _ e) = allVarsExpr' e
@@ -572,7 +575,7 @@
 setSLocCon :: SLoc -> Con -> Con
 setSLocCon l (ConData ti i fs) = ConData ti (setSLocIdent l i) fs
 setSLocCon l (ConNew i fs) = ConNew (setSLocIdent l i) fs
-setSLocCon l (ConSyn i n) = ConSyn (setSLocIdent l i) n
+setSLocCon l (ConSyn i n m) = ConSyn (setSLocIdent l i) n m
 
 errorMessage :: forall a .
                 HasCallStack =>
@@ -744,6 +747,7 @@
         ELazy False p -> text "!" <> ppE p
         EOr ps -> parens $ hsep (punctuate (text ";") (map ppE ps))
         EUVar i -> text ("_a" ++ show i)
+        EQVar e _ -> ppE e
         ECon c -> ppCon c
         EForall _ iks e -> ppForall iks <+> ppEType e
 
@@ -780,7 +784,7 @@
 ppCon :: Con -> Doc
 ppCon (ConData _ s _) = ppIdent s
 ppCon (ConNew s _) = ppIdent s
-ppCon (ConSyn s _) = ppIdent s
+ppCon (ConSyn s _ _) = ppIdent s
 
 -- Literals are tagged the way they appear in the combinator file:
 --  #   Int
--- a/src/MicroHs/Parse.hs
+++ b/src/MicroHs/Parse.hs
@@ -462,6 +462,8 @@
       impType     <$> pUQIdentSym <*> (pSpec '(' *> pConList <* pSpec ')')
   <|< ImpTypeSome <$> pUQIdentSym <*> pure []
   <|< ImpValue    <$> pLQIdentSym
+  <|< ImpValue    <$> (pKeyword "pattern" *> pUQIdentSym)
+  <|< ImpTypeSome <$> (pKeyword "type" *> pLQIdentSym) <*> pure []
   where impType i Nothing   = ImpTypeAll  i
         impType i (Just is) = ImpTypeSome i is
 
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -1300,8 +1300,8 @@
   --  type check all definitions (the inferred ones will be rechecked)
   defs'' <- mapM (\ d -> do { tcReset; tcDefValue d}) defs'
   let defs''' = concat signDefs ++ defs''
-  traceM $ "tcDefsValue: ------------ done"
-  traceM $ showEDefs defs'''
+--  traceM $ "tcDefsValue: ------------ done"
+--  traceM $ showEDefs defs'''
   pure defs'''
 
 -- Infer a type for a definition
@@ -1390,6 +1390,7 @@
 addPatSyn :: EType -> Ident -> T ()
 addPatSyn at i = do
   at' <- expandSyn at
+  mn <- gets moduleName
   let (t', n) =
         -- Patterns must have two universals.
         -- XXX Add double contexts
@@ -1397,8 +1398,9 @@
           EForall b vs t -> (EForall b     vs $ EForall False []  t, arity t)
           _              -> (EForall False [] $ EForall False [] at, arity at)
       arity = length . fst . getArrows  -- XXX
-  mn <- gets moduleName
-  extValETop i t' $ ECon $ ConSyn (qualIdent mn i) n
+      qi = qualIdent mn i
+      mtch = (EVar $ mkPatSynMatch qi, mkPatSynType t')
+  extValETop i t' $ ECon $ ConSyn qi n mtch
 
 -- XXX FunDep
 addValueClass :: [EConstraint] -> Ident -> [IdKind] -> [FunDep] -> [EBind] -> T ()
@@ -1435,9 +1437,9 @@
     Fcn i eqns -> do
       (_, t) <- tLookup "type signature" i
       t' <- expandSyn t
-      tcTrace $ "tcDefValue: ------- start " ++ showIdent i
-      tcTrace $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr t'
-      tcTrace $ "tcDefValue: " ++ showEDefs [adef]
+--      tcTrace $ "tcDefValue: ------- start " ++ showIdent i
+--      tcTrace $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr t'
+--      tcTrace $ "tcDefValue: " ++ showEDefs [adef]
       teqns <- tcEqns True t' eqns
 --      tcTrace ("tcDefValue: after\n" ++ showEDefs [adef, Fcn i teqns])
 --      cs <- gets constraints
@@ -1466,7 +1468,7 @@
       dropForall (EForall _ _ t) = dropForall t
       dropForall t = t
   (_, _, p') <- step vks (dropForall at)   -- XXX
-  me' <- traverse (tcEqns True at) me
+  me' <- case me of Nothing -> pure Nothing; Just e -> Just <$> tcEqns True at e
   mn <- gets moduleName
   checkConstraints
 --  traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
@@ -1604,6 +1606,8 @@
                  _ -> return t
              --tcTrace $ "EVar: " ++ showIdent i ++ " :: " ++ showExpr t ++ " = " ++ showExpr t' ++ " mt=" ++ show mt
              instSigma loc e t' mt
+    EQVar e t ->  -- already resolved, just instantiate
+             instSigma loc e t mt
 
     EApp f a -> do
 --      tcTrace $ "txExpr(0) EApp: expr=" ++ show ae ++ ":: " ++ show mt
@@ -2254,11 +2258,11 @@
             return ()
     (con, xpt) <- tLookupV i
     case con of
-     ECon (ConSyn ps n) -> do
+     ECon (ConSyn _ n (e, t)) -> do
       checkArity n
       let (_, _, pcon) = mkMatchIdents loc False n
-          vp = EViewPat (EVar $ mkPatSynMatch ps) (eApps (EVar pcon) args)
-      traceM ("patsyn " ++ show vp)
+          vp = EViewPat (EQVar e t) (eApps (EVar pcon) args)
+      --traceM ("patsyn " ++ show vp)
       tcPat mt vp
      _ -> do 
 --      tcTrace (show xpt)
@@ -2548,15 +2552,17 @@
 addPatSynMatch ps at = do
   mn <- gets moduleName
   let im = mkPatSynMatch ps
-      (m, vks, (ats, rt)) =
+  extValETop im (mkPatSynType at) (EVar (qualIdent mn im))
+  return im
+
+mkPatSynType :: EType -> EType
+mkPatSynType at =
+  let (m, vks, (ats, rt)) =
         case at of
           EForall m' vks' (EForall _ _ t) -> (m', vks', getArrows t) -- XXX
-          _ -> impossible
-      (pstycon, _, _) = mkMatchIdents (getSLoc ps) False (length ats) -- XXX
-      conty = EForall m vks $ rt `tArrow` tApps pstycon ats -- XXX
---      con = ConSyn (qualIdent mn im) (length ats)
-  extValETop im conty (EVar (qualIdent mn im))
-  return im
+          _ -> impossibleShow at
+      (pstycon, _, _) = mkMatchIdents (getSLoc at) False (length ats) -- XXX
+  in  EForall m vks $ rt `tArrow` tApps pstycon ats -- XXX
 
 mkPatSynMatch :: Ident -> Ident
 mkPatSynMatch i = addIdentSuffix i "%"