Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions openmeter/app/sandbox/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type AppFactory interface {
NewApp(ctx context.Context, appBase app.AppBase) (app.App, error)
}

type InvoiceUpsertCallback func(billing.Invoice) *billing.UpsertInvoiceResult
type InvoiceUpsertCallback func(billing.Invoice) (*billing.UpsertInvoiceResult, error)

type MockApp struct {
validateCustomerResponse mo.Option[error]
Expand Down Expand Up @@ -76,7 +76,7 @@ func (m *MockApp) UpsertInvoice(ctx context.Context, invoice billing.Invoice) (*
m.upsertInvoiceCalled = true

if m.upsertInvoiceCallback.IsPresent() && m.upsertInvoiceCallback.MustGet() != nil {
return m.upsertInvoiceCallback.MustGet()(invoice), nil
return m.upsertInvoiceCallback.MustGet()(invoice)
}

return billing.NewUpsertInvoiceResult(), nil
Expand Down
111 changes: 61 additions & 50 deletions openmeter/billing/adapter/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,62 +651,68 @@ func (a *adapter) GetInvoiceOwnership(ctx context.Context, in billing.GetInvoice
})
}

func (a *adapter) mapInvoiceFromDB(ctx context.Context, invoice *db.BillingInvoice, expand billing.InvoiceExpand) (billing.Invoice, error) {
res := billing.Invoice{
InvoiceBase: billing.InvoiceBase{
ID: invoice.ID,
Namespace: invoice.Namespace,
Metadata: invoice.Metadata,
Currency: invoice.Currency,
Status: invoice.Status,
StatusDetails: invoice.StatusDetailsCache,
Type: invoice.Type,
Number: invoice.Number,
Description: invoice.Description,
DueAt: convert.TimePtrIn(invoice.DueAt, time.UTC),
DraftUntil: convert.TimePtrIn(invoice.DraftUntil, time.UTC),
SentToCustomerAt: convert.TimePtrIn(invoice.SentToCustomerAt, time.UTC),
Supplier: billing.SupplierContact{
Name: invoice.SupplierName,
Address: models.Address{
Country: invoice.SupplierAddressCountry,
PostalCode: invoice.SupplierAddressPostalCode,
City: invoice.SupplierAddressCity,
State: invoice.SupplierAddressState,
Line1: invoice.SupplierAddressLine1,
Line2: invoice.SupplierAddressLine2,
PhoneNumber: invoice.SupplierAddressPhoneNumber,
},
TaxCode: invoice.SupplierTaxCode,
func (a *adapter) mapInvoiceBaseFromDB(ctx context.Context, invoice *db.BillingInvoice) billing.InvoiceBase {
return billing.InvoiceBase{
ID: invoice.ID,
Namespace: invoice.Namespace,
Metadata: invoice.Metadata,
Currency: invoice.Currency,
Status: invoice.Status,
StatusDetails: invoice.StatusDetailsCache,
Type: invoice.Type,
Number: invoice.Number,
Description: invoice.Description,
DueAt: convert.TimePtrIn(invoice.DueAt, time.UTC),
DraftUntil: convert.TimePtrIn(invoice.DraftUntil, time.UTC),
SentToCustomerAt: convert.TimePtrIn(invoice.SentToCustomerAt, time.UTC),
Supplier: billing.SupplierContact{
Name: invoice.SupplierName,
Address: models.Address{
Country: invoice.SupplierAddressCountry,
PostalCode: invoice.SupplierAddressPostalCode,
City: invoice.SupplierAddressCity,
State: invoice.SupplierAddressState,
Line1: invoice.SupplierAddressLine1,
Line2: invoice.SupplierAddressLine2,
PhoneNumber: invoice.SupplierAddressPhoneNumber,
},
TaxCode: invoice.SupplierTaxCode,
},

Customer: billing.InvoiceCustomer{
CustomerID: invoice.CustomerID,
Name: invoice.CustomerName,
BillingAddress: &models.Address{
Country: invoice.CustomerAddressCountry,
PostalCode: invoice.CustomerAddressPostalCode,
City: invoice.CustomerAddressCity,
State: invoice.CustomerAddressState,
Line1: invoice.CustomerAddressLine1,
Line2: invoice.CustomerAddressLine2,
PhoneNumber: invoice.CustomerAddressPhoneNumber,
},
UsageAttribution: invoice.CustomerUsageAttribution.CustomerUsageAttribution,
Customer: billing.InvoiceCustomer{
CustomerID: invoice.CustomerID,
Name: invoice.CustomerName,
BillingAddress: &models.Address{
Country: invoice.CustomerAddressCountry,
PostalCode: invoice.CustomerAddressPostalCode,
City: invoice.CustomerAddressCity,
State: invoice.CustomerAddressState,
Line1: invoice.CustomerAddressLine1,
Line2: invoice.CustomerAddressLine2,
PhoneNumber: invoice.CustomerAddressPhoneNumber,
},
Period: mapPeriodFromDB(invoice.PeriodStart, invoice.PeriodEnd),
IssuedAt: convert.TimePtrIn(invoice.IssuedAt, time.UTC),
CreatedAt: invoice.CreatedAt.In(time.UTC),
UpdatedAt: invoice.UpdatedAt.In(time.UTC),
DeletedAt: convert.TimePtrIn(invoice.DeletedAt, time.UTC),
UsageAttribution: invoice.CustomerUsageAttribution.CustomerUsageAttribution,
},
Period: mapPeriodFromDB(invoice.PeriodStart, invoice.PeriodEnd),
IssuedAt: convert.TimePtrIn(invoice.IssuedAt, time.UTC),
CreatedAt: invoice.CreatedAt.In(time.UTC),
UpdatedAt: invoice.UpdatedAt.In(time.UTC),
DeletedAt: convert.TimePtrIn(invoice.DeletedAt, time.UTC),

CollectionAt: lo.ToPtr(invoice.CollectionAt.In(time.UTC)),
CollectionAt: lo.ToPtr(invoice.CollectionAt.In(time.UTC)),

ExternalIDs: billing.InvoiceExternalIDs{
Invoicing: lo.FromPtrOr(invoice.InvoicingAppExternalID, ""),
Payment: lo.FromPtrOr(invoice.PaymentAppExternalID, ""),
},
ExternalIDs: billing.InvoiceExternalIDs{
Invoicing: lo.FromPtrOr(invoice.InvoicingAppExternalID, ""),
Payment: lo.FromPtrOr(invoice.PaymentAppExternalID, ""),
},
}
}

func (a *adapter) mapInvoiceFromDB(ctx context.Context, invoice *db.BillingInvoice, expand billing.InvoiceExpand) (billing.Invoice, error) {
base := a.mapInvoiceBaseFromDB(ctx, invoice)

res := billing.Invoice{
InvoiceBase: base,

Totals: billing.Totals{
Amount: invoice.Amount,
Expand Down Expand Up @@ -752,6 +758,11 @@ func (a *adapter) mapInvoiceFromDB(ctx context.Context, invoice *db.BillingInvoi
return billing.Invoice{}, err
}

mappedLines, err = a.expandProgressiveLineHierarchy(ctx, invoice.Namespace, mappedLines)
if err != nil {
return billing.Invoice{}, err
}

res.Lines = billing.NewLineChildren(mappedLines)
}

Expand Down
146 changes: 146 additions & 0 deletions openmeter/billing/adapter/invoicelineprogressive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package billingadapter

import (
"context"
"fmt"

"github.com/samber/lo"

"github.com/openmeterio/openmeter/openmeter/billing"
"github.com/openmeterio/openmeter/openmeter/ent/db"
"github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline"
"github.com/openmeterio/openmeter/pkg/slicesx"
)

// expandProgressiveLineHierarchy expands the given lines with their progressive line hierarchy
// This is done by fetching all the lines that are children of the given lines parent lines and then building
// the hierarchy.
func (a *adapter) expandProgressiveLineHierarchy(ctx context.Context, namespace string, lines []*billing.Line) ([]*billing.Line, error) {
// Let's collect all the lines with a parent line id set

lineIDsToParentIDs := map[string]string{}

for _, line := range lines {
if line.ParentLineID != nil {
lineIDsToParentIDs[line.ID] = *line.ParentLineID
}
}

if len(lineIDsToParentIDs) == 0 {
return lines, nil
}

inScopeLines, err := a.fetchAllLinesForParentIDs(ctx, namespace, lo.Values(lineIDsToParentIDs))
if err != nil {
return nil, err
}

// let's build the hierarchy objects
hierarchyByParentID, err := a.buildProgressiveLineHierarchy(inScopeLines)
if err != nil {
return nil, err
}

// Let's validate the hierarchy
for parentID, hierarchy := range hierarchyByParentID {
if hierarchy.Root.Line == nil {
return nil, fmt.Errorf("root line for parent line[%s] not found", parentID)
}

for _, child := range hierarchy.Children {
if child.Line == nil {
return nil, fmt.Errorf("child line for parent line[%s] not found", parentID)
}

// This is the only valid state for a child line
if child.Line.Status != billing.InvoiceLineStatusValid {
return nil, fmt.Errorf("child line for parent line[%s] is not valid", parentID)
}
}
}

// let's assign the hierarchy to the already fetched lines
return slicesx.MapWithErr(lines, func(line *billing.Line) (*billing.Line, error) {
if line.ParentLineID == nil {
return line, nil
}

hierarchy, ok := hierarchyByParentID[*line.ParentLineID]
if !ok {
return nil, fmt.Errorf("parent line for line[%s] not found", line.ID)
}

line.ProgressiveLineHierarchy = &hierarchy

return line, nil
})
}

func (a *adapter) fetchAllLinesForParentIDs(ctx context.Context, namespace string, parentIDs []string) ([]billing.InvoiceLineWithInvoiceBase, error) {
query := a.db.BillingInvoiceLine.Query().
Where(
billinginvoiceline.Or(
billinginvoiceline.IDIn(parentIDs...),
billinginvoiceline.ParentLineIDIn(parentIDs...),
),
billinginvoiceline.Namespace(namespace),
).
WithFlatFeeLine().
WithUsageBasedLine().
WithLineDiscounts().
WithBillingInvoice() // TODO[later]: we can consider loading this in a separate query, might be more efficient

dbLines, err := query.All(ctx)
if err != nil {
return nil, err
}

mappedLines, err := slicesx.MapWithErr(dbLines, func(dbLine *db.BillingInvoiceLine) (billing.InvoiceLineWithInvoiceBase, error) {
empty := billing.InvoiceLineWithInvoiceBase{}

line, err := a.mapInvoiceLineWithoutReferences(dbLine)
if err != nil {
return empty, err
}

return billing.InvoiceLineWithInvoiceBase{
Line: &line,
Invoice: a.mapInvoiceBaseFromDB(ctx, dbLine.Edges.BillingInvoice),
}, nil
})
if err != nil {
return nil, err
}

return mappedLines, nil
}

func (a *adapter) buildProgressiveLineHierarchy(inScopeLines []billing.InvoiceLineWithInvoiceBase) (map[string]billing.InvoiceLineProgressiveHierarchy, error) {
hierarchyByParentID := map[string]billing.InvoiceLineProgressiveHierarchy{}

for _, line := range inScopeLines {
if line.Line.ParentLineID == nil {
// We have encountered a parent line

hierarchy, ok := hierarchyByParentID[line.Line.ID]
if ok {
if hierarchy.Root.Line != nil {
return nil, fmt.Errorf("parent line[%s] already exists", line.Line.ID)
}
}

hierarchy.Root = line
hierarchyByParentID[line.Line.ID] = hierarchy
continue
}

// We have encountered a child line
parentID := *line.Line.ParentLineID

hierarchy := hierarchyByParentID[parentID]
hierarchy.Children = append(hierarchy.Children, line)
hierarchyByParentID[parentID] = hierarchy
}

return hierarchyByParentID, nil
}
13 changes: 12 additions & 1 deletion openmeter/billing/adapter/invoicelines.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,18 @@ func (a *adapter) AssociateLinesToInvoice(ctx context.Context, input billing.Ass
return nil, fmt.Errorf("not all lines were associated")
}

return tx.fetchLines(ctx, input.Invoice.Namespace, input.LineIDs)
invoiceLines, err := tx.fetchLines(ctx, input.Invoice.Namespace, input.LineIDs)
if err != nil {
return nil, fmt.Errorf("fetching lines: %w", err)
}

// Let's expand the line hierarchy so that we can have a full view of the invoice during the upcoming calculations
invoiceLines, err = a.expandProgressiveLineHierarchy(ctx, input.Invoice.Namespace, invoiceLines)
if err != nil {
return nil, err
}

return invoiceLines, nil
})
}

Expand Down
2 changes: 2 additions & 0 deletions openmeter/billing/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ type InvoiceBase struct {
Workflow InvoiceWorkflow `json:"workflow,omitempty"`

ExternalIDs InvoiceExternalIDs `json:"externalIds,omitempty"`

// TODO[later]: Let's also include the totals here, as that's part of the invoice db table
}

func (i InvoiceBase) Validate() error {
Expand Down
10 changes: 8 additions & 2 deletions openmeter/billing/invoiceline.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,9 @@ type Line struct {
FlatFee *FlatFeeLine `json:"flatFee,omitempty"`
UsageBased *UsageBasedLine `json:"usageBased,omitempty"`

Children LineChildren `json:"children,omitempty"`
ParentLine *Line `json:"parent,omitempty"`
Children LineChildren `json:"children,omitempty"`
ParentLine *Line `json:"parent,omitempty"`
ProgressiveLineHierarchy *InvoiceLineProgressiveHierarchy `json:"progressiveLineHierarchy,omitempty"`

Discounts LineDiscounts `json:"discounts,omitempty"`

Expand All @@ -309,6 +310,7 @@ func (i Line) CloneWithoutDependencies() *Line {
clone.ID = ""
clone.ParentLineID = nil
clone.ParentLine = nil
clone.ProgressiveLineHierarchy = nil

if clone.FlatFee != nil {
clone.FlatFee.ConfigID = ""
Expand Down Expand Up @@ -402,6 +404,10 @@ func (i Line) clone(opts cloneOptions) *Line {
res.Discounts = i.Discounts.Clone()
}

if i.ProgressiveLineHierarchy != nil {
res.ProgressiveLineHierarchy = lo.ToPtr(i.ProgressiveLineHierarchy.Clone())
}

return res
}

Expand Down
Loading
Loading