Skip to content

Commit

Permalink
Merge branch 'feature/sm-billing' of https://github.com/bitwarden/server
Browse files Browse the repository at this point in the history
 into feature/sm-billing
  • Loading branch information
cyprain-okeke committed Jul 24, 2023
2 parents 4fcb9da + 2c0ff09 commit 5b215cd
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 26 deletions.
16 changes: 8 additions & 8 deletions src/Billing/Constants/HandledStripeWebhook.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

public static class HandledStripeWebhook
{
public static string SubscriptionDeleted => "customer.subscription.deleted";
public static string SubscriptionUpdated => "customer.subscription.updated";
public static string UpcomingInvoice => "invoice.upcoming";
public static string ChargeSucceeded => "charge.succeeded";
public static string ChargeRefunded => "charge.refunded";
public static string PaymentSucceeded => "invoice.payment_succeeded";
public static string PaymentFailed => "invoice.payment_failed";
public static string InvoiceCreated => "invoice.created";
public const string SubscriptionDeleted = "customer.subscription.deleted";
public const string SubscriptionUpdated = "customer.subscription.updated";
public const string UpcomingInvoice = "invoice.upcoming";
public const string ChargeSucceeded = "charge.succeeded";
public const string ChargeRefunded = "charge.refunded";
public const string PaymentSucceeded = "invoice.payment_succeeded";
public const string PaymentFailed = "invoice.payment_failed";
public const string InvoiceCreated = "invoice.created";
}
89 changes: 82 additions & 7 deletions src/Billing/Controllers/StripeController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Options;
using Stripe;
using Event = Stripe.Event;
using TaxRate = Bit.Core.Entities.TaxRate;

namespace Bit.Billing.Controllers;
Expand Down Expand Up @@ -41,6 +42,7 @@ public class StripeController : Controller
private readonly ITaxRateRepository _taxRateRepository;
private readonly IUserRepository _userRepository;
private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;

public StripeController(
GlobalSettings globalSettings,
Expand Down Expand Up @@ -83,6 +85,7 @@ public StripeController(
PrivateKey = globalSettings.Braintree.PrivateKey
};
_currentContext = currentContext;
_globalSettings = globalSettings;
}

[HttpPost("webhook")]
Expand Down Expand Up @@ -114,6 +117,12 @@ public async Task<IActionResult> PostWebhook([FromQuery] string key)
return new BadRequestResult();
}

// If the customer and server cloud regions don't match, early return 200 to avoid unnecessary errors
if (!await ValidateCloudRegionAsync(parsedEvent))
{
return new OkResult();
}

var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted);
var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated);

Expand Down Expand Up @@ -471,6 +480,68 @@ await _referenceEventService.RaiseEventAsync(
return new OkResult();
}

/// <summary>
/// Ensures that the customer associated with the parsed event's data is in the correct region for this server.
/// We use the customer instead of the subscription given that all subscriptions have customers, but not all
/// customers have subscriptions
/// </summary>
/// <param name="parsedEvent"></param>
/// <returns>true if the customer's region and the server's region match, otherwise false</returns>
/// <exception cref="Exception"></exception>
private async Task<bool> ValidateCloudRegionAsync(Event parsedEvent)
{
string customerRegion;

var serverRegion = _globalSettings.BaseServiceUri.CloudRegion;
var eventType = parsedEvent.Type;

switch (eventType)
{
case HandledStripeWebhook.SubscriptionDeleted:
case HandledStripeWebhook.SubscriptionUpdated:
{
var subscription = await GetSubscriptionAsync(parsedEvent, true, new List<string> { "customer" });
customerRegion = GetCustomerRegionFromMetadata(subscription.Customer.Metadata);
break;
}
case HandledStripeWebhook.ChargeSucceeded:
case HandledStripeWebhook.ChargeRefunded:
{
var charge = await GetChargeAsync(parsedEvent, true, new List<string> { "customer" });
customerRegion = GetCustomerRegionFromMetadata(charge.Customer.Metadata);
break;
}
case HandledStripeWebhook.UpcomingInvoice:
case HandledStripeWebhook.PaymentSucceeded:
case HandledStripeWebhook.PaymentFailed:
case HandledStripeWebhook.InvoiceCreated:
{
var invoice = await GetInvoiceAsync(parsedEvent, true, new List<string> { "customer" });
customerRegion = GetCustomerRegionFromMetadata(invoice.Customer.Metadata);
break;
}
default:
{
// For all Stripe events that we're not listening to, just return 200
return false;
}
}

return customerRegion == serverRegion;
}

/// <summary>
/// Gets the region from the customer metadata. If no region is present, defaults to "US"
/// </summary>
/// <param name="customerMetadata"></param>
/// <returns></returns>
private static string GetCustomerRegionFromMetadata(Dictionary<string, string> customerMetadata)
{
return customerMetadata.TryGetValue("region", out var value)
? value
: "US";
}

private Tuple<Guid?, Guid?> GetIdsFromMetaData(IDictionary<string, string> metaData)
{
if (metaData == null || !metaData.Any())
Expand Down Expand Up @@ -732,7 +803,7 @@ private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice)
invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null;
}

