shithub: MicroHs

Download patch

ref: c293ac729fb9487933675258e8fff4caf205b372
parent: 8cfb829ca07f64a38eea051229e8d937f4ae3ed6
author: Lennart Augustsson <lennart.augustsson@epicgames.com>
date: Sun Oct 22 11:15:34 EDT 2023

Make instance lookup more efficient.

--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -28,7 +28,7 @@
   [TypeExport]    -- exported types
   [SynDef]        -- all type synonyms, exported or not
   [ClsDef]        -- all classes
-  [InstDict]      -- all instances
+  [InstDef]       -- all instances
   [ValueExport]   -- exported values (including from T(..))
   a               -- bindings
   --Xderiving (Show)
@@ -47,6 +47,7 @@
 type FixDef = (Ident, Fixity)
 type SynDef = (Ident, EType)
 type ClsDef = (Ident, ClassInfo)
+type InstDef= (Ident, InstInfo)
 
 type ClassInfo = ([IdKind], [EConstraint], [Ident])  -- class tyvars, superclasses, methods
 
@@ -66,9 +67,26 @@
 type FixTable   = M.Map Fixity     -- precedence and associativity of operators
 type AssocTable = M.Map [Ident]    -- maps a type identifier to its associated construcors/selectors/methods
 type ClassTable = M.Map ClassInfo  -- maps a class identifier to its associated information
-type InstTable  = M.Map [InstDict] -- indexed by class name
+type InstTable  = M.Map InstInfo   -- indexed by class name
 type Constraints= [(Ident, EConstraint)]
 
+-- To make type checking fast it is essential to solve constraints fast.
+-- The naive implementation of InstInfo would be [InstDict], but
+-- that is slow.
+-- Instead, the data structure is specialized
+--  * For single parameter type classes for atomic types, e.g., Eq Int
+--    we use the type name (i.e., Int) to index into a map that gives
+--    the dictionary directly.  This map is also used for dictionary arguments
+--    of type, e.g., Eq a.
+--  * NOT IMPLEMENTED: look up by type name of the left-most type
+--  * As a last resort, just look through dictionaries.
+data InstInfo = InstInfo
+       (M.Map Expr)               -- map for direct lookup of atomic types
+       [InstDict]                 -- slow path
+  --Xderiving (Show)
+
+-- This is the dictionary express, instance variables, instance context,
+-- and class&types.
 type InstDict   = (Expr, [IdKind], [EConstraint], EConstraint)
 
 type Sigma = EType
@@ -98,7 +116,7 @@
 
 -- A hack to force evaluation of errors.
 -- This should be redone to all happen in the T monad.
-tModule :: IdentModule -> [FixDef] -> [TypeExport] -> [SynDef] -> [ClsDef] -> [InstDict] -> [ValueExport] -> [EDef] ->
+tModule :: IdentModule -> [FixDef] -> [TypeExport] -> [SynDef] -> [ClsDef] -> [InstDef] -> [ValueExport] -> [EDef] ->
            TModule [EDef]
 tModule mn fs ts ss cs is vs ds =
 --  trace ("tmodule " ++ showIdent mn ++ ": " ++ show ts) $
@@ -218,7 +236,7 @@
     ces = M.toList ct
 
     -- All instances
-    ies = concat $ M.elems it
+    ies = M.toList it
   in  TModule mn fes tes ses ces ies ves impossible
 
 -- Find all value Entry for names associated with a type.
@@ -273,10 +291,18 @@
     allInsts :: InstTable
     allInsts =
       let
-        insts (_, TModule _ _ _ _ _ ies _ _) = map (\ ie -> (getInstCon ie, [ie])) ies
-      in  M.fromListWith (unionBy eqInstDict) $ concatMap insts mdls
+        insts (_, TModule _ _ _ _ _ ies _ _) = ies
+      in  M.fromListWith mergeInstInfo $ concatMap insts mdls
   in  (allFixes, allTypes, allSyns, allClasses, allInsts, allValues, allAssocs)
 
