Skip to content

Commit

Permalink
server/discount: fix issues when updating discount code leading to gh…
Browse files Browse the repository at this point in the history
…ost coupons in Stripe
  • Loading branch information
frankie567 committed Jan 2, 2025
1 parent 677e354 commit 91e5529
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 12 deletions.
28 changes: 20 additions & 8 deletions server/polar/discount/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ async def update(
discount: Discount,
discount_update: DiscountUpdate,
) -> Discount:
previous_name = discount.name

if (
discount_update.duration is not None
and discount_update.duration != discount.duration
Expand Down Expand Up @@ -253,10 +251,13 @@ async def update(
)
discount.discount_products.append(DiscountProduct(product=product))

updated_fields = set()
for attr, value in discount_update.model_dump(
exclude_unset=True, exclude={"products"}
).items():
setattr(discount, attr, value)
if value != getattr(discount, attr):
setattr(discount, attr, value)
updated_fields.add(attr)

sensitive_fields = {
"starts_at",
Expand All @@ -267,13 +268,24 @@ async def update(
"basis_points",
"duration_in_months",
}
if len(discount_update.model_fields_set.intersection(sensitive_fields)) > 0:
await stripe_service.delete_coupon(discount.stripe_coupon_id)
stripe_coupon = await stripe_service.create_coupon(
if sensitive_fields.intersection(updated_fields):
if discount.ends_at is not None and discount.ends_at < utc_now():
raise PolarRequestValidationError(
[
{
"type": "value_error",
"loc": ("body", "ends_at"),
"msg": "Ends at must be in the future.",
"input": discount.ends_at,
}
]
)
new_stripe_coupon = await stripe_service.create_coupon(
**discount.get_stripe_coupon_params()
)
discount.stripe_coupon_id = stripe_coupon.id
elif previous_name != discount.name:
await stripe_service.delete_coupon(discount.stripe_coupon_id)
discount.stripe_coupon_id = new_stripe_coupon.id
elif "name" in updated_fields:
await stripe_service.update_coupon(
discount.stripe_coupon_id, name=discount.name
)
Expand Down
5 changes: 1 addition & 4 deletions server/polar/models/discount.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ class DiscountDuration(StrEnum):
forever = "forever"
repeating = "repeating"

def as_literal(self) -> Literal["once", "forever", "repeating"]:
return cast(Literal["once", "forever", "repeating"], self.value)


class Discount(MetadataMixin, RecordModel):
__tablename__ = "discounts"
Expand Down Expand Up @@ -122,7 +119,7 @@ def get_discount_amount(self, amount: int) -> int:
def get_stripe_coupon_params(self) -> stripe_lib.Coupon.CreateParams:
params: stripe_lib.Coupon.CreateParams = {
"name": self.name,
"duration": self.duration.as_literal(),
"duration": cast(Literal["once", "forever", "repeating"], self.duration),
"metadata": {
"discount_id": str(self.id),
"organization_id": str(self.organization.id),
Expand Down
39 changes: 39 additions & 0 deletions server/tests/discount/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,45 @@ async def test_update_products_reset(
assert updated_discount.stripe_coupon_id == old_stripe_coupon_id
stripe_service_mock.update_coupon.assert_not_called()

async def test_update_discount_past_dates(
self,
stripe_service_mock: MagicMock,
save_fixture: SaveFixture,
session: AsyncSession,
organization: Organization,
) -> None:
discount = await create_discount(
save_fixture,
type=DiscountType.percentage,
basis_points=1000,
duration=DiscountDuration.once,
organization=organization,
starts_at=utc_now() - timedelta(days=7),
ends_at=utc_now() - timedelta(days=1),
)
old_stripe_coupon_id = discount.stripe_coupon_id

updated_discount = await discount_service.update(
session,
discount,
discount_update=DiscountUpdate(
name="Updated Name",
starts_at=discount.starts_at,
ends_at=discount.ends_at,
),
)

assert updated_discount.name == "Updated Name"
assert updated_discount.stripe_coupon_id == old_stripe_coupon_id
assert updated_discount.starts_at == discount.starts_at
assert updated_discount.ends_at == discount.ends_at

stripe_service_mock.update_coupon.assert_called_once_with(
old_stripe_coupon_id, name="Updated Name"
)
stripe_service_mock.delete_coupon.assert_not_called()
stripe_service_mock.create_coupon.assert_not_called()


@pytest.mark.asyncio
class TestIsRedeemableDiscount:
Expand Down

0 comments on commit 91e5529

Please sign in to comment.