shithub: MicroHs

Download patch

ref: 3bc7e25513e7b7057ab21fb88d4ba562b6b1f06a
parent: 4f57f6768020437e82cdedb959000058fa766630
author: Lennart Augustsson <lennart@augustsson.net>
date: Thu Dec 12 15:40:05 EST 2024

First version of pattern synonyms

--- a/lib/Mhs/Builtin.hs
+++ b/lib/Mhs/Builtin.hs
@@ -12,6 +12,8 @@
   module Data.Semigroup,
   module Data.String,
   module Text.Show,
+  P0(..), P1(..), P2(..), P3(..), P4(..),
+  PC0(..), PC1(..), PC2(..), PC3(..), PC4(..),
   ) where
 import Prelude()
 import Control.Error(error)
@@ -30,3 +32,19 @@
 import Data.Records(HasField(..), SetField(..), composeSet)
 import {-# SOURCE #-} Data.Typeable(Typeable(..), mkTyConApp, mkTyCon)
 import Text.Show(Show(..), showString, showParen)
+
+-- These types are used as return values for pattern synonym matching functions.
+-- The number indicates the number of parameters to the synonym.
+-- Nx is the non-match, Mx is the match, carrying the matched values.
+data P0 = N0 | M0
+data P1 a1 = N1 | M1 a1
+data P2 a1 a2 = N2 | M2 a1 a2
+data P3 a1 a2 a3 = N3 | M3 a1 a2 a3
+data P4 a1 a2 a3 a4 = N4 | M4 a1 a2 a3 a4
+
+-- For synonyms with a constructor context
+data PC0 ctx = NC0 | ctx => MC0
+data PC1 ctx a1 = NC1 | ctx => MC1 a1
+data PC2 ctx a1 a2 = NC2 | ctx => MC2 a1 a2
+data PC3 ctx a1 a2 a3 = NC3 | ctx => MC3 a1 a2 a3
+data PC4 ctx a1 a2 a3 a4 = NC4 | ctx => MC4 a1 a2 a3 a4
--- a/src/MicroHs/Builtin.hs
+++ b/src/MicroHs/Builtin.hs
@@ -1,6 +1,7 @@
 module MicroHs.Builtin(
   builtinMdl,
   mkBuiltin,
+  mkBuiltinQ,
   ) where
 import Prelude(); import MHSPrelude
 import MicroHs.Ident
@@ -12,6 +13,13 @@
 -- cannot be used accidentally in user code.
 builtinMdl :: String
 builtinMdl = "B@"
+builtinMdlQ :: String
+builtinMdlQ = "Mhs.Builtin"
 
+-- Identifier for a builtin that will be renamed.
 mkBuiltin :: SLoc -> String -> Ident
 mkBuiltin loc name = mkIdentSLoc loc ((builtinMdl ++ ".") ++ name)
+
+-- Identifier for a builtin that is alread renamed.
+mkBuiltinQ :: SLoc -> String -> Ident
+mkBuiltinQ loc name = mkIdentSLoc loc ((builtinMdlQ ++ ".") ++ name)
\ No newline at end of file
--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -163,6 +163,7 @@
 data Con
   = ConData ConTyInfo Ident [FieldName]
   | ConNew Ident [FieldName]
+  | ConSyn Ident Int
 --DEBUG  deriving(Show)
 
 data Listish
@@ -178,18 +179,22 @@
             Con -> Ident
 conIdent (ConData _ i _) = i
 conIdent (ConNew i _) = i
+conIdent (ConSyn i _) = i
 
 conArity :: Con -> Int
 conArity (ConData cs i _) = fromMaybe (error "conArity") $ lookup i cs
 conArity (ConNew _ _) = 1
+conArity (ConSyn _ n) = n
 
 conFields :: Con -> [FieldName]
 conFields (ConData _ _ fs) = fs
 conFields (ConNew _ fs) = fs
+conFields (ConSyn _ _) = []
 
 instance Eq Con where
   (==) (ConData _ i _) (ConData _ j _) = i == j
   (==) (ConNew    i _) (ConNew    j _) = i == j
+  (==) (ConSyn    i _) (ConSyn    j _) = i == j
   (==) _               _               = False
 
 data Lit
@@ -389,6 +394,7 @@
 instance HasLoc Con where
   getSLoc (ConData _ i _) = getSLoc i
   getSLoc (ConNew i _) = getSLoc i
+  getSLoc (ConSyn i _) = getSLoc i
 
 instance HasLoc Listish where
   getSLoc (LList es) = getSLoc es
@@ -566,6 +572,7 @@
 setSLocCon :: SLoc -> Con -> Con
 setSLocCon l (ConData ti i fs) = ConData ti (setSLocIdent l i) fs
 setSLocCon l (ConNew i fs) = ConNew (setSLocIdent l i) fs
+setSLocCon l (ConSyn i n) = ConSyn (setSLocIdent l i) n
 
 errorMessage :: forall a .
                 HasCallStack =>
@@ -773,6 +780,7 @@
 ppCon :: Con -> Doc
 ppCon (ConData _ s _) = ppIdent s
 ppCon (ConNew s _) = ppIdent s
+ppCon (ConSyn s _) = ppIdent s
 
 -- Literals are tagged the way they appear in the combinator file:
 --  #   Int
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -1293,10 +1293,16 @@
   -- type infer and enter each SCC in the symbol table
   -- return inferred Sign
   signDefs <- mapM tcSCC sccs
-  --  type check all definitions (the inferred ones will be rechecked)
+  defs' <- concat <$> mapM expandPatSyn defs
+--  traceM $ "tcDefsValue: ------------ expandPatSyn"
+--  traceM $ showEDefs defs'
 --  tcTrace $ "tcDefsValue: ------------ check"
-  defs' <- mapM (\ d -> do { tcReset; tcDefValue d}) defs
-  tcPatSyn $ concat signDefs ++ defs'
+  --  type check all definitions (the inferred ones will be rechecked)
+  defs'' <- mapM (\ d -> do { tcReset; tcDefValue d}) defs'
+  let defs''' = concat signDefs ++ defs''
+  traceM $ "tcDefsValue: ------------ done"
+  traceM $ showEDefs defs'''
+  pure defs'''
 
 -- Infer a type for a definition
 tInferDefs :: [EDef] -> T [EDef]
@@ -1320,7 +1326,8 @@
   ctx <- getUnsolved
   -- For each definition, quantify over the free meta variables, and include
   -- context mentioning them.
-  let genTop :: (Ident, EType) -> T EDef
+  let isPatSyn = isConIdent   -- hacky way to recognize pattern synonyms
+      genTop :: (Ident, EType) -> T EDef
       genTop (i, t) = do
         t' <- derefUVar t
         let vs = metaTvs [t']
@@ -1329,8 +1336,12 @@
             vs' = metaTvs [t'']
         t''' <- quantify vs' t''
         --tcTrace $ "tInferDefs: " ++ showIdent i ++ " :: " ++ showEType t'''
-        extValQTop i t'''
-        return $ Sign [i] t'''
+        if isPatSyn i then do
+          addPatSyn t''' i
+          return $ PatternSign [i] t'''
+         else do
+          extValQTop i t'''
+          return $ Sign [i] t'''
   mapM genTop xts
 
 getUnsolved :: T [EConstraint]
@@ -1372,16 +1383,23 @@
       addConFields tycon con
     ForImp _ i t -> extValQTop i t
     Class ctx (i, vks) fds ms -> addValueClass ctx i vks fds ms
-    PatternSign is at -> do
-      let t' =
-            -- Patterns must have two universals.
-            -- XXX Add double contexts
-            case at of
-              EForall b vs t -> EForall b     vs $ EForall False []  t
-              _              -> EForall False [] $ EForall False [] at
-      mapM_ (\ i -> extValQTop i t') is
+    PatternSign is at -> mapM_ (addPatSyn at) is
     _ -> return ()
 
+-- Add a pattern synonym to the symbol table.
+addPatSyn :: EType -> Ident -> T ()
+addPatSyn at i = do
+  at' <- expandSyn at
+  let (t', n) =
+        -- Patterns must have two universals.
+        -- XXX Add double contexts
+        case at' of
+          EForall b vs t -> (EForall b     vs $ EForall False []  t, arity t)
+          _              -> (EForall False [] $ EForall False [] at, arity at)
+      arity = length . fst . getArrows  -- XXX
+  mn <- gets moduleName
+  extValETop i t' $ ECon $ ConSyn (qualIdent mn i) n
+
 -- XXX FunDep
 addValueClass :: [EConstraint] -> Ident -> [IdKind] -> [FunDep] -> [EBind] -> T ()
 addValueClass ctx iCls vks fds ms = do
@@ -1417,9 +1435,9 @@
     Fcn i eqns -> do
       (_, t) <- tLookup "type signature" i
       t' <- expandSyn t
---      tcTrace $ "tcDefValue: ------- start " ++ showIdent i
---      tcTrace $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr t'
---      tcTrace $ "tcDefValue: " ++ showEDefs [adef]
+      tcTrace $ "tcDefValue: ------- start " ++ showIdent i
+      tcTrace $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr t'
+      tcTrace $ "tcDefValue: " ++ showEDefs [adef]
       teqns <- tcEqns True t' eqns
 --      tcTrace ("tcDefValue: after\n" ++ showEDefs [adef, Fcn i teqns])
 --      cs <- gets constraints
@@ -1440,7 +1458,7 @@
 
 tcPattern :: EDef -> EType -> T EDef
 tcPattern (Pattern (ip, vks) p me) at = do
-  traceM ("Pattern " ++ show (ip, vks, p, me, at))
+--  traceM ("Pattern " ++ show (ip, vks, p, me, at))
   let step [] t = tcPat (Check t) p
       step (ik:iks) t = do
         (ti, tr) <- unArrow (getSLoc ik) t
@@ -1451,7 +1469,7 @@
   me' <- traverse (tcEqns True at) me
   mn <- gets moduleName
   checkConstraints
-  traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
+--  traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
   return $ Pattern (qualIdent mn ip, vks) p' me'
 tcPattern _ _ = error "tcPattern"
 
@@ -1694,7 +1712,7 @@
                 failMsg s = EApp (EVar (mkBuiltin loc "fail")) (ELit loc (LStr s))
                 failAlt =
                   if nofail then []
-                  else [(EVar dummyIdent, simpleAlts $ failMsg "bind")]
+                  else [(eDummy, simpleAlts $ failMsg "bind")]
               tcExpr mt (EApp (EApp (EVar sbind) a)
                               (eLam [x] (ECase x (patAlt ++ failAlt))))
             SThen a -> do
@@ -2203,16 +2221,16 @@
     EUpdate p [] -> do
       (p', _) <- tInferExpr p
       case p' of
-        ECon c -> tcPat mt $ eApps p (replicate (conArity c) (EVar dummyIdent))          
+        ECon c -> tcPat mt $ eApps p (replicate (conArity c) eDummy)          
         _      -> impossible
     EUpdate p isps -> do
-      me <- dsUpdate (const $ EVar dummyIdent) p isps
+      me <- dsUpdate (const eDummy) p isps
       case me of
         Just p' -> tcPat mt p'
         Nothing -> impossible
 
     EOr ps -> do
-      let orFun = ELam $ [ eEqn [p] true | p <- ps] ++ [ eEqn [EVar dummyIdent] (eFalse loc) ]
+      let orFun = ELam $ [ eEqn [p] true | p <- ps] ++ [ eEqn [eDummy] (eFalse loc) ]
           true = eTrue loc
       tcPat mt $ EViewPat orFun true
 
@@ -2223,10 +2241,26 @@
            Expected -> [EPat] -> EPat -> T EPatRet
 --tcPatAp mt args afn | trace ("tcPatAp: " ++ show (mt, args, afn)) False = undefined
 tcPatAp mt args afn =
-  case afn of
-    EVar i | isConIdent i -> do
-      let loc = getSLoc i
-      (con, xpt) <- tLookupV i
+ case afn of
+  EVar i | isConIdent i -> do
+    let loc = getSLoc i
+        nargs = length args
+        checkArity ary =
+          if nargs < ary then
+            tcError loc "too few arguments"
+          else if nargs > ary then
+            tcError loc "too many arguments"
+          else
+            return ()
+    (con, xpt) <- tLookupV i
+    case con of
+     ECon (ConSyn ps n) -> do
+      checkArity n
+      let (_, _, pcon) = mkMatchIdents loc False n
+          vp = EViewPat (EVar $ mkPatSynMatch ps) (eApps (EVar pcon) args)
+      traceM ("patsyn " ++ show vp)
+      tcPat mt vp
+     _ -> do 
 --      tcTrace (show xpt)
       case xpt of
          -- Sanity check
@@ -2246,13 +2280,7 @@
             where arity (ECon c) = conArity c
                   arity (EApp f _) = arity f - 1  -- deal with dictionary added above
                   arity e = impossibleShow e
-          nargs = length args
-      if nargs < ary then
-        tcError loc "too few arguments"
-       else if nargs > ary then
-        tcError loc "too many arguments"
-       else
-        return ()
+      checkArity ary
 
       let step [] t r = return (t, r)
           step (a:as) t (sk, d, f) = do
@@ -2266,11 +2294,11 @@
               Infer r   -> do { tSetRefType loc r tt; return pr }
       return (skr, dr, pp)
 
-    EApp f a -> tcPatAp mt (a:args) f
+  EApp f a -> tcPatAp mt (a:args) f
 
-    EParen e -> tcPatAp mt args e
+  EParen e -> tcPatAp mt args e
 
-    _ -> tcError (getSLoc afn) ("Bad pattern " ++ show afn)
+  _ -> tcError (getSLoc afn) ("Bad pattern " ++ show afn)
   
 
 eTrue :: SLoc -> Expr
@@ -2374,7 +2402,7 @@
   let usedVars = allVarsExpr ty -- Avoid used type variables
       newVars = take (length tvs) (allBinders \\ usedVars)
       newVarsK = map (\ i -> IdKind i noKind) newVars
-      noKind = EVar dummyIdent
+      noKind = eDummy
   osubst <- gets uvarSubst
   zipWithM_ (\ tv n -> setUVar tv (EVar n)) tvs newVars
   ty' <- derefUVar ty
@@ -2500,6 +2528,46 @@
 
 -----
 
+-- Expand a pattern synonym into the builder and matcher definitions.
+expandPatSyn :: EDef -> T [EDef]
+expandPatSyn (Pattern (i, vks) p me) = do
+  (_, t) <- tLookup "type signature" i
+  im <- addPatSynMatch i t
+  let (_, no, yes) = mkMatchIdents (getSLoc i) False (length vks)  -- XXX
+      mexp = fmap (\ e -> Fcn i e) me
+      pat = Fcn im [ eEqn [p] match
+                   , eEqn [eDummy] nomatch]
+      match = eApps (EVar yes) (map (EVar . idKindIdent) vks)
+      nomatch = EVar no
+  pure $ pat : maybeToList mexp
+expandPatSyn d = pure [d]
+
+-- Add the matcher for a pattern synonym to the symbol table.
+-- Return the added identifier.
+addPatSynMatch :: Ident -> EType -> T Ident
+addPatSynMatch ps at = do
+  mn <- gets moduleName
+  let im = mkPatSynMatch ps
+      (m, vks, (ats, rt)) =
+        case at of
+          EForall m' vks' (EForall _ _ t) -> (m', vks', getArrows t) -- XXX
+          _ -> impossible
+      (pstycon, _, _) = mkMatchIdents (getSLoc ps) False (length ats) -- XXX
+      conty = EForall m vks $ rt `tArrow` tApps pstycon ats -- XXX
+--      con = ConSyn (qualIdent mn im) (length ats)
+  extValETop im conty (EVar (qualIdent mn im))
+  return im
+
+mkPatSynMatch :: Ident -> Ident
+mkPatSynMatch i = addIdentSuffix i "%"
+
+mkMatchIdents :: SLoc -> Bool -> Int -> (Ident, Ident, Ident)
+mkMatchIdents loc ctx n = (mkBuiltinQ loc ("P" ++ sc ++ sn), mkBuiltin loc ("N" ++ sc ++ sn), mkBuiltin loc ("M" ++ sc ++ sn))
+  where sn = show n
+        sc = if ctx then "C" else ""
+
+-----
+
 -- Given a dictionary of a (constraint type), split it up
 --  * components of a tupled constraint
 --  * superclasses of a constraint
@@ -2979,69 +3047,3 @@
       _ -> tcError (getSLoc act) ("not data/newtype " ++ showIdent tname)
   -- We want 'instance ctx => cls ty'
   deriveNoHdr act lhs cs cls
-
-tcPatSyn :: [EDef] -> T [EDef]
-tcPatSyn ds = do
-  let patSyns = [ (i, iks, p, mes) | Pattern (i, iks) p mes <- ds ]
-  if null patSyns then
-    return ds
-   else do
-    let ds' = ds ++ [ Fcn i es | (i, _, _, Just es) <- patSyns ]
-        ps = M.fromList [ (i, (map idKindIdent iks, p)) | (i, iks, p, _) <- patSyns ]
-        tr as (EVar i) | Just (vs, p) <- M.lookup i ps = if length as /= length vs then tr [] $ subst (zip vs as) p
-                                                         else errorMessage (getSLoc i) "Bad synonym arity"
-        tr as (EApp f a) = tr (as ++ [a]) f
-        tr [] p = p
-        tr _ _ = undefined
-    return $ transformPat (tr []) ds'
-
-class TransformPat a where
-  transformPat :: (EPat -> EPat) -> a -> a
-
-instance (TransformPat a) => TransformPat [a] where
-  transformPat f es = map (transformPat f) es
-
-instance TransformPat EDef where
-  transformPat f (Fcn i eqns) = Fcn i (transformPat f eqns)
-  transformPat _ d = d
-
-instance TransformPat Expr where
-  transformPat _ e@(EVar _) = e
-  transformPat f (EApp e1 e2) = EApp (transformPat f e1) (transformPat f e2)
-  transformPat f (ELam es) = ELam (transformPat f es)
-  transformPat _ e@(ELit _ _) = e
-  transformPat f (ECase e arms) = ECase (transformPat f e) (transformPat f arms)
-  transformPat f (ELet bs e) = ELet (transformPat f bs) (transformPat f e)
-  transformPat f (EListish l) = EListish (transformPat f l)
-  transformPat f (EIf e1 e2 e3) = EIf (transformPat f e1) (transformPat f e2) (transformPat f e3)
-  transformPat _ e@(ECon _) = e
-  transformPat _ e = impossibleShow e
-
-instance TransformPat Listish where
-  transformPat f (LList es) = LList (transformPat f es)
-  transformPat f (LCompr e ss) = LCompr (transformPat f e) (transformPat f ss)
-  transformPat _ _ = impossible
-
-instance TransformPat ECaseArm where
-  transformPat f (p, alts) = (f p, transformPat f alts)
-
-instance TransformPat EStmt where
-  transformPat f (SBind p e) = SBind (f p) (transformPat f e)
-  transformPat f (SThen e) = SThen (transformPat f e)
-  transformPat f (SLet bs) = SLet (transformPat f bs)
-
-instance TransformPat EBind where
-  transformPat f (BFcn i es) = BFcn i (transformPat f es)
-  transformPat f (BPat p e) = BPat (f p) (transformPat f e)
-  transformPat _ b = b
-
-instance TransformPat Eqn where
-  transformPat f (Eqn ps as) = Eqn (map f ps) (transformPat f as)
-
-instance TransformPat EAlts where
-  transformPat f (EAlts as bs) = EAlts (transformPat f as) (transformPat f bs)
-
-instance TransformPat EAlt where
-  transformPat f (ss, e) = (transformPat f ss, transformPat f e)
-
-