ref: f1136c9acb3b74f87a3bef4ca931564394a917f7
parent: dc708ee4bd2c5e76a60d4cecff26bac23318c04a
author: Lennart Augustsson <lennart@augustsson.net>
date: Mon Dec 23 06:01:14 EST 2024
More patsyn
--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -747,8 +747,8 @@
ELazy False p -> text "!" <> ppE p
EOr ps -> parens $ hsep (punctuate (text ";") (map ppE ps))
EUVar i -> text ("_a" ++ show i)
- EQVar e _ -> ppE e
- ECon c -> ppCon c
+ EQVar e t -> parens $ ppE e <> text "::" <> ppE t
+ ECon c -> text "***" <> ppCon c
EForall _ iks e -> ppForall iks <+> ppEType e
ppApp :: [Expr] -> Expr -> Doc
--- a/src/MicroHs/TCMonad.hs
+++ b/src/MicroHs/TCMonad.hs
@@ -97,7 +97,6 @@
dataTable :: DataTable, -- data/newtype definitions
valueTable :: ValueTable, -- value symbol table
assocTable :: AssocTable, -- values associated with a type, indexed by QIdent
- patSynTable :: M.Map ([Ident], EPat), -- pattern synonyms
uvarSubst :: (IM.IntMap EType), -- mapping from unique id to type
tcMode :: TCMode, -- pattern, value, or type
classTable :: ClassTable, -- class info, indexed by QIdent
@@ -254,3 +253,7 @@
tImplies :: EType -> EType -> EType
tImplies a r = tApp (tApp (tConI builtinLoc "Primitives.=>") a) r
+
+etImplies :: EType -> EType -> EType
+etImplies (EVar i) t | i == tupleConstr noSLoc 0 = t
+etImplies a t = tImplies a t
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -389,8 +389,7 @@
classTable = gClassTable globs,
ctxTables = (gInstInfo globs, [], [], []),
constraints = [],
- defaults = dflts,
- patSynTable = M.empty
+ defaults = dflts
}
mergeDefaults :: Defaults -> Defaults -> Defaults
@@ -1090,7 +1089,7 @@
Type lhs t -> withLHS lhs $ \ lhs' -> first (Type lhs') <$> tInferTypeT t
Class ctx lhs fds ms -> withLHS lhs $ \ lhs' -> flip (,) kConstraint <$> (Class <$> tcCtx ctx <*> return lhs' <*> mapM tcFD fds <*> mapM tcMethod ms)
Sign is t -> Sign is <$> tCheckTypeTImpl kType t
- PatternSign is t -> PatternSign is <$> tCheckTypeTImpl kType t
+ PatternSign is t -> PatternSign is <$> tCheckTypeTImpl kType t
ForImp ie i t -> ForImp ie i <$> tCheckTypeTImpl kType t
Instance ct m -> Instance <$> tCheckTypeTImpl kConstraint ct <*> return m
Default mc ts -> Default (Just c) <$> mapM (tcDefault c) ts
@@ -1290,6 +1289,7 @@
node _ = undefined
tcSCC (AcyclicSCC d) = tInferDefs [d]
tcSCC (CyclicSCC ds) = tInferDefs ds
+ --traceM $ "tcDefsValue: unsigned=" ++ show unsigned
-- type infer and enter each SCC in the symbol table
-- return inferred Sign
signDefs <- mapM tcSCC sccs
@@ -1307,6 +1307,7 @@
-- Infer a type for a definition
tInferDefs :: [EDef] -> T [EDef]
tInferDefs fcns = do
+-- traceM "tInferDefs"
tcReset
-- Invent type variables for the definitions
let idOf (Fcn i _) = i
@@ -1337,8 +1338,9 @@
t''' <- quantify vs' t''
--tcTrace $ "tInferDefs: " ++ showIdent i ++ " :: " ++ showEType t'''
if isPatSyn i then do
- addPatSyn t''' i
- return $ PatternSign [i] t'''
+ t'''' <- canonPatSynType t'''
+ addPatSyn t'''' i
+ return $ PatternSign [i] t''''
else do
extValQTop i t'''
return $ Sign [i] t'''
@@ -1383,24 +1385,21 @@
addConFields tycon con
ForImp _ i t -> extValQTop i t
Class ctx (i, vks) fds ms -> addValueClass ctx i vks fds ms
- PatternSign is at -> mapM_ (addPatSyn at) is
+ PatternSign is at -> do
+ at' <- canonPatSynType at
+ mapM_ (addPatSyn at') is
_ -> return ()
-- Add a pattern synonym to the symbol table.
addPatSyn :: EType -> Ident -> T ()
addPatSyn at i = do
- at' <- expandSyn at
mn <- gets moduleName
- let (t', n) =
- -- Patterns must have two universals.
- -- XXX Add double contexts
- case at' of
- EForall b vs t -> (EForall b vs $ EForall False [] t, arity t)
- _ -> (EForall False [] $ EForall False [] at, arity at)
- arity = length . fst . getArrows -- XXX
+ let (_, _, _, _, t) = splitPatSynType at
+ n = length $ fst $ getArrows t
qi = qualIdent mn i
- mtch = (EVar $ mkPatSynMatch qi, mkPatSynType t')
- extValETop i t' $ ECon $ ConSyn qi n mtch
+ qip = mkPatSynMatch qi
+ mtch = (EVar qip, mkPatSynMatchType qip at)
+ extValETop i at $ ECon $ ConSyn qi n mtch
-- XXX FunDep
addValueClass :: [EConstraint] -> Ident -> [IdKind] -> [FunDep] -> [EBind] -> T ()
@@ -1437,6 +1436,9 @@
Fcn i eqns -> do
(_, t) <- tLookup "type signature" i
t' <- expandSyn t
+ when (isConIdent i) $ do
+ tcTrace $ "tcDefValue: patsyn\n" ++ show i ++ " :: " ++ show t'
+ tcTrace $ "tcDefValue:\n" ++ showEDefs [adef]
-- tcTrace $ "tcDefValue: ------- start " ++ showIdent i
-- tcTrace $ "tcDefValue: " ++ showIdent i ++ " :: " ++ showExpr t'
-- tcTrace $ "tcDefValue: " ++ showEDefs [adef]
@@ -1452,26 +1454,35 @@
mn <- gets moduleName
t' <- expandSyn t
return (ForImp ie (qualIdent mn i) t')
- Pattern (i, _) _ _ -> do
+ Pattern _ _ _ -> impossible
+ {- Pattern (i, _) _ _ -> do
(_, t) <- tLookup "pattern type signature" i
t' <- expandSyn t
- tcPattern adef t'
+ tcPattern adef t'-}
_ -> return adef
+-- This is only used during inference.
+-- When doing type checking the actual Pattern definition will have been
+-- removed by expandPatSyn.
tcPattern :: EDef -> EType -> T EDef
tcPattern (Pattern (ip, vks) p me) at = do
--- traceM ("Pattern " ++ show (ip, vks, p, me, at))
- let step [] t = tcPat (Check t) p
+ traceM ("Pattern enter " ++ show (ip, vks, p, me, at))
+ let (_vks1, ctx1, _vks2, _ctx2, ty) = splitPatSynType at
+ step [] t = do
+ d <- newADictIdent (getSLoc ip)
+-- traceM $ "tcPattern: add ctx " ++ show ctx1
+ withDict d ctx1 $ do
+ r <- tcPat (Check t) p
+ _ <- solveConstraints
+ checkConstraints
+ pure r
step (ik:iks) t = do
(ti, tr) <- unArrow (getSLoc ik) t
withExtVal (idKindIdent ik) ti $ step iks tr
- dropForall (EForall _ _ t) = dropForall t
- dropForall t = t
- (_, _, p') <- step vks (dropForall at) -- XXX
- me' <- case me of Nothing -> pure Nothing; Just e -> Just <$> tcEqns True at e
+ (_, _, p') <- step vks ty
+ me' <- case me of Nothing -> pure Nothing; Just e -> Just <$> do e' <- tcEqns True at e; checkConstraints; pure e'
mn <- gets moduleName
- checkConstraints
--- traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
+ traceM ("Pattern after " ++ show (qualIdent mn ip, vks, p', me'))
return $ Pattern (qualIdent mn ip, vks) p' me'
tcPattern _ _ = error "tcPattern"
@@ -2169,6 +2180,7 @@
unify loc t ext
return ([], [], p)
| otherwise -> tcPatAp mt [] ae
+ EQVar _ _ -> tcPatAp mt [] ae
EApp f _
| isNeg f -> lit -- if it's (negate e) it must have been a negative literal
| otherwise -> tcPatAp mt [] ae
@@ -2238,7 +2250,7 @@
true = eTrue loc
tcPat mt $ EViewPat orFun true
- _ -> error $ "tcPat: " ++ show (getSLoc ae) ++ " " ++ show ae
+ _ -> error $ "tcPat: not handled " ++ show (getSLoc ae) ++ " " ++ show ae
-- The expected type is for (eApps afn (reverse args))
tcPatAp :: HasCallStack =>
@@ -2245,31 +2257,45 @@
Expected -> [EPat] -> EPat -> T EPatRet
--tcPatAp mt args afn | trace ("tcPatAp: " ++ show (mt, args, afn)) False = undefined
tcPatAp mt args afn =
- case afn of
- EVar i | isConIdent i -> do
- let loc = getSLoc i
- nargs = length args
- checkArity ary =
- if nargs < ary then
- tcError loc "too few arguments"
- else if nargs > ary then
- tcError loc "too many arguments"
- else
- return ()
- (con, xpt) <- tLookupV i
- case con of
- ECon (ConSyn _ n (e, t)) -> do
+ case afn of
+ EVar i | isConIdent i -> do
+ (con, xpt) <- tLookupV i
+ tcPatApCon mt args con xpt
+
+ EQVar con xpt -> tcPatApCon mt args con xpt
+
+ EApp f a -> tcPatAp mt (a:args) f
+
+ EParen e -> tcPatAp mt args e
+
+ _ -> tcError (getSLoc afn) ("Bad pattern " ++ show afn)
+
+tcPatApCon :: Expected -> [EPat] -> EPat -> EType -> T EPatRet
+tcPatApCon mt args con xpt = do
+ let loc = getSLoc con
+ nargs = length args
+ checkArity ary =
+ if nargs < ary then
+ tcError loc "too few arguments"
+ else if nargs > ary then
+ tcError loc "too many arguments"
+ else
+ return ()
+ case con of
+ -- Pattern synonym
+ ECon (ConSyn qi n (e, t)) -> do
checkArity n
- let (_, _, pcon) = mkMatchIdents loc False n
- vp = EViewPat (EQVar e t) (eApps (EVar pcon) args)
+ let (yes, _) = mkMatchDataTypeConstr (mkPatSynMatch qi) xpt
+ vp = EViewPat (EQVar e t) (eApps yes args)
--traceM ("patsyn " ++ show vp)
tcPat mt vp
- _ -> do
--- tcTrace (show xpt)
+
+ -- Regular constructor
+ _ -> do
case xpt of
-- Sanity check
EForall _ _ (EForall _ _ _) -> return ()
- _ -> impossibleShow i
+ _ -> impossibleShow con
EForall _ avs apt <- tInst' xpt
(sks, spt) <- shallowSkolemise avs apt
@@ -2298,13 +2324,6 @@
Infer r -> do { tSetRefType loc r tt; return pr }
return (skr, dr, pp)
- EApp f a -> tcPatAp mt (a:args) f
-
- EParen e -> tcPatAp mt args e
-
- _ -> tcError (getSLoc afn) ("Bad pattern " ++ show afn)
-
-
eTrue :: SLoc -> Expr
eTrue l = EVar $ mkBuiltin l "True"
@@ -2532,46 +2551,142 @@
-----
+--
+-- Pattern synonyms look like
+-- pattern P :: forall a1...an . ctxr => forall b1...bm . ctxp => t1 -> ... -> ti -> t
+-- pattern P x1...xi <- p where P = e
+-- (this type is the canonicalized type, generated by canonPatSynType).
+-- The synonym is translated into a builder, a matcher and a type.
+-- Each synonym use is replaced by a simple view pattern.
+--
+-- The builder is simple. It gets the same name and type as the pattern synonym,
+-- and the definition is the one provided in the definition.
+-- P :: forall a1...an . ctxr => forall b1...bm . ctxp => t1 -> ... ti -> t
+-- P = e
+--
+-- The matcher needs to account for possible existentials so we get a data type
+-- for the match result that can have existentials.
+-- data P%T a1...an = forall b1...bm . ctxp => M t1 ... ti
+-- | N
+--
+-- The matcher itself has the required part of the synonym type, whereas
+-- the provided part is in the data type. The matcher simply matches on the given pattern.
+-- P% :: forall a1...an . ctxr => t -> P%T a1...an
+-- P% p = M x1...xi
+-- P% _ = N
+-- So when the synonym P matches the matcher P% will return the M constructor
+-- of the P%T type, and then N constructor when there is no match.
+--
+-- Each use of the pattern synonym
+-- P p1...pi
+-- is replaced by
+-- (P% -> M p1...pi)
+--
+-- The data type, P%T, is not entered into any symbol tables.
+-- The matcher, P%, is in the symbol table, but is not part of the exported symbols.
+-- The transformed expression simply carries enough information about the types
+-- (using EQVar). The exported ECon for P has this information.
+--
+
+emptyCtx :: EConstraint
+emptyCtx = EVar $ tupleConstr noSLoc 0
+
+isEmptyCtx :: EConstraint -> Bool
+isEmptyCtx (EVar i) = i == tupleConstr noSLoc 0
+isEmptyCtx _ = False
+
-- Expand a pattern synonym into the builder and matcher definitions.
+-- Removes that actual pattern definition
expandPatSyn :: EDef -> T [EDef]
expandPatSyn (Pattern (i, vks) p me) = do
(_, t) <- tLookup "type signature" i
- im <- addPatSynMatch i t
- let (_, no, yes) = mkMatchIdents (getSLoc i) False (length vks) -- XXX
- mexp = fmap (\ e -> Fcn i e) me
+ (vks1, _ctx1, vks2, ctx2, _ty) = splitPatSynType t
+ (im, qim) <- addPatSynMatch i t
+ let (yes, no) = mkMatchDataTypeConstr qim t
+ mexp = fmap (Fcn i) me
pat = Fcn im [ eEqn [p] match
- , eEqn [eDummy] nomatch]
- match = eApps (EVar yes) (map (EVar . idKindIdent) vks)
- nomatch = EVar no
- pure $ pat : maybeToList mexp
+ , eEqn [eDummy] no]
+ match = eApps yes (map (EVar . idKindIdent) vks)
+ ddata = Data lhs [cm, cn] []
+ where lhs = (mkMatchDataTypeName qim, vks1)
+ cm = Constr vks2 (if isEmptyCtx ctx2 then [] else [ctx2]) mi (Left (False, xxx))
+ cn = Constr [] [] ni (Left (False, ))
+ EQVar (ECon (ConData _ mi _)) tm = yes
+ EQVar (ECon (ConData _ ni _)) tn = no
+ pure $ maybeToList mexp ++ [pat, ddata]
expandPatSyn d = pure [d]
-- Add the matcher for a pattern synonym to the symbol table.
-- Return the added identifier.
-addPatSynMatch :: Ident -> EType -> T Ident
-addPatSynMatch ps at = do
+addPatSynMatch :: Ident -> EType -> T (Ident, Ident)
+addPatSynMatch i at = do
mn <- gets moduleName
- let im = mkPatSynMatch ps
- extValETop im (mkPatSynType at) (EVar (qualIdent mn im))
- return im
+ let ip = mkPatSynMatch i
+ qip = qualIdent mn ip
+ extValETop ip (mkPatSynMatchType qip at) (EVar qip)
+ return (ip, qip)
-mkPatSynType :: EType -> EType
-mkPatSynType at =
- let (m, vks, (ats, rt)) =
- case at of
- EForall m' vks' (EForall _ _ t) -> (m', vks', getArrows t) -- XXX
- _ -> impossibleShow at
- (pstycon, _, _) = mkMatchIdents (getSLoc at) False (length ats) -- XXX
- in EForall m vks $ rt `tArrow` tApps pstycon ats -- XXX
+mkPatSynMatchType :: Ident -> EType -> EType
+mkPatSynMatchType qip at =
+ let (vks1, ctx1, _vks2, _ctx2, ty) = splitPatSynType at
+ (_ats, rt) = getArrows ty
+ pstycon = mkMatchDataTypeName qip
+ in eForall vks1 $ etImplies ctx1 $ rt `tArrow` tApps pstycon (map (EVar . idKindIdent) vks1)
+-- Given the (qualified) name of a synonym and its type generate:
+-- match-constructor, nomatch-constructor
+mkMatchDataTypeConstr :: HasCallStack => Ident -> EType -> (Expr, Expr)
+mkMatchDataTypeConstr qi at =
+ let loc = getSLoc at
+ (vks1, _ctx1, vks2, ctx2, ty) = splitPatSynType at
+ (ats, _rt) = getArrows ty
+ n = length ats
+ mi = addIdentSuffix qi "M"
+ ni = addIdentSuffix qi "N"
+ cti = [ (mi, n + if isEmptyCtx ctx2 then 0 else 1), (ni, 0) ]
+ conm = ConData cti mi []
+ conn = ConData cti ni []
+ tycon = mkMatchDataTypeName qi
+ tr = tApps tycon $ map (EVar . idKindIdent) vks1
+ tn = EForall True vks1 $ EForall True [] tr
+ tm = EForall True vks1 $ EForall True vks2 $ etImplies ctx2 $ foldr tArrow tr ats
+ in trace ("M :: " ++ show tm ++ ", N :: " ++ show tn)
+ (EQVar (ECon conm) tm, EQVar (ECon conn) tn)
+
mkPatSynMatch :: Ident -> Ident
mkPatSynMatch i = addIdentSuffix i "%"
-mkMatchIdents :: SLoc -> Bool -> Int -> (Ident, Ident, Ident)
-mkMatchIdents loc ctx n = (mkBuiltinQ loc ("P" ++ sc ++ sn), mkBuiltin loc ("N" ++ sc ++ sn), mkBuiltin loc ("M" ++ sc ++ sn))
- where sn = show n
- sc = if ctx then "C" else ""
+mkMatchDataTypeName :: Ident -> Ident
+mkMatchDataTypeName i = addIdentSuffix i "T"
+-- A pattern synonym always has a type of the form
+-- forall vs1 . ctx1 => forall vs2 . ctx2 => ty
+-- required provided
+canonPatSynType :: EType -> T EType
+canonPatSynType at = do
+ at' <- expandSyn at
+ let (vks, t0) =
+ case at' of
+ EForall _ xs t -> (xs, t)
+ t -> ([], t)
+ (ctx1, ctx2, ty) =
+ case getImplies t0 of
+ Nothing -> (emptyCtx, emptyCtx, t0)
+ Just (c1, t1) ->
+ case getImplies t1 of
+ Nothing -> (c1, emptyCtx, t1)
+ Just (c2, t2) -> (c1, c2, t2)
+ vs2 = freeTyVars [ctx2]
+ (vks2, vks1) = partition ((`elem` vs2) . idKindIdent) vks
+ pure $ EForall True vks1 $ tImplies ctx1 $ EForall True vks2 $ tImplies ctx2 ty
+
+splitPatSynType :: EType -> ([IdKind], EConstraint, [IdKind], EConstraint, EType)
+splitPatSynType (EForall _ vks1 t0)
+ | Just (ctx1, EForall _ vks2 t1) <- getImplies t0
+ , Just (ctx2, ty) <- getImplies t1
+ = (vks1, ctx1, vks2, ctx2, ty)
+splitPatSynType t = impossibleShow t
+
-----
-- Given a dictionary of a (constraint type), split it up
@@ -2681,6 +2796,7 @@
Nothing -> [c]
Just _ -> concatMap flatten ts
return $ loop [] ais
+
{-
showInstInfo :: InstInfo -> String
--- a/tests/PatSynE.hs
+++ b/tests/PatSynE.hs
@@ -17,3 +17,6 @@
dup [x,x'] | x==x' = Just x
dup _ = Nothing
-}
+
+pattern One :: forall a . (Eq a) => a
+pattern One <- 1