Skip to content

Commit 614c583

Browse files
committed
convert return values from API to Decimal
1 parent 1e3526d commit 614c583

File tree

2 files changed

+277
-22
lines changed

2 files changed

+277
-22
lines changed

gdax/trader.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
"""
66

77
import copy
8+
from decimal import Decimal
89
import json
910
import logging
1011
import time
11-
import hmac
12-
import hashlib
1312

1413
import asyncio
1514
import aiohttp
@@ -50,7 +49,34 @@ def _auth_headers(self, path, method, body=''):
5049
'CB-ACCESS-PASSPHRASE': self.passphrase,
5150
}
5251

53-
async def _get(self, path, params=None, pagination=False):
52+
def _convert_return_fields(self, fields, decimal_fields, convert_all):
53+
if decimal_fields is None and not convert_all:
54+
return fields
55+
if isinstance(fields, list):
56+
return [self._convert_return_fields(field, decimal_fields,
57+
convert_all)
58+
for field in fields]
59+
elif isinstance(fields, dict):
60+
new_fields = {}
61+
for k, v in fields.items():
62+
if (decimal_fields is not None and k in decimal_fields) \
63+
or convert_all:
64+
if isinstance(v, list):
65+
new_fields[k] = self._convert_return_fields(
66+
v, decimal_fields, convert_all)
67+
else:
68+
new_fields[k] = Decimal(v)
69+
else:
70+
new_fields[k] = v
71+
return new_fields
72+
else:
73+
if convert_all and not isinstance(fields, int):
74+
return Decimal(fields)
75+
else:
76+
return fields
77+
78+
async def _get(self, path, params=None, decimal_return_fields=None,
79+
convert_all=False, pagination=False):
5480
if params is None:
5581
params_copy = {}
5682
else:
@@ -81,9 +107,11 @@ async def _get(self, path, params=None, pagination=False):
81107
if "cb-after" in resp_headers:
82108
params_copy['after'] = resp_headers['cb-after']
83109
else:
84-
return results
110+
return self._convert_return_fields(
111+
results, decimal_return_fields, convert_all)
85112
else:
86-
return res
113+
return self._convert_return_fields(
114+
res, decimal_return_fields, convert_all)
87115

88116
async def _post(self, path, data=None):
89117
json_data = json.dumps(data)
@@ -109,42 +137,57 @@ async def _delete(self, path, data=None):
109137
return await response.json()
110138

111139
async def get_products(self):
112-
return await self._get('/products')
140+
return await self._get(
141+
'/products',
142+
decimal_return_fields={'base_min_size', 'base_max_size',
143+
'quote_increment'})
113144

114145
async def get_product_ticker(self, product_id=None):
115146
return await self._get(
116-
'/products/{}/ticker'.format(product_id or self.product_id))
147+
'/products/{}/ticker'.format(product_id or self.product_id),
148+
decimal_return_fields={'price', 'size', 'bid', 'ask', 'volume'})
117149

118150
async def get_product_trades(self, product_id=None):
119151
return await self._get(
120-
'/products/{}/trades'.format(product_id or self.product_id))
152+
'/products/{}/trades'.format(product_id or self.product_id),
153+
decimal_return_fields={'price', 'size'})
121154

122155
async def get_product_order_book(self, product_id=None, level=1):
123156
params = {'level': level}
124157
return await self._get(
125158
'/products/{}/book'.format(product_id or self.product_id),
126-
params=params)
159+
params=params, decimal_return_fields={'bids', 'asks'},
160+
convert_all=True)
127161

128162
async def get_product_historic_rates(self, product_id=None, start='',
129163
end='', granularity=''):
130164
payload = {}
131165
payload["start"] = start
132166
payload["end"] = end
133167
payload["granularity"] = granularity
134-
return await self._get(
168+
res = await self._get(
135169
'/products/{}/candles'.format(product_id or self.product_id),
136170
params=payload)
171+
# NOTE: there's a bug where the API returns floats instead of strings
172+
# here
173+
for row in res:
174+
for i, col in enumerate(row[1:]):
175+
row[i + 1] = Decimal(str(col))
176+
return res
137177

138178
async def get_product_24hr_stats(self, product_id=None):
139179
return await self._get(
140-
'/products/{}/stats'.format(product_id or self.product_id))
180+
'/products/{}/stats'.format(product_id or self.product_id),
181+
convert_all=True)
141182

142183
async def get_currencies(self):
143-
return await self._get('/currencies')
184+
return await self._get('/currencies',
185+
decimal_return_fields={'min_size'})
144186

145187
async def get_time(self):
146188
return await self._get('/time')
147189

190+
# TODO: convert return values
148191
# authenticated API
149192
async def get_account(self, account_id=''):
150193
assert self.authenticated
@@ -358,6 +401,7 @@ async def main(): # pragma: no cover
358401
trader.get_products(),
359402
trader.get_product_ticker(),
360403
trader.get_time(),
404+
trader.get_product_historic_rates(),
361405
# trader.buy(type='limit', size='0.01', price='2500.12'),
362406
)
363407
logging.info(res)

0 commit comments

Comments
 (0)