5
5
"""
6
6
7
7
import copy
8
+ from decimal import Decimal
8
9
import json
9
10
import logging
10
11
import time
11
- import hmac
12
- import hashlib
13
12
14
13
import asyncio
15
14
import aiohttp
@@ -50,7 +49,34 @@ def _auth_headers(self, path, method, body=''):
50
49
'CB-ACCESS-PASSPHRASE' : self .passphrase ,
51
50
}
52
51
53
- async def _get (self , path , params = None , pagination = False ):
52
+ def _convert_return_fields (self , fields , decimal_fields , convert_all ):
53
+ if decimal_fields is None and not convert_all :
54
+ return fields
55
+ if isinstance (fields , list ):
56
+ return [self ._convert_return_fields (field , decimal_fields ,
57
+ convert_all )
58
+ for field in fields ]
59
+ elif isinstance (fields , dict ):
60
+ new_fields = {}
61
+ for k , v in fields .items ():
62
+ if (decimal_fields is not None and k in decimal_fields ) \
63
+ or convert_all :
64
+ if isinstance (v , list ):
65
+ new_fields [k ] = self ._convert_return_fields (
66
+ v , decimal_fields , convert_all )
67
+ else :
68
+ new_fields [k ] = Decimal (v )
69
+ else :
70
+ new_fields [k ] = v
71
+ return new_fields
72
+ else :
73
+ if convert_all and not isinstance (fields , int ):
74
+ return Decimal (fields )
75
+ else :
76
+ return fields
77
+
78
+ async def _get (self , path , params = None , decimal_return_fields = None ,
79
+ convert_all = False , pagination = False ):
54
80
if params is None :
55
81
params_copy = {}
56
82
else :
@@ -81,9 +107,11 @@ async def _get(self, path, params=None, pagination=False):
81
107
if "cb-after" in resp_headers :
82
108
params_copy ['after' ] = resp_headers ['cb-after' ]
83
109
else :
84
- return results
110
+ return self ._convert_return_fields (
111
+ results , decimal_return_fields , convert_all )
85
112
else :
86
- return res
113
+ return self ._convert_return_fields (
114
+ res , decimal_return_fields , convert_all )
87
115
88
116
async def _post (self , path , data = None ):
89
117
json_data = json .dumps (data )
@@ -109,42 +137,57 @@ async def _delete(self, path, data=None):
109
137
return await response .json ()
110
138
111
139
async def get_products (self ):
112
- return await self ._get ('/products' )
140
+ return await self ._get (
141
+ '/products' ,
142
+ decimal_return_fields = {'base_min_size' , 'base_max_size' ,
143
+ 'quote_increment' })
113
144
114
145
async def get_product_ticker (self , product_id = None ):
115
146
return await self ._get (
116
- '/products/{}/ticker' .format (product_id or self .product_id ))
147
+ '/products/{}/ticker' .format (product_id or self .product_id ),
148
+ decimal_return_fields = {'price' , 'size' , 'bid' , 'ask' , 'volume' })
117
149
118
150
async def get_product_trades (self , product_id = None ):
119
151
return await self ._get (
120
- '/products/{}/trades' .format (product_id or self .product_id ))
152
+ '/products/{}/trades' .format (product_id or self .product_id ),
153
+ decimal_return_fields = {'price' , 'size' })
121
154
122
155
async def get_product_order_book (self , product_id = None , level = 1 ):
123
156
params = {'level' : level }
124
157
return await self ._get (
125
158
'/products/{}/book' .format (product_id or self .product_id ),
126
- params = params )
159
+ params = params , decimal_return_fields = {'bids' , 'asks' },
160
+ convert_all = True )
127
161
128
162
async def get_product_historic_rates (self , product_id = None , start = '' ,
129
163
end = '' , granularity = '' ):
130
164
payload = {}
131
165
payload ["start" ] = start
132
166
payload ["end" ] = end
133
167
payload ["granularity" ] = granularity
134
- return await self ._get (
168
+ res = await self ._get (
135
169
'/products/{}/candles' .format (product_id or self .product_id ),
136
170
params = payload )
171
+ # NOTE: there's a bug where the API returns floats instead of strings
172
+ # here
173
+ for row in res :
174
+ for i , col in enumerate (row [1 :]):
175
+ row [i + 1 ] = Decimal (str (col ))
176
+ return res
137
177
138
178
async def get_product_24hr_stats (self , product_id = None ):
139
179
return await self ._get (
140
- '/products/{}/stats' .format (product_id or self .product_id ))
180
+ '/products/{}/stats' .format (product_id or self .product_id ),
181
+ convert_all = True )
141
182
142
183
async def get_currencies (self ):
143
- return await self ._get ('/currencies' )
184
+ return await self ._get ('/currencies' ,
185
+ decimal_return_fields = {'min_size' })
144
186
145
187
async def get_time (self ):
146
188
return await self ._get ('/time' )
147
189
190
+ # TODO: convert return values
148
191
# authenticated API
149
192
async def get_account (self , account_id = '' ):
150
193
assert self .authenticated
@@ -358,6 +401,7 @@ async def main(): # pragma: no cover
358
401
trader .get_products (),
359
402
trader .get_product_ticker (),
360
403
trader .get_time (),
404
+ trader .get_product_historic_rates (),
361
405
# trader.buy(type='limit', size='0.01', price='2500.12'),
362
406
)
363
407
logging .info (res )
0 commit comments