Skip to content

Commit

Permalink
Fix detection of aggregate functions inside custom functions
Browse files Browse the repository at this point in the history
  • Loading branch information
janedbal authored Jul 12, 2024
1 parent d453424 commit 6339dff
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 273 deletions.
295 changes: 22 additions & 273 deletions src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use Doctrine\ORM\Query;
use Doctrine\ORM\Query\AST;
use function is_string;
use function is_array;

class QueryAggregateFunctionDetectorTreeWalker extends Query\TreeWalkerAdapter
{
Expand All @@ -13,294 +13,38 @@ class QueryAggregateFunctionDetectorTreeWalker extends Query\TreeWalkerAdapter

public function walkSelectStatement(AST\SelectStatement $selectStatement): void
{
$this->doWalkSelectClause($selectStatement->selectClause);
$this->walkNode($selectStatement->selectClause);
}

/**
* @param AST\SelectClause $selectClause
* @param mixed $node
*/
public function doWalkSelectClause($selectClause): void
public function walkNode($node): void
{
foreach ($selectClause->selectExpressions as $selectExpression) {
$this->doWalkSelectExpression($selectExpression);
}
}

/**
* @param AST\SelectExpression $selectExpression
*/
public function doWalkSelectExpression($selectExpression): void
{
$this->doWalkNode($selectExpression->expression);
}

/**
* @param mixed $expr
*/
private function doWalkNode($expr): void
{
if ($expr instanceof AST\AggregateExpression) {
$this->markAggregateFunctionFound();

} elseif ($expr instanceof AST\Functions\FunctionNode) {
if ($this->isAggregateFunction($expr)) {
$this->markAggregateFunctionFound();
}

} elseif ($expr instanceof AST\SimpleArithmeticExpression) {
foreach ($expr->arithmeticTerms as $term) {
$this->doWalkArithmeticTerm($term);
}

} elseif ($expr instanceof AST\ArithmeticTerm) {
$this->doWalkArithmeticTerm($expr);

} elseif ($expr instanceof AST\ArithmeticFactor) {
$this->doWalkArithmeticFactor($expr);

} elseif ($expr instanceof AST\ParenthesisExpression) {
$this->doWalkArithmeticPrimary($expr->expression);

} elseif ($expr instanceof AST\NullIfExpression) {
$this->doWalkNullIfExpression($expr);

} elseif ($expr instanceof AST\CoalesceExpression) {
$this->doWalkCoalesceExpression($expr);

} elseif ($expr instanceof AST\GeneralCaseExpression) {
$this->doWalkGeneralCaseExpression($expr);

} elseif ($expr instanceof AST\SimpleCaseExpression) {
$this->doWalkSimpleCaseExpression($expr);

} elseif ($expr instanceof AST\ArithmeticExpression) {
$this->doWalkArithmeticExpression($expr);

} elseif ($expr instanceof AST\ComparisonExpression) {
$this->doWalkComparisonExpression($expr);

} elseif ($expr instanceof AST\BetweenExpression) {
$this->doWalkBetweenExpression($expr);
}
}

public function doWalkCoalesceExpression(AST\CoalesceExpression $coalesceExpression): void
{
foreach ($coalesceExpression->scalarExpressions as $scalarExpression) {
$this->doWalkSimpleArithmeticExpression($scalarExpression);
}
}

public function doWalkNullIfExpression(AST\NullIfExpression $nullIfExpression): void
{
if (!is_string($nullIfExpression->firstExpression)) {
$this->doWalkSimpleArithmeticExpression($nullIfExpression->firstExpression);
}

if (is_string($nullIfExpression->secondExpression)) {
if (!$node instanceof AST\Node) {
return;
}

$this->doWalkSimpleArithmeticExpression($nullIfExpression->secondExpression);
}

public function doWalkGeneralCaseExpression(AST\GeneralCaseExpression $generalCaseExpression): void
{
foreach ($generalCaseExpression->whenClauses as $whenClause) {
$this->doWalkConditionalExpression($whenClause->caseConditionExpression);
$this->doWalkSimpleArithmeticExpression($whenClause->thenScalarExpression);
}

$this->doWalkSimpleArithmeticExpression($generalCaseExpression->elseScalarExpression);
}

public function doWalkSimpleCaseExpression(AST\SimpleCaseExpression $simpleCaseExpression): void
{
foreach ($simpleCaseExpression->simpleWhenClauses as $simpleWhenClause) {
$this->doWalkSimpleArithmeticExpression($simpleWhenClause->caseScalarExpression);
$this->doWalkSimpleArithmeticExpression($simpleWhenClause->thenScalarExpression);
}

$this->doWalkSimpleArithmeticExpression($simpleCaseExpression->elseScalarExpression);
}

/**
* @param AST\ConditionalExpression|AST\Phase2OptimizableConditional $condExpr
*/
public function doWalkConditionalExpression($condExpr): void
{
if (!$condExpr instanceof AST\ConditionalExpression) {
$this->doWalkConditionalTerm($condExpr); // @phpstan-ignore-line PHPStan do not read @psalm-inheritors of Phase2OptimizableConditional
return;
}

foreach ($condExpr->conditionalTerms as $conditionalTerm) {
$this->doWalkConditionalTerm($conditionalTerm);
}
}

/**
* @param AST\ConditionalTerm|AST\ConditionalPrimary|AST\ConditionalFactor $condTerm
*/
public function doWalkConditionalTerm($condTerm): void
{
if (!$condTerm instanceof AST\ConditionalTerm) {
$this->doWalkConditionalFactor($condTerm);
if ($this->isAggregateFunction($node)) {
$this->markAggregateFunctionFound();
return;
}

foreach ($condTerm->conditionalFactors as $conditionalFactor) {
$this->doWalkConditionalFactor($conditionalFactor);
}
}
foreach ((array) $node as $property) {
if ($property instanceof AST\Node) {
$this->walkNode($property);
}

/**
* @param AST\ConditionalFactor|AST\ConditionalPrimary $factor
*/
public function doWalkConditionalFactor($factor): void
{
if (!$factor instanceof AST\ConditionalFactor) {
$this->doWalkConditionalPrimary($factor);
} else {
$this->doWalkConditionalPrimary($factor->conditionalPrimary);
}
}
if (is_array($property)) {
foreach ($property as $propertyValue) {
$this->walkNode($propertyValue);
}
}

/**
* @param AST\ConditionalPrimary $primary
*/
public function doWalkConditionalPrimary($primary): void
{
if ($primary->isSimpleConditionalExpression()) {
if ($primary->simpleConditionalExpression instanceof AST\ComparisonExpression) {
$this->doWalkComparisonExpression($primary->simpleConditionalExpression);
if ($this->wasAggregateFunctionFound()) {
return;
}
$this->doWalkNode($primary->simpleConditionalExpression);
}

if (!$primary->isConditionalExpression()) {
return;
}

if ($primary->conditionalExpression === null) {
return;
}

$this->doWalkConditionalExpression($primary->conditionalExpression);
}

/**
* @param AST\BetweenExpression $betweenExpr
*/
public function doWalkBetweenExpression($betweenExpr): void
{
$this->doWalkArithmeticExpression($betweenExpr->expression);
$this->doWalkArithmeticExpression($betweenExpr->leftBetweenExpression);
$this->doWalkArithmeticExpression($betweenExpr->rightBetweenExpression);
}

/**
* @param AST\ComparisonExpression $compExpr
*/
public function doWalkComparisonExpression($compExpr): void
{
$leftExpr = $compExpr->leftExpression;
$rightExpr = $compExpr->rightExpression;

if ($leftExpr instanceof AST\Node) {
$this->doWalkNode($leftExpr);
}

if (!($rightExpr instanceof AST\Node)) {
return;
}

$this->doWalkNode($rightExpr);
}

/**
* @param AST\ArithmeticExpression $arithmeticExpr
*/
public function doWalkArithmeticExpression($arithmeticExpr): void
{
if (!$arithmeticExpr->isSimpleArithmeticExpression()) {
return;
}

if ($arithmeticExpr->simpleArithmeticExpression === null) {
return;
}

$this->doWalkSimpleArithmeticExpression($arithmeticExpr->simpleArithmeticExpression);
}

/**
* @param AST\Node|string $simpleArithmeticExpr
*/
public function doWalkSimpleArithmeticExpression($simpleArithmeticExpr): void
{
if (!$simpleArithmeticExpr instanceof AST\SimpleArithmeticExpression) {
$this->doWalkArithmeticTerm($simpleArithmeticExpr);
return;
}

foreach ($simpleArithmeticExpr->arithmeticTerms as $term) {
$this->doWalkArithmeticTerm($term);
}
}

/**
* @param AST\Node|string $term
*/
public function doWalkArithmeticTerm($term): void
{
if (is_string($term)) {
return;
}

if (!$term instanceof AST\ArithmeticTerm) {
$this->doWalkArithmeticFactor($term);
return;
}

foreach ($term->arithmeticFactors as $factor) {
$this->doWalkArithmeticFactor($factor);
}
}

/**
* @param AST\Node|string $factor
*/
public function doWalkArithmeticFactor($factor): void
{
if (is_string($factor)) {
return;
}

if (!$factor instanceof AST\ArithmeticFactor) {
$this->doWalkArithmeticPrimary($factor);
return;
}

$this->doWalkArithmeticPrimary($factor->arithmeticPrimary);
}

/**
* @param AST\Node|string $primary
*/
public function doWalkArithmeticPrimary($primary): void
{
if ($primary instanceof AST\SimpleArithmeticExpression) {
$this->doWalkSimpleArithmeticExpression($primary);
return;
}

if (!($primary instanceof AST\Node)) {
return;
}

$this->doWalkNode($primary);
}

private function isAggregateFunction(AST\Node $node): bool
Expand All @@ -318,4 +62,9 @@ private function markAggregateFunctionFound(): void
$this->_getQuery()->setHint(self::HINT_HAS_AGGREGATE_FUNCTION, true);
}

private function wasAggregateFunctionFound(): bool
{
return $this->_getQuery()->hasHint(self::HINT_HAS_AGGREGATE_FUNCTION);
}

}
4 changes: 4 additions & 0 deletions src/Type/Doctrine/Query/QueryResultTypeWalker.php
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,10 @@ public function walkSelectExpression($selectExpression): string
$this->resolveDoctrineType($dbalTypeName, null, TypeCombinator::containsNull($type))
);

if ($this->hasAggregateWithoutGroupBy() && !$expr instanceof AST\Functions\CountFunction) {
$type = TypeCombinator::addNull($type);
}

} else {
// Expressions default to Doctrine's StringType, whose
// convertToPHPValue() is a no-op. So the actual type depends on
Expand Down
33 changes: 33 additions & 0 deletions tests/Platform/QueryResultTypeWalkerFetchTypeMatrixTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -3961,6 +3961,38 @@ public static function provideCases(): iterable
'stringify' => self::STRINGIFY_DEFAULT,
];

yield 'INT_WRAP(MIN(t.col_float)) + no data' => [
'data' => self::dataNone(),
'select' => 'SELECT INT_WRAP(MIN(t.col_float)) FROM %s t',
'mysql' => self::intOrNull(),
'sqlite' => self::intOrNull(),
'pdo_pgsql' => self::intOrNull(),
'pgsql' => self::intOrNull(),
'mssql' => self::intOrNull(),
'mysqlResult' => null,
'sqliteResult' => null,
'pdoPgsqlResult' => null,
'pgsqlResult' => null,
'mssqlResult' => null,
'stringify' => self::STRINGIFY_NONE,
];

yield 'INT_WRAP(MIN(t.col_float))' => [
'data' => self::dataDefault(),
'select' => 'SELECT INT_WRAP(MIN(t.col_float)) FROM %s t',
'mysql' => self::intOrNull(),
'sqlite' => self::intOrNull(),
'pdo_pgsql' => self::intOrNull(),
'pgsql' => self::intOrNull(),
'mssql' => self::intOrNull(),
'mysqlResult' => 0,
'sqliteResult' => 0,
'pdoPgsqlResult' => 0,
'pgsqlResult' => 0,
'mssqlResult' => 0,
'stringify' => self::STRINGIFY_NONE,
];

yield 'COALESCE(t.col_datetime, t.col_datetime)' => [
'data' => self::dataDefault(),
'select' => 'SELECT COALESCE(t.col_datetime, t.col_datetime) FROM %s t',
Expand Down Expand Up @@ -5018,6 +5050,7 @@ private function createOrmConfig(): Configuration
$config->addCustomStringFunction('INT_PI', TypedExpressionIntegerPiFunction::class);
$config->addCustomStringFunction('BOOL_PI', TypedExpressionBooleanPiFunction::class);
$config->addCustomStringFunction('STRING_PI', TypedExpressionStringPiFunction::class);
$config->addCustomStringFunction('INT_WRAP', TypedExpressionIntegerWrapFunction::class);

return $config;
}
Expand Down
Loading

0 comments on commit 6339dff

Please sign in to comment.