@@ -308,7 +308,9 @@ mkConstructorType :: Datatype TyName Name uni fun (Provenance a) -> VarDecl TyNa
308308-- we don't need to do anything to the declared type
309309-- see note [Abstract data types]
310310-- FIXME: normalize constructors also here
311- mkConstructorType (Datatype _ _ tvs _ _) constr = PIR. mkIterTyForall tvs $ _varDeclType constr
311+ mkConstructorType (Datatype _ _ tvs _ _) constr =
312+ let constrTy = PIR. mkIterTyForall tvs $ _varDeclType constr
313+ in fmap (\ a -> DatatypeComponent ConstructorType a) constrTy
312314
313315-- See note [Scott encoding of datatypes]
314316-- | Make a constructor of a 'Datatype' with the given pattern functor. The constructor argument mostly serves to identify the constructor
@@ -333,34 +335,35 @@ mkConstructor dty d@(Datatype ann _ tvs _ constrs) index = do
333335 pure $ zipWith (VarDecl ann) caseArgNames caseTypes
334336
335337 -- This is inelegant, but it should never fail
336- let constr = constrs !! index
338+ let thisConstr = constrs !! index
337339 let thisCase = PIR. mkVar ann $ casesAndTypes !! index
338340
339341 -- constructor args and their types
340342 argsAndTypes <- do
341343 -- these types appear *outside* the scope of the abstraction for the datatype, so we need to use the concrete datatype here
342344 -- see note [Abstract data types]
343345 -- FIXME: normalize datacons' types also here
344- let argTypes = unveilDatatype (getType dty) d <$> constructorArgTypes constr
346+ let argTypes = unveilDatatype (getType dty) d <$> constructorArgTypes thisConstr
345347 -- we don't have any names for these things, we just had the type, so we call them "arg_i
346348 argNames <- for [0 .. (length argTypes - 1 )] (\ i -> safeFreshName $ " arg_" <> showText i)
347349 pure $ zipWith (VarDecl ann) argNames argTypes
348350
349351
350- pure $
351- -- /\t_1 .. t_n
352- PIR. mkIterTyAbs tvs $
353- -- \arg_1 .. arg_m
354- PIR. mkIterLamAbs argsAndTypes $
355- -- See Note [Recursive datatypes]
356- -- wrap
357- wrap ann dty (fmap (PIR. mkTyVar ann) tvs)$
358- -- forall out
359- TyAbs ann resultType (Type ann) $
360- -- \case_1 .. case_j
361- PIR. mkIterLamAbs casesAndTypes $
362- -- c_i arg_1 .. arg_m
363- PIR. mkIterApp ann thisCase (fmap (PIR. mkVar ann) argsAndTypes)
352+ let constr =
353+ -- /\t_1 .. t_n
354+ PIR. mkIterTyAbs tvs $
355+ -- \arg_1 .. arg_m
356+ PIR. mkIterLamAbs argsAndTypes $
357+ -- See Note [Recursive datatypes]
358+ -- wrap
359+ wrap ann dty (fmap (PIR. mkTyVar ann) tvs)$
360+ -- forall out
361+ TyAbs ann resultType (Type ann) $
362+ -- \case_1 .. case_j
363+ PIR. mkIterLamAbs casesAndTypes $
364+ -- c_i arg_1 .. arg_m
365+ PIR. mkIterApp ann thisCase (fmap (PIR. mkVar ann) argsAndTypes)
366+ pure $ fmap (\ a -> DatatypeComponent Constructor a) constr
364367
365368-- Destructors
366369
@@ -379,15 +382,16 @@ mkDestructor dty (Datatype ann _ tvs _ _) = do
379382 let appliedReal = PIR. mkIterTyApp ann (getType dty) (fmap (PIR. mkTyVar ann) tvs)
380383
381384 xn <- safeFreshName " x"
382- pure $
383- -- /\t_1 .. t_n
384- PIR. mkIterTyAbs tvs $
385- -- \x
386- LamAbs ann xn appliedReal $
387- -- See note [Recursive datatypes]
388- -- unwrap
389- unwrap ann dty $
390- Var ann xn
385+ let destr =
386+ -- /\t_1 .. t_n
387+ PIR. mkIterTyAbs tvs $
388+ -- \x
389+ LamAbs ann xn appliedReal $
390+ -- See note [Recursive datatypes]
391+ -- unwrap
392+ unwrap ann dty $
393+ Var ann xn
394+ pure $ fmap (\ a -> DatatypeComponent Destructor a) destr
391395
392396-- See note [Scott encoding of datatypes]
393397-- | Make the type of a destructor for a 'Datatype'.
@@ -396,8 +400,8 @@ mkDestructor dty (Datatype ann _ tvs _ _) = do
396400-- = forall (a :: *) . (List a) -> (<pattern functor of List>)
397401-- = forall (a :: *) . (List a) -> (forall (out_List :: *) . (out_List -> (a -> List a -> out_List) -> out_List))
398402-- @
399- mkDestructorTy :: ann -> Type TyName uni ann -> Datatype TyName Name uni fun ann -> Type TyName uni ann
400- mkDestructorTy ann pf dt@ (Datatype _ _ tvs _ _) =
403+ mkDestructorTy :: PIRType uni a -> Datatype TyName Name uni fun ( Provenance a ) -> PIRType uni a
404+ mkDestructorTy pf dt@ (Datatype ann _ tvs _ _) =
401405 -- we essentially "unveil" the abstract type, so this
402406 -- is a function from the (instantiated) abstract type
403407 -- to the (unwrapped, i.e. the pattern functor of the) "real" Scott-encoded type that we can use as
@@ -409,9 +413,8 @@ mkDestructorTy ann pf dt@(Datatype _ _ tvs _ _) =
409413 -- t t_1 .. t_n
410414 let appliedAbstract = mkDatatypeValueType ann dt
411415 -- forall t_1 .. t_n
412- in
413- PIR. mkIterTyForall tvs $
414- TyFun ann appliedAbstract pf
416+ destrTy = PIR. mkIterTyForall tvs $ TyFun ann appliedAbstract pf
417+ in fmap (\ a -> DatatypeComponent DestructorType a) destrTy
415418
416419-- The main function
417420
@@ -425,8 +428,8 @@ compileDatatype r body d = do
425428 let
426429 tyVars = [PIR. defVar concreteTyDef]
427430 tys = [getType $ PIR. defVal concreteTyDef]
428- vars = fmap PIR. defVar constrDefs ++ [PIR. defVar destrDef]
429- vals = fmap PIR. defVal constrDefs ++ [PIR. defVal destrDef]
431+ vars = fmap PIR. defVar constrDefs ++ [ PIR. defVar destrDef ]
432+ vals = fmap PIR. defVal constrDefs ++ [ PIR. defVal destrDef ]
430433 -- See note [Abstract data types]
431434 pure $ PIR. mkIterApp p (PIR. mkIterInst p (PIR. mkIterTyAbs tyVars (PIR. mkIterLamAbs vars body)) tys) vals
432435
@@ -443,11 +446,11 @@ compileDatatypeDefs r d@(Datatype ann tn _ destr constrs) = do
443446
444447 constrDefs <- for (zip constrs [0 .. ]) $ \ (c, i) -> do
445448 let constrTy = mkConstructorType d c
446- PIR. Def (VarDecl ann (_varDeclName c) constrTy) <$> mkConstructor (PIR. defVal concreteTyDef) d i
449+ PIR. Def (VarDecl ( DatatypeComponent Constructor ann) (_varDeclName c) constrTy) <$> mkConstructor (PIR. defVal concreteTyDef) d i
447450
448451 destrDef <- do
449- let destTy = mkDestructorTy ann pf d
450- PIR. Def (VarDecl ann destr destTy) <$> mkDestructor (PIR. defVal concreteTyDef) d
452+ let destTy = mkDestructorTy pf d
453+ PIR. Def (VarDecl ( DatatypeComponent Destructor ann) destr destTy) <$> mkDestructor (PIR. defVal concreteTyDef) d
451454
452455 pure (concreteTyDef, constrDefs, destrDef)
453456
0 commit comments