Skip to content

Commit

Permalink
Improve count() narrowing of constant arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
herndlm committed Dec 5, 2024
1 parent c586014 commit 22ef97b
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 55 deletions.
102 changes: 48 additions & 54 deletions src/Analyser/TypeSpecifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -277,22 +277,20 @@ public function specifyTypesInCondition(
) {
$argType = $scope->getType($expr->right->getArgs()[0]->value);

if ($argType instanceof UnionType) {
$sizeType = null;
if ($leftType instanceof ConstantIntegerType) {
if ($orEqual) {
$sizeType = IntegerRangeType::createAllGreaterThanOrEqualTo($leftType->getValue());
} else {
$sizeType = IntegerRangeType::createAllGreaterThan($leftType->getValue());
}
} elseif ($leftType instanceof IntegerRangeType) {
$sizeType = $leftType;
$sizeType = null;
if ($leftType instanceof ConstantIntegerType) {
if ($orEqual) {
$sizeType = IntegerRangeType::createAllGreaterThanOrEqualTo($leftType->getValue());
} else {
$sizeType = IntegerRangeType::createAllGreaterThan($leftType->getValue());
}
} elseif ($leftType instanceof IntegerRangeType) {
$sizeType = $leftType;
}

$narrowed = $this->narrowUnionByArraySize($expr->right, $argType, $sizeType, $context, $scope, $rootExpr);
if ($narrowed !== null) {
return $narrowed;
}
$specifiedTypes = $this->specifyTypesForCountFuncCall($expr->right, $argType, $sizeType, $context, $scope, $rootExpr);
if ($specifiedTypes !== null) {
$result = $result->unionWith($specifiedTypes);
}

if (
Expand Down Expand Up @@ -1010,66 +1008,52 @@ public function specifyTypesInCondition(
return new SpecifiedTypes([], [], false, [], $rootExpr);
}

private function narrowUnionByArraySize(FuncCall $countFuncCall, UnionType $argType, ?Type $sizeType, TypeSpecifierContext $context, Scope $scope, ?Expr $rootExpr): ?SpecifiedTypes
private function specifyTypesForCountFuncCall(FuncCall $countFuncCall, Type $type, ?Type $sizeType, TypeSpecifierContext $context, Scope $scope, ?Expr $rootExpr): ?SpecifiedTypes
{
if ($sizeType === null) {
return null;
}

if (count($countFuncCall->getArgs()) === 1) {
$isNormalCount = TrinaryLogic::createYes();
} else {
$mode = $scope->getType($countFuncCall->getArgs()[1]->value);
$isNormalCount = (new ConstantIntegerType(COUNT_NORMAL))->isSuperTypeOf($mode)->or($argType->getIterableValueType()->isArray()->negate());
}

if (
$isNormalCount->yes()
&& $argType->isConstantArray()->yes()
$this->isFuncCallWithNormalCount($countFuncCall, $scope)->yes()
&& $type->isConstantArray()->yes()
) {
$result = [];
foreach ($argType->getTypes() as $innerType) {
$arraySize = $innerType->getArraySize();
$resultType = TypeTraverser::map($type, function (Type $type, callable $traverse) use ($sizeType, $context) {
if ($type instanceof UnionType) {
return $traverse($type);
}

$arraySize = $type->getArraySize();
$isSize = $sizeType->isSuperTypeOf($arraySize);
if ($context->truthy()) {
if ($isSize->no()) {
continue;
return new NeverType();
}

$constArray = $this->turnListIntoConstantArray($countFuncCall, $innerType, $sizeType, $scope);
$constArray = $this->turnListIntoConstantArray($type, $sizeType);
if ($constArray !== null) {
$innerType = $constArray;
$type = $constArray;
}
}
if ($context->falsey()) {
if (!$isSize->yes()) {
continue;
return new NeverType();
}
}

$result[] = $innerType;
}
return $type;
});

return $this->create($countFuncCall->getArgs()[0]->value, TypeCombinator::union(...$result), $context, false, $scope, $rootExpr);
return $this->create($countFuncCall->getArgs()[0]->value, $resultType, $context, false, $scope, $rootExpr);
}

return null;
}

private function turnListIntoConstantArray(FuncCall $countFuncCall, Type $type, Type $sizeType, Scope $scope): ?Type
private function turnListIntoConstantArray(Type $type, Type $sizeType): ?Type
{
$argType = $scope->getType($countFuncCall->getArgs()[0]->value);

if (count($countFuncCall->getArgs()) === 1) {
$isNormalCount = TrinaryLogic::createYes();
} else {
$mode = $scope->getType($countFuncCall->getArgs()[1]->value);
$isNormalCount = (new ConstantIntegerType(COUNT_NORMAL))->isSuperTypeOf($mode)->or($argType->getIterableValueType()->isArray()->negate());
}

if (
$isNormalCount->yes()
&& $type->isList()->yes()
$type->isList()->yes()
&& $sizeType instanceof ConstantIntegerType
&& $sizeType->getValue() < ConstantArrayTypeBuilder::ARRAY_COUNT_LIMIT
) {
Expand All @@ -1083,8 +1067,7 @@ private function turnListIntoConstantArray(FuncCall $countFuncCall, Type $type,
}

if (
$isNormalCount->yes()
&& $type->isList()->yes()
$type->isList()->yes()
&& $sizeType instanceof IntegerRangeType
&& $sizeType->getMin() !== null
) {
Expand Down Expand Up @@ -1121,6 +1104,18 @@ private function turnListIntoConstantArray(FuncCall $countFuncCall, Type $type,
return null;
}

private function isFuncCallWithNormalCount(FuncCall $countFuncCall, Scope $scope): TrinaryLogic
{
$argType = $scope->getType($countFuncCall->getArgs()[0]->value);

if (count($countFuncCall->getArgs()) === 1) {
return TrinaryLogic::createYes();
}
$mode = $scope->getType($countFuncCall->getArgs()[1]->value);

return (new ConstantIntegerType(COUNT_NORMAL))->isSuperTypeOf($mode)->or($argType->getIterableValueType()->isArray()->negate());
}

private function specifyTypesForConstantBinaryExpression(
Expr $exprNode,
Type $constantType,
Expand Down Expand Up @@ -2171,11 +2166,9 @@ public function resolveIdentical(Expr\BinaryOp\Identical $expr, Scope $scope, Ty
);
}

if ($argType instanceof UnionType) {
$narrowed = $this->narrowUnionByArraySize($unwrappedLeftExpr, $argType, $rightType, $context, $scope, $rootExpr);
if ($narrowed !== null) {
return $narrowed;
}
$specifiedTypes = $this->specifyTypesForCountFuncCall($unwrappedLeftExpr, $argType, $rightType, $context, $scope, $rootExpr);
if ($specifiedTypes !== null) {
return $specifiedTypes;
}

if ($context->truthy()) {
Expand All @@ -2188,7 +2181,8 @@ public function resolveIdentical(Expr\BinaryOp\Identical $expr, Scope $scope, Ty
}

$funcTypes = $this->create($unwrappedLeftExpr, $rightType, $context, false, $scope, $rootExpr);
$constArray = $this->turnListIntoConstantArray($unwrappedLeftExpr, $argType, $rightType, $scope);
$isNormalCount = $this->isFuncCallWithNormalCount($unwrappedLeftExpr, $scope);
$constArray = $isNormalCount->yes() ? $this->turnListIntoConstantArray($argType, $rightType) : null;
if ($constArray !== null) {
return $funcTypes->unionWith(
$this->create($unwrappedLeftExpr->getArgs()[0]->value, $constArray, $context, false, $scope, $rootExpr),
Expand Down
2 changes: 1 addition & 1 deletion tests/PHPStan/Analyser/nsrt/bug-4700.php
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function(array $array, int $count): void {
if (isset($array['d'])) $a[] = $array['d'];
if (isset($array['e'])) $a[] = $array['e'];
if (count($a) > $count) {
assertType('int<1, 5>', count($a));
assertType('int<2, 5>', count($a));
assertType('array{0: mixed~null, 1?: mixed~null, 2?: mixed~null, 3?: mixed~null, 4?: mixed~null}', $a);
} else {
assertType('0', count($a));
Expand Down
23 changes: 23 additions & 0 deletions tests/PHPStan/Analyser/nsrt/count-type.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ public function doFooBar(
}
}

/** @param array{0: string, 1?: string} $arr */
public function doBar(array $arr): void
{
if (count($arr) <= 1) {
assertType('1', count($arr));
return;
}

assertType('2', count($arr));
assertType('array{string, string}', $arr);
}

/** @param array{0: string, 1?: string} $arr */
public function doBaz(array $arr): void
{
if (count($arr) > 1) {
assertType('2', count($arr));
assertType('array{string, string}', $arr);
}

assertType('1|2', count($arr));
}

}

/**
Expand Down

0 comments on commit 22ef97b

Please sign in to comment.