Skip to content

Commit

Permalink
feat: allow STI child entities to have non-nullable relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
simPod committed Oct 19, 2022
1 parent bac784c commit 459649e
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 16 deletions.
17 changes: 12 additions & 5 deletions lib/Doctrine/ORM/Tools/SchemaTool.php
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ private function gatherRelationsSql(
$this->gatherRelationJoinColumns(
$mapping['joinColumns'],
$table,
$class,
$foreignClass,
$mapping,
$primaryKeyColumns,
Expand Down Expand Up @@ -593,6 +594,7 @@ private function gatherRelationsSql(
$joinTable['joinColumns'],
$theJoinTable,
$class,
$class,
$mapping,
$primaryKeyColumns,
$addedFks,
Expand All @@ -603,6 +605,7 @@ private function gatherRelationsSql(
$this->gatherRelationJoinColumns(
$joinTable['inverseJoinColumns'],
$theJoinTable,
$class,
$foreignClass,
$mapping,
$primaryKeyColumns,
Expand Down Expand Up @@ -670,6 +673,7 @@ private function gatherRelationJoinColumns(
array $joinColumns,
Table $theJoinTable,
ClassMetadata $class,
ClassMetadata $foreignClass,
array $mapping,
array &$primaryKeyColumns,
array &$addedFks,
Expand All @@ -678,12 +682,12 @@ private function gatherRelationJoinColumns(
$localColumns = [];
$foreignColumns = [];
$fkOptions = [];
$foreignTableName = $this->quoteStrategy->getTableName($class, $this->platform);
$foreignTableName = $this->quoteStrategy->getTableName($foreignClass, $this->platform);
$uniqueConstraints = [];

foreach ($joinColumns as $joinColumn) {
[$definingClass, $referencedFieldName] = $this->getDefiningClass(
$class,
$foreignClass,
$joinColumn['referencedColumnName']
);

Expand All @@ -695,10 +699,10 @@ private function gatherRelationJoinColumns(
);
}

$quotedColumnName = $this->quoteStrategy->getJoinColumnName($joinColumn, $class, $this->platform);
$quotedColumnName = $this->quoteStrategy->getJoinColumnName($joinColumn, $foreignClass, $this->platform);
$quotedRefColumnName = $this->quoteStrategy->getReferencedJoinColumnName(
$joinColumn,
$class,
$foreignClass,
$this->platform
);

Expand All @@ -721,7 +725,10 @@ private function gatherRelationJoinColumns(
$columnOptions['columnDefinition'] = $fieldMapping['columnDefinition'];
}

if (isset($joinColumn['nullable'])) {
if (
isset($joinColumn['nullable'])
&& ! ($class->isInheritanceTypeSingleTable() && $class->parentClasses)
) {
$columnOptions['notnull'] = ! $joinColumn['nullable'];
}

Expand Down
30 changes: 30 additions & 0 deletions tests/Doctrine/Tests/Models/Company/CompanyCarContract.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<?php

declare(strict_types=1);

namespace Doctrine\Tests\Models\Company;

use Doctrine\ORM\Mapping as ORM;
use Doctrine\ORM\Mapping\Entity;

/** @Entity */
class CompanyCarContract extends CompanyContract
{
/**
* @ORM\ManyToOne(targetEntity="CompanyCar")
* @ORM\JoinColumn(nullable=false, onDelete="CASCADE")
*
* @var CompanyCar
*/
private $companyCar;

public function calculatePrice(): int
{
return 0;
}

public function setCompanyCar(CompanyCar $companyCar): void
{
$this->companyCar = $companyCar;
}
}
5 changes: 3 additions & 2 deletions tests/Doctrine/Tests/Models/Company/CompanyContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
* @DiscriminatorMap({
* "fix" = "CompanyFixContract",
* "flexible" = "CompanyFlexContract",
* "flexultra" = "CompanyFlexUltraContract"
* "flexultra" = "CompanyFlexUltraContract",
* "car" = "CompanyCarContract"
* })
* @NamedNativeQueries({
* @NamedNativeQuery(
Expand Down Expand Up @@ -84,7 +85,7 @@
#[ORM\Entity, ORM\Table(name: 'company_contracts')]
#[ORM\InheritanceType('SINGLE_TABLE')]
#[ORM\DiscriminatorColumn(name: 'discr', type: 'string')]
#[ORM\DiscriminatorMap(['fix' => 'CompanyFixContract', 'flexible' => 'CompanyFlexContract', 'flexultra' => 'CompanyFlexUltraContract'])]
#[ORM\DiscriminatorMap(['fix' => 'CompanyFixContract', 'flexible' => 'CompanyFlexContract', 'flexultra' => 'CompanyFlexUltraContract', 'car' => 'CompanyCarContract'])]
#[ORM\EntityListeners(['CompanyContractListener'])]
abstract class CompanyContract
{
Expand Down
21 changes: 21 additions & 0 deletions tests/Doctrine/Tests/ORM/Functional/SingleTableInheritanceTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
use Doctrine\ORM\Mapping\ClassMetadata;
use Doctrine\ORM\Persisters\MatchingAssociationFieldRequiresObject;
use Doctrine\Persistence\Proxy;
use Doctrine\Tests\Models\Company\CompanyCar;
use Doctrine\Tests\Models\Company\CompanyCarContract;
use Doctrine\Tests\Models\Company\CompanyContract;
use Doctrine\Tests\Models\Company\CompanyEmployee;
use Doctrine\Tests\Models\Company\CompanyFixContract;
Expand Down Expand Up @@ -418,4 +420,23 @@ public function testEagerLoadInheritanceHierarchy(): void

self::assertNotInstanceOf(Proxy::class, $contract->getSalesPerson());
}

public function testChildCanHaveNonNullableRelation(): void
{
$companyCar = new CompanyCar('BMW');
$fixContract = new CompanyFixContract();
$carContract = new CompanyCarContract();
$carContract->setCompanyCar($companyCar);

$this->_em->persist($fixContract);
$this->_em->persist($companyCar);
$this->_em->persist($carContract);
$this->_em->flush();
$this->_em->clear();

$repo = $this->_em->getRepository(CompanyCarContract::class);
$carContracts = $repo->findAll();

self::assertCount(1, $carContracts);
}
}
18 changes: 9 additions & 9 deletions tests/Doctrine/Tests/ORM/Query/SelectSqlGenerationTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public function testSupportsJoinOnMultipleComponentsWithJoinedInheritanceType():

$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c JOIN c.salesPerson s LEFT JOIN Doctrine\Tests\Models\Company\CompanyEvent e WITH s.id = e.id',
'SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_events c3_ ON (c2_.id = c3_.id) WHERE c0_.discr IN (\'fix\', \'flexible\', \'flexultra\')'
'SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_events c3_ ON (c2_.id = c3_.id) WHERE c0_.discr IN (\'fix\', \'flexible\', \'flexultra\', \'car\')'
);
}

Expand Down Expand Up @@ -1498,7 +1498,7 @@ public function testInheritanceTypeSingleTableInRootClassWithDisabledForcePartia
{
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
[ORMQuery::HINT_FORCE_PARTIAL_LOAD => false]
);
}
Expand All @@ -1508,7 +1508,7 @@ public function testInheritanceTypeSingleTableInRootClassWithEnabledForcePartial
{
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
[ORMQuery::HINT_FORCE_PARTIAL_LOAD => true]
);
}
Expand Down Expand Up @@ -2078,7 +2078,7 @@ public function testSingleTableInheritanceLeftJoinWithCondition(): void
// Regression test for the bug
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e LEFT JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand All @@ -2088,7 +2088,7 @@ public function testSingleTableInheritanceLeftJoinWithConditionAndWhere(): void
// Ensure other WHERE predicates are passed through to the main WHERE clause
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e LEFT JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id WHERE e.salary > 1000',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra') WHERE c1_.salary > 1000"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car') WHERE c1_.salary > 1000"
);
}

Expand All @@ -2098,7 +2098,7 @@ public function testSingleTableInheritanceInnerJoinWithCondition(): void
// Test inner joins too
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e INNER JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id INNER JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id INNER JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand All @@ -2109,7 +2109,7 @@ public function testSingleTableInheritanceLeftJoinNonAssociationWithConditionAnd
// the where clause when not joining onto that table
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c LEFT JOIN Doctrine\Tests\Models\Company\CompanyEmployee e WITH e.id = c.salesPerson WHERE c.completed = true',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ LEFT JOIN (company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id) ON (c2_.id = c0_.salesPerson_id) WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ LEFT JOIN (company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id) ON (c2_.id = c0_.salesPerson_id) WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand All @@ -2121,7 +2121,7 @@ public function testSingleTableInheritanceJoinCreatesOnCondition(): void
// via a join association
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c JOIN c.salesPerson s WHERE c.completed = true',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand All @@ -2133,7 +2133,7 @@ public function testSingleTableInheritanceCreatesOnConditionAndWhere(): void
// into the ON clause of the join
$this->assertSqlGeneration(
'SELECT e, COUNT(c) FROM Doctrine\Tests\Models\Company\CompanyEmployee e JOIN e.contracts c WHERE e.department = :department',
"SELECT c0_.id AS id_0, c0_.name AS name_1, c1_.salary AS salary_2, c1_.department AS department_3, c1_.startDate AS startDate_4, COUNT(c2_.id) AS sclr_5, c0_.discr AS discr_6 FROM company_employees c1_ INNER JOIN company_persons c0_ ON c1_.id = c0_.id INNER JOIN company_contract_employees c3_ ON c1_.id = c3_.employee_id INNER JOIN company_contracts c2_ ON c2_.id = c3_.contract_id AND c2_.discr IN ('fix', 'flexible', 'flexultra') WHERE c1_.department = ?",
"SELECT c0_.id AS id_0, c0_.name AS name_1, c1_.salary AS salary_2, c1_.department AS department_3, c1_.startDate AS startDate_4, COUNT(c2_.id) AS sclr_5, c0_.discr AS discr_6 FROM company_employees c1_ INNER JOIN company_persons c0_ ON c1_.id = c0_.id INNER JOIN company_contract_employees c3_ ON c1_.id = c3_.employee_id INNER JOIN company_contracts c2_ ON c2_.id = c3_.contract_id AND c2_.discr IN ('fix', 'flexible', 'flexultra', 'car') WHERE c1_.department = ?",
[],
['department' => 'foobar']
);
Expand Down

0 comments on commit 459649e

Please sign in to comment.