261261--- @param source parser.object
262262--- @param fieldName string
263263--- @param literal parser.object
264- --- @return string[] ?
264+ --- @return [ string, boolean] []?
265265local function getNodeTypesWithLiteralField (uri , source , fieldName , literal )
266266 local loc = vm .getVariable (source )
267267 if not loc then
@@ -279,7 +279,9 @@ local function getNodeTypesWithLiteralField(uri, source, fieldName, literal)
279279 for _ , t in ipairs (f .extends .types ) do
280280 if t [1 ] == literal [1 ] then
281281 tys = tys or {}
282- table.insert (tys , set .class [1 ])
282+ -- If the type is in a union (e.g. 'lit' | foo), then the outNode
283+ -- cannot be narrowed.
284+ table.insert (tys , {set .class [1 ], # f .extends .types > 1 })
283285 break
284286 end
285287 end
@@ -682,16 +684,16 @@ local lookIntoChild = util.switch()
682684
683685 -- TODO: handle more types
684686 if tys and # tys == 1 then
685- local ty = tys [1 ]
687+ local ty , tyInUnion = tys [1 ][ 1 ], tys [ 1 ][ 2 ]
686688 topNode = topNode :copy ()
687689 if action .op .type == ' ==' then
688690 topNode :narrow (tracer .uri , ty )
689- if outNode then
691+ if not tyInUnion and outNode then
690692 outNode :remove (ty )
691693 end
692694 else
693695 topNode :remove (ty )
694- if outNode then
696+ if not tyInUnion and outNode then
695697 outNode :narrow (tracer .uri , ty )
696698 end
697699 end
0 commit comments