Skip to content

Commit

Permalink
Creating Decimal type to use for amounts
Browse files Browse the repository at this point in the history
Rather than depending on float values which can lose precision.

Created methods so Decimals can be read and written as XML (and possibly
other formats that use Un/MarshalText).
  • Loading branch information
jszwedko committed Feb 6, 2015
1 parent 481af9e commit e61c1ef
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 21 deletions.
60 changes: 60 additions & 0 deletions decimal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package braintree

import (
"bytes"
"strconv"
"strings"
)

const precision = 16

// Decimal represents fixed precision numbers
type Decimal struct {
Unscaled int64
Scale int
}

// NewDecimal creates a new decimal number equal to
// unscaled ** 10 ^ (-scale)
func NewDecimal(unscaled int64, scale int) *Decimal {
return &Decimal{Unscaled: unscaled, Scale: scale}
}

// MarshalText outputs a decimal representation of the scaled number
func (d *Decimal) MarshalText() (text []byte, err error) {
b := new(bytes.Buffer)
if d.Scale <= 0 {
b.WriteString(strconv.FormatInt(d.Unscaled, 10))
b.WriteString(strings.Repeat("0", -d.Scale))
} else {
str := strconv.FormatInt(d.Unscaled, 10)
b.WriteString(str[:len(str)-d.Scale])
b.WriteString(".")
b.WriteString(str[len(str)-d.Scale:])
}
return b.Bytes(), nil
}

// UnmarshalText creates a Decimal from a string representation (e.g. 5.20)
// Currently only supports decimal strings
func (d *Decimal) UnmarshalText(text []byte) (err error) {
var (
str = string(text)
unscaled int64 = 0
scale int = 0
)

if i := strings.Index(str, "."); i != -1 {
scale = len(str) - i - 1
str = strings.Replace(str, ".", "", 1)
}

if unscaled, err = strconv.ParseInt(str, 10, 64); err != nil {
return err
}

d.Unscaled = unscaled
d.Scale = scale

return nil
}
65 changes: 65 additions & 0 deletions decimal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package braintree

import (
"reflect"
"testing"
)

func TestDecimalUnmarshalText(t *testing.T) {
tests := []struct {
in []byte
out *Decimal
shouldError bool
}{
{[]byte("2.50"), NewDecimal(250, 2), false},
{[]byte("2"), NewDecimal(2, 0), false},
{[]byte("-5.504"), NewDecimal(-5504, 3), false},
{[]byte("0.5"), NewDecimal(5, 1), false},
{[]byte(".5"), NewDecimal(5, 1), false},
{[]byte("5.504.98"), NewDecimal(0, 0), true},
{[]byte("5E6"), NewDecimal(0, 0), true},
}

for _, tt := range tests {
d := &Decimal{}
err := d.UnmarshalText(tt.in)

if tt.shouldError {
if err == nil {
t.Errorf("expected UnmarshalText(%s) => to error, but it did not", tt.in)
}
} else {
if err != nil {
t.Errorf("expected UnmarshalText(%s) => to not error, but it did with %s", tt.in, err)
}
}

if !reflect.DeepEqual(d, tt.out) {
t.Errorf("UnmarshalText(%s) => %+v, want %+v", tt.in, d, tt.out)
}
}
}

