Skip to content

Commit

Permalink
server/customer: add a Customer.user_id foreign key instead of having…
Browse files Browse the repository at this point in the history
… UserCustomer association table

Must have been really tired when I did this initially...
  • Loading branch information
frankie567 committed Jan 3, 2025
1 parent 4bf0212 commit 317a4aa
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 157 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Add Customer.user_id and remove UserCustomer
Revision ID: c996df1d397f
Revises: 81faf775fce0
Create Date: 2025-01-03 14:43:30.632086
"""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# Polar Custom Imports

# revision identifiers, used by Alembic.
revision = "c996df1d397f"
down_revision = "81faf775fce0"
branch_labels: tuple[str] | None = None
depends_on: tuple[str] | None = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("customers", sa.Column("user_id", sa.Uuid(), nullable=True))
op.create_foreign_key(
op.f("customers_user_id_fkey"),
"customers",
"users",
["user_id"],
["id"],
ondelete="set null",
)

op.execute(
"""
UPDATE customers
SET user_id = user_customers.user_id
FROM user_customers
WHERE customers.id = user_customers.customer_id
"""
)

op.drop_index("ix_user_customers_created_at", table_name="user_customers")
op.drop_index("ix_user_customers_deleted_at", table_name="user_customers")
op.drop_index("ix_user_customers_modified_at", table_name="user_customers")
op.drop_table("user_customers")

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"user_customers",
sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=False),
sa.Column("customer_id", sa.UUID(), autoincrement=False, nullable=False),
sa.Column("id", sa.UUID(), autoincrement=False, nullable=False),
sa.Column(
"created_at",
postgresql.TIMESTAMP(timezone=True),
autoincrement=False,
nullable=False,
),
sa.Column(
"modified_at",
postgresql.TIMESTAMP(timezone=True),
autoincrement=False,
nullable=True,
),
sa.Column(
"deleted_at",
postgresql.TIMESTAMP(timezone=True),
autoincrement=False,
nullable=True,
),
sa.ForeignKeyConstraint(
["customer_id"],
["customers.id"],
name="user_customers_customer_id_fkey",
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
name="user_customers_user_id_fkey",
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name="user_customers_pkey"),
sa.UniqueConstraint("customer_id", name="user_customers_customer_id_key"),
)
op.create_index(
"ix_user_customers_modified_at", "user_customers", ["modified_at"], unique=False
)
op.create_index(
"ix_user_customers_deleted_at", "user_customers", ["deleted_at"], unique=False
)
op.create_index(
"ix_user_customers_created_at", "user_customers", ["created_at"], unique=False
)

op.execute(
"""
INSERT INTO user_customers (user_id, customer_id, created_at)
SELECT user_id, id, created_at
FROM customers
"""
)

op.drop_constraint(op.f("customers_user_id_fkey"), "customers", type_="foreignkey")
op.drop_column("customers", "user_id")

# ### end Alembic commands ###
48 changes: 16 additions & 32 deletions server/polar/customer/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any

from sqlalchemy import Select, UnaryExpression, asc, desc, func, or_, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql.base import ExecutableOption
from stripe import Customer as StripeCustomer

Expand All @@ -13,7 +12,7 @@
from polar.kit.pagination import PaginationParams, paginate
from polar.kit.services import ResourceServiceReader
from polar.kit.sorting import Sorting
from polar.models import Customer, Organization, User, UserCustomer, UserOrganization
from polar.models import Customer, Organization, User, UserOrganization
from polar.organization.resolver import get_payload_organization
from polar.postgres import AsyncSession

Expand Down Expand Up @@ -184,43 +183,31 @@ async def get_by_email_and_organization(
async def get_by_id_and_user(
self, session: AsyncSession, id: uuid.UUID, user: User
) -> Customer | None:
statement = (
select(Customer)
.join(UserCustomer, onclause=UserCustomer.customer_id == Customer.id)
.where(
Customer.deleted_at.is_(None),
Customer.id == id,
UserCustomer.user_id == user.id,
)
statement = select(Customer).where(
Customer.deleted_at.is_(None),
Customer.id == id,
Customer.user_id == user.id,
)
result = await session.execute(statement)
return result.scalar_one_or_none()

async def get_by_user_and_organization(
self, session: AsyncSession, user: User, organization: Organization
) -> Customer | None:
statement = (
select(Customer)
.join(UserCustomer, onclause=UserCustomer.customer_id == Customer.id)
.where(
Customer.deleted_at.is_(None),
UserCustomer.user_id == user.id,
Customer.organization_id == organization.id,
)
statement = select(Customer).where(
Customer.deleted_at.is_(None),
Customer.user_id == user.id,
Customer.organization_id == organization.id,
)
result = await session.execute(statement)
return result.scalar_one_or_none()

async def get_by_user(
self, session: AsyncSession, user: User
) -> Sequence[Customer]:
statement = (
select(Customer)
.join(UserCustomer, onclause=UserCustomer.customer_id == Customer.id)
.where(
Customer.deleted_at.is_(None),
UserCustomer.user_id == user.id,
)
statement = select(Customer).where(
Customer.deleted_at.is_(None),
Customer.user_id == user.id,
)
result = await session.execute(statement)
return result.unique().scalars().all()
Expand Down Expand Up @@ -275,13 +262,10 @@ async def get_or_create_from_stripe_customer(
async def link_user(
self, session: AsyncSession, customer: Customer, user: User
) -> None:
insert_statement = insert(UserCustomer).values(
user_id=user.id, customer_id=customer.id
)
insert_statement = insert_statement.on_conflict_do_nothing(
index_elements=["customer_id"]
)
await session.execute(insert_statement)
if customer.user_id is not None:
return
customer.user = user
session.add(customer)

def _get_readable_customer_statement(
self, auth_subject: AuthSubject[User | Organization]
Expand Down
5 changes: 2 additions & 3 deletions server/polar/customer_portal/service/benefit_grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Organization,
Subscription,
User,
UserCustomer,
)
from polar.models.benefit import BenefitType
from polar.worker import enqueue_job
Expand Down Expand Up @@ -209,8 +208,8 @@ def _get_readable_benefit_grant_statement(
if is_user(auth_subject):
statement = statement.where(
BenefitGrant.customer_id.in_(
select(UserCustomer.customer_id).where(
UserCustomer.user_id == auth_subject.subject.id
select(Customer.id).where(
Customer.user_id == auth_subject.subject.id
)
)
)
Expand Down
6 changes: 3 additions & 3 deletions server/polar/customer_portal/service/downloadables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from polar.kit.pagination import PaginationParams, paginate
from polar.kit.services import ResourceService
from polar.kit.utils import utc_now
from polar.models import Benefit, Customer, User, UserCustomer
from polar.models import Benefit, Customer, User
from polar.models.downloadable import Downloadable, DownloadableStatus
from polar.models.file import File
from polar.postgres import AsyncSession, sql
Expand Down Expand Up @@ -243,8 +243,8 @@ def _get_base_query(
if is_user(auth_subject):
statement = statement.where(
Downloadable.customer_id.in_(
sql.select(UserCustomer.customer_id).where(
UserCustomer.user_id == auth_subject.subject.id
sql.select(Customer.id).where(
Customer.user_id == auth_subject.subject.id
)
)
)
Expand Down
14 changes: 3 additions & 11 deletions server/polar/customer_portal/service/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,7 @@
from polar.kit.pagination import PaginationParams, paginate
from polar.kit.services import ResourceServiceReader
from polar.kit.sorting import Sorting
from polar.models import (
Customer,
Order,
Organization,
Product,
ProductPrice,
User,
UserCustomer,
)
from polar.models import Customer, Order, Organization, Product, ProductPrice, User
from polar.models.product_price import ProductPriceType


Expand Down Expand Up @@ -161,8 +153,8 @@ def _get_readable_order_statement(
if is_user(auth_subject):
statement = statement.where(
Order.customer_id.in_(
select(UserCustomer.customer_id).where(
UserCustomer.user_id == auth_subject.subject.id
select(Customer.id).where(
Customer.user_id == auth_subject.subject.id
)
)
)
Expand Down
5 changes: 2 additions & 3 deletions server/polar/customer_portal/service/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
ProductPrice,
Subscription,
User,
UserCustomer,
)
from polar.models.subscription import CustomerCancellationReason
from polar.subscription.service import subscription as subscription_service
Expand Down Expand Up @@ -215,8 +214,8 @@ def _get_readable_subscription_statement(
if is_user(auth_subject):
statement = statement.where(
Subscription.customer_id.in_(
select(UserCustomer.customer_id).where(
UserCustomer.user_id == auth_subject.subject.id
select(Customer.id).where(
Customer.user_id == auth_subject.subject.id
)
)
)
Expand Down
5 changes: 2 additions & 3 deletions server/polar/license_key/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
LicenseKeyActivation,
Organization,
User,
UserCustomer,
UserOrganization,
)
from polar.models.benefit import BenefitLicenseKeys
Expand Down Expand Up @@ -487,8 +486,8 @@ def _get_select_customer_base(
if is_user(auth_subject):
statement = statement.where(
LicenseKey.customer_id.in_(
select(UserCustomer.customer_id).where(
UserCustomer.user_id == auth_subject.subject.id
select(Customer.id).where(
Customer.user_id == auth_subject.subject.id
)
)
)
Expand Down
2 changes: 0 additions & 2 deletions server/polar/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from .subscription import Subscription
from .transaction import Transaction
from .user import OAuthAccount, User
from .user_customer import UserCustomer
from .user_notification import UserNotification
from .user_organization import UserOrganization
from .user_session import UserSession
Expand Down Expand Up @@ -104,7 +103,6 @@
"Subscription",
"Transaction",
"User",
"UserCustomer",
"UserNotification",
"UserOrganization",
"UserSession",
Expand Down
14 changes: 14 additions & 0 deletions server/polar/models/customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if TYPE_CHECKING:
from .organization import Organization
from .user import User


class CustomerOAuthPlatform(StrEnum):
Expand Down Expand Up @@ -86,6 +87,19 @@ class Customer(MetadataMixin, RecordModel):
)
tax_id: Mapped[TaxID | None] = mapped_column(TaxIDType, nullable=True, default=None)

user_id: Mapped[UUID | None] = mapped_column(
Uuid, ForeignKey("users.id", ondelete="set null"), nullable=True
)

@declared_attr
def user(cls) -> Mapped["User | None"]:
return relationship(
"User",
lazy="raise",
back_populates="customers",
foreign_keys="[Customer.user_id]",
)

_oauth_accounts: Mapped[dict[str, dict[str, Any]]] = mapped_column(
"oauth_accounts", JSONB, nullable=False, default=dict
)
Expand Down
15 changes: 7 additions & 8 deletions server/polar/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
func,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
from sqlalchemy.schema import Index, UniqueConstraint

Expand All @@ -27,7 +26,6 @@

if TYPE_CHECKING:
from .customer import Customer
from .user_customer import UserCustomer


class OAuthPlatform(StrEnum):
Expand Down Expand Up @@ -115,12 +113,13 @@ def oauth_accounts(cls) -> Mapped[list[OAuthAccount]]:
return relationship(OAuthAccount, lazy="joined", back_populates="user")

@declared_attr
def user_customers(cls) -> Mapped[list["UserCustomer"]]:
return relationship("UserCustomer", lazy="raise", back_populates="user")

customers: AssociationProxy[list["Customer"]] = association_proxy(
"user_customers", "customer"
)
def customers(cls) -> Mapped[list["Customer"]]:
return relationship(
"Customer",
lazy="raise",
back_populates="user",
foreign_keys="[Customer.user_id]",
)

accepted_terms_of_service: Mapped[bool] = mapped_column(
Boolean,
Expand Down
Loading

0 comments on commit 317a4aa

Please sign in to comment.