shithub: MicroHs

Download patch

ref: f1136c9acb3b74f87a3bef4ca931564394a917f7
parent: dc708ee4bd2c5e76a60d4cecff26bac23318c04a
author: Lennart Augustsson <lennart@augustsson.net>
date: Mon Dec 23 06:01:14 EST 2024

More patsyn

--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -747,8 +747,8 @@
         ELazy False p -> text "!" <> ppE p
         EOr ps -> parens $ hsep (punctuate (text ";") (map ppE ps))
         EUVar i -> text ("_a" ++ show i)
-        EQVar e _ -> ppE e
-        ECon c -> ppCon c
+        EQVar e t -> parens $ ppE e <> text "::" <> ppE t
+        ECon c -> text "***" <> ppCon c
         EForall _ iks e -> ppForall iks <+> ppEType e
 
     ppApp :: [Expr] -> Expr -> Doc
--- a/src/MicroHs/TCMonad.hs
+++ b/src/MicroHs/TCMonad.hs
@@ -97,7 +97,6 @@
   dataTable   :: DataTable,             -- data/newtype definitions
   valueTable  :: ValueTable,            -- value symbol table
   assocTable  :: AssocTable,            -- values associated with a type, indexed by QIdent
-  patSynTable :: M.Map ([Ident], EPat), -- pattern synonyms
   uvarSubst   :: (IM.IntMap EType),     -- mapping from unique id to type
   tcMode      :: TCMode,                -- pattern, value, or type
   classTable  :: ClassTable,            -- class info, indexed by QIdent
@@ -254,3 +253,7 @@
 
 tImplies :: EType -> EType -> EType
 tImplies a r = tApp (tApp (tConI builtinLoc "Primitives.=>") a) r
+
+etImplies :: EType -> EType -> EType
+etImplies (EVar i) t | i == tupleConstr noSLoc 0 = t
+etImplies a t = tImplies a t
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -389,8 +389,7 @@
           classTable = gClassTable globs,
           ctxTables = (gInstInfo globs, [], [], []),
           constraints = [],
-          defaults = dflts,
-          patSynTable = M.empty
+          defaults = dflts
         }
 
 mergeDefaults :: Defaults -> Defaults -> Defaults
@@ -1090,7 +1089,7 @@
     Type    lhs t          -> withLHS lhs $ \ lhs' -> first                    (Type    lhs') <$> tInferTypeT t
     Class   ctx lhs fds ms -> withLHS lhs $ \ lhs' -> flip (,) kConstraint <$> (Class         <$> tcCtx ctx <*> return lhs' <*> mapM tcFD fds <*> mapM tcMethod ms)
     Sign      is t         ->                                                  Sign      is   <$> tCheckTypeTImpl kType t
-    PatternSign is t       ->                                                  PatternSign is   <$> tCheckTypeTImpl kType t
+    PatternSign is t       ->                                                  PatternSign is <$> tCheckTypeTImpl kType t
     ForImp ie i t          ->                                                  ForImp ie i    <$> tCheckTypeTImpl kType t
     Instance ct m          ->                                                  Instance       <$> tCheckTypeTImpl kConstraint ct <*> return m
     Default mc ts          ->                                                  Default (Just c) <$> mapM (tcDefault c) ts
@@ -1290,6 +1289,7 @@
               node _ = undefined
       tcSCC (AcyclicSCC d) = tInferDefs [d]
       tcSCC (CyclicSCC ds) = tInferDefs ds
+  --traceM $ "tcDefsValue: unsigned=" ++ show unsigned
   -- type infer and enter each SCC in the symbol table
   -- return inferred Sign
   signDefs <- mapM tcSCC sccs
@@ -1307,6 +1307,7 @@
 -- Infer a type for a definition
 tInferDefs :: [EDef] -> T [EDef]
 tInferDefs fcns = do
+--  traceM "tInferDefs"
   tcReset
   -- Invent type variables for the definitions
   let idOf (Fcn i _) = i
@@ -1337,8 +1338,9 @@
         t''' <- quantify vs' t''
         --tcTrace $ "tInferDefs: " ++ showIdent i ++ " :: " ++ showEType t'''
         if isPatSyn i then do
