Skip to content

Commit

Permalink
feat: allow STI child entities to have non-nullable relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
simPod committed Oct 11, 2024
1 parent b447742 commit 2afdc8e
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 14 deletions.
17 changes: 12 additions & 5 deletions src/Tools/SchemaTool.php
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ private function gatherRelationsSql(
$this->gatherRelationJoinColumns(
$mapping->joinColumns,
$table,
$class,
$foreignClass,
$mapping,
$primaryKeyColumns,
Expand All @@ -561,6 +562,7 @@ private function gatherRelationsSql(
$joinTable->joinColumns,
$theJoinTable,
$class,
$class,
$mapping,
$primaryKeyColumns,
$addedFks,
Expand All @@ -571,6 +573,7 @@ private function gatherRelationsSql(
$this->gatherRelationJoinColumns(
$joinTable->inverseJoinColumns,
$theJoinTable,
$class,
$foreignClass,
$mapping,
$primaryKeyColumns,
Expand Down Expand Up @@ -637,6 +640,7 @@ private function gatherRelationJoinColumns(
array $joinColumns,
Table $theJoinTable,
ClassMetadata $class,
ClassMetadata $foreignClass,
AssociationMapping $mapping,
array &$primaryKeyColumns,
array &$addedFks,
Expand All @@ -645,12 +649,12 @@ private function gatherRelationJoinColumns(
$localColumns = [];
$foreignColumns = [];
$fkOptions = [];
$foreignTableName = $this->quoteStrategy->getTableName($class, $this->platform);
$foreignTableName = $this->quoteStrategy->getTableName($foreignClass, $this->platform);
$uniqueConstraints = [];

foreach ($joinColumns as $joinColumn) {
[$definingClass, $referencedFieldName] = $this->getDefiningClass(
$class,
$foreignClass,
$joinColumn->referencedColumnName,
);

Expand All @@ -662,10 +666,10 @@ private function gatherRelationJoinColumns(
);
}

$quotedColumnName = $this->quoteStrategy->getJoinColumnName($joinColumn, $class, $this->platform);
$quotedColumnName = $this->quoteStrategy->getJoinColumnName($joinColumn, $foreignClass, $this->platform);
$quotedRefColumnName = $this->quoteStrategy->getReferencedJoinColumnName(
$joinColumn,
$class,
$foreignClass,
$this->platform,
);

Expand All @@ -688,7 +692,10 @@ private function gatherRelationJoinColumns(
$columnOptions['columnDefinition'] = $fieldMapping->columnDefinition;
}

if (isset($joinColumn->nullable)) {
if (
isset($joinColumn->nullable)
&& ! ($class->isInheritanceTypeSingleTable() && $class->parentClasses)
) {
$columnOptions['notnull'] = ! $joinColumn->nullable;
}

Expand Down
25 changes: 25 additions & 0 deletions tests/Tests/Models/Company/CompanyCarContract.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<?php

declare(strict_types=1);

namespace Doctrine\Tests\Models\Company;

use Doctrine\ORM\Mapping as ORM;

#[ORM\Entity]
class CompanyCarContract extends CompanyContract
{
#[ORM\ManyToOne(targetEntity: CompanyCar::class)]
#[ORM\JoinColumn(nullable: false, onDelete: 'CASCADE')]
private CompanyCar $companyCar;

public function calculatePrice(): int
{
return 0;
}

public function setCompanyCar(CompanyCar $companyCar): void
{
$this->companyCar = $companyCar;
}
}
2 changes: 1 addition & 1 deletion tests/Tests/Models/Company/CompanyContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#[ORM\Table(name: 'company_contracts')]
#[ORM\InheritanceType('SINGLE_TABLE')]
#[ORM\DiscriminatorColumn(name: 'discr', type: 'string')]
#[ORM\DiscriminatorMap(['fix' => 'CompanyFixContract', 'flexible' => 'CompanyFlexContract', 'flexultra' => 'CompanyFlexUltraContract'])]
#[ORM\DiscriminatorMap(['fix' => 'CompanyFixContract', 'flexible' => 'CompanyFlexContract', 'flexultra' => 'CompanyFlexUltraContract', 'car' => 'CompanyCarContract'])]
#[ORM\EntityListeners(['CompanyContractListener'])]
abstract class CompanyContract
{
Expand Down
21 changes: 21 additions & 0 deletions tests/Tests/ORM/Functional/SingleTableInheritanceTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
use Doctrine\Common\Collections\Criteria;
use Doctrine\ORM\Mapping\ClassMetadata;
use Doctrine\ORM\Persisters\MatchingAssociationFieldRequiresObject;
use Doctrine\Tests\Models\Company\CompanyCar;
use Doctrine\Tests\Models\Company\CompanyCarContract;
use Doctrine\Tests\Models\Company\CompanyContract;
use Doctrine\Tests\Models\Company\CompanyEmployee;
use Doctrine\Tests\Models\Company\CompanyFixContract;
Expand Down Expand Up @@ -413,4 +415,23 @@ public function testEagerLoadInheritanceHierarchy(): void

self::assertFalse($this->isUninitializedObject($contract->getSalesPerson()));
}

public function testChildCanHaveNonNullableRelation(): void
{
$companyCar = new CompanyCar('BMW');
$fixContract = new CompanyFixContract();
$carContract = new CompanyCarContract();
$carContract->setCompanyCar($companyCar);

$this->_em->persist($fixContract);
$this->_em->persist($companyCar);
$this->_em->persist($carContract);
$this->_em->flush();
$this->_em->clear();

$repo = $this->_em->getRepository(CompanyCarContract::class);
$carContracts = $repo->findAll();

self::assertCount(1, $carContracts);
}
}
16 changes: 8 additions & 8 deletions tests/Tests/ORM/Query/SelectSqlGenerationTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public function testSupportsJoinOnMultipleComponentsWithJoinedInheritanceType():

$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c JOIN c.salesPerson s LEFT JOIN Doctrine\Tests\Models\Company\CompanyEvent e WITH s.id = e.id',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN (company_events c4_ LEFT JOIN company_auctions c5_ ON c4_.id = c5_.id LEFT JOIN company_raffles c6_ ON c4_.id = c6_.id) ON (c2_.id = c4_.id) WHERE c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN (company_events c4_ LEFT JOIN company_auctions c5_ ON c4_.id = c5_.id LEFT JOIN company_raffles c6_ ON c4_.id = c6_.id) ON (c2_.id = c4_.id) WHERE c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
);
}

Expand Down Expand Up @@ -1395,7 +1395,7 @@ public function testInheritanceTypeSingleTableInRootClass(): void
{
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
);
}

Expand Down Expand Up @@ -1922,7 +1922,7 @@ public function testSingleTableInheritanceLeftJoinWithCondition(): void
// Regression test for the bug
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e LEFT JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
);
}

Expand All @@ -1932,7 +1932,7 @@ public function testSingleTableInheritanceLeftJoinWithConditionAndWhere(): void
// Ensure other WHERE predicates are passed through to the main WHERE clause
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e LEFT JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id WHERE e.salary > 1000',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra') WHERE c1_.salary > 1000",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car') WHERE c1_.salary > 1000",
);
}

