Skip to content

Commit 1c9d6ba

Browse files
committed
int8 test
1 parent 57b57ea commit 1c9d6ba

File tree

5 files changed

+47
-11
lines changed

5 files changed

+47
-11
lines changed

.idea/workspace.xml

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dkeras/dkeras.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import os
99
import time
10+
import numpy as np
1011

1112
import ray
1213

@@ -138,7 +139,7 @@ def make_model():
138139
else:
139140
time.sleep(1e-3)
140141

141-
def predict(self, data, distributed=True, close=False):
142+
def predict(self, data, distributed=True, close=False, int8_cvrt=False):
142143
"""
143144
Run inference on a data batch, returns predictions
144145
@@ -149,6 +150,9 @@ def predict(self, data, distributed=True, close=False):
149150
return: Predictions
150151
"""
151152
if distributed:
153+
if int8_cvrt:
154+
data = np.asarray(data)
155+
data = np.uint8(data*255)
152156
n_data = len(data)
153157
if n_data % self.n_workers > 0:
154158
self.data_server.set_batch_size.remote(
@@ -159,12 +163,12 @@ def predict(self, data, distributed=True, close=False):
159163
self.data_server.push_data.remote(data)
160164
while not ray.get(self.data_server.is_complete.remote()):
161165
time.sleep(1e-4)
166+
if close:
167+
self.close()
162168
return ray.get(self.data_server.pull_results.remote())
163-
print("Completed!")
164169
else:
165170
return self.model.predict(data)
166-
if close:
167-
self.close()
171+
168172

169173
def close(self, stop_ray=False):
170174
"""

dkeras/worker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,14 @@ def worker_task(weights, ds, make_model):
3030
if packet_id == 'STOP':
3131
break
3232
if len(data) > 0:
33-
data = np.asarray(data)
34-
results = worker_model.predict(data)
35-
ds.push.remote(results, packet_id)
33+
if packet_id == 'infer_float':
34+
data = np.asarray(data)
35+
results = worker_model.predict(data)
36+
ds.push.remote(results, packet_id)
37+
elif packet_id == 'infer_int8':
38+
data = np.asarray(data)
39+
data = np.float16(data/255)
40+
results = worker_model.predict(data)
41+
ds.push.remote(results, packet_id)
3642
else:
3743
time.sleep(config.WORKER_WAIT_TIME)

dkeras/workers/worker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,14 @@ def worker_task(worker_id, weights, ds, make_model):
2929
if packet_id == 'STOP':
3030
break
3131
if len(data) > 0:
32-
data = np.asarray(data)
33-
results = worker_model.predict(data, batch_size=batch_size)
34-
ds.push.remote(results, packet_id)
32+
if packet_id == 'infer_float':
33+
data = np.asarray(data)
34+
results = worker_model.predict(data, batch_size=batch_size)
35+
ds.push.remote(results, packet_id)
36+
elif packet_id == 'infer_uint8':
37+
data = np.asarray(data)
38+
data = np.float16(data/255)
39+
results = worker_model.predict(data, batch_size=batch_size)
40+
ds.push.remote(results, packet_id)
3541
else:
3642
time.sleep(wait_time)

testing/int8_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from tensorflow.keras.applications import ResNet50
2+
from dkeras import dKeras
3+
import numpy as np
4+
import time
5+
import ray
6+
7+
ray.init()
8+
9+
n_data = 100
10+
data = np.random.uniform(-1, 1, (n_data, 224, 224, 3))
11+
12+
model = dKeras(ResNet50, init_ray=False, wait_for_workers=True, n_workers=4)
13+
14+
start_time = time.time()
15+
preds = model.predict(data, int8_cvrt=True)
16+
elapsed = start_time - time.time()
17+
print(elapsed, n_data/elapsed)

0 commit comments

Comments
 (0)