+mergeInstInfo :: InstInfo -> InstInfo -> InstInfo
+mergeInstInfo (InstInfo m1 l1) (InstInfo m2 l2) =
+  let
+    m = foldr (uncurry $ M.insertWith mrg) m2 (M.toList m1)
+    mrg e1 e2 = if eqExpr e1 e2 then e1 else errorMessage (getSLocExpr e1) $ "Multiple instances: " ++ showSLoc (getSLocExpr e2)
+    l = unionBy eqInstDict l1 l2
+  in  InstInfo m l
+
 eqEntry :: Entry -> Entry -> Bool
 eqEntry x y =
   case x of
@@ -294,12 +320,18 @@
 -- Approximate equality for dictionaries.
 -- The important thing is to avoid exact duplicates in the instance table.
 eqInstDict :: InstDict -> InstDict -> Bool
-eqInstDict (EVar i, _, _, _) (EVar i', _, _, _) = eqIdent i i'
-eqInstDict _                 _                  = False
+eqInstDict (e, _, _, _) (e', _, _, _) = eqExpr e e'
 
 getInstCon :: InstDict -> Ident
 getInstCon (_, _, _, t) = getAppCon t
 
+-- Very partial implementation of Expr equality.
+-- It is only used to compare instances, so this suffices.
+eqExpr :: Expr -> Expr -> Bool
+eqExpr (EVar i) (EVar i') = eqIdent i i'
+eqExpr (EApp f a) (EApp f' a') = eqExpr f f' && eqExpr a a'
+eqExpr _ _ = False
+
 --------------------------
 
 type Typed a = (a, EType)
@@ -421,7 +453,10 @@
 addInstTable :: [InstDict] -> T ()
 addInstTable ics = T.do
   is <- gets instTable
-  putInstTable $ foldr (\ ic -> M.insertWith (unionBy eqInstDict) (getInstCon ic) [ic]) is ics
+  let mkInstInfo :: InstDict -> InstInfo
+      mkInstInfo (e, [], [], EApp _ (EVar i)) = InstInfo (M.singleton i e) []
+      mkInstInfo ic = InstInfo M.empty [ic]
+  putInstTable $ foldr (\ ic -> M.insertWith mergeInstInfo (getInstCon ic) (mkInstInfo ic)) is ics
 
 addConstraint :: String -> (Ident, EConstraint) -> T ()
 addConstraint _msg e@(_d, _ctx) = T.do
@@ -1909,6 +1944,10 @@
 showInstDict :: InstDict -> String
 showInstDict (e, iks, ctx, ct) = showExpr e ++ " :: " ++ showEType (eForall iks $ addConstraints ctx ct)
 
+showInstDef :: InstDef -> String
+showInstDef (cls, InstInfo m ds) = "instDef " ++ showIdent cls ++ ": "
+            ++ showList (showPair showIdent showExpr) (M.toList m) ++ ", " ++ showList showInstDict ds
+
 showConstraint :: (Ident, EConstraint) -> String
 showConstraint (i, t) = showIdent i ++ " :: " ++ showEType t
 
@@ -1930,8 +1969,7 @@
     cs' <- T.mapM (\ (i,t) -> T.do { t' <- derefUVar t; T.return (i,t') }) cs
 --    traceM ("constraints:\n" ++ unlines (map showConstraint cs'))
     it <- gets instTable
-    let instsOf c = fromMaybe [] $ M.lookup c it
---    traceM ("instances:\n" ++ unlines (map showInstDict (concat $ M.elems it)))
+--    traceM ("instances:\n" ++ unlines (map showInstDef (M.toList it)))
     let solve :: [(Ident, EType)] -> [(Ident, EType)] -> [(Ident, Expr)] -> T ([(Ident, EType)], [(Ident, Expr)])
         solve [] uns sol = T.return (uns, sol)
         solve (cns@(di, ct) : cnss) uns sol = T.do
@@ -1941,23 +1979,41 @@
           case getTupleConstr iCls of
             Just _ -> T.do
               goals <- T.mapM (\ c -> T.do { d <- newIdent loc "dict"; T.return (d, c) }) cts
+--              traceM ("split tuple " ++ showList showConstraint goals)
               solve (goals ++ cnss) uns ((di, ETuple (map (EVar . fst) goals)) : sol)
-            Nothing -> T.do
-              let matches = getBestMatches $ findMatches (instsOf iCls) ct
---              traceM ("matches " ++ showList showMatch matches)
-              case matches of
-                []          -> solve cnss (cns :  uns)             sol
-                [(de, ctx)] ->
-                  if null ctx then
-                    solve cnss uns ((di, de) : sol)
-                  else T.do
-                    d <- newIdent (getSLocIdent iCls) "dict"
---                    traceM ("constraint " ++ showIdent di ++ " :: " ++ showEType ct ++ "\n" ++
---                            "   turns into " ++ showIdent d ++ " :: " ++ showEType (tupleConstraints ctx) ++ ", " ++
---                            showIdent di ++ " = " ++ showExpr (EApp de (EVar d)))
-                    solve ((d, tupleConstraints ctx) : cnss) uns ((di, EApp de (EVar d)) : sol)
-                _           -> tcError loc $ "Multiple constraint solutions for: " ++ showEType ct
+            Nothing ->
+              case M.lookup iCls it of
+                Nothing -> T.do
+--                  traceM ("class missing " ++ showIdent iCls)
+                  solve cnss (cns : uns) sol   -- no instances, so no chance
+                Just (InstInfo atomMap insts) ->
+                  case cts of
+                    [EVar i] -> T.do
+--                      traceM ("solveSimple " ++ showIdent i ++ " -> " ++ showMaybe showExpr (M.lookup i atomMap))
+                      solveSimple (M.lookup i atomMap) cns cnss uns sol
+                    _        -> solveGen loc insts cns cnss uns sol
 
+        -- An instance of the form (C T)
+        solveSimple Nothing  cns     cnss uns sol = solve cnss (cns : uns)            sol   -- no instance
+        solveSimple (Just e) (di, _) cnss uns sol = solve cnss        uns  ((di, e) : sol)  -- e is the dictionary expression
+
+        solveGen loc insts cns@(di, ct) cnss uns sol = T.do
+--          traceM ("solveGen " ++ showEType ct)
+          let matches = getBestMatches $ findMatches insts ct
+--          traceM ("matches " ++ showList showMatch matches)
+          case matches of
+            []          -> solve cnss (cns : uns) sol
+            [(de, ctx)] ->
+              if null ctx then
+                solve cnss uns ((di, de) : sol)
+              else T.do
+                d <- newIdent loc "dict"
+--                traceM ("constraint " ++ showIdent di ++ " :: " ++ showEType ct ++ "\n" ++
+--                        "   turns into " ++ showIdent d ++ " :: " ++ showEType (tupleConstraints ctx) ++ ", " ++
+--                        showIdent di ++ " = " ++ showExpr (EApp de (EVar d)))
+                solve ((d, tupleConstraints ctx) : cnss) uns ((di, EApp de (EVar d)) : sol)
+            _           -> tcError loc $ "Multiple constraint solutions for: " ++ showEType ct
+
     (unsolved, solved) <- solve cs' [] []
     putConstraints unsolved
 --    traceM ("solved:\n"   ++ unlines [ showIdent i ++ " = "  ++ showExpr  e | (i, e) <- solved ])
@@ -1976,6 +2032,7 @@
  in --trace ("findMatches: " ++ showList showInstDict ds ++ "; " ++ showEType ct ++ "; " ++ show rrr)
     rrr
   where
+
     -- Change type variable to unique unification variables.
     -- These unification variables will never leak out of findMatches.
     freshSubst iks = zipWith (\ ik j -> (idKindIdent ik, EUVar j)) iks [1000000000 ..] -- make sure the variables are unique
--