ref: 3bc7e25513e7b7057ab21fb88d4ba562b6b1f06a
parent: 4f57f6768020437e82cdedb959000058fa766630
author: Lennart Augustsson <lennart@augustsson.net>
date: Thu Dec 12 15:40:05 EST 2024
First version of pattern synonyms
--- a/lib/Mhs/Builtin.hs
+++ b/lib/Mhs/Builtin.hs
@@ -12,6 +12,8 @@
module Data.Semigroup,
module Data.String,
module Text.Show,
+ P0(..), P1(..), P2(..), P3(..), P4(..),
+ PC0(..), PC1(..), PC2(..), PC3(..), PC4(..),
) where
import Prelude()
import Control.Error(error)
@@ -30,3 +32,19 @@
import Data.Records(HasField(..), SetField(..), composeSet)
import {-# SOURCE #-} Data.Typeable(Typeable(..), mkTyConApp, mkTyCon)
import Text.Show(Show(..), showString, showParen)
+
+-- These types are used as return values for pattern synonym matching functions.
+-- The number indicates the number of parameters to the synonym.
+-- Nx is the non-match, Mx is the match, carrying the matched values.
+data P0 = N0 | M0
+data P1 a1 = N1 | M1 a1
+data P2 a1 a2 = N2 | M2 a1 a2
+data P3 a1 a2 a3 = N3 | M3 a1 a2 a3
+data P4 a1 a2 a3 a4 = N4 | M4 a1 a2 a3 a4
+
+-- For synonyms with a constructor context
+data PC0 ctx = NC0 | ctx => MC0
+data PC1 ctx a1 = NC1 | ctx => MC1 a1
+data PC2 ctx a1 a2 = NC2 | ctx => MC2 a1 a2
+data PC3 ctx a1 a2 a3 = NC3 | ctx => MC3 a1 a2 a3
+data PC4 ctx a1 a2 a3 a4 = NC4 | ctx => MC4 a1 a2 a3 a4
--- a/src/MicroHs/Builtin.hs
+++ b/src/MicroHs/Builtin.hs
@@ -1,6 +1,7 @@
module MicroHs.Builtin(
builtinMdl,
mkBuiltin,
+ mkBuiltinQ,
) where
import Prelude(); import MHSPrelude
import MicroHs.Ident
@@ -12,6 +13,13 @@
-- cannot be used accidentally in user code.
builtinMdl :: String
builtinMdl = "B@"
+builtinMdlQ :: String
+builtinMdlQ = "Mhs.Builtin"
+-- Identifier for a builtin that will be renamed.
mkBuiltin :: SLoc -> String -> Ident
mkBuiltin loc name = mkIdentSLoc loc ((builtinMdl ++ ".") ++ name)
+
+-- Identifier for a builtin that is alread renamed.
+mkBuiltinQ :: SLoc -> String -> Ident
+mkBuiltinQ loc name = mkIdentSLoc loc ((builtinMdlQ ++ ".") ++ name)
\ No newline at end of file
--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -163,6 +163,7 @@
data Con
= ConData ConTyInfo Ident [FieldName]
| ConNew Ident [FieldName]
+ | ConSyn Ident Int
--DEBUG deriving(Show)
data Listish
@@ -178,18 +179,22 @@
Con -> Ident
conIdent (ConData _ i _) = i
conIdent (ConNew i _) = i
+conIdent (ConSyn i _) = i
conArity :: Con -> Int
conArity (ConData cs i _) = fromMaybe (error "conArity") $ lookup i cs
conArity (ConNew _ _) = 1
+conArity (ConSyn _ n) = n
conFields :: Con -> [FieldName]
conFields (ConData _ _ fs) = fs
conFields (ConNew _ fs) = fs
+conFields (ConSyn _ _) = []
instance Eq Con where
(==) (ConData _ i _) (ConData _ j _) = i == j
(==) (ConNew i _) (ConNew j _) = i == j
+ (==) (ConSyn i _) (ConSyn j _) = i == j
(==) _ _ = False
data Lit
@@ -389,6 +394,7 @@
instance HasLoc Con where
getSLoc (ConData _ i _) = getSLoc i
getSLoc (ConNew i _) = getSLoc i
+ getSLoc (ConSyn i _) = getSLoc i
instance HasLoc Listish where
getSLoc (LList es) = getSLoc es
@@ -566,6 +572,7 @@
setSLocCon :: SLoc -> Con -> Con
setSLocCon l (ConData ti i fs) = ConData ti (setSLocIdent l i) fs
setSLocCon l (ConNew i fs) = ConNew (setSLocIdent l i) fs
+setSLocCon l (ConSyn i n) = ConSyn (setSLocIdent l i) n
errorMessage :: forall a .
HasCallStack =>
@@ -773,6 +780,7 @@
ppCon :: Con -> Doc
ppCon (ConData _ s _) = ppIdent s
ppCon (ConNew s _) = ppIdent s
+ppCon (ConSyn s _) = ppIdent s
-- Literals are tagged the way they appear in the combinator file:
-- # Int
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -1293,10 +1293,16 @@
-- type infer and enter each SCC in the symbol table
-- return inferred Sign
signDefs <- mapM tcSCC sccs
- -- type check all definitions (the inferred ones will be rechecked)
+ defs' <- concat <$> mapM expandPatSyn defs
+-- traceM $ "tcDefsValue: ------------ expandPatSyn"
+-- traceM $ showEDefs defs'
-- tcTrace $ "tcDefsValue: ------------ check"
- defs' <- mapM (\ d -> do { tcReset; tcDefValue d}) defs
- tcPatSyn $ concat signDefs ++ defs'
+ -- type check all definitions (the inferred ones will be rechecked)
+ defs'' <- mapM (\ d -> do { tcReset; tcDefValue d}) defs'
+ let defs''' = concat signDefs ++ defs''
+ traceM $ "tcDefsValue: ------------ done"
+ traceM $ showEDefs defs'''
+ pure defs'''
-- Infer a type for a definition
tInferDefs :: [EDef] -> T [EDef]
@@ -1320,7 +1326,8 @@
ctx <- getUnsolved
-- For each definition, quantify over the free meta variables, and include
-- context mentioning them.
- let genTop :: (Ident, EType) -> T EDef
+ let isPatSyn = isConIdent -- hacky way to recognize pattern synonyms
+ genTop :: (Ident, EType) -> T EDef
genTop (i, t) = do
t' <- derefUVar t
let vs = metaTvs [t']
@@ -1329,8 +1336,12 @@
vs' = metaTvs [t'']
t''' <- quantify vs' t''
--tcTrace $ "tInferDefs: " ++ showIdent i ++ " :: " ++ showEType t'''
- extValQTop i t'''
- return $ Sign [i] t'''
+ if isPatSyn i then do
+ addPatSyn t''' i
+ return $ PatternSign [i] t'''
+ else do
+ extValQTop i t'''
+ return $ Sign [i] t'''
mapM genTop xts
getUnsolved :: T [EConstraint]
@@ -1372,16 +1383,23 @@
addConFields tycon con
ForImp _ i t -> extValQTop i t
Class ctx (i, vks) fds ms -> addValueClass ctx i vks fds ms
- PatternSign is at -> do
- let t' =
- -- Patterns must have two universals.
- -- XXX Add double contexts
- case at of
- EForall b vs t -> EForall b vs $ EForall False [] t
- _ -> EForall False [] $ EForall False [] at
- mapM_ (\ i -> extValQTop i t') is
+ PatternSign is at -> mapM_ (addPatSyn at) is
_ -> return ()
+-- Add a pattern synonym to the symbol table.
+addPatSyn :: EType -> Ident -> T ()
+addPatSyn at i = do
+ at' <- expandSyn at
+ let (t', n) =
+ -- Patterns must have two universals.
+ -- XXX Add double contexts
+ case at' of
+ EForall b vs t -> (EForall b vs $ EForall False [] t, arity t)
+ _ -> (EForall False [] $ EForall False [] at, arity at)
+ arity = length . fst . getArrows -- XXX
+ mn <- gets moduleName
+ extValETop i t' $ ECon $ ConSyn (qualIdent mn i) n
+
-- XXX FunDep
addValueClass :: [EConstraint] -> Ident -> [IdKind] -> [FunDep] -> [EBind] -> T ()
addValueClass ctx iCls vks fds ms = do
@@ -1417,9 +1435,9 @@
Fcn i eqns -> do
(_, t) <- tLookup "type signature" i
t' <- expandSyn t
--- tcTrace $ "tcDefValue: ------- start " ++ showIdent i
--- tcTrace $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr t'
--- tcTrace $ "tcDefValue: " ++ showEDefs [adef]
+ tcTrace $ "tcDefValue: ------- start " ++ showIdent i
+ tcTrace $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr t'
+ tcTrace $ "tcDefValue: " ++ showEDefs [adef]
teqns <- tcEqns True t' eqns
-- tcTrace ("tcDefValue: after\n" ++ showEDefs [adef, Fcn i teqns])
-- cs <- gets constraints
@@ -1440,7 +1458,7 @@
tcPattern :: EDef -> EType -> T EDef
tcPattern (Pattern (ip, vks) p me) at = do
- traceM ("Pattern " ++ show (ip, vks, p, me, at))
+-- traceM ("Pattern " ++ show (ip, vks, p, me, at))
let step [] t = tcPat (Check t) p
step (ik:iks) t = do
(ti, tr) <- unArrow (getSLoc ik) t
@@ -1451,7 +1469,7 @@
me' <- traverse (tcEqns True at) me
mn <- gets moduleName
checkConstraints
- traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
+-- traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
return $ Pattern (qualIdent mn ip, vks) p' me'
tcPattern _ _ = error "tcPattern"
@@ -1694,7 +1712,7 @@
failMsg s = EApp (EVar (mkBuiltin loc "fail")) (ELit loc (LStr s))
failAlt =
if nofail then []
- else [(EVar dummyIdent, simpleAlts $ failMsg "bind")]
+ else [(eDummy, simpleAlts $ failMsg "bind")]
tcExpr mt (EApp (EApp (EVar sbind) a)
(eLam [x] (ECase x (patAlt ++ failAlt))))
SThen a -> do
@@ -2203,16 +2221,16 @@
EUpdate p [] -> do
(p', _) <- tInferExpr p
case p' of
- ECon c -> tcPat mt $ eApps p (replicate (conArity c) (EVar dummyIdent))
+ ECon c -> tcPat mt $ eApps p (replicate (conArity c) eDummy)
_ -> impossible
EUpdate p isps -> do
- me <- dsUpdate (const $ EVar dummyIdent) p isps
+ me <- dsUpdate (const eDummy) p isps
case me of
Just p' -> tcPat mt p'
Nothing -> impossible
EOr ps -> do
- let orFun = ELam $ [ eEqn [p] true | p <- ps] ++ [ eEqn [EVar dummyIdent] (eFalse loc) ]
+ let orFun = ELam $ [ eEqn [p] true | p <- ps] ++ [ eEqn [eDummy] (eFalse loc) ]
true = eTrue loc
tcPat mt $ EViewPat orFun true
@@ -2223,10 +2241,26 @@
Expected -> [EPat] -> EPat -> T EPatRet
--tcPatAp mt args afn | trace ("tcPatAp: " ++ show (mt, args, afn)) False = undefined
tcPatAp mt args afn =
- case afn of
- EVar i | isConIdent i -> do
- let loc = getSLoc i
- (con, xpt) <- tLookupV i
+ case afn of
+ EVar i | isConIdent i -> do
+ let loc = getSLoc i
+ nargs = length args
+ checkArity ary =
+ if nargs < ary then
+ tcError loc "too few arguments"
+ else if nargs > ary then
+ tcError loc "too many arguments"
+ else
+ return ()
+ (con, xpt) <- tLookupV i
+ case con of
+ ECon (ConSyn ps n) -> do
+ checkArity n
+ let (_, _, pcon) = mkMatchIdents loc False n
+ vp = EViewPat (EVar $ mkPatSynMatch ps) (eApps (EVar pcon) args)
+ traceM ("patsyn " ++ show vp)
+ tcPat mt vp
+ _ -> do
-- tcTrace (show xpt)
case xpt of
-- Sanity check
@@ -2246,13 +2280,7 @@
where arity (ECon c) = conArity c
arity (EApp f _) = arity f - 1 -- deal with dictionary added above
arity e = impossibleShow e
- nargs = length args
- if nargs < ary then
- tcError loc "too few arguments"
- else if nargs > ary then
- tcError loc "too many arguments"
- else
- return ()
+ checkArity ary
let step [] t r = return (t, r)
step (a:as) t (sk, d, f) = do
@@ -2266,11 +2294,11 @@
Infer r -> do { tSetRefType loc r tt; return pr }
return (skr, dr, pp)
- EApp f a -> tcPatAp mt (a:args) f
+ EApp f a -> tcPatAp mt (a:args) f
- EParen e -> tcPatAp mt args e
+ EParen e -> tcPatAp mt args e
- _ -> tcError (getSLoc afn) ("Bad pattern " ++ show afn)
+ _ -> tcError (getSLoc afn) ("Bad pattern " ++ show afn)
eTrue :: SLoc -> Expr
@@ -2374,7 +2402,7 @@
let usedVars = allVarsExpr ty -- Avoid used type variables
newVars = take (length tvs) (allBinders \\ usedVars)
newVarsK = map (\ i -> IdKind i noKind) newVars
- noKind = EVar dummyIdent
+ noKind = eDummy
osubst <- gets uvarSubst
zipWithM_ (\ tv n -> setUVar tv (EVar n)) tvs newVars
ty' <- derefUVar ty
@@ -2500,6 +2528,46 @@
-----
+-- Expand a pattern synonym into the builder and matcher definitions.
+expandPatSyn :: EDef -> T [EDef]
+expandPatSyn (Pattern (i, vks) p me) = do
+ (_, t) <- tLookup "type signature" i
+ im <- addPatSynMatch i t
+ let (_, no, yes) = mkMatchIdents (getSLoc i) False (length vks) -- XXX
+ mexp = fmap (\ e -> Fcn i e) me
+ pat = Fcn im [ eEqn [p] match
+ , eEqn [eDummy] nomatch]
+ match = eApps (EVar yes) (map (EVar . idKindIdent) vks)
+ nomatch = EVar no
+ pure $ pat : maybeToList mexp
+expandPatSyn d = pure [d]
+
+-- Add the matcher for a pattern synonym to the symbol table.
+-- Return the added identifier.
+addPatSynMatch :: Ident -> EType -> T Ident
+addPatSynMatch ps at = do
+ mn <- gets moduleName
+ let im = mkPatSynMatch ps
+ (m, vks, (ats, rt)) =
+ case at of
+ EForall m' vks' (EForall _ _ t) -> (m', vks', getArrows t) -- XXX
+ _ -> impossible
+ (pstycon, _, _) = mkMatchIdents (getSLoc ps) False (length ats) -- XXX
+ conty = EForall m vks $ rt `tArrow` tApps pstycon ats -- XXX
+-- con = ConSyn (qualIdent mn im) (length ats)
+ extValETop im conty (EVar (qualIdent mn im))
+ return im
+
+mkPatSynMatch :: Ident -> Ident
+mkPatSynMatch i = addIdentSuffix i "%"
+
+mkMatchIdents :: SLoc -> Bool -> Int -> (Ident, Ident, Ident)
+mkMatchIdents loc ctx n = (mkBuiltinQ loc ("P" ++ sc ++ sn), mkBuiltin loc ("N" ++ sc ++ sn), mkBuiltin loc ("M" ++ sc ++ sn))
+ where sn = show n
+ sc = if ctx then "C" else ""
+
+-----
+
-- Given a dictionary of a (constraint type), split it up
-- * components of a tupled constraint
-- * superclasses of a constraint
@@ -2979,69 +3047,3 @@
_ -> tcError (getSLoc act) ("not data/newtype " ++ showIdent tname)
-- We want 'instance ctx => cls ty'
deriveNoHdr act lhs cs cls
-
-tcPatSyn :: [EDef] -> T [EDef]
-tcPatSyn ds = do
- let patSyns = [ (i, iks, p, mes) | Pattern (i, iks) p mes <- ds ]
- if null patSyns then
- return ds
- else do
- let ds' = ds ++ [ Fcn i es | (i, _, _, Just es) <- patSyns ]
- ps = M.fromList [ (i, (map idKindIdent iks, p)) | (i, iks, p, _) <- patSyns ]
- tr as (EVar i) | Just (vs, p) <- M.lookup i ps = if length as /= length vs then tr [] $ subst (zip vs as) p
- else errorMessage (getSLoc i) "Bad synonym arity"
- tr as (EApp f a) = tr (as ++ [a]) f
- tr [] p = p
- tr _ _ = undefined
- return $ transformPat (tr []) ds'
-
-class TransformPat a where
- transformPat :: (EPat -> EPat) -> a -> a
-
-instance (TransformPat a) => TransformPat [a] where
- transformPat f es = map (transformPat f) es
-
-instance TransformPat EDef where
- transformPat f (Fcn i eqns) = Fcn i (transformPat f eqns)
- transformPat _ d = d
-
-instance TransformPat Expr where
- transformPat _ e@(EVar _) = e
- transformPat f (EApp e1 e2) = EApp (transformPat f e1) (transformPat f e2)
- transformPat f (ELam es) = ELam (transformPat f es)
- transformPat _ e@(ELit _ _) = e
- transformPat f (ECase e arms) = ECase (transformPat f e) (transformPat f arms)
- transformPat f (ELet bs e) = ELet (transformPat f bs) (transformPat f e)
- transformPat f (EListish l) = EListish (transformPat f l)
- transformPat f (EIf e1 e2 e3) = EIf (transformPat f e1) (transformPat f e2) (transformPat f e3)
- transformPat _ e@(ECon _) = e
- transformPat _ e = impossibleShow e
-
-instance TransformPat Listish where
- transformPat f (LList es) = LList (transformPat f es)
- transformPat f (LCompr e ss) = LCompr (transformPat f e) (transformPat f ss)
- transformPat _ _ = impossible
-
-instance TransformPat ECaseArm where
- transformPat f (p, alts) = (f p, transformPat f alts)
-
-instance TransformPat EStmt where
- transformPat f (SBind p e) = SBind (f p) (transformPat f e)
- transformPat f (SThen e) = SThen (transformPat f e)
- transformPat f (SLet bs) = SLet (transformPat f bs)
-
-instance TransformPat EBind where
- transformPat f (BFcn i es) = BFcn i (transformPat f es)
- transformPat f (BPat p e) = BPat (f p) (transformPat f e)
- transformPat _ b = b
-
-instance TransformPat Eqn where
- transformPat f (Eqn ps as) = Eqn (map f ps) (transformPat f as)
-
-instance TransformPat EAlts where
- transformPat f (EAlts as bs) = EAlts (transformPat f as) (transformPat f bs)
-
-instance TransformPat EAlt where
- transformPat f (ss, e) = (transformPat f ss, transformPat f e)
-
-