Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions openmeter/customer/adapter/customer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package adapter
import (
"context"
"fmt"
"slices"
"time"

"entgo.io/ent/dialect/sql"
Expand Down Expand Up @@ -36,7 +37,9 @@ func (a *adapter) ListCustomers(ctx context.Context, input customer.ListCustomer

query := repo.db.Customer.Query().Where(customerdb.Namespace(input.Namespace))
query = WithSubjects(query, now)
query = WithActiveSubscriptions(query, now)
if slices.Contains(input.Expands, customer.ExpandSubscriptions) {
query = WithActiveSubscriptions(query, now)
}

// Do not return deleted customers by default
if !input.IncludeDeleted {
Expand Down Expand Up @@ -110,7 +113,7 @@ func (a *adapter) ListCustomers(ctx context.Context, input customer.ListCustomer
a.logger.WarnContext(ctx, "invalid query result: nil customer received")
continue
}
cust, err := CustomerFromDBEntity(*item)
cust, err := CustomerFromDBEntity(*item, input.Expands)
if err != nil {
return response, fmt.Errorf("failed to convert customer: %w", err)
}
Expand Down Expand Up @@ -388,7 +391,9 @@ func (a *adapter) GetCustomer(ctx context.Context, input customer.GetCustomerInp

query := repo.db.Customer.Query()
query = WithSubjects(query, now)
query = WithActiveSubscriptions(query, now)
if slices.Contains(input.Expands, customer.ExpandSubscriptions) {
query = WithActiveSubscriptions(query, now)
}

if input.CustomerID != nil {
query = query.Where(customerdb.Namespace(input.CustomerID.Namespace))
Expand Down Expand Up @@ -438,7 +443,7 @@ func (a *adapter) GetCustomer(ctx context.Context, input customer.GetCustomerInp
return nil, fmt.Errorf("invalid query result: nil customer received")
}

return CustomerFromDBEntity(*entity)
return CustomerFromDBEntity(*entity, input.Expands)
})
}

Expand All @@ -464,7 +469,9 @@ func (a *adapter) GetCustomerByUsageAttribution(ctx context.Context, input custo
)).
Where(customerdb.DeletedAtIsNil())
query = WithSubjects(query, now)
query = WithActiveSubscriptions(query, now)
if slices.Contains(input.Expands, customer.ExpandSubscriptions) {
query = WithActiveSubscriptions(query, now)
}

customerEntity, err := query.First(ctx)
if err != nil {
Expand All @@ -481,7 +488,7 @@ func (a *adapter) GetCustomerByUsageAttribution(ctx context.Context, input custo
return nil, fmt.Errorf("invalid query result: nil customer received")
}

return CustomerFromDBEntity(*customerEntity)
return CustomerFromDBEntity(*customerEntity, input.Expands)
})
}

Expand Down
50 changes: 31 additions & 19 deletions openmeter/customer/adapter/entitymapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,19 @@ import (
"slices"

"github.com/samber/lo"
"github.com/samber/mo"

"github.com/openmeterio/openmeter/openmeter/customer"
"github.com/openmeterio/openmeter/openmeter/ent/db"
"github.com/openmeterio/openmeter/pkg/models"
)