func TestDecimalMarshalText(t *testing.T) {
tests := []struct {
in *Decimal
out []byte
}{
{NewDecimal(250, -2), []byte("25000")},
{NewDecimal(2, 0), []byte("2")},
{NewDecimal(250, 2), []byte("2.50")},
{NewDecimal(4586, 2), []byte("45.86")},
{NewDecimal(-5504, 2), []byte("-55.04")},
}

for _, tt := range tests {
b, err := tt.in.MarshalText()
if err != nil {
t.Errorf("expected %+v.MarshaText() => to not error, but it did with %s", tt.in, err)
}

if string(tt.out) != string(b) {
t.Errorf("%+v.MarshaText() => %s, want %s", tt.in, b, tt.out)
}
}
}
4 changes: 2 additions & 2 deletions merchant_account_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ func TestMerchantAccountTransaction(t *testing.T) {

tx, err := testGateway.Transaction().Create(&Transaction{
Type: "sale",
Amount: 100.00 + offset(),
Amount: randomAmount(),
CreditCard: &CreditCard{
Number: testCreditCards["visa"].Number,
ExpirationDate: "05/14",
},
ServiceFeeAmount: 5.00,
ServiceFeeAmount: NewDecimal(500, 2),
MerchantAccountId: strconv.Itoa(acctId),
})

Expand Down
4 changes: 2 additions & 2 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ type Transaction struct {
CustomerID string `xml:"customer-id,omitempty"`
Status string `xml:"status,omitempty"`
Type string `xml:"type,omitempty"`
Amount float64 `xml:"amount"`
Amount *Decimal `xml:"amount"`
OrderId string `xml:"order-id,omitempty"`
PaymentMethodToken string `xml:"payment-method-token,omitempty"`
MerchantAccountId string `xml:"merchant-account-id,omitempty"`
Expand All @@ -16,7 +16,7 @@ type Transaction struct {
BillingAddress *Address `xml:"billing,omitempty"`
ShippingAddress *Address `xml:"shipping,omitempty"`
Options *TransactionOptions `xml:"options,omitempty"`
ServiceFeeAmount float64 `xml:"service-fee-amount,attr,omitempty"`
ServiceFeeAmount *Decimal `xml:"service-fee-amount,attr,omitempty"`
CreatedAt string `xml:"created-at,omitempty"`
UpdatedAt string `xml:"updated-at,omitempty"`
DisbursementDetails *DisbursementDetails `xml:"disbursement-details,omitempty"`
Expand Down
2 changes: 1 addition & 1 deletion transaction_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (g *TransactionGateway) Create(tx *Transaction) (*Transaction, error) {

// SubmitForSettlement submits the transaction with the specified id for settlement.
// If the amount is omitted, the full amount is settled.
func (g *TransactionGateway) SubmitForSettlement(id string, amount ...float64) (*Transaction, error) {
func (g *TransactionGateway) SubmitForSettlement(id string, amount ...*Decimal) (*Transaction, error) {
var tx *Transaction
if len(amount) > 0 {
tx = &Transaction{
Expand Down
34 changes: 18 additions & 16 deletions transaction_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
package braintree

import (
"math"
"math/rand"
"reflect"
"strconv"
"testing"
"time"
)

func offset() float64 {
return math.Ceil(rand.Float64() * 100.0)
func randomAmount() *Decimal {
return NewDecimal(rand.Int63n(10000), 2)
}

func TestTransactionCreateSubmitForSettlementAndVoid(t *testing.T) {
tx, err := testGateway.Transaction().Create(&Transaction{
Type: "sale",
Amount: 130.00 + offset(),
Amount: NewDecimal(2000, 2),
CreditCard: &CreditCard{
Number: testCreditCards["visa"].Number,
ExpirationDate: "05/14",
Expand All @@ -35,7 +35,8 @@ func TestTransactionCreateSubmitForSettlementAndVoid(t *testing.T) {
}

// Submit for settlement
tx2, err := testGateway.Transaction().SubmitForSettlement(tx.Id, 10)
ten := NewDecimal(1000, 2)
tx2, err := testGateway.Transaction().SubmitForSettlement(tx.Id, ten)

t.Log(tx2)

Expand All @@ -45,8 +46,8 @@ func TestTransactionCreateSubmitForSettlementAndVoid(t *testing.T) {
if x := tx2.Status; x != "submitted_for_settlement" {
t.Fatal(x)
}
if x := tx2.Amount; x != 10 {
t.Fatal(x)
if amount := tx2.Amount; !reflect.DeepEqual(amount, ten) {
t.Fatalf("transaction settlement amount (%s) did not equal amount requested (%s)", amount, ten)
}

// Void
Expand All @@ -64,7 +65,7 @@ func TestTransactionCreateSubmitForSettlementAndVoid(t *testing.T) {

func TestTransactionSearch(t *testing.T) {
txg := testGateway.Transaction()
createTx := func(amount float64, customerName string) error {
createTx := func(amount *Decimal, customerName string) error {
_, err := txg.Create(&Transaction{
Type: "sale",
Amount: amount,
Expand All @@ -82,10 +83,11 @@ func TestTransactionSearch(t *testing.T) {
ts := strconv.FormatInt(time.Now().Unix(), 10)
name := "Erik-" + ts

if err := createTx(100.0+offset(), name); err != nil {
if err := createTx(randomAmount(), name); err != nil {
t.Fatal(err)
}
if err := createTx(150.0+offset(), "Lionel-"+ts); err != nil {

if err := createTx(randomAmount(), "Lionel-"+ts); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -113,7 +115,7 @@ func TestTransactionSearch(t *testing.T) {
func TestTransactionCreateWhenGatewayRejected(t *testing.T) {
_, err := testGateway.Transaction().Create(&Transaction{
Type: "sale",
Amount: 2010.00,
Amount: NewDecimal(201000, 2),
CreditCard: &CreditCard{
Number: testCreditCards["visa"].Number,
ExpirationDate: "05/14",
Expand All @@ -130,7 +132,7 @@ func TestTransactionCreateWhenGatewayRejected(t *testing.T) {
func TestFindTransaction(t *testing.T) {
createdTransaction, err := testGateway.Transaction().Create(&Transaction{
Type: "sale",
Amount: 110.00 + offset(),
Amount: randomAmount(),
CreditCard: &CreditCard{
Number: testCreditCards["mastercard"].Number,
ExpirationDate: "05/14",
Expand Down Expand Up @@ -163,7 +165,7 @@ func TestFindNonExistantTransaction(t *testing.T) {
func TestAllTransactionFields(t *testing.T) {
tx := &Transaction{
Type: "sale",
Amount: 100.00 + offset(),
Amount: randomAmount(),
OrderId: "my_custom_order",
CreditCard: &CreditCard{
Number: testCreditCards["visa"].Number,
Expand Down Expand Up @@ -201,7 +203,7 @@ func TestAllTransactionFields(t *testing.T) {
if tx2.Type != tx.Type {
t.Fail()
}
if tx2.Amount != tx.Amount {
if !reflect.DeepEqual(tx2.Amount, tx.Amount) {
t.Fail()
}
if tx2.OrderId != tx.OrderId {
Expand Down Expand Up @@ -289,7 +291,7 @@ func TestTransactionCreateFromPaymentMethodCode(t *testing.T) {
tx, err := testGateway.Transaction().Create(&Transaction{
Type: "sale",
CustomerID: customer.Id,
Amount: 120 + offset(),
Amount: randomAmount(),
PaymentMethodToken: customer.CreditCards.CreditCard[0].Token,
})

Expand All @@ -306,7 +308,7 @@ func TestSettleTransaction(t *testing.T) {

txn, err := testGateway.Transaction().Create(&Transaction{
Type: "sale",
Amount: 130.00 + offset(),
Amount: randomAmount(),
CreditCard: &CreditCard{
Number: testCreditCards["visa"].Number,
ExpirationDate: "05/14",
Expand Down

0 comments on commit e61c1ef

Please sign in to comment.