9
9
import numpy as np
10
10
import torch
11
11
from nvidia .dali .pipeline import Pipeline
12
- from nvidia .dali .plugin .pytorch import DALIGenericIterator , LastBatchPolicy
13
12
14
13
from ts .torch_handler .image_classifier import ImageClassifier
15
14
@@ -32,7 +31,6 @@ def initialize(self, context):
32
31
]
33
32
if not len (self .dali_file ):
34
33
raise RuntimeError ("Missing dali pipeline file." )
35
- self .PREFETCH_QUEUE_DEPTH = 2
36
34
dali_config_file = os .path .join (self .model_dir , "dali_config.json" )
37
35
if not os .path .isfile (dali_config_file ):
38
36
raise RuntimeError ("Missing dali_config.json file." )
@@ -43,6 +41,7 @@ def initialize(self, context):
43
41
filename = dali_filename ,
44
42
batch_size = self .dali_configs ["batch_size" ],
45
43
num_threads = self .dali_configs ["num_threads" ],
44
+ prefetch_queue_depth = 1 ,
46
45
device_id = self .dali_configs ["device_id" ],
47
46
seed = self .dali_configs ["seed" ],
48
47
)
@@ -62,27 +61,16 @@ def preprocess(self, data):
62
61
list : The preprocess function returns the input image as a list of float tensors.
63
62
"""
64
63
batch_tensor = []
64
+ result = []
65
65
66
66
input_byte_arrays = [i ["body" ] if "body" in i else i ["data" ] for i in data ]
67
67
for byte_array in input_byte_arrays :
68
68
np_image = np .frombuffer (byte_array , dtype = np .uint8 )
69
69
batch_tensor .append (np_image ) # we can use numpy
70
70
71
- for _ in range (self .PREFETCH_QUEUE_DEPTH + 1 ):
72
- self .pipe .feed_input ("source" , batch_tensor )
73
-
74
- datum = DALIGenericIterator (
75
- [self .pipe ],
76
- ["data" ],
77
- last_batch_policy = LastBatchPolicy .PARTIAL ,
78
- last_batch_padded = True ,
79
- )
80
-
81
- result = []
82
- for count , data in enumerate (datum ):
83
- self .pipe .feed_input ("source" , batch_tensor )
84
- result .append (data [0 ]["data" ])
85
- if count == len (input_byte_arrays ) - 1 :
86
- break
71
+ response = self .pipe .run (source = batch_tensor )
72
+ for idx , _ in enumerate (response [0 ]):
73
+ data = torch .tensor (response [0 ].at (idx ))
74
+ result .append (data .unsqueeze (0 ))
87
75
88
76
return torch .cat (result ).to (self .device )
0 commit comments