-          addPatSyn t''' i
-          return $ PatternSign [i] t'''
+          t'''' <- canonPatSynType t'''
+          addPatSyn t'''' i
+          return $ PatternSign [i] t''''
          else do
           extValQTop i t'''
           return $ Sign [i] t'''
@@ -1383,24 +1385,21 @@
       addConFields tycon con
     ForImp _ i t -> extValQTop i t
     Class ctx (i, vks) fds ms -> addValueClass ctx i vks fds ms
-    PatternSign is at -> mapM_ (addPatSyn at) is
+    PatternSign is at -> do
+      at' <- canonPatSynType at
+      mapM_ (addPatSyn at') is
     _ -> return ()
 
 -- Add a pattern synonym to the symbol table.
 addPatSyn :: EType -> Ident -> T ()
 addPatSyn at i = do
-  at' <- expandSyn at
   mn <- gets moduleName
-  let (t', n) =
-        -- Patterns must have two universals.
-        -- XXX Add double contexts
-        case at' of
-          EForall b vs t -> (EForall b     vs $ EForall False []  t, arity t)
-          _              -> (EForall False [] $ EForall False [] at, arity at)
-      arity = length . fst . getArrows  -- XXX
+  let (_, _, _, _, t) = splitPatSynType at
+      n = length $ fst $ getArrows t
       qi = qualIdent mn i
-      mtch = (EVar $ mkPatSynMatch qi, mkPatSynType t')
-  extValETop i t' $ ECon $ ConSyn qi n mtch
+      qip = mkPatSynMatch qi
+      mtch = (EVar qip, mkPatSynMatchType qip at)
+  extValETop i at $ ECon $ ConSyn qi n mtch
 
 -- XXX FunDep
 addValueClass :: [EConstraint] -> Ident -> [IdKind] -> [FunDep] -> [EBind] -> T ()
@@ -1437,6 +1436,9 @@
     Fcn i eqns -> do
       (_, t) <- tLookup "type signature" i
       t' <- expandSyn t
+      when (isConIdent i) $ do
+        tcTrace $ "tcDefValue: patsyn\n" ++ show i ++ " :: " ++ show t'
+        tcTrace $ "tcDefValue:\n" ++ showEDefs [adef]
 --      tcTrace $ "tcDefValue: ------- start " ++ showIdent i
 --      tcTrace $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr t'
 --      tcTrace $ "tcDefValue: " ++ showEDefs [adef]
@@ -1452,26 +1454,35 @@
       mn <- gets moduleName
       t' <- expandSyn t
       return (ForImp ie (qualIdent mn i) t')
-    Pattern (i, _) _ _ -> do
+    Pattern _ _ _ -> impossible
+    {- Pattern (i, _) _ _ -> do
       (_, t) <- tLookup "pattern type signature" i
       t' <- expandSyn t
-      tcPattern adef t'
+      tcPattern adef t'-}
     _ -> return adef
 
+-- This is only used during inference.
+-- When doing type checking the actual Pattern definition will have been
+-- removed by expandPatSyn.
 tcPattern :: EDef -> EType -> T EDef
 tcPattern (Pattern (ip, vks) p me) at = do
---  traceM ("Pattern " ++ show (ip, vks, p, me, at))
-  let step [] t = tcPat (Check t) p
+  traceM ("Pattern enter " ++ show (ip, vks, p, me, at))
+  let (_vks1, ctx1, _vks2, _ctx2, ty) = splitPatSynType at
+      step [] t = do
+        d <- newADictIdent (getSLoc ip)
+--        traceM $ "tcPattern: add ctx " ++ show ctx1
+        withDict d ctx1 $ do
+          r <- tcPat (Check t) p
+          _ <- solveConstraints
+          checkConstraints
+          pure r
       step (ik:iks) t = do
         (ti, tr) <- unArrow (getSLoc ik) t
         withExtVal (idKindIdent ik) ti $ step iks tr
-      dropForall (EForall _ _ t) = dropForall t
-      dropForall t = t
-  (_, _, p') <- step vks (dropForall at)   -- XXX
-  me' <- case me of Nothing -> pure Nothing; Just e -> Just <$> tcEqns True at e
+  (_, _, p') <- step vks ty
+  me' <- case me of Nothing -> pure Nothing; Just e -> Just <$> do e' <- tcEqns True at e; checkConstraints; pure e'
   mn <- gets moduleName
-  checkConstraints
---  traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
+  traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
   return $ Pattern (qualIdent mn ip, vks) p' me'
 tcPattern _ _ = error "tcPattern"
 
@@ -2169,6 +2180,7 @@
                unify loc t ext
                return ([], [], p)
            | otherwise -> tcPatAp mt [] ae
+    EQVar _ _ -> tcPatAp mt [] ae
     EApp f _
            | isNeg f   -> lit            -- if it's (negate e) it must have been a negative literal
            | otherwise -> tcPatAp mt [] ae
@@ -2238,7 +2250,7 @@
           true = eTrue loc
       tcPat mt $ EViewPat orFun true
 
-    _ -> error $ "tcPat: " ++ show (getSLoc ae) ++ " " ++ show ae
+    _ -> error $ "tcPat: not handled " ++ show (getSLoc ae) ++ " " ++ show ae
 
 -- The expected type is for (eApps afn (reverse args))
 tcPatAp :: HasCallStack =>
@@ -2245,31 +2257,45 @@
            Expected -> [EPat] -> EPat -> T EPatRet
 --tcPatAp mt args afn | trace ("tcPatAp: " ++ show (mt, args, afn)) False = undefined
 tcPatAp mt args afn =
- case afn of
-  EVar i | isConIdent i -> do
-    let loc = getSLoc i
-        nargs = length args
-        checkArity ary =
-          if nargs < ary then
-            tcError loc "too few arguments"
-          else if nargs > ary then
-            tcError loc "too many arguments"
-          else
-            return ()
-    (con, xpt) <- tLookupV i
-    case con of
-     ECon (ConSyn _ n (e, t)) -> do
+  case afn of
+    EVar i | isConIdent i -> do
+      (con, xpt) <- tLookupV i
+      tcPatApCon mt args con xpt
+
+    EQVar con xpt -> tcPatApCon mt args con xpt
+
+    EApp f a -> tcPatAp mt (a:args) f
+
+    EParen e -> tcPatAp mt args e
+
+    _ -> tcError (getSLoc afn) ("Bad pattern " ++ show afn)
+
+tcPatApCon :: Expected -> [EPat] -> EPat -> EType -> T EPatRet
+tcPatApCon mt args con xpt = do
+  let loc = getSLoc con
+      nargs = length args
+      checkArity ary =
+        if nargs < ary then
+          tcError loc "too few arguments"
+        else if nargs > ary then
+          tcError loc "too many arguments"
+        else
+          return ()
+  case con of
+    -- Pattern synonym
+    ECon (ConSyn qi n (e, t)) -> do
       checkArity n
-      let (_, _, pcon) = mkMatchIdents loc False n
-          vp = EViewPat (EQVar e t) (eApps (EVar pcon) args)
+      let (yes, _) = mkMatchDataTypeConstr (mkPatSynMatch qi) xpt
+          vp = EViewPat (EQVar e t) (eApps yes args)
       --traceM ("patsyn " ++ show vp)
       tcPat mt vp
-     _ -> do 
---      tcTrace (show xpt)
+
+    -- Regular constructor
+    _ -> do 
       case xpt of
          -- Sanity check
          EForall _ _ (EForall _ _ _) -> return ()
-         _ -> impossibleShow i
+         _ -> impossibleShow con
       EForall _ avs apt <- tInst' xpt
 
       (sks, spt) <- shallowSkolemise avs apt
@@ -2298,13 +2324,6 @@
               Infer r   -> do { tSetRefType loc r tt; return pr }
       return (skr, dr, pp)
 
-  EApp f a -> tcPatAp mt (a:args) f
-
-  EParen e -> tcPatAp mt args e
-
-  _ -> tcError (getSLoc afn) ("Bad pattern " ++ show afn)
-  
-
 eTrue :: SLoc -> Expr
 eTrue l = EVar $ mkBuiltin l "True"
 
@@ -2532,46 +2551,142 @@
 
 -----
 
+--
+-- Pattern synonyms look like
+--   pattern P :: forall a1...an . ctxr => forall b1...bm . ctxp => t1 -> ... -> ti -> t
+--   pattern P x1...xi <- p where P = e
+-- (this type is the canonicalized type, generated by canonPatSynType).
+-- The synonym is translated into a builder, a matcher and a type.
+-- Each synonym use is replaced by a simple view pattern.
+--
+-- The builder is simple.  It gets the same name and type as the pattern synonym,
+-- and the definition is the one provided in the definition.
+--   P :: forall a1...an . ctxr => forall b1...bm . ctxp => t1 -> ... ti -> t
+--   P = e
+--
+-- The matcher needs to account for possible existentials so we get a data type
+-- for the match result that can have existentials.
+--   data P%T a1...an = forall b1...bm . ctxp => M t1 ... ti
+--                    | N
+--
+-- The matcher itself has the required part of the synonym type, whereas
+-- the provided part is in the data type.  The matcher simply matches on the given pattern.
+--   P% :: forall a1...an . ctxr => t -> P%T a1...an
+--   P% p = M x1...xi
+--   P% _ = N
+-- So when the synonym P matches the matcher P% will return the M constructor
+-- of the P%T type, and then N constructor when there is no match.
+--
+-- Each use of the pattern synonym
+--   P p1...pi
+-- is replaced by
+--   (P% -> M p1...pi)
+--
+-- The data type, P%T, is not entered into any symbol tables.
+-- The matcher, P%, is in the symbol table, but is not part of the exported symbols.
+-- The transformed expression simply carries enough information about the types
+-- (using EQVar).  The exported ECon for P has this information.
+--
+
+emptyCtx :: EConstraint
+emptyCtx = EVar $ tupleConstr noSLoc 0
+
+isEmptyCtx :: EConstraint -> Bool
+isEmptyCtx (EVar i) = i == tupleConstr noSLoc 0
+isEmptyCtx _ = False
+
 -- Expand a pattern synonym into the builder and matcher definitions.
+-- Removes that actual pattern definition
 expandPatSyn :: EDef -> T [EDef]
 expandPatSyn (Pattern (i, vks) p me) = do
   (_, t) <- tLookup "type signature" i
-  im <- addPatSynMatch i t
-  let (_, no, yes) = mkMatchIdents (getSLoc i) False (length vks)  -- XXX
-      mexp = fmap (\ e -> Fcn i e) me
+  (vks1, _ctx1, vks2, ctx2, _ty) = splitPatSynType t
+  (im, qim) <- addPatSynMatch i t
+  let (yes, no) = mkMatchDataTypeConstr qim t
+      mexp = fmap (Fcn i) me
       pat = Fcn im [ eEqn [p] match
-                   , eEqn [eDummy] nomatch]
-      match = eApps (EVar yes) (map (EVar . idKindIdent) vks)
-      nomatch = EVar no
-  pure $ pat : maybeToList mexp
+                   , eEqn [eDummy] no]
+      match = eApps yes (map (EVar . idKindIdent) vks)
+      ddata = Data lhs [cm, cn] []
+            where lhs = (mkMatchDataTypeName qim, vks1)
+                  cm = Constr vks2 (if isEmptyCtx ctx2 then [] else [ctx2]) mi (Left (False, xxx))
+                  cn = Constr []   []                                       ni (Left (False, ))
+                  EQVar (ECon (ConData _ mi _)) tm = yes
+                  EQVar (ECon (ConData _ ni _)) tn = no
+  pure $ maybeToList mexp ++ [pat, ddata]
 expandPatSyn d = pure [d]
 
 -- Add the matcher for a pattern synonym to the symbol table.
 -- Return the added identifier.
-addPatSynMatch :: Ident -> EType -> T Ident
-addPatSynMatch ps at = do
+addPatSynMatch :: Ident -> EType -> T (Ident, Ident)
+addPatSynMatch i at = do
   mn <- gets moduleName
-  let im = mkPatSynMatch ps
-  extValETop im (mkPatSynType at) (EVar (qualIdent mn im))
-  return im
+  let ip = mkPatSynMatch i
+      qip = qualIdent mn ip
+  extValETop ip (mkPatSynMatchType qip at) (EVar qip)
+  return (ip, qip)
 
-mkPatSynType :: EType -> EType
-mkPatSynType at =
-  let (m, vks, (ats, rt)) =
-        case at of
-          EForall m' vks' (EForall _ _ t) -> (m', vks', getArrows t) -- XXX
-          _ -> impossibleShow at
-      (pstycon, _, _) = mkMatchIdents (getSLoc at) False (length ats) -- XXX
-  in  EForall m vks $ rt `tArrow` tApps pstycon ats -- XXX
+mkPatSynMatchType :: Ident -> EType -> EType
+mkPatSynMatchType qip at =
+  let (vks1, ctx1, _vks2, _ctx2, ty) = splitPatSynType at
+      (_ats, rt) = getArrows ty
+      pstycon = mkMatchDataTypeName qip
+  in  eForall vks1 $ etImplies ctx1 $ rt `tArrow` tApps pstycon (map (EVar . idKindIdent) vks1)
 
+-- Given the (qualified) name of a synonym and its type generate:
+-- match-constructor, nomatch-constructor
+mkMatchDataTypeConstr :: HasCallStack => Ident -> EType -> (Expr, Expr)
+mkMatchDataTypeConstr qi at =
+  let loc = getSLoc at
+      (vks1, _ctx1, vks2, ctx2, ty) = splitPatSynType at
+      (ats, _rt) = getArrows ty
+      n = length ats
+      mi = addIdentSuffix qi "M"
+      ni = addIdentSuffix qi "N"
+      cti = [ (mi, n + if isEmptyCtx ctx2 then 0 else 1), (ni, 0) ]
+      conm = ConData cti mi []
+      conn = ConData cti ni []
+      tycon = mkMatchDataTypeName qi
+      tr = tApps tycon $ map (EVar . idKindIdent) vks1
+      tn = EForall True vks1 $ EForall True [] tr
+      tm = EForall True vks1 $ EForall True vks2 $ etImplies ctx2 $ foldr tArrow tr ats
+  in  trace ("M :: " ++ show tm ++ ",  N :: " ++ show tn)
+      (EQVar (ECon conm) tm, EQVar (ECon conn) tn)
+
 mkPatSynMatch :: Ident -> Ident
 mkPatSynMatch i = addIdentSuffix i "%"
 
-mkMatchIdents :: SLoc -> Bool -> Int -> (Ident, Ident, Ident)
-mkMatchIdents loc ctx n = (mkBuiltinQ loc ("P" ++ sc ++ sn), mkBuiltin loc ("N" ++ sc ++ sn), mkBuiltin loc ("M" ++ sc ++ sn))
-  where sn = show n
-        sc = if ctx then "C" else ""
+mkMatchDataTypeName :: Ident -> Ident
+mkMatchDataTypeName i = addIdentSuffix i "T"
 
+-- A pattern synonym always has a type of the form
+--  forall vs1 . ctx1 => forall vs2 . ctx2 => ty
+--         required             provided
+canonPatSynType :: EType -> T EType
+canonPatSynType at = do
+  at' <- expandSyn at
+  let (vks, t0) =
+        case at' of
+          EForall _ xs t -> (xs, t)
+          t              -> ([], t)
+      (ctx1, ctx2, ty) =
+        case getImplies t0 of
+          Nothing ->           (emptyCtx, emptyCtx, t0)
+          Just (c1, t1) ->
+            case getImplies t1 of
+              Nothing       -> (c1,       emptyCtx, t1)
+              Just (c2, t2) -> (c1,       c2,       t2)
+      vs2 = freeTyVars [ctx2]
+      (vks2, vks1) = partition ((`elem` vs2) . idKindIdent) vks
+  pure $ EForall True vks1 $ tImplies ctx1 $ EForall True vks2 $ tImplies ctx2 ty
+
+splitPatSynType :: EType -> ([IdKind], EConstraint, [IdKind], EConstraint, EType)
+splitPatSynType (EForall _ vks1 t0)
+  | Just  (ctx1, EForall _ vks2 t1) <- getImplies t0
+  , Just  (ctx2, ty) <- getImplies t1
+  = (vks1, ctx1, vks2, ctx2, ty)
+splitPatSynType t = impossibleShow t
+
 -----
 
 -- Given a dictionary of a (constraint type), split it up
@@ -2681,6 +2796,7 @@
           Nothing -> [c]
           Just _ -> concatMap flatten ts
   return $ loop [] ais
+
 
 {-
 showInstInfo :: InstInfo -> String
--- a/tests/PatSynE.hs
+++ b/tests/PatSynE.hs
@@ -17,3 +17,6 @@
 dup [x,x'] | x==x' = Just x
 dup _ = Nothing
 -}
+
+pattern One :: forall a . (Eq a) => a
+pattern One <- 1