ref: c293ac729fb9487933675258e8fff4caf205b372
parent: 8cfb829ca07f64a38eea051229e8d937f4ae3ed6
author: Lennart Augustsson <lennart.augustsson@epicgames.com>
date: Sun Oct 22 11:15:34 EDT 2023
Make instance lookup more efficient.
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -28,7 +28,7 @@
[TypeExport] -- exported types
[SynDef] -- all type synonyms, exported or not
[ClsDef] -- all classes
- [InstDict] -- all instances
+ [InstDef] -- all instances
[ValueExport] -- exported values (including from T(..))
a -- bindings
--Xderiving (Show)
@@ -47,6 +47,7 @@
type FixDef = (Ident, Fixity)
type SynDef = (Ident, EType)
type ClsDef = (Ident, ClassInfo)
+type InstDef= (Ident, InstInfo)
type ClassInfo = ([IdKind], [EConstraint], [Ident]) -- class tyvars, superclasses, methods
@@ -66,9 +67,26 @@
type FixTable = M.Map Fixity -- precedence and associativity of operators
type AssocTable = M.Map [Ident] -- maps a type identifier to its associated construcors/selectors/methods
type ClassTable = M.Map ClassInfo -- maps a class identifier to its associated information
-type InstTable = M.Map [InstDict] -- indexed by class name
+type InstTable = M.Map InstInfo -- indexed by class name
type Constraints= [(Ident, EConstraint)]
+-- To make type checking fast it is essential to solve constraints fast.
+-- The naive implementation of InstInfo would be [InstDict], but
+-- that is slow.
+-- Instead, the data structure is specialized
+-- * For single parameter type classes for atomic types, e.g., Eq Int
+-- we use the type name (i.e., Int) to index into a map that gives
+-- the dictionary directly. This map is also used for dictionary arguments
+-- of type, e.g., Eq a.
+-- * NOT IMPLEMENTED: look up by type name of the left-most type
+-- * As a last resort, just look through dictionaries.
+data InstInfo = InstInfo
+ (M.Map Expr) -- map for direct lookup of atomic types
+ [InstDict] -- slow path
+ --Xderiving (Show)
+
+-- This is the dictionary express, instance variables, instance context,
+-- and class&types.
type InstDict = (Expr, [IdKind], [EConstraint], EConstraint)
type Sigma = EType
@@ -98,7 +116,7 @@
-- A hack to force evaluation of errors.
-- This should be redone to all happen in the T monad.
-tModule :: IdentModule -> [FixDef] -> [TypeExport] -> [SynDef] -> [ClsDef] -> [InstDict] -> [ValueExport] -> [EDef] ->
+tModule :: IdentModule -> [FixDef] -> [TypeExport] -> [SynDef] -> [ClsDef] -> [InstDef] -> [ValueExport] -> [EDef] ->
TModule [EDef]
tModule mn fs ts ss cs is vs ds =
-- trace ("tmodule " ++ showIdent mn ++ ": " ++ show ts) $@@ -218,7 +236,7 @@
ces = M.toList ct
-- All instances
- ies = concat $ M.elems it
+ ies = M.toList it
in TModule mn fes tes ses ces ies ves impossible
-- Find all value Entry for names associated with a type.
@@ -273,10 +291,18 @@
allInsts :: InstTable
allInsts =
let
- insts (_, TModule _ _ _ _ _ ies _ _) = map (\ ie -> (getInstCon ie, [ie])) ies
- in M.fromListWith (unionBy eqInstDict) $ concatMap insts mdls
+ insts (_, TModule _ _ _ _ _ ies _ _) = ies
+ in M.fromListWith mergeInstInfo $ concatMap insts mdls
in (allFixes, allTypes, allSyns, allClasses, allInsts, allValues, allAssocs)
+mergeInstInfo :: InstInfo -> InstInfo -> InstInfo
+mergeInstInfo (InstInfo m1 l1) (InstInfo m2 l2) =
+ let
+ m = foldr (uncurry $ M.insertWith mrg) m2 (M.toList m1)
+ mrg e1 e2 = if eqExpr e1 e2 then e1 else errorMessage (getSLocExpr e1) $ "Multiple instances: " ++ showSLoc (getSLocExpr e2)
+ l = unionBy eqInstDict l1 l2
+ in InstInfo m l
+
eqEntry :: Entry -> Entry -> Bool
eqEntry x y =
case x of
@@ -294,12 +320,18 @@
-- Approximate equality for dictionaries.
-- The important thing is to avoid exact duplicates in the instance table.
eqInstDict :: InstDict -> InstDict -> Bool
-eqInstDict (EVar i, _, _, _) (EVar i', _, _, _) = eqIdent i i'
-eqInstDict _ _ = False
+eqInstDict (e, _, _, _) (e', _, _, _) = eqExpr e e'
getInstCon :: InstDict -> Ident
getInstCon (_, _, _, t) = getAppCon t
+-- Very partial implementation of Expr equality.
+-- It is only used to compare instances, so this suffices.
+eqExpr :: Expr -> Expr -> Bool
+eqExpr (EVar i) (EVar i') = eqIdent i i'
+eqExpr (EApp f a) (EApp f' a') = eqExpr f f' && eqExpr a a'
+eqExpr _ _ = False
+
--------------------------
type Typed a = (a, EType)
@@ -421,7 +453,10 @@
addInstTable :: [InstDict] -> T ()
addInstTable ics = T.do
is <- gets instTable
- putInstTable $ foldr (\ ic -> M.insertWith (unionBy eqInstDict) (getInstCon ic) [ic]) is ics
+ let mkInstInfo :: InstDict -> InstInfo
+ mkInstInfo (e, [], [], EApp _ (EVar i)) = InstInfo (M.singleton i e) []
+ mkInstInfo ic = InstInfo M.empty [ic]
+ putInstTable $ foldr (\ ic -> M.insertWith mergeInstInfo (getInstCon ic) (mkInstInfo ic)) is ics
addConstraint :: String -> (Ident, EConstraint) -> T ()
addConstraint _msg e@(_d, _ctx) = T.do
@@ -1909,6 +1944,10 @@
showInstDict :: InstDict -> String
showInstDict (e, iks, ctx, ct) = showExpr e ++ " :: " ++ showEType (eForall iks $ addConstraints ctx ct)
+showInstDef :: InstDef -> String
+showInstDef (cls, InstInfo m ds) = "instDef " ++ showIdent cls ++ ": "
+ ++ showList (showPair showIdent showExpr) (M.toList m) ++ ", " ++ showList showInstDict ds
+
showConstraint :: (Ident, EConstraint) -> String
showConstraint (i, t) = showIdent i ++ " :: " ++ showEType t
@@ -1930,8 +1969,7 @@
cs' <- T.mapM (\ (i,t) -> T.do { t' <- derefUVar t; T.return (i,t') }) cs -- traceM ("constraints:\n" ++ unlines (map showConstraint cs'))it <- gets instTable
- let instsOf c = fromMaybe [] $ M.lookup c it
--- traceM ("instances:\n" ++ unlines (map showInstDict (concat $ M.elems it)))+-- traceM ("instances:\n" ++ unlines (map showInstDef (M.toList it)))let solve :: [(Ident, EType)] -> [(Ident, EType)] -> [(Ident, Expr)] -> T ([(Ident, EType)], [(Ident, Expr)])
solve [] uns sol = T.return (uns, sol)
solve (cns@(di, ct) : cnss) uns sol = T.do
@@ -1941,23 +1979,41 @@
case getTupleConstr iCls of
Just _ -> T.do
goals <- T.mapM (\ c -> T.do { d <- newIdent loc "dict"; T.return (d, c) }) cts+-- traceM ("split tuple " ++ showList showConstraint goals)solve (goals ++ cnss) uns ((di, ETuple (map (EVar . fst) goals)) : sol)
- Nothing -> T.do
- let matches = getBestMatches $ findMatches (instsOf iCls) ct
--- traceM ("matches " ++ showList showMatch matches)- case matches of
- [] -> solve cnss (cns : uns) sol
- [(de, ctx)] ->
- if null ctx then
- solve cnss uns ((di, de) : sol)
- else T.do
- d <- newIdent (getSLocIdent iCls) "dict"
--- traceM ("constraint " ++ showIdent di ++ " :: " ++ showEType ct ++ "\n" ++--- " turns into " ++ showIdent d ++ " :: " ++ showEType (tupleConstraints ctx) ++ ", " ++
--- showIdent di ++ " = " ++ showExpr (EApp de (EVar d)))
- solve ((d, tupleConstraints ctx) : cnss) uns ((di, EApp de (EVar d)) : sol)
- _ -> tcError loc $ "Multiple constraint solutions for: " ++ showEType ct
+ Nothing ->
+ case M.lookup iCls it of
+ Nothing -> T.do
+-- traceM ("class missing " ++ showIdent iCls)+ solve cnss (cns : uns) sol -- no instances, so no chance
+ Just (InstInfo atomMap insts) ->
+ case cts of
+ [EVar i] -> T.do
+-- traceM ("solveSimple " ++ showIdent i ++ " -> " ++ showMaybe showExpr (M.lookup i atomMap))+ solveSimple (M.lookup i atomMap) cns cnss uns sol
+ _ -> solveGen loc insts cns cnss uns sol
+ -- An instance of the form (C T)
+ solveSimple Nothing cns cnss uns sol = solve cnss (cns : uns) sol -- no instance
+ solveSimple (Just e) (di, _) cnss uns sol = solve cnss uns ((di, e) : sol) -- e is the dictionary expression
+
+ solveGen loc insts cns@(di, ct) cnss uns sol = T.do
+-- traceM ("solveGen " ++ showEType ct)+ let matches = getBestMatches $ findMatches insts ct
+-- traceM ("matches " ++ showList showMatch matches)+ case matches of
+ [] -> solve cnss (cns : uns) sol
+ [(de, ctx)] ->
+ if null ctx then
+ solve cnss uns ((di, de) : sol)
+ else T.do
+ d <- newIdent loc "dict"
+-- traceM ("constraint " ++ showIdent di ++ " :: " ++ showEType ct ++ "\n" +++-- " turns into " ++ showIdent d ++ " :: " ++ showEType (tupleConstraints ctx) ++ ", " ++
+-- showIdent di ++ " = " ++ showExpr (EApp de (EVar d)))
+ solve ((d, tupleConstraints ctx) : cnss) uns ((di, EApp de (EVar d)) : sol)
+ _ -> tcError loc $ "Multiple constraint solutions for: " ++ showEType ct
+
(unsolved, solved) <- solve cs' [] []
putConstraints unsolved
-- traceM ("solved:\n" ++ unlines [ showIdent i ++ " = " ++ showExpr e | (i, e) <- solved ])@@ -1976,6 +2032,7 @@
in --trace ("findMatches: " ++ showList showInstDict ds ++ "; " ++ showEType ct ++ "; " ++ show rrr)rrr
where
+
-- Change type variable to unique unification variables.
-- These unification variables will never leak out of findMatches.
freshSubst iks = zipWith (\ ik j -> (idKindIdent ik, EUVar j)) iks [1000000000 ..] -- make sure the variables are unique
--
⑨