Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 42 additions & 25 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.orm import InstanceState, Mapper
from sqlalchemy.orm import InstanceState, Mapper, RelationshipProperty
except ImportError as e:
msg = "sqlalchemy is not installed"
raise MissingDependencyException(msg) from e
Expand Down Expand Up @@ -240,6 +240,36 @@ def get_type_from_collection_class(

return annotation

@classmethod
def _get_relationship_type(cls, relationship: RelationshipProperty[Any]) -> type:
class_ = relationship.entity.class_
annotation: type

if relationship.uselist:
collection_class = relationship.collection_class
if collection_class is None:
annotation = list[class_] # type: ignore[valid-type]
else:
annotation = cls.get_type_from_collection_class(collection_class, class_)
else:
annotation = class_

return annotation

@classmethod
def _get_association_proxy_type(cls, table: Mapper, proxy: AssociationProxy) -> type | None:
target_collection = table.relationships.get(proxy.target_collection)
if not target_collection:
return None

target_class = target_collection.entity.class_
target_attr = getattr(target_class, proxy.value_attr)
if not target_attr:
return None

class_ = target_attr.entity.class_
return class_ if not target_collection.uselist else list[class_] # type: ignore[valid-type]

@classmethod
def get_model_fields(cls) -> list[FieldMeta]:
fields_meta: list[FieldMeta] = []
Expand All @@ -255,18 +285,7 @@ def get_model_fields(cls) -> list[FieldMeta]:
)
if cls.__set_relationships__:
for name, relationship in table.relationships.items():
class_ = relationship.entity.class_
annotation: Any

if relationship.uselist:
collection_class = relationship.collection_class
if collection_class is None:
annotation = list[class_] # type: ignore[valid-type]
else:
annotation = cls.get_type_from_collection_class(collection_class, class_)
else:
annotation = class_

annotation = cls._get_relationship_type(relationship)
fields_meta.append(
FieldMeta.from_type(
name=name,
Expand All @@ -276,19 +295,17 @@ def get_model_fields(cls) -> list[FieldMeta]:
if cls.__set_association_proxy__:
for name, attr in table.all_orm_descriptors.items():
if isinstance(attr, AssociationProxy):
target_collection = table.relationships.get(attr.target_collection)
if target_collection:
target_class = target_collection.entity.class_
target_attr = getattr(target_class, attr.value_attr)
if target_attr:
class_ = target_attr.entity.class_
annotation = class_ if not target_collection.uselist else list[class_] # type: ignore[valid-type]
fields_meta.append(
FieldMeta.from_type(
name=name,
annotation=annotation,
)
# Read-only proxies derive from the underlying relationship and shouldn't be set directly.
if not getattr(attr, "creator", None):
continue

if annotation := cls._get_association_proxy_type(table, attr): # type: ignore[assignment]
fields_meta.append(
FieldMeta.from_type(
name=name,
annotation=annotation,
)
)

return fields_meta

Expand Down
24 changes: 24 additions & 0 deletions tests/sqlalchemy_factory/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,27 @@ class Department(Base):

id = Column(Integer, primary_key=True)
director_id = Column(Integer, ForeignKey("users.id"))


class Company(Base):
__tablename__ = "companies"

id = Column(Integer, primary_key=True)
name: Any = Column(String)
employees = relationship(
"Employee",
back_populates="company",
)
employee_ids = association_proxy(
"employees",
"id",
)


class Employee(Base):
__tablename__ = "employees"

id = Column(Integer, primary_key=True)
name = Column(String)
company_id = Column(Integer, ForeignKey("companies.id"))
company = relationship(Company, back_populates="employees")
22 changes: 21 additions & 1 deletion tests/sqlalchemy_factory/test_association_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory
from tests.sqlalchemy_factory.models import Keyword, User, UserKeywordAssociation
from tests.sqlalchemy_factory.models import Company, Keyword, User, UserKeywordAssociation


class KeywordFactory(SQLAlchemyFactory[Keyword]): ...
Expand All @@ -18,6 +18,26 @@ class UserFactory(SQLAlchemyFactory[User]):
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)


def test_association_proxy_no_creator() -> None:
class CompanyFactory(SQLAlchemyFactory[Company]):
__set_relationships__ = True
__set_association_proxy__ = True

company = CompanyFactory.build()
assert isinstance(company.employee_ids[0], int)
assert company.employees[0].id in company.employee_ids


def test_association_proxy_no_creator_no_relationship() -> None:
class CompanyFactory(SQLAlchemyFactory[Company]):
__set_relationships__ = False
__set_association_proxy__ = True

company = CompanyFactory.build()
assert len(company.employees) == 0
assert len(company.employee_ids) == 0


async def test_async_persistence(async_engine: AsyncEngine) -> None:
async with AsyncSession(async_engine) as session:

Expand Down
Loading