1
1
import os
2
2
import tempfile
3
- from typing import List , Optional
3
+ from typing import Any , Dict , List , Optional
4
4
5
5
import numpy as np
6
6
import pandas as pd
10
10
from fastapi import Depends , FastAPI
11
11
12
12
import ray
13
+ import ray .cloudpickle as ray_pickle
13
14
from ray import serve
14
15
from ray .train import Checkpoint
15
16
from ray .serve .air_integrations import _BatchingManager
@@ -128,6 +129,22 @@ def test_unpack_dataframe(self, batched_df, expected):
128
129
)
129
130
130
131
132
+ def create_dict_checkpoint (
133
+ data : Dict [str , Any ], directory : Optional [str ] = None
134
+ ) -> Checkpoint :
135
+ if not directory :
136
+ directory = tempfile .mkdtemp ()
137
+ with open (os .path .join (directory , "data.pkl" ), "wb" ) as f :
138
+ ray_pickle .dump (data , f )
139
+ return Checkpoint .from_directory (directory )
140
+
141
+
142
+ def load_dict_checkpoint (checkpoint : Checkpoint ) -> Dict [str , Any ]:
143
+ with checkpoint .as_directory () as checkpoint_dir :
144
+ with open (os .path .join (checkpoint_dir , "data.pkl" ), "rb" ) as f :
145
+ return ray_pickle .load (f )
146
+
147
+
131
148
class AdderPredictor (Predictor ):
132
149
def __init__ (self , increment : int , do_double : bool ) -> None :
133
150
self .increment = increment
@@ -137,7 +154,7 @@ def __init__(self, increment: int, do_double: bool) -> None:
137
154
def from_checkpoint (
138
155
cls , checkpoint : Checkpoint , do_double : bool = False
139
156
) -> "AdderPredictor" :
140
- return cls (checkpoint . to_dict ( )["increment" ], do_double )
157
+ return cls (load_dict_checkpoint ( checkpoint )["increment" ], do_double )
141
158
142
159
def predict (
143
160
self , data : np .ndarray , override_increment : Optional [int ] = None
@@ -170,7 +187,7 @@ async def __call__(self, request: Request):
170
187
return self .predictor .predict (np .array (data ["array" ]))
171
188
172
189
AdderDeployment .options (name = "Adder" ).deploy (
173
- checkpoint = Checkpoint . from_dict ({"increment" : 2 }),
190
+ checkpoint = create_dict_checkpoint ({"increment" : 2 }),
174
191
)
175
192
resp = ray .get (send_request .remote (json = {"array" : [40 ]}))
176
193
assert resp == [{"value" : 42 , "batch_size" : 1 }]
@@ -189,7 +206,7 @@ async def __call__(self, request: Request):
189
206
)
190
207
191
208
AdderDeployment .options (name = "Adder" ).deploy (
192
- checkpoint = Checkpoint . from_dict ({"increment" : 2 }),
209
+ checkpoint = create_dict_checkpoint ({"increment" : 2 }),
193
210
)
194
211
195
212
resp = ray .get (send_request .remote (json = {"array" : [40 ]}))
@@ -207,7 +224,7 @@ async def __call__(self, request: Request):
207
224
return self .predictor .predict (np .array (data ["array" ]))
208
225
209
226
AdderDeployment .options (name = "Adder" ).deploy (
210
- checkpoint = Checkpoint . from_dict ({"increment" : 2 }),
227
+ checkpoint = create_dict_checkpoint ({"increment" : 2 }),
211
228
)
212
229
resp = ray .get (send_request .remote (json = {"array" : [40 ]}))
213
230
assert resp == [{"value" : 84 , "batch_size" : 1 }]
@@ -226,7 +243,7 @@ async def __call__(self, requests: List[Request]):
226
243
return self .predictor .predict (batch )
227
244
228
245
AdderDeployment .options (name = "Adder" ).deploy (
229
- checkpoint = Checkpoint . from_dict ({"increment" : 2 }),
246
+ checkpoint = create_dict_checkpoint ({"increment" : 2 }),
230
247
)
231
248
232
249
refs = [send_request .remote (json = {"array" : [40 ]}) for _ in range (2 )]
@@ -250,8 +267,7 @@ async def predict(self, data=Depends(json_to_ndarray)):
250
267
251
268
def test_air_integrations_in_pipeline (serve_instance ):
252
269
path = tempfile .mkdtemp ()
253
- uri = f"file://{ path } /test_uri"
254
- Checkpoint .from_dict ({"increment" : 2 }).to_uri (uri )
270
+ create_dict_checkpoint ({"increment" : 2 }, path )
255
271
256
272
@serve .deployment
257
273
class AdderDeployment :
@@ -263,7 +279,7 @@ async def __call__(self, data):
263
279
264
280
with InputNode () as dag_input :
265
281
m1 = AdderDeployment .bind (
266
- checkpoint = Checkpoint .from_uri ( uri ),
282
+ checkpoint = Checkpoint .from_directory ( path ),
267
283
)
268
284
dag = m1 .__call__ .bind (dag_input )
269
285
deployments = build (Ingress .bind (dag ), "" )
@@ -278,8 +294,7 @@ async def __call__(self, data):
278
294
279
295
def test_air_integrations_reconfigure (serve_instance ):
280
296
path = tempfile .mkdtemp ()
281
- uri = f"file://{ path } /test_uri"
282
- Checkpoint .from_dict ({"increment" : 2 }).to_uri (uri )
297
+ create_dict_checkpoint ({"increment" : 2 }, path )
283
298
284
299
@serve .deployment
285
300
class AdderDeployment :
@@ -288,7 +303,7 @@ def __init__(self, checkpoint: Checkpoint):
288
303
289
304
def reconfigure (self , config ):
290
305
self .predictor = AdderPredictor .from_checkpoint (
291
- Checkpoint . from_dict (config ["checkpoint" ])
306
+ create_dict_checkpoint (config ["checkpoint" ])
292
307
)
293
308
294
309
async def __call__ (self , data ):
@@ -300,7 +315,7 @@ async def __call__(self, data):
300
315
301
316
with InputNode () as dag_input :
302
317
m1 = AdderDeployment .options (user_config = additional_config ).bind (
303
- checkpoint = Checkpoint .from_uri ( uri ),
318
+ checkpoint = Checkpoint .from_directory ( path ),
304
319
)
305
320
dag = m1 .__call__ .bind (dag_input )
306
321
deployments = build (Ingress .bind (dag ), "" )
0 commit comments