private async Task<Charge> GetChargeAsync(Stripe.Event parsedEvent, bool fresh = false)
private async Task<Charge> GetChargeAsync(Event parsedEvent, bool fresh = false, List<string> expandOptions = null)
{
if (!(parsedEvent.Data.Object is Charge eventCharge))
{
Expand All @@ -743,15 +814,16 @@ private async Task<Charge> GetChargeAsync(Stripe.Event parsedEvent, bool fresh =
return eventCharge;
}
var chargeService = new ChargeService();
var charge = await chargeService.GetAsync(eventCharge.Id);
var chargeGetOptions = new ChargeGetOptions { Expand = expandOptions };
var charge = await chargeService.GetAsync(eventCharge.Id, chargeGetOptions);
if (charge == null)
{
throw new Exception("Charge is null. " + eventCharge.Id);
}
return charge;
}

private async Task<Invoice> GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false)
private async Task<Invoice> GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false, List<string> expandOptions = null)
{
if (!(parsedEvent.Data.Object is Invoice eventInvoice))
{
Expand All @@ -762,17 +834,19 @@ private async Task<Invoice> GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh
return eventInvoice;
}
var invoiceService = new InvoiceService();
var invoice = await invoiceService.GetAsync(eventInvoice.Id);
var invoiceGetOptions = new InvoiceGetOptions { Expand = expandOptions };
var invoice = await invoiceService.GetAsync(eventInvoice.Id, invoiceGetOptions);
if (invoice == null)
{
throw new Exception("Invoice is null. " + eventInvoice.Id);
}
return invoice;
}

private async Task<Subscription> GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false)
private async Task<Subscription> GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false,
List<string> expandOptions = null)
{
if (!(parsedEvent.Data.Object is Subscription eventSubscription))
if (parsedEvent.Data.Object is not Subscription eventSubscription)
{
throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id);
}
Expand All @@ -781,7 +855,8 @@ private async Task<Subscription> GetSubscriptionAsync(Stripe.Event parsedEvent,
return eventSubscription;
}
var subscriptionService = new SubscriptionService();
var subscription = await subscriptionService.GetAsync(eventSubscription.Id);
var subscriptionGetOptions = new SubscriptionGetOptions { Expand = expandOptions };
var subscription = await subscriptionService.GetAsync(eventSubscription.Id, subscriptionGetOptions);
if (subscription == null)
{
throw new Exception("Subscription is null. " + eventSubscription.Id);
Expand Down
22 changes: 16 additions & 6 deletions src/Core/Services/Implementations/StripePaymentService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
using StaticStore = Bit.Core.Models.StaticStore;
using TaxRate = Bit.Core.Entities.TaxRate;
Expand All @@ -25,6 +26,7 @@ public class StripePaymentService : IPaymentService
private readonly Braintree.IBraintreeGateway _btGateway;
private readonly ITaxRateRepository _taxRateRepository;
private readonly IStripeAdapter _stripeAdapter;
private readonly IGlobalSettings _globalSettings;

public StripePaymentService(
ITransactionRepository transactionRepository,
Expand All @@ -33,7 +35,8 @@ public StripePaymentService(
ILogger<StripePaymentService> logger,
ITaxRateRepository taxRateRepository,
IStripeAdapter stripeAdapter,
Braintree.IBraintreeGateway braintreeGateway)
Braintree.IBraintreeGateway braintreeGateway,
IGlobalSettings globalSettings)
{
_transactionRepository = transactionRepository;
_userRepository = userRepository;
Expand All @@ -42,6 +45,7 @@ public StripePaymentService(
_taxRateRepository = taxRateRepository;
_stripeAdapter = stripeAdapter;
_btGateway = braintreeGateway;
_globalSettings = globalSettings;
}

public async Task<string> PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType,
Expand All @@ -52,9 +56,12 @@ public async Task<string> PurchaseOrganizationAsync(Organization org, PaymentMet
Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null;
var stripeCustomerMetadata = new Dictionary<string, string>();
var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card ||
paymentMethodType == PaymentMethodType.BankAccount;
paymentMethodType == PaymentMethodType.BankAccount;

if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken))
{
Expand Down Expand Up @@ -391,7 +398,7 @@ public async Task<string> PurchasePremiumAsync(User user, PaymentMethodType paym

if (customer == null && !string.IsNullOrWhiteSpace(paymentToken))
{
var stripeCustomerMetadata = new Dictionary<string, string>();
var stripeCustomerMetadata = new Dictionary<string, string> { { "region", _globalSettings.BaseServiceUri.CloudRegion } };
if (paymentMethodType == PaymentMethodType.PayPal)
{
var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false);
Expand Down Expand Up @@ -1193,9 +1200,12 @@ public async Task<bool> UpdatePaymentMethodAsync(ISubscriber subscriber, Payment
Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null;
var stripeCustomerMetadata = new Dictionary<string, string>();
var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card ||
paymentMethodType == PaymentMethodType.BankAccount;
paymentMethodType == PaymentMethodType.BankAccount;
var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp ||
paymentMethodType == PaymentMethodType.GoogleInApp;

Expand Down
Loading

0 comments on commit 5b215cd

Please sign in to comment.