Expand All @@ -1942,7 +1942,7 @@ public function testSingleTableInheritanceInnerJoinWithCondition(): void
// Test inner joins too
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e INNER JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id INNER JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id INNER JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
);
}

Expand All @@ -1953,7 +1953,7 @@ public function testSingleTableInheritanceLeftJoinNonAssociationWithConditionAnd
// the where clause when not joining onto that table
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c LEFT JOIN Doctrine\Tests\Models\Company\CompanyEmployee e WITH e.id = c.salesPerson WHERE c.completed = true',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ LEFT JOIN (company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id) ON (c2_.id = c0_.salesPerson_id) WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ LEFT JOIN (company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id) ON (c2_.id = c0_.salesPerson_id) WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
);
}

Expand All @@ -1965,7 +1965,7 @@ public function testSingleTableInheritanceJoinCreatesOnCondition(): void
// via a join association
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c JOIN c.salesPerson s WHERE c.completed = true',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
);
}

Expand All @@ -1977,7 +1977,7 @@ public function testSingleTableInheritanceCreatesOnConditionAndWhere(): void
// into the ON clause of the join
$this->assertSqlGeneration(
'SELECT e, COUNT(c) FROM Doctrine\Tests\Models\Company\CompanyEmployee e JOIN e.contracts c WHERE e.department = :department',
"SELECT c0_.id AS id_0, c0_.name AS name_1, c1_.salary AS salary_2, c1_.department AS department_3, c1_.startDate AS startDate_4, c2_.title AS title_5, COUNT(c3_.id) AS sclr_6, c0_.discr AS discr_7, c0_.spouse_id AS spouse_id_8, c2_.car_id AS car_id_9 FROM company_employees c1_ INNER JOIN company_persons c0_ ON c1_.id = c0_.id LEFT JOIN company_managers c2_ ON c1_.id = c2_.id INNER JOIN company_contract_employees c4_ ON c1_.id = c4_.employee_id INNER JOIN company_contracts c3_ ON c3_.id = c4_.contract_id AND c3_.discr IN ('fix', 'flexible', 'flexultra') WHERE c1_.department = ?",
"SELECT c0_.id AS id_0, c0_.name AS name_1, c1_.salary AS salary_2, c1_.department AS department_3, c1_.startDate AS startDate_4, c2_.title AS title_5, COUNT(c3_.id) AS sclr_6, c0_.discr AS discr_7, c0_.spouse_id AS spouse_id_8, c2_.car_id AS car_id_9 FROM company_employees c1_ INNER JOIN company_persons c0_ ON c1_.id = c0_.id LEFT JOIN company_managers c2_ ON c1_.id = c2_.id INNER JOIN company_contract_employees c4_ ON c1_.id = c4_.employee_id INNER JOIN company_contracts c3_ ON c3_.id = c4_.contract_id AND c3_.discr IN ('fix', 'flexible', 'flexultra', 'car') WHERE c1_.department = ?",
[],
['department' => 'foobar'],
);
Expand Down

0 comments on commit 2afdc8e

Please sign in to comment.