From 91e5529d9bb2c7a0213efaf2cb3378c7fa555fe7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 2 Jan 2025 13:03:32 +0100 Subject: [PATCH] server/discount: fix issues when updating discount code leading to ghost coupons in Stripe --- server/polar/discount/service.py | 28 +++++++++++++------ server/polar/models/discount.py | 5 +--- server/tests/discount/test_service.py | 39 +++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 12 deletions(-) diff --git a/server/polar/discount/service.py b/server/polar/discount/service.py index 7176b6dd6b..2d62f09e6f 100644 --- a/server/polar/discount/service.py +++ b/server/polar/discount/service.py @@ -179,8 +179,6 @@ async def update( discount: Discount, discount_update: DiscountUpdate, ) -> Discount: - previous_name = discount.name - if ( discount_update.duration is not None and discount_update.duration != discount.duration @@ -253,10 +251,13 @@ async def update( ) discount.discount_products.append(DiscountProduct(product=product)) + updated_fields = set() for attr, value in discount_update.model_dump( exclude_unset=True, exclude={"products"} ).items(): - setattr(discount, attr, value) + if value != getattr(discount, attr): + setattr(discount, attr, value) + updated_fields.add(attr) sensitive_fields = { "starts_at", @@ -267,13 +268,24 @@ async def update( "basis_points", "duration_in_months", } - if len(discount_update.model_fields_set.intersection(sensitive_fields)) > 0: - await stripe_service.delete_coupon(discount.stripe_coupon_id) - stripe_coupon = await stripe_service.create_coupon( + if sensitive_fields.intersection(updated_fields): + if discount.ends_at is not None and discount.ends_at < utc_now(): + raise PolarRequestValidationError( + [ + { + "type": "value_error", + "loc": ("body", "ends_at"), + "msg": "Ends at must be in the future.", + "input": discount.ends_at, + } + ] + ) + new_stripe_coupon = await stripe_service.create_coupon( **discount.get_stripe_coupon_params() ) - discount.stripe_coupon_id = stripe_coupon.id - elif previous_name != discount.name: + await stripe_service.delete_coupon(discount.stripe_coupon_id) + discount.stripe_coupon_id = new_stripe_coupon.id + elif "name" in updated_fields: await stripe_service.update_coupon( discount.stripe_coupon_id, name=discount.name ) diff --git a/server/polar/models/discount.py b/server/polar/models/discount.py index 13fb0e70fa..7081d87741 100644 --- a/server/polar/models/discount.py +++ b/server/polar/models/discount.py @@ -54,9 +54,6 @@ class DiscountDuration(StrEnum): forever = "forever" repeating = "repeating" - def as_literal(self) -> Literal["once", "forever", "repeating"]: - return cast(Literal["once", "forever", "repeating"], self.value) - class Discount(MetadataMixin, RecordModel): __tablename__ = "discounts" @@ -122,7 +119,7 @@ def get_discount_amount(self, amount: int) -> int: def get_stripe_coupon_params(self) -> stripe_lib.Coupon.CreateParams: params: stripe_lib.Coupon.CreateParams = { "name": self.name, - "duration": self.duration.as_literal(), + "duration": cast(Literal["once", "forever", "repeating"], self.duration), "metadata": { "discount_id": str(self.id), "organization_id": str(self.organization.id), diff --git a/server/tests/discount/test_service.py b/server/tests/discount/test_service.py index d4863abea9..774887226a 100644 --- a/server/tests/discount/test_service.py +++ b/server/tests/discount/test_service.py @@ -252,6 +252,45 @@ async def test_update_products_reset( assert updated_discount.stripe_coupon_id == old_stripe_coupon_id stripe_service_mock.update_coupon.assert_not_called() + async def test_update_discount_past_dates( + self, + stripe_service_mock: MagicMock, + save_fixture: SaveFixture, + session: AsyncSession, + organization: Organization, + ) -> None: + discount = await create_discount( + save_fixture, + type=DiscountType.percentage, + basis_points=1000, + duration=DiscountDuration.once, + organization=organization, + starts_at=utc_now() - timedelta(days=7), + ends_at=utc_now() - timedelta(days=1), + ) + old_stripe_coupon_id = discount.stripe_coupon_id + + updated_discount = await discount_service.update( + session, + discount, + discount_update=DiscountUpdate( + name="Updated Name", + starts_at=discount.starts_at, + ends_at=discount.ends_at, + ), + ) + + assert updated_discount.name == "Updated Name" + assert updated_discount.stripe_coupon_id == old_stripe_coupon_id + assert updated_discount.starts_at == discount.starts_at + assert updated_discount.ends_at == discount.ends_at + + stripe_service_mock.update_coupon.assert_called_once_with( + old_stripe_coupon_id, name="Updated Name" + ) + stripe_service_mock.delete_coupon.assert_not_called() + stripe_service_mock.create_coupon.assert_not_called() + @pytest.mark.asyncio class TestIsRedeemableDiscount: