@@ -11390,7 +11390,7 @@ namespace ts {
11390
11390
// Flags we want to propagate to the result if they exist in all source symbols
11391
11391
let optionalFlag = isUnion ? SymbolFlags.None : SymbolFlags.Optional;
11392
11392
let syntheticFlag = CheckFlags.SyntheticMethod;
11393
- let checkFlags = 0 ;
11393
+ let checkFlags = CheckFlags.OnlyUnitTypes ;
11394
11394
for (const current of containingType.types) {
11395
11395
const type = getApparentType(current);
11396
11396
if (!(type === errorType || type.flags & TypeFlags.Never)) {
@@ -11478,6 +11478,9 @@ namespace ts {
11478
11478
if (type.flags & TypeFlags.Never) {
11479
11479
checkFlags |= CheckFlags.HasNeverType;
11480
11480
}
11481
+ if (!isUnitType(type)) {
11482
+ checkFlags &= ~CheckFlags.OnlyUnitTypes;
11483
+ }
11481
11484
propTypes.push(type);
11482
11485
}
11483
11486
addRange(propTypes, indexTypes);
@@ -17526,8 +17529,19 @@ namespace ts {
17526
17529
17527
17530
function typeRelatedToSomeType(source: Type, target: UnionOrIntersectionType, reportErrors: boolean): Ternary {
17528
17531
const targetTypes = target.types;
17529
- if (target.flags & TypeFlags.Union && containsType(targetTypes, source)) {
17530
- return Ternary.True;
17532
+ if (target.flags & TypeFlags.Union) {
17533
+ if (containsType(targetTypes, source)) {
17534
+ return Ternary.True;
17535
+ }
17536
+ if (targetTypes.length >= 4) {
17537
+ const match = getMatchingUnionConstituentForType(<UnionType>target, source);
17538
+ if (match) {
17539
+ const related = isRelatedTo(source, match, /*reportErrors*/ false);
17540
+ if (related) {
17541
+ return related;
17542
+ }
17543
+ }
17544
+ }
17531
17545
}
17532
17546
for (const type of targetTypes) {
17533
17547
const related = isRelatedTo(source, type, /*reportErrors*/ false);
@@ -21364,6 +21378,71 @@ namespace ts {
21364
21378
return result;
21365
21379
}
21366
21380
21381
+ function getUnitTypeProperties(unionType: UnionType): Symbol[] {
21382
+ return unionType.unitTypeProperties || (unionType.unitTypeProperties =
21383
+ filter(getPropertiesOfUnionOrIntersectionType(unionType), prop => !!(
21384
+ getCheckFlags(prop) & CheckFlags.SyntheticProperty &&
21385
+ ((<TransientSymbol>prop).checkFlags & CheckFlags.UnitDiscriminant) === CheckFlags.UnitDiscriminant)));
21386
+ }
21387
+
21388
+ function getUnionConstituentKeyForType(unionType: UnionType, type: Type) {
21389
+ const unitTypeProperties = getUnitTypeProperties(unionType);
21390
+ if (unitTypeProperties.length === 0) {
21391
+ return undefined;
21392
+ }
21393
+ const propTypes = [];
21394
+ for (const prop of unitTypeProperties) {
21395
+ const propType = getTypeOfPropertyOfType(type, prop.escapedName);
21396
+ if (!(propType && isUnitType(propType))) {
21397
+ return undefined;
21398
+ }
21399
+ propTypes.push(getRegularTypeOfLiteralType(propType));
21400
+ }
21401
+ return getTypeListId(propTypes);
21402
+ }
21403
+
21404
+ function getUnionConstituentKeyForObjectLiteral(unionType: UnionType, node: ObjectLiteralExpression) {
21405
+ const unitTypeProperties = getUnitTypeProperties(unionType);
21406
+ if (unitTypeProperties.length === 0) {
21407
+ return undefined;
21408
+ }
21409
+ const propTypes = [];
21410
+ for (const prop of unitTypeProperties) {
21411
+ const propNode = find(node.properties, p => p.symbol && p.kind === SyntaxKind.PropertyAssignment &&
21412
+ p.symbol.escapedName === prop.escapedName && isPossiblyDiscriminantValue(p.initializer));
21413
+ const propType = propNode && getTypeOfExpression((<PropertyAssignment>propNode).initializer);
21414
+ if (!(propType && isUnitType(propType))) {
21415
+ return undefined;
21416
+ }
21417
+ propTypes.push(getRegularTypeOfLiteralType(propType));
21418
+ }
21419
+ return getTypeListId(propTypes);
21420
+ }
21421
+
21422
+ function getUnionConstituentMap(unionType: UnionType) {
21423
+ if (!unionType.constituentMap) {
21424
+ const map = unionType.constituentMap = new Map<string, Type | undefined>();
21425
+ for (const t of unionType.types) {
21426
+ const key = getUnionConstituentKeyForType(unionType, t);
21427
+ if (key) {
21428
+ const duplicate = map.has(key);
21429
+ map.set(key, duplicate ? undefined : t);
21430
+ }
21431
+ }
21432
+ }
21433
+ return unionType.constituentMap;
21434
+ }
21435
+
21436
+ function getMatchingUnionConstituentForType(unionType: UnionType, type: Type) {
21437
+ const key = getUnionConstituentKeyForType(unionType, type);
21438
+ return key && getUnionConstituentMap(unionType).get(key);
21439
+ }
21440
+
21441
+ function getMatchingUnionConstituentForObjectLiteral(unionType: UnionType, node: ObjectLiteralExpression) {
21442
+ const key = getUnionConstituentKeyForObjectLiteral(unionType, node);
21443
+ return key && getUnionConstituentMap(unionType).get(key);
21444
+ }
21445
+
21367
21446
function isOrContainsMatchingReference(source: Node, target: Node) {
21368
21447
return isMatchingReference(source, target) || containsMatchingReference(source, target);
21369
21448
}
@@ -24609,7 +24688,7 @@ namespace ts {
24609
24688
}
24610
24689
24611
24690
function discriminateContextualTypeByObjectMembers(node: ObjectLiteralExpression, contextualType: UnionType) {
24612
- return discriminateTypeByDiscriminableItems(contextualType,
24691
+ return getMatchingUnionConstituentForObjectLiteral(contextualType, node) || discriminateTypeByDiscriminableItems(contextualType,
24613
24692
map(
24614
24693
filter(node.properties, p => !!p.symbol && p.kind === SyntaxKind.PropertyAssignment && isPossiblyDiscriminantValue(p.initializer) && isDiscriminantProperty(contextualType, p.symbol.escapedName)),
24615
24694
prop => ([() => checkExpression((prop as PropertyAssignment).initializer), prop.symbol.escapedName] as [() => Type, __String])
@@ -41035,6 +41114,10 @@ namespace ts {
41035
41114
// Keep this up-to-date with the same logic within `getApparentTypeOfContextualType`, since they should behave similarly
41036
41115
function findMatchingDiscriminantType(source: Type, target: Type, isRelatedTo: (source: Type, target: Type) => Ternary, skipPartial?: boolean) {
41037
41116
if (target.flags & TypeFlags.Union && source.flags & (TypeFlags.Intersection | TypeFlags.Object)) {
41117
+ const match = getMatchingUnionConstituentForType(<UnionType>target, source);
41118
+ if (match) {
41119
+ return match;
41120
+ }
41038
41121
const sourceProperties = getPropertiesOfType(source);
41039
41122
if (sourceProperties) {
41040
41123
const sourcePropertiesFiltered = findDiscriminantProperties(sourceProperties, target);
0 commit comments