func CustomerFromDBEntity(e db.Customer) (*customer.Customer, error) {
func CustomerFromDBEntity(e db.Customer, expands customer.Expands) (*customer.Customer, error) {
subjectKeys, err := subjectKeysFromDBEntity(e)
if err != nil {
return nil, err
}

subscriptions, err := e.Edges.SubscriptionOrErr()
if err != nil {
if db.IsNotLoaded(err) {
return nil, errors.New("subscriptions must be loaded for customer")
}

return nil, err
}

subscriptionIDs := lo.FilterMap(subscriptions, func(item *db.Subscription, _ int) (string, bool) {
if item == nil {
return "", false
}

return item.ID, true
})

var metadata *models.Metadata

if len(e.Metadata) > 0 {
Expand Down Expand Up @@ -63,8 +47,15 @@ func CustomerFromDBEntity(e db.Customer) (*customer.Customer, error) {
Currency: e.Currency,
Metadata: metadata,
Annotation: annotations,
}

if slices.Contains(expands, customer.ExpandSubscriptions) {
activeSubscriptionIDs, err := resolveActiveSubscriptionIDs(e)
if err != nil {
return nil, err
}

ActiveSubscriptionIDs: subscriptionIDs,
result.ActiveSubscriptionIDs = mo.Some(activeSubscriptionIDs)
}

if e.Key != "" {
Expand All @@ -86,6 +77,27 @@ func CustomerFromDBEntity(e db.Customer) (*customer.Customer, error) {
return result, nil
}

func resolveActiveSubscriptionIDs(e db.Customer) ([]string, error) {
subscriptions, err := e.Edges.SubscriptionOrErr()
if err != nil {
if db.IsNotLoaded(err) {
return nil, errors.New("subscriptions must be loaded for customer")
}

return nil, err
}

subscriptionIDs := lo.FilterMap(subscriptions, func(item *db.Subscription, _ int) (string, bool) {
if item == nil {
return "", false
}

return item.ID, true
})

return subscriptionIDs, nil
}

func subjectKeysFromDBEntity(customerEntity db.Customer) ([]string, error) {
subjectEntities, err := customerEntity.Edges.SubjectsOrErr()
if err != nil {
Expand Down
43 changes: 41 additions & 2 deletions openmeter/customer/customer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"errors"
"fmt"

"github.com/samber/mo"

"github.com/openmeterio/openmeter/api"
"github.com/openmeterio/openmeter/openmeter/streaming"
"github.com/openmeterio/openmeter/pkg/clock"
Expand All @@ -13,6 +15,25 @@ import (
"github.com/openmeterio/openmeter/pkg/sortx"
)

type (
Expand string
Expands []Expand
)

const (
ExpandSubscriptions Expand = "subscriptions"
)

func (e Expands) Validate() error {
for _, expand := range e {
if expand != ExpandSubscriptions {
return models.NewGenericValidationError(fmt.Errorf("invalid expand: %s", expand))
}
}

return nil
}

var _ streaming.Customer = &Customer{}

// Customer represents a customer
Expand All @@ -27,7 +48,7 @@ type Customer struct {
Metadata *models.Metadata `json:"metadata,omitempty"`
Annotation *models.Annotations `json:"annotations,omitempty"`

ActiveSubscriptionIDs []string
ActiveSubscriptionIDs mo.Option[[]string]
}

// GetUsageAttribution returns the customer usage attribution
Expand Down Expand Up @@ -193,6 +214,9 @@ func (c CustomerUsageAttribution) GetSubjectKey() (string, error) {
type GetCustomerByUsageAttributionInput struct {
Namespace string
SubjectKey string

// Expand
Expands Expands
}

func (i GetCustomerByUsageAttributionInput) Validate() error {
Expand All @@ -204,6 +228,10 @@ func (i GetCustomerByUsageAttributionInput) Validate() error {
return models.NewGenericValidationError(errors.New("subject key is required"))
}

if err := i.Expands.Validate(); err != nil {
return models.NewGenericValidationError(err)
}

return nil
}

Expand All @@ -225,13 +253,20 @@ type ListCustomersInput struct {
Subject *string
PlanKey *string
CustomerIDs []string

// Expand
Expands Expands
}

func (i ListCustomersInput) Validate() error {
if i.Namespace == "" {
return models.NewGenericValidationError(errors.New("namespace is required"))
}

if err := i.Expands.Validate(); err != nil {
return models.NewGenericValidationError(err)
}

return nil
}

Expand Down Expand Up @@ -300,7 +335,7 @@ type GetCustomerInput struct {
CustomerIDOrKey *CustomerIDOrKey

// Expand
Expand []api.CustomerExpand
Expands Expands
}

func (i GetCustomerInput) Validate() error {
Expand Down Expand Up @@ -337,6 +372,10 @@ func (i GetCustomerInput) Validate() error {
errs = append(errs, i.CustomerIDOrKey.Validate())
}

if err := i.Expands.Validate(); err != nil {
return models.NewGenericValidationError(err)
}

return errors.Join(errs...)
}

Expand Down
4 changes: 2 additions & 2 deletions openmeter/customer/httpdriver/apimapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func FromAnnotations(annotations models.Annotations) *api.Annotations {
}

// CustomerToAPI converts a Customer to an API Customer
func CustomerToAPI(c customer.Customer, subscriptions []subscription.Subscription, expand []api.CustomerExpand) (api.Customer, error) {
func CustomerToAPI(c customer.Customer, subscriptions []subscription.Subscription, expand customer.Expands) (api.Customer, error) {
// Map the customer to the API Customer
apiCustomer := api.Customer{
Id: c.ManagedResource.ID,
Expand Down Expand Up @@ -154,7 +154,7 @@ func CustomerToAPI(c customer.Customer, subscriptions []subscription.Subscriptio
apiCustomer.CurrentSubscriptionId = lo.ToPtr(subscriptions[0].ID)

// Map the subscriptions to the API Subscriptions if the expand is set
if lo.Contains(expand, api.CustomerExpandSubscriptions) {
if lo.Contains(expand, customer.ExpandSubscriptions) {
apiCustomer.Subscriptions = lo.ToPtr(lo.Map(subscriptions, func(s subscription.Subscription, _ int) api.Subscription {
return subscriptionhttp.MapSubscriptionToAPI(s)
}))
Expand Down
Loading
Loading