Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions openmeter/billing/service/gatheringinvoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package billingservice

import (
"context"
"fmt"

"github.com/samber/lo"

"github.com/openmeterio/openmeter/openmeter/billing"
"github.com/openmeterio/openmeter/openmeter/customer"
"github.com/openmeterio/openmeter/pkg/framework/transaction"
"github.com/openmeterio/openmeter/pkg/pagination"
)
Expand All @@ -19,3 +23,65 @@ func (s *Service) ListGatheringInvoices(ctx context.Context, input billing.ListG
return s.adapter.ListGatheringInvoices(ctx, input)
})
}

func (s *Service) emulateStandardInvoicesGatheringInvoiceFields(ctx context.Context, invoices []billing.StandardInvoice) ([]billing.StandardInvoice, error) {
mergedProfiles := make(map[customer.CustomerID]billing.CustomerOverrideWithDetails)

for idx := range invoices {
invoice := &invoices[idx]

if invoice.Status != billing.StandardInvoiceStatusGathering {
continue
}

if _, ok := mergedProfiles[invoice.CustomerID()]; !ok {
expand := billing.CustomerOverrideExpand{
Customer: true,
Apps: true,
}

mergedProfile, err := s.GetCustomerOverride(ctx, billing.GetCustomerOverrideInput{
Customer: invoice.CustomerID(),
Expand: expand,
})
if err != nil {
return nil, err
}

mergedProfiles[invoice.CustomerID()] = mergedProfile
}

mergedProfile := mergedProfiles[invoice.CustomerID()]

invoice.Customer = billing.InvoiceCustomer{
CustomerID: invoice.CustomerID().ID,
Name: mergedProfile.Customer.Name,
Key: mergedProfile.Customer.Key,
UsageAttribution: lo.ToPtr(mergedProfile.Customer.GetUsageAttribution()),
}

invoice.Supplier = mergedProfile.MergedProfile.Supplier

invoice.Workflow = billing.InvoiceWorkflow{
AppReferences: lo.FromPtr(mergedProfile.MergedProfile.AppReferences),
Apps: mergedProfile.MergedProfile.Apps,
SourceBillingProfileID: mergedProfile.MergedProfile.ID,
Config: mergedProfile.MergedProfile.WorkflowConfig,
}
}

return invoices, nil
}

func (s *Service) emulateStandardInvoiceGatheringInvoiceFields(ctx context.Context, invoice billing.StandardInvoice) (billing.StandardInvoice, error) {
invoices, err := s.emulateStandardInvoicesGatheringInvoiceFields(ctx, []billing.StandardInvoice{invoice})
if err != nil {
return billing.StandardInvoice{}, err
}

if len(invoices) != 1 {
return billing.StandardInvoice{}, fmt.Errorf("expected 1 invoice, got %d", len(invoices))
}

return invoices[0], nil
}
14 changes: 14 additions & 0 deletions openmeter/billing/service/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ func (s *Service) ListInvoices(ctx context.Context, input billing.ListInvoicesIn
return billing.ListInvoicesResponse{}, err
}

updatedInvoices, err := s.emulateStandardInvoicesGatheringInvoiceFields(ctx, invoices.Items)
if err != nil {
return billing.ListInvoicesResponse{}, fmt.Errorf("error emulating standard invoices gathering invoice fields: %w", err)
}

invoices.Items = updatedInvoices

for i := range invoices.Items {
invoiceID := invoices.Items[i].ID

Expand Down Expand Up @@ -233,6 +240,13 @@ func (s *Service) GetInvoiceByID(ctx context.Context, input billing.GetInvoiceBy
return billing.StandardInvoice{}, err
}

if invoice.Status == billing.StandardInvoiceStatusGathering {
invoice, err = s.emulateStandardInvoiceGatheringInvoiceFields(ctx, invoice)
if err != nil {
return billing.StandardInvoice{}, fmt.Errorf("error emulating standard invoice gathering invoice fields for invoice [%s]: %w", invoiceID, err)
}
}

invoice, err = s.resolveWorkflowApps(ctx, invoice)
if err != nil {
return billing.StandardInvoice{}, fmt.Errorf("error resolving workload apps for invoice [%s]: %w", invoiceID, err)
Expand Down
73 changes: 73 additions & 0 deletions test/billing/invoice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4251,3 +4251,76 @@ func (s *InvoicingTestSuite) TestSnapshotQuantityInvalidDatabaseState() {
s.NotEmpty(invoice.ValidationIssues)
})
}

func (s *InvoicingTestSuite) TestGatheringInvoiceEmulation() {
namespace := "ns-gathering-invoice-emulation"
now := lo.Must(time.Parse(time.RFC3339, "2024-09-02T11:13:14Z"))
periodStart := lo.Must(time.Parse(time.RFC3339, "2024-09-02T11:13:14Z"))
periodEnd := lo.Must(time.Parse(time.RFC3339, "2024-09-02T13:13:14Z"))

clock.SetTime(now)
defer clock.ResetTime()

sandboxApp := s.InstallSandboxApp(s.T(), namespace)

ctx := context.Background()

// Given we provision a new gathering invoice
// When we fetch the invoice using the standard invoice path
// We get the current supplier and customer data

customerEntity, err := s.CustomerService.CreateCustomer(ctx, customer.CreateCustomerInput{
Namespace: namespace,

CustomerMutate: customer.CustomerMutate{
Name: "Test Customer",
Key: lo.ToPtr("test-customer"),
PrimaryEmail: lo.ToPtr("test@test.com"),
BillingAddress: &models.Address{
Country: lo.ToPtr(models.CountryCode("US")),
},
Currency: lo.ToPtr(currencyx.Code(currency.USD)),
},
})
require.NoError(s.T(), err)
require.NotNil(s.T(), customerEntity)
require.NotEmpty(s.T(), customerEntity.ID)

// Given we have a default profile for the namespace
profile := s.ProvisionBillingProfile(ctx, namespace, sandboxApp.GetID())

res, err := s.BillingService.CreatePendingInvoiceLines(ctx,
billing.CreatePendingInvoiceLinesInput{
Customer: customerEntity.GetID(),
Currency: currencyx.Code(currency.USD),
Lines: []billing.GatheringLine{
billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{
Namespace: namespace,
Period: billing.Period{Start: periodStart, End: periodEnd},
InvoiceAt: now,
ManagedBy: billing.ManuallyManagedLine,
Name: "Test item1",
PerUnitAmount: alpacadecimal.NewFromFloat(100),
PaymentTerm: productcatalog.InAdvancePaymentTerm,
}),
},
})

// Then we should have the items created
require.NoError(s.T(), err)
require.Len(s.T(), res.Lines, 1)

gatheringInvoiceID := res.Invoice.InvoiceID()
require.NotEmpty(s.T(), gatheringInvoiceID)

// Let's get the invoice using the standard invoice path
invoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{
Invoice: gatheringInvoiceID,
})
require.NoError(s.T(), err)
require.NotNil(s.T(), invoice)
require.Equal(s.T(), customerEntity.ID, invoice.Customer.CustomerID)
require.Equal(s.T(), customerEntity.Name, invoice.Customer.Name)
require.Equal(s.T(), profile.Supplier.Name, invoice.Supplier.Name)
require.Equal(s.T(), sandboxApp.GetID(), invoice.Workflow.Apps.Invoicing.GetID())
